1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifdef INTEL_MKL
17
18#include <memory>
19#include <queue>
20#include <set>
21#include <string>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/core/common_runtime/function.h"
26#include "tensorflow/core/common_runtime/optimization_registry.h"
27#include "tensorflow/core/framework/node_def_util.h"
28#include "tensorflow/core/graph/algorithm.h"
29#include "tensorflow/core/graph/graph.h"
30#include "tensorflow/core/graph/node_builder.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/gtl/map_util.h"
33#include "tensorflow/core/lib/hash/hash.h"
34#include "tensorflow/core/platform/logging.h"
35
36#include "tensorflow/core/graph/mkl_graph_util.h"
37#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
38
39namespace tensorflow {
40
41// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
42// in the graph in between A and B, where A and B match any one
43// of the following cases:
44//
45//  1) A = a node that generates output in the Mkl format and,
46//     B = a node that does not accept input in the Mkl format and,
47//     A -> B (there is a direct edge between A and B, then
48//     We will insert C such that A->C->B.
49//
50//  2) A = a node that generates output in the Mkl format and,
51//     B = NULL (in other words, A is the last node in the graph), then
52//     We will insert C such that A->C->B. (C will be the last node.)
53//
54//  Note that case 1 applies to all outputs of A that are input to B.
55//  In other words, the conversions will be required for every output
56//  of A that is input to B. For example, let us say the output of A
57//  is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not
58//  in Mkl format, and all of them are input to B. In such case, we will
59//  do the conversion for A1 and A2 only. We do not need to do any conversion
60//  for A3.
61//
62// This pass relies on ops registering themselves about their Mkl compliance.
63// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
64// in the Mkl format. Non-compliant ops accept inputs and outputs in the
65// TensorFlow format.
66//
67// ADDENDUM: For element-wise ops, we may or may not need a conversion to
68// take place before we hit the op. For this, we add a new op before each
69// element-wise MKL op to deal with the inputs, called _MklInputConversion.
70// This pass has been enhanced to add this capability.
71//
72// The _MklInputConversion op will check the inputs to the elementwise op and
73// make sure that either both are in MKL format or both are in TF format,
74// depending on their initial state and whether broadcast is needed or not.
75
76class MklToTfConversionPass : public GraphOptimizationPass {
77 public:
78  MklToTfConversionPass() {}
79  Status Run(const GraphOptimizationPassOptions& options);
80
81  // Insert layout conversion node in the graph pointed by g.
82  // Function scans the graph for candidate edges where we
83  // need to insert conversion nodes.
84  //
85  // @return true even if single conversion node is inserted;
86  // false, otherwise.
87  bool RunPass(std::unique_ptr<Graph>* g);
88
89 private:
90  // Is the input Op supported by Mkl-specific layout?
91  //
92  // @input op_name string of the op
93  // @input T Datatype to use for checking input op
94  // @return true if op is Mkl supported; false, otherwise.
95  inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
96    return mkl_op_registry::IsMklOp(op_name, T);
97  }
98
99  // Is the input Op supported by Mkl-specific layout AND
100  //  is it element-wise?
101  //
102  // @input op_name string of the op
103  // @input T Datatype to use for checking input op
104  // @return true if op is Mkl supported; false, otherwise.
105  inline bool IsMklElementWiseOp(const string& op_name, DataType T) const {
106    return mkl_op_registry::IsMklElementWiseOp(op_name, T);
107  }
108
109  // Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
110  //
111  // Edge will be deleted once a call to this function is successful.
112  // Any attempt to use the edge after this call
113  // will lead to undefined behaviors.
114  //
115  // @return Success:OK() if insertion is successful, otherwise returns
116  //         appropriate error status code.
117  Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
118
119  // For element-wise ops, we need to sanitize the inputs. For this, we add a
120  // new node at the input of the replacement element-wise node that checks
121  // the inputs and converts one/both of them as required. See the op code
122  // comments for details.
123  //
124  // Insert input conversion node as parent of 'n' from graph 'g'.
125  //
126  // @return Success:OK() if insertion is successful, otherwise returns
127  //         appropriate error status code.
128  Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*);
129};
130
131// We register MklToTf insertion for phase 2 in post-partition grouping
132// because we register MklLayoutRewritePass for phase 1 in post-partition
133// grouping. We register this pass after partitioning so that we get a
134// complete picture of inputs and outputs of the nodes in the graphs.
135const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
136    OptimizationPassRegistry::POST_PARTITIONING;
137REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
138
139Status MklToTfConversionPass::InsertConversionNodeOnEdge(
140    std::unique_ptr<Graph>* g, Edge* e) {
141  CHECK_NOTNULL(e);
142
143  Node* src = e->src();
144  Node* dst = e->dst();
145
146  CHECK_NOTNULL(src);
147  CHECK_NOTNULL(dst);
148
149  Node* conversion_node = nullptr;
150  DataType src_datatype = DT_INVALID;
151  DataType dst_datatype = DT_INVALID;
152  string data_format;
153
154  TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
155  bool dst_dtype_found =
156      GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK();
157  // We compare source and destination datatypes only when both are found.
158  if (dst_dtype_found && (src_datatype != dst_datatype)) {
159    string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
160                     " do not match. Will not insert" +
161                     " MklToTf node in such case.";
162    return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
163  }
164
165  // Build the conversion node and specify src as input.
166  TF_CHECK_OK(
167      NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
168          .Input(src, e->src_output())
169          .Input(src, DataIndexToMetaDataIndex(
170                          e->src_output(),
171                          src->num_outputs()))  // Get an Mkl tensor slot
172                                                // from the Tf tensor slot.
173          .Device(src->def().device())  // We want to get conversion node
174                                        // on same device as source node.
175          .Attr("T", src_datatype)
176          .Finalize(&**g, &conversion_node));
177
178  CHECK_NOTNULL(conversion_node);
179  if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
180    conversion_node->AddAttr("data_format", data_format);
181  }
182
183  // Get assigned device from source node and apply it to conversion node.
184  // We want conversion node to be on the same device as the source node.
185  conversion_node->set_assigned_device_name(src->assigned_device_name());
186
187  // Set the Mkl op label for this op.
188  conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
189
190  // Now that we have added edge from src->conversion_node, let's add edge from
191  // output of conversion_node to the dest node. Since conversion_node
192  // has only 1 output, the src_output of conversion_node is 0.
193  CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input()));
194
195  VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: "
196          << src->type_string() << " and " << dst->type_string()
197          << " successful.";
198
199  // Remove src->dst edge now.
200  (*g)->RemoveEdge(e);
201  return Status::OK();
202}
203
204Status MklToTfConversionPass::InsertInputConversionNode(
205    std::unique_ptr<Graph>* g, Node* n) {
206  CHECK_NOTNULL(n);
207
208  // Get the input nodes and edges
209  std::vector<const Edge*> edges;
210  TF_CHECK_OK(n->input_edges(&edges));
211  if (edges.size() != 4) {
212    return Status(error::Code::INVALID_ARGUMENT,
213                  "MKL Binary Element-wise op should have exactly 2 data"
214                  " inputs and 2 metadata inputs");
215  }
216
217  // Sanity check: ensure that both inputs are of the expected type, and the
218  // same type as input type
219  CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
220           BaseType(edges[1]->src()->output_type(edges[1]->src_output())));
221  CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
222           BaseType(n->input_type(0)));
223
224  // Check ordering of edges
225  for (uint32 i = 0; i < 4; i++) {
226    CHECK_EQ((edges[i]->dst_input() == i), true);
227  }
228
229  // Build the conversion node and specify src as input.
230  Node* conversion_node = nullptr;
231
232  TF_CHECK_OK(
233      NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion")
234          .Input(edges[0]->src(), edges[0]->src_output())
235          .Input(edges[1]->src(), edges[1]->src_output())
236          .Input(edges[2]->src(), edges[2]->src_output())
237          .Input(edges[3]->src(), edges[3]->src_output())
238          .Device(n->def().device())
239          .Attr("T", n->input_type(0))
240          .Finalize(&**g, &conversion_node));
241
242  CHECK_NOTNULL(conversion_node);
243
244  // Change the destination of any control edges to the InputConversion node
245  if (edges.size() != n->in_edges().size()) {
246    std::vector<const Edge*> edges_to_remove;
247    for (const Edge* e : n->in_edges()) {
248      if (e->IsControlEdge()) {
249        CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node));
250        edges_to_remove.push_back(e);
251      }
252    }
253    for (const Edge* e : edges_to_remove) {
254      (*g)->RemoveEdge(e);
255    }
256  }
257
258  string data_format;
259  if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
260      Status::OK()) {
261    conversion_node->AddAttr("data_format", data_format);
262  }
263
264  // Get assigned device from destination node and apply it to conversion node.
265  // We want conversion node to be on the same device as the destination node.
266  conversion_node->set_assigned_device_name(n->assigned_device_name());
267
268  // Set the Mkl op label for this op.
269  conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
270
271  // Now that we have added edges from src->conversion_node, let's add edge from
272  // output of conversion_node to the element-wise node.
273  CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input()));
274  CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input()));
275  CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input()));
276  CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input()));
277
278  VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input "
279          << "conversion node on: " << n->type_string() << " successful.";
280
281  // Remove src->dst edge now.
282  (*g)->RemoveEdge(edges[0]);
283  (*g)->RemoveEdge(edges[1]);
284  (*g)->RemoveEdge(edges[2]);
285  (*g)->RemoveEdge(edges[3]);
286
287  return Status::OK();
288}
289
290bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
291  bool result = false;
292
293  CHECK_NOTNULL(g);
294
295  DumpGraph("Before MklToTfConversionPass", &**g);
296
297  // Since we are looking for an Mkl-supported op node immediately
298  // followed by a non-Mkl op node, we will just iterate over edge
299  // set of the graph.
300  // edge set whose source and destination are candidates for
301  // inserting conversion node
302  std::vector<Edge*> candidate_edges;
303
304  for (const Edge* e : (*g)->edges()) {
305    Node* src = e->src();
306    Node* dst = e->dst();
307
308    // We skip control edges.
309    if (e->IsControlEdge()) {
310      continue;
311    }
312
313    // We skip adding MklToTf on an edge between X->MklToTf or
314    // MklToTf->X, where X is any node.
315    if (src->type_string().compare("_MklToTf") == 0 ||
316        dst->type_string().compare("_MklToTf") == 0) {
317      continue;
318    }
319
320    VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: "
321            << src->type_string() << " and " << dst->type_string();
322
323    // Let's get source and destination data type.
324    // We cannot check datatype on destination node because destination node
325    // may not be Mkl node.
326    DataType src_datatype;
327    DataType dst_datatype;
328    bool src_is_mkl_op =
329        (GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() &&
330         IsMklSupportedOp(src->type_string(), src_datatype));
331    bool dst_is_mkl_op =
332        (GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() &&
333         IsMklSupportedOp(dst->type_string(), dst_datatype));
334
335    // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
336    if (src_is_mkl_op && !dst_is_mkl_op) {
337      VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
338              << " and " << dst->name() << " for inserting conversion nodes";
339      candidate_edges.push_back(const_cast<Edge*>(e));
340    }
341  }
342
343  // Process all candidate edges and insert conversion nodes on them.
344  for (Edge* e : candidate_edges) {
345    // Even if we insert conversion node on a single edge, we
346    // need to return true.
347    string src_name = e->src()->name();
348    string dst_name = e->dst()->name();
349    if (InsertConversionNodeOnEdge(g, e) == Status::OK()) {
350      VLOG(1) << "MklToTfConversionPass: Inserted conversion "
351              << "node on edge between " << src_name << " and " << dst_name;
352      result = true;
353    }
354  }
355
356  DumpGraph("After MklToTfConversionPass", &**g);
357
358  //---------------------------------------------------------------------------
359  // Check all nodes and add an input-conversion-node if the node is an mkl
360  // element-wise node.
361  VLOG(1) << "Before running MklToTfConversionPass - InputConversion";
362
363  std::vector<Node*> candidate_nodes;
364  std::vector<Node*> order;
365  GetReversePostOrder(**g, &order);  // This will give us topological sort.
366
367  for (Node* n : order) {
368    // If node is not an op or it does not have a datatype, then skip.
369    DataType datatype;
370    if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != Status::OK())) {
371      continue;
372    }
373    if (IsMklElementWiseOp(n->type_string(), datatype)) {
374      // If the input node is an input-conversion op, skip
375      Node* input_node = nullptr;
376      TF_CHECK_OK(n->input_node(0, &input_node));
377      DataType input_datatype;
378      if ((GetNodeAttr(n->def(), "T", &input_datatype) == Status::OK()) &&
379          (input_node->type_string().compare("_MklInputConversion") == 0)) {
380        continue;
381      }
382
383      VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node "
384              << n->name() << " for inserting input conversion node";
385      candidate_nodes.push_back(const_cast<Node*>(n));
386    }
387  }
388
389  // Process all candidate edges and insert conversion nodes on them.
390  for (Node* n : candidate_nodes) {
391    // Even if we insert conversion node on a single node, we
392    // need to return true.
393    if (InsertInputConversionNode(g, n) == Status::OK()) {
394      VLOG(1) << "MklToTfConversionPass: Inserted conversion "
395              << "on node " << n->name();
396      result = true;
397    }
398  }
399  DumpGraph("After MklToTfConversionPass - InputConversion", &**g);
400
401  // We need to return true even if we insert one conversion node
402  // anywhere in the graph.
403  return result;
404}
405
406//////////////////////////////////////////////////////////////////////////////
407//              Run function for the pass
408//////////////////////////////////////////////////////////////////////////////
409
410bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
411  return MklToTfConversionPass().RunPass(g);
412}
413
414Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
415  if (options.graph == nullptr && options.partition_graphs == nullptr) {
416    return Status::OK();
417  }
418
419  auto process_graph = [&](std::unique_ptr<Graph>* g) {
420    // Get the ownership of graph
421    std::unique_ptr<Graph>* ng = std::move(g);
422    RunPass(ng);
423    // Return the ownership of graph back
424    g->reset(ng->release());
425  };
426
427  if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) {
428    // For any pre-partitioning phase, graph is stored in options.graph.
429    process_graph(options.graph);
430  } else {
431    // For post partitioning phase, graphs are stored in
432    // options.partition_graphs.
433    for (auto& pg : *options.partition_graphs) {
434      process_graph(&pg.second);
435    }
436  }
437
438  return Status::OK();
439}
440
441}  // namespace tensorflow
442
443#endif
444