1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifdef INTEL_MKL 17 18#include <memory> 19#include <queue> 20#include <set> 21#include <string> 22#include <utility> 23#include <vector> 24 25#include "tensorflow/core/common_runtime/function.h" 26#include "tensorflow/core/common_runtime/optimization_registry.h" 27#include "tensorflow/core/framework/node_def_util.h" 28#include "tensorflow/core/graph/algorithm.h" 29#include "tensorflow/core/graph/graph.h" 30#include "tensorflow/core/graph/node_builder.h" 31#include "tensorflow/core/lib/core/status.h" 32#include "tensorflow/core/lib/gtl/map_util.h" 33#include "tensorflow/core/lib/hash/hash.h" 34#include "tensorflow/core/platform/logging.h" 35 36#include "tensorflow/core/graph/mkl_graph_util.h" 37#include "tensorflow/core/graph/mkl_tfconversion_pass.h" 38 39namespace tensorflow { 40 41// This pass inserts Mkl to Tf tensor conversion nodes (represented by C) 42// in the graph in between A and B, where A and B match any one 43// of the following cases: 44// 45// 1) A = a node that generates output in the Mkl format and, 46// B = a node that does not accept input in the Mkl format and, 47// A -> B (there is a direct edge between A and B, then 48// We will insert C such that A->C->B. 49// 50// 2) A = a node that generates output in the Mkl format and, 51// B = NULL (in other words, A is the last node in the graph), then 52// We will insert C such that A->C->B. (C will be the last node.) 53// 54// Note that case 1 applies to all outputs of A that are input to B. 55// In other words, the conversions will be required for every output 56// of A that is input to B. For example, let us say the output of A 57// is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not 58// in Mkl format, and all of them are input to B. In such case, we will 59// do the conversion for A1 and A2 only. We do not need to do any conversion 60// for A3. 61// 62// This pass relies on ops registering themselves about their Mkl compliance. 63// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs 64// in the Mkl format. Non-compliant ops accept inputs and outputs in the 65// TensorFlow format. 66// 67// ADDENDUM: For element-wise ops, we may or may not need a conversion to 68// take place before we hit the op. For this, we add a new op before each 69// element-wise MKL op to deal with the inputs, called _MklInputConversion. 70// This pass has been enhanced to add this capability. 71// 72// The _MklInputConversion op will check the inputs to the elementwise op and 73// make sure that either both are in MKL format or both are in TF format, 74// depending on their initial state and whether broadcast is needed or not. 75 76class MklToTfConversionPass : public GraphOptimizationPass { 77 public: 78 MklToTfConversionPass() {} 79 Status Run(const GraphOptimizationPassOptions& options); 80 81 // Insert layout conversion node in the graph pointed by g. 82 // Function scans the graph for candidate edges where we 83 // need to insert conversion nodes. 84 // 85 // @return true even if single conversion node is inserted; 86 // false, otherwise. 87 bool RunPass(std::unique_ptr<Graph>* g); 88 89 private: 90 // Is the input Op supported by Mkl-specific layout? 91 // 92 // @input op_name string of the op 93 // @input T Datatype to use for checking input op 94 // @return true if op is Mkl supported; false, otherwise. 95 inline bool IsMklSupportedOp(const string& op_name, DataType T) const { 96 return mkl_op_registry::IsMklOp(op_name, T); 97 } 98 99 // Is the input Op supported by Mkl-specific layout AND 100 // is it element-wise? 101 // 102 // @input op_name string of the op 103 // @input T Datatype to use for checking input op 104 // @return true if op is Mkl supported; false, otherwise. 105 inline bool IsMklElementWiseOp(const string& op_name, DataType T) const { 106 return mkl_op_registry::IsMklElementWiseOp(op_name, T); 107 } 108 109 // Insert layout conversion node on the edge pointed by 'e' from graph 'g'. 110 // 111 // Edge will be deleted once a call to this function is successful. 112 // Any attempt to use the edge after this call 113 // will lead to undefined behaviors. 114 // 115 // @return Success:OK() if insertion is successful, otherwise returns 116 // appropriate error status code. 117 Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*); 118 119 // For element-wise ops, we need to sanitize the inputs. For this, we add a 120 // new node at the input of the replacement element-wise node that checks 121 // the inputs and converts one/both of them as required. See the op code 122 // comments for details. 123 // 124 // Insert input conversion node as parent of 'n' from graph 'g'. 125 // 126 // @return Success:OK() if insertion is successful, otherwise returns 127 // appropriate error status code. 128 Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*); 129}; 130 131// We register MklToTf insertion for phase 2 in post-partition grouping 132// because we register MklLayoutRewritePass for phase 1 in post-partition 133// grouping. We register this pass after partitioning so that we get a 134// complete picture of inputs and outputs of the nodes in the graphs. 135const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = 136 OptimizationPassRegistry::POST_PARTITIONING; 137REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); 138 139Status MklToTfConversionPass::InsertConversionNodeOnEdge( 140 std::unique_ptr<Graph>* g, Edge* e) { 141 CHECK_NOTNULL(e); 142 143 Node* src = e->src(); 144 Node* dst = e->dst(); 145 146 CHECK_NOTNULL(src); 147 CHECK_NOTNULL(dst); 148 149 Node* conversion_node = nullptr; 150 DataType src_datatype = DT_INVALID; 151 DataType dst_datatype = DT_INVALID; 152 string data_format; 153 154 TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); 155 bool dst_dtype_found = 156 GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK(); 157 // We compare source and destination datatypes only when both are found. 158 if (dst_dtype_found && (src_datatype != dst_datatype)) { 159 string err_msg = "T attribute of " + src->name() + " and " + dst->name() + 160 " do not match. Will not insert" + 161 " MklToTf node in such case."; 162 return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); 163 } 164 165 // Build the conversion node and specify src as input. 166 TF_CHECK_OK( 167 NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf") 168 .Input(src, e->src_output()) 169 .Input(src, DataIndexToMetaDataIndex( 170 e->src_output(), 171 src->num_outputs())) // Get an Mkl tensor slot 172 // from the Tf tensor slot. 173 .Device(src->def().device()) // We want to get conversion node 174 // on same device as source node. 175 .Attr("T", src_datatype) 176 .Finalize(&**g, &conversion_node)); 177 178 CHECK_NOTNULL(conversion_node); 179 if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) { 180 conversion_node->AddAttr("data_format", data_format); 181 } 182 183 // Get assigned device from source node and apply it to conversion node. 184 // We want conversion node to be on the same device as the source node. 185 conversion_node->set_assigned_device_name(src->assigned_device_name()); 186 187 // Set the Mkl op label for this op. 188 conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); 189 190 // Now that we have added edge from src->conversion_node, let's add edge from 191 // output of conversion_node to the dest node. Since conversion_node 192 // has only 1 output, the src_output of conversion_node is 0. 193 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input())); 194 195 VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: " 196 << src->type_string() << " and " << dst->type_string() 197 << " successful."; 198 199 // Remove src->dst edge now. 200 (*g)->RemoveEdge(e); 201 return Status::OK(); 202} 203 204Status MklToTfConversionPass::InsertInputConversionNode( 205 std::unique_ptr<Graph>* g, Node* n) { 206 CHECK_NOTNULL(n); 207 208 // Get the input nodes and edges 209 std::vector<const Edge*> edges; 210 TF_CHECK_OK(n->input_edges(&edges)); 211 if (edges.size() != 4) { 212 return Status(error::Code::INVALID_ARGUMENT, 213 "MKL Binary Element-wise op should have exactly 2 data" 214 " inputs and 2 metadata inputs"); 215 } 216 217 // Sanity check: ensure that both inputs are of the expected type, and the 218 // same type as input type 219 CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), 220 BaseType(edges[1]->src()->output_type(edges[1]->src_output()))); 221 CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), 222 BaseType(n->input_type(0))); 223 224 // Check ordering of edges 225 for (uint32 i = 0; i < 4; i++) { 226 CHECK_EQ((edges[i]->dst_input() == i), true); 227 } 228 229 // Build the conversion node and specify src as input. 230 Node* conversion_node = nullptr; 231 232 TF_CHECK_OK( 233 NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion") 234 .Input(edges[0]->src(), edges[0]->src_output()) 235 .Input(edges[1]->src(), edges[1]->src_output()) 236 .Input(edges[2]->src(), edges[2]->src_output()) 237 .Input(edges[3]->src(), edges[3]->src_output()) 238 .Device(n->def().device()) 239 .Attr("T", n->input_type(0)) 240 .Finalize(&**g, &conversion_node)); 241 242 CHECK_NOTNULL(conversion_node); 243 244 // Change the destination of any control edges to the InputConversion node 245 if (edges.size() != n->in_edges().size()) { 246 std::vector<const Edge*> edges_to_remove; 247 for (const Edge* e : n->in_edges()) { 248 if (e->IsControlEdge()) { 249 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node)); 250 edges_to_remove.push_back(e); 251 } 252 } 253 for (const Edge* e : edges_to_remove) { 254 (*g)->RemoveEdge(e); 255 } 256 } 257 258 string data_format; 259 if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) == 260 Status::OK()) { 261 conversion_node->AddAttr("data_format", data_format); 262 } 263 264 // Get assigned device from destination node and apply it to conversion node. 265 // We want conversion node to be on the same device as the destination node. 266 conversion_node->set_assigned_device_name(n->assigned_device_name()); 267 268 // Set the Mkl op label for this op. 269 conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); 270 271 // Now that we have added edges from src->conversion_node, let's add edge from 272 // output of conversion_node to the element-wise node. 273 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input())); 274 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input())); 275 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input())); 276 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input())); 277 278 VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input " 279 << "conversion node on: " << n->type_string() << " successful."; 280 281 // Remove src->dst edge now. 282 (*g)->RemoveEdge(edges[0]); 283 (*g)->RemoveEdge(edges[1]); 284 (*g)->RemoveEdge(edges[2]); 285 (*g)->RemoveEdge(edges[3]); 286 287 return Status::OK(); 288} 289 290bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { 291 bool result = false; 292 293 CHECK_NOTNULL(g); 294 295 DumpGraph("Before MklToTfConversionPass", &**g); 296 297 // Since we are looking for an Mkl-supported op node immediately 298 // followed by a non-Mkl op node, we will just iterate over edge 299 // set of the graph. 300 // edge set whose source and destination are candidates for 301 // inserting conversion node 302 std::vector<Edge*> candidate_edges; 303 304 for (const Edge* e : (*g)->edges()) { 305 Node* src = e->src(); 306 Node* dst = e->dst(); 307 308 // We skip control edges. 309 if (e->IsControlEdge()) { 310 continue; 311 } 312 313 // We skip adding MklToTf on an edge between X->MklToTf or 314 // MklToTf->X, where X is any node. 315 if (src->type_string().compare("_MklToTf") == 0 || 316 dst->type_string().compare("_MklToTf") == 0) { 317 continue; 318 } 319 320 VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: " 321 << src->type_string() << " and " << dst->type_string(); 322 323 // Let's get source and destination data type. 324 // We cannot check datatype on destination node because destination node 325 // may not be Mkl node. 326 DataType src_datatype; 327 DataType dst_datatype; 328 bool src_is_mkl_op = 329 (GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() && 330 IsMklSupportedOp(src->type_string(), src_datatype)); 331 bool dst_is_mkl_op = 332 (GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() && 333 IsMklSupportedOp(dst->type_string(), dst_datatype)); 334 335 // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. 336 if (src_is_mkl_op && !dst_is_mkl_op) { 337 VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name() 338 << " and " << dst->name() << " for inserting conversion nodes"; 339 candidate_edges.push_back(const_cast<Edge*>(e)); 340 } 341 } 342 343 // Process all candidate edges and insert conversion nodes on them. 344 for (Edge* e : candidate_edges) { 345 // Even if we insert conversion node on a single edge, we 346 // need to return true. 347 string src_name = e->src()->name(); 348 string dst_name = e->dst()->name(); 349 if (InsertConversionNodeOnEdge(g, e) == Status::OK()) { 350 VLOG(1) << "MklToTfConversionPass: Inserted conversion " 351 << "node on edge between " << src_name << " and " << dst_name; 352 result = true; 353 } 354 } 355 356 DumpGraph("After MklToTfConversionPass", &**g); 357 358 //--------------------------------------------------------------------------- 359 // Check all nodes and add an input-conversion-node if the node is an mkl 360 // element-wise node. 361 VLOG(1) << "Before running MklToTfConversionPass - InputConversion"; 362 363 std::vector<Node*> candidate_nodes; 364 std::vector<Node*> order; 365 GetReversePostOrder(**g, &order); // This will give us topological sort. 366 367 for (Node* n : order) { 368 // If node is not an op or it does not have a datatype, then skip. 369 DataType datatype; 370 if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != Status::OK())) { 371 continue; 372 } 373 if (IsMklElementWiseOp(n->type_string(), datatype)) { 374 // If the input node is an input-conversion op, skip 375 Node* input_node = nullptr; 376 TF_CHECK_OK(n->input_node(0, &input_node)); 377 DataType input_datatype; 378 if ((GetNodeAttr(n->def(), "T", &input_datatype) == Status::OK()) && 379 (input_node->type_string().compare("_MklInputConversion") == 0)) { 380 continue; 381 } 382 383 VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node " 384 << n->name() << " for inserting input conversion node"; 385 candidate_nodes.push_back(const_cast<Node*>(n)); 386 } 387 } 388 389 // Process all candidate edges and insert conversion nodes on them. 390 for (Node* n : candidate_nodes) { 391 // Even if we insert conversion node on a single node, we 392 // need to return true. 393 if (InsertInputConversionNode(g, n) == Status::OK()) { 394 VLOG(1) << "MklToTfConversionPass: Inserted conversion " 395 << "on node " << n->name(); 396 result = true; 397 } 398 } 399 DumpGraph("After MklToTfConversionPass - InputConversion", &**g); 400 401 // We need to return true even if we insert one conversion node 402 // anywhere in the graph. 403 return result; 404} 405 406////////////////////////////////////////////////////////////////////////////// 407// Run function for the pass 408////////////////////////////////////////////////////////////////////////////// 409 410bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) { 411 return MklToTfConversionPass().RunPass(g); 412} 413 414Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) { 415 if (options.graph == nullptr && options.partition_graphs == nullptr) { 416 return Status::OK(); 417 } 418 419 auto process_graph = [&](std::unique_ptr<Graph>* g) { 420 // Get the ownership of graph 421 std::unique_ptr<Graph>* ng = std::move(g); 422 RunPass(ng); 423 // Return the ownership of graph back 424 g->reset(ng->release()); 425 }; 426 427 if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) { 428 // For any pre-partitioning phase, graph is stored in options.graph. 429 process_graph(options.graph); 430 } else { 431 // For post partitioning phase, graphs are stored in 432 // options.partition_graphs. 433 for (auto& pg : *options.partition_graphs) { 434 process_graph(&pg.second); 435 } 436 } 437 438 return Status::OK(); 439} 440 441} // namespace tensorflow 442 443#endif 444