1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18#include <algorithm>
19#include <deque>
20#include <stack>
21#include <unordered_set>
22#include <vector>
23
24#include "tensorflow/compiler/jit/graph_to_functiondef.h"
25#include "tensorflow/compiler/jit/union_find.h"
26#include "tensorflow/compiler/tf2xla/dump_graph.h"
27#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
28#include "tensorflow/compiler/xla/ptr_util.h"
29#include "tensorflow/compiler/xla/status_macros.h"
30#include "tensorflow/core/common_runtime/function.h"
31#include "tensorflow/core/framework/node_def_builder.h"
32#include "tensorflow/core/graph/algorithm.h"
33#include "tensorflow/core/graph/control_flow.h"
34#include "tensorflow/core/lib/gtl/optional.h"
35
36namespace tensorflow {
37
38namespace {
39
40using xla::StatusOr;
41
42const char* const kArgOp = "_Arg";
43const char* const kRetValOp = "_Retval";
44
45// Information about a loop argument.
46struct Arg {
47  // Every loop argument has an Enter node.
48  Node* enter;
49
50  // Is the loop argument a loop-invariant value? Taken from the `is_constant`
51  // attribute on the Enter node.
52  bool is_loop_invariant;
53
54  // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
55  // arguments must have all of the following nodes:
56  Node* merge = nullptr;
57  Node* switch_node = nullptr;
58  Node* next_iteration = nullptr;
59  Node* exit = nullptr;
60};
61
62// Information about a loop frame.
63struct Frame {
64  string name;
65
66  // Pointer to the parent frame. The root frame has a pointer to itself.
67  Frame* parent = nullptr;
68  int num_children = 0;
69
70  // Arguments to this loop.
71  std::vector<Arg> args;
72
73  // The loop condition of the loop. There should be exactly one loop condition
74  // in every loop.
75  Node* loop_cond = nullptr;
76
77  // Set of nodes that belong to the loop frame.
78  std::unordered_set<Node*> nodes;
79};
80
81// Comparison function used for sorting nodes consistently.
82// a) resource variables are last, and
83// b) sort lexicographically by name (for deterministic output).
84struct NodeCmp {
85  bool operator()(const Node* lhs, const Node* rhs) const {
86    bool lhs_is_resource =
87        lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
88    bool rhs_is_resource =
89        rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
90    return std::tie(lhs_is_resource, lhs->name()) <
91           std::tie(rhs_is_resource, rhs->name());
92  }
93};
94
95// Returns a textual representation of the names of the nodes in the input.
96template <typename T>
97string NodesToString(const T& nodes) {
98  return strings::StrCat("{",
99                         str_util::Join(nodes, ",",
100                                        [](string* output, const Node* node) {
101                                          strings::StrAppend(output,
102                                                             node->name());
103                                        }),
104                         "}");
105}
106
107// Copies a subgraph from `graph` to `output` by performing a reverse DFS
108// starting at nodes in vector `stack`.
109// `node_map` is a vector indexed by source node ID to dest nodes.
110// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
111// before the traversal clients can cut the graph. If a frame is provided (frame
112// != nullptr), then this functions will return an error if the
113// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
114// cut the graph and prevent the traversal from escaping.
115//
116// `squash_src_outputs` contains a bool for each source node ID. If true, then
117// the source output on that node will be replaced by zero when copied. This is
118// used when replacing a Switch node with an _Arg node. The output we are
119// taking from the Switch node was not necessarily the first output, but _Arg
120// nodes only have one output. By adding the Switch node to `squash_src_outputs`
121// we rewrite the src_output of the corresponding edge to be 0.
122Status CopySubgraph(const Graph& graph, const Frame* frame,
123                    std::vector<Node*> stack,
124                    const std::vector<bool>& squash_src_outputs,
125                    std::vector<Node*>* node_map, Graph* output) {
126  VLOG(3) << "Stack: " << NodesToString(stack);
127  std::vector<bool> visited(graph.num_node_ids(), false);
128  while (!stack.empty()) {
129    Node* n = stack.back();
130    stack.pop_back();
131
132    VLOG(5) << "Copying node " << n->name();
133
134    if (visited[n->id()]) continue;
135    visited[n->id()] = true;
136
137    for (const Edge* e : n->in_edges()) {
138      Node* src = e->src();
139      if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
140        // We traversed out of the loop frame, without encountering a cut node.
141        return errors::Internal("Graph traversal of loop frame ", frame->name,
142                                " escaped frame at ", src->name(),
143                                " without encountering an argument node.");
144      }
145      if ((*node_map)[src->id()] == nullptr) {
146        (*node_map)[src->id()] = output->CopyNode(src);
147        stack.push_back(src);
148      }
149      Node* src_copy = (*node_map)[e->src()->id()];
150      int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
151                           ? 0
152                           : e->src_output();
153      Node* dst_copy = (*node_map)[e->dst()->id()];
154      output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
155    }
156  }
157  return Status::OK();
158}
159
160StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
161  Status status;
162  Node* inserted_node = graph->AddNode(node_def, &status);
163  if (!status.ok()) {
164    return status;
165  }
166  return inserted_node;
167}
168
169StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
170  NodeDef arg_def;
171  NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
172  builder.Attr("T", type);
173  builder.Attr("index", index);
174  TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
175  return AddNode(arg_def, graph);
176}
177
178StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
179  NodeDef ret_def;
180  ret_def.set_op(kRetValOp);
181  ret_def.set_name(strings::StrCat(kRetValOp, index));
182  AddNodeAttr("T", type, &ret_def);
183  AddNodeAttr("index", index, &ret_def);
184  return AddNode(ret_def, graph);
185}
186
187// Builds a graph for the loop condition.
188Status BuildLoopCondition(const Graph& graph, Frame* frame,
189                          std::unique_ptr<Graph>* cond_output) {
190  VLOG(2) << "Building loop condition for " << frame->name;
191  *cond_output = xla::MakeUnique<Graph>(graph.op_registry());
192  Graph* output = cond_output->get();
193
194  // Map from nodes in the original graph to the condition graph.
195  std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
196  std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
197
198  // Build one _Arg node for each Enter node.
199  for (int i = 0; i < frame->args.size(); ++i) {
200    const Arg& arg = frame->args[i];
201
202    TF_ASSIGN_OR_RETURN(Node * arg_node,
203                        BuildArgNode(output, arg.enter->input_type(0), i));
204    if (arg.is_loop_invariant) {
205      node_map[arg.enter->id()] = arg_node;
206    } else {
207      node_map[arg.merge->id()] = arg_node;
208    }
209  }
210
211  // Build a Retval node for the loop condition. The LoopCond nodes are always
212  // boolean because of the type constraints on the LoopCond op.
213  TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
214                      BuildRetvalNode(output, DT_BOOL, 0));
215
216  // Performs a reverse DFS, copying nodes and edges to the output graph.
217  // The _Arg and _Retval nodes were added unconditionally above, so we are
218  // guaranteed to get the correct function signature.
219  return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
220                      &node_map, output);
221}
222
223// Builds a graph for the loop body.
224Status BuildLoopBody(const Graph& graph, Frame* frame,
225                     DataTypeVector* arg_types,
226                     std::unique_ptr<Graph>* body_output) {
227  VLOG(2) << "Building loop body for " << frame->name;
228  *body_output = xla::MakeUnique<Graph>(graph.op_registry());
229  Graph* output = body_output->get();
230
231  // Map from nodes in the original graph to the condition graph.
232  std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
233  std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
234
235  // Build one _Arg node for each Enter node.
236  std::vector<Node*> next_iterations;
237  next_iterations.reserve(frame->args.size());
238  arg_types->reserve(frame->args.size());
239  for (int i = 0; i < frame->args.size(); ++i) {
240    const Arg& arg = frame->args[i];
241
242    DataType dtype = arg.enter->input_type(0);
243    arg_types->push_back(dtype);
244
245    TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
246
247    if (dtype == DT_RESOURCE) {
248      // The convention of the XLA bridge is that resource variable arguments
249      // are only inputs to the loop body and have no corresponding output.
250      // TODO(b/37741920): change the convention so that DT_RESOURCE variables
251      // are both inputs and outputs, and then remove this case.
252      TF_RET_CHECK(arg.is_loop_invariant);
253      node_map[arg.enter->id()] = arg_node;
254    } else {
255      TF_ASSIGN_OR_RETURN(Node * retval_node,
256                          BuildRetvalNode(output, dtype, i));
257
258      if (arg.is_loop_invariant) {
259        // Argument is loop-invariant. Forward it from the Arg to the Retval.
260        node_map[arg.enter->id()] = arg_node;
261        output->AddEdge(arg_node, 0, retval_node, 0);
262      } else {
263        // Argument is loop-varying.
264        node_map[arg.switch_node->id()] = arg_node;
265        // The Switch node has two outputs, but _Arg only has one. This tells
266        // the CopySubgraph function to rewrite the output number of edges from
267        // the _Arg node to be 0 rather than copying the output number from the
268        // Switch node.
269        squash_src_outputs[arg.switch_node->id()] = true;
270        node_map[arg.next_iteration->id()] = retval_node;
271        next_iterations.push_back(arg.next_iteration);
272      }
273    }
274  }
275
276  // Performs a reverse DFS, copying nodes and edges to the output graph.
277  // The _Arg and _Retval nodes were added unconditionally above, so we are
278  // guaranteed to get the correct function signature.
279  TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
280                                  squash_src_outputs, &node_map, output));
281
282  return Status::OK();
283}
284
285Status FunctionalizeLoop(Graph* graph, Frame* frame,
286                         FunctionLibraryDefinition* library) {
287  VLOG(2) << "Frame " << frame->name << " before: "
288          << dump_graph::DumpGraphToFile("functionalize_before", *graph,
289                                         library);
290
291  // Split loop-varying Enter nodes with multiple successors. If the same
292  // Tensor is fed as input to multiple loop arguments, we may end up with a
293  // shared Enter node. We clone Enter nodes with multiple successors to
294  // maintain the invariant of a unique Enter node per argument of the final
295  // loop.
296  std::vector<Arg> args;
297  for (const Arg& arg : frame->args) {
298    if (arg.is_loop_invariant) {
299      args.push_back(arg);
300    } else {
301      std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
302                                     arg.enter->out_edges().end());
303      for (int i = 0; i < edges.size(); ++i) {
304        if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
305          continue;
306        }
307        TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
308        Arg new_arg;
309        new_arg.is_loop_invariant = false;
310        if (i == 0) {
311          new_arg.enter = arg.enter;
312        } else {
313          new_arg.enter = graph->CopyNode(arg.enter);
314          frame->nodes.insert(new_arg.enter);
315          for (Edge const* e : arg.enter->in_edges()) {
316            graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
317                           e->IsControlEdge() ? Graph::kControlSlot : 0);
318          }
319          Node* dst = edges[i]->dst();
320          int dst_input = edges[i]->dst_input();
321          graph->RemoveEdge(edges[i]);
322          graph->AddEdge(new_arg.enter, 0, dst, dst_input);
323        }
324        args.push_back(new_arg);
325      }
326    }
327  }
328  frame->args = std::move(args);
329
330  std::sort(
331      frame->args.begin(), frame->args.end(),
332      [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); });
333
334  if (frame->loop_cond == nullptr) {
335    return errors::InvalidArgument("Loop ", frame->name,
336                                   " has no LoopCond node");
337  }
338
339  // Find the set of Switch nodes that are successors of the LoopCond.
340  std::unordered_set<Node*> switches;
341  for (const Edge* edge : frame->loop_cond->out_edges()) {
342    if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
343        edge->dst_input() == 1) {
344      switches.insert(edge->dst());
345    }
346  }
347
348  // For each non-constant argument, looks for the following pattern of nodes:
349  // Enter ----> Merge  -------->  Switch  --> Exit
350  //               ^                  ^
351  //               |                  |
352  //         NextIteration         LoopCond
353  //               ^                  ^
354  //               |                  |
355  //              ...                ...
356  for (Arg& arg : frame->args) {
357    if (!arg.is_loop_invariant) {
358      // Follow the edge from the Enter to Merge.
359      const Edge* enter_merge = nullptr;
360      for (const Edge* e : arg.enter->out_edges()) {
361        // Ignore control-edges to the sink node. These are allowed by the
362        // graph invariants, although probably they should have been stripped
363        // off earlier.
364        if (e->IsControlEdge() && e->dst()->IsSink()) {
365          continue;
366        }
367        if (enter_merge != nullptr) {
368          return errors::Internal(
369              "Enter node for loop-varying argument ", arg.enter->name(),
370              " has multiple successors: ", enter_merge->dst()->name(), " and ",
371              e->dst()->name());
372        }
373        enter_merge = e;
374      }
375      if (enter_merge == nullptr) {
376        return errors::Internal("Enter node for loop-varying argument ",
377                                arg.enter->name(), " has zero successors");
378      }
379      arg.merge = enter_merge->dst();
380      if (!IsMerge(arg.merge)) {
381        return errors::InvalidArgument(
382            "Successor of Enter node for loop-varying argument ",
383            arg.merge->name(),
384            " is not a Merge node; got: ", arg.merge->type_string());
385      }
386
387      // Find the NextIteration from the merge. There should be two inputs to
388      // the Merge and the NextIteration should be the other input.
389      if (arg.merge->input_types().size() != 2) {
390        return errors::InvalidArgument(
391            "Unexpected number of inputs to Merge node for loop-varying "
392            "argument ",
393            arg.merge->name(), "; expected 2, got ",
394            arg.merge->input_types().size());
395      }
396      TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
397                                               &arg.next_iteration));
398      if (!IsNextIteration(arg.next_iteration)) {
399        return errors::InvalidArgument(
400            "Expected NextIteration node as input to Merge node; got node ",
401            arg.next_iteration->name(), " with kind ",
402            arg.next_iteration->type_string());
403      }
404
405      // Find the Switch successor of the Merge. There should be exactly one
406      // Switch node that is a successor of both the Merge and the LoopCond.
407      for (const Edge* edge : arg.merge->out_edges()) {
408        if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
409            switches.find(edge->dst()) != switches.end()) {
410          if (arg.switch_node != nullptr) {
411            return errors::InvalidArgument("Duplicate Switch successors to ",
412                                           arg.merge->name());
413          }
414          arg.switch_node = edge->dst();
415        }
416      }
417      if (arg.switch_node == nullptr) {
418        return errors::InvalidArgument("Missing Switch successor to ",
419                                       arg.merge->name());
420      }
421
422      // Update the device on the Identity outputs of the switch to match their
423      // target. These Identity outputs do not
424
425      // Loop over the switch node's output to:
426      // - Find the Exit successor.
427      // - Set the sharding on all Identity outputs of the switch. These
428      //   identity nodes are values used by the loop body or condition.
429      //   The Identity node may have the wrong device so copy the device from
430      //   one of its outputs instead.
431      std::deque<const Edge*> possible_exit;
432      for (const Edge* edge : arg.switch_node->out_edges()) {
433        if (edge->src_output() == 0) {
434          possible_exit.push_back(edge);
435        }
436        if (IsIdentity(edge->dst())) {
437          TF_RETURN_IF_ERROR(
438              SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
439        }
440      }
441      // TODO(b/67425339): Allow general graph between switch and exit.
442      while (!possible_exit.empty()) {
443        const Edge* edge = possible_exit.front();
444        possible_exit.pop_front();
445        if (IsExit(edge->dst())) {
446          if (arg.exit != nullptr) {
447            return errors::InvalidArgument("Duplicate Exit successors to ",
448                                           arg.switch_node->name());
449          }
450          arg.exit = edge->dst();
451        } else {
452          if (!IsIdentity(edge->dst())) {
453            return errors::Unimplemented("General graph between switch (",
454                                         arg.switch_node->name(),
455                                         ") and exit node of frame ",
456                                         frame->name, " not supported yet.");
457          }
458          for (const Edge* out : edge->dst()->out_edges()) {
459            possible_exit.push_back(out);
460          }
461        }
462      }
463    }
464  }
465
466  // Builds the condition and body functions.
467  std::unique_ptr<Graph> cond_graph;
468  TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
469  DataTypeVector arg_types;
470  std::unique_ptr<Graph> body_graph;
471  TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
472
473  VLOG(2) << "Frame " << frame->name << " condition: "
474          << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
475          << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
476
477  static std::atomic<int64> sequence_num(0LL);
478  int64 id = ++sequence_num;
479  NameAttrList cond_name;
480  cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
481  NameAttrList body_name;
482  body_name.set_name(strings::StrCat("_functionalize_body_", id));
483  FunctionDef cond_fdef;
484  TF_RETURN_IF_ERROR(
485      GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
486  FunctionDef body_fdef;
487  TF_RETURN_IF_ERROR(
488      GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
489
490  TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
491  TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
492
493  // Builds a While operator.
494  NodeDef while_def;
495  NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
496  builder.Attr("T", arg_types);
497  builder.Attr("cond", cond_name);
498  builder.Attr("body", body_name);
499  std::vector<NodeDefBuilder::NodeOut> inputs;
500  for (int i = 0; i < frame->args.size(); ++i) {
501    const Arg& arg = frame->args[i];
502    const Edge* in_edge;
503    TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
504    if (in_edge->IsControlEdge()) {
505      builder.ControlInput(in_edge->src()->name());
506    } else {
507      inputs.push_back(NodeDefBuilder::NodeOut(
508          in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
509    }
510  }
511  builder.Input(inputs);
512  TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
513  TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph));
514
515  // Copies edges to the Enter nodes and from the Exit nodes onto the While.
516  for (int i = 0; i < frame->args.size(); ++i) {
517    const Arg& arg = frame->args[i];
518    const Edge* in_edge;
519    TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
520    if (in_edge->IsControlEdge()) {
521      graph->AddControlEdge(in_edge->src(), while_node);
522    } else {
523      graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
524    }
525
526    if (!arg.is_loop_invariant) {
527      // Add output edges if the output of the loop is consumed.
528      if (arg.exit != nullptr) {
529        std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
530                                       arg.exit->out_edges().end());
531        for (const Edge* edge : edges) {
532          Node* dst = edge->dst();
533          int dst_input = edge->dst_input();
534          graph->RemoveEdge(edge);
535
536          if (dst_input == Graph::kControlSlot) {
537            graph->AddControlEdge(while_node, dst);
538          } else {
539            graph->AddEdge(while_node, i, dst, dst_input);
540          }
541        }
542      }
543    }
544  }
545
546  // Remove the old nodes from the graph, and add the while node to the parent
547  // frame.
548  for (Node* node : frame->nodes) {
549    graph->RemoveNode(node);
550  }
551  frame->nodes.clear();
552  frame->parent->nodes.insert(while_node);
553
554  VLOG(2) << "Frame " << frame->name << " after: "
555          << dump_graph::DumpGraphToFile("functionalize_after", *graph,
556                                         library);
557
558  return Status::OK();
559}
560
561class FunctionalizeCond {
562 public:
563  // All nodes are assumed to be either in no branch, then branch, else branch,
564  // or both branches (such as merge nodes).
565  enum Branch {
566    kElseBranch = 0,
567    kThenBranch = 1,
568    kBoth = 2,
569    kNeither = 3,
570    kNumBranchTypes = 4
571  };
572
573  // Returns a textual representation of the Branch b.
574  static string Branch_Name(FunctionalizeCond::Branch b);
575
576  // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
577  // nodes. That is, attempt to transform every remaining switch and merge nodes
578  // in the graph into XlaIf nodes.
579  // Precondition: All while loops have been removed from graph.
580  static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
581
582 private:
583  // CondArgNode represents a input to the conditional and its corresponding
584  // switch nodes.
585  struct CondArgNode {
586    explicit CondArgNode(Node* input) : input(input) {}
587    string ToString() const {
588      return strings::StrCat("input=", input->name(),
589                             " switches=", NodesToString(switches));
590    }
591
592    Node* input;
593    std::vector<Node*> switches;
594  };
595  using CondArgNodes = std::vector<CondArgNode>;
596
597  struct ForwardFlowNode {
598    explicit ForwardFlowNode(Branch branch = Branch::kNeither)
599        : branch(branch), count(0) {}
600    string ToString() const {
601      return strings::StrCat("branch=", Branch_Name(branch), " count=", count);
602    }
603    Branch branch;
604    int count;
605  };
606
607  // Group of switch nodes that will be part of the same XlaIf.
608  struct SwitchCluster {
609    explicit SwitchCluster(Node* predicate) : predicate(predicate) {}
610    string ToString() const {
611      return strings::StrCat(name, " predicate=", predicate->name(),
612                             " switches=", NodesToString(switches));
613    }
614
615    string name;
616    Node* predicate;
617    std::vector<Node*> switches;
618  };
619
620  FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
621                    bool dump_graphs)
622      : library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
623
624  // Perform the actual cond functionalization. Iterate over groups of switch
625  // nodes (linked by common predicate), from innermost to outermost, and
626  // extract into XlaIf nodes.
627  Status FunctionalizeInternal();
628
629  // Determines the branch_map (mapping from node to branch of cond) and
630  // frontier (the nodes where the cond ends).
631  StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
632                     std::unordered_set<Node*>>>
633  DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
634
635  // Returns XlaIf node created from subgraph of merge and switch nodes. This
636  // encapsulates the process of extracting the bodies needed for the then and
637  // else branch, creates a XlaIf node, removing the nodes of the branches from
638  // the graph and replacing the merge node with a XlaIf.
639  StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
640                                 const SwitchCluster& switch_cluster,
641                                 const std::vector<Node*>& switches);
642
643  // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
644  StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
645                                     const SwitchCluster& switch_cluster,
646                                     const std::vector<Node*>& merge_nodes);
647
648  // Extracts a function body corresponding to the given input edge of the merge
649  // node.
650  Status ExtractBody(const CondArgNodes& cond_arg_nodes,
651                     const std::vector<Node*>& switches,
652                     const std::vector<Node*>& merge_nodes, int input_edge,
653                     Graph* body);
654
655  // Adds all the input edges to `if_node` corresponding to the arguments.
656  Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate,
657                       Node* if_node);
658
659  // Adds all output edges from the `if_node`.
660  Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
661
662  // Returns the switch clusters of graph_ in postorder. Dead switch nodes are
663  // skipped and removed from the graph.
664  StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
665
666  // Update the state for destination based on the state of source and the node
667  // being updated.
668  Status Join(const ForwardFlowNode& src_state, const Node* dst,
669              ForwardFlowNode* dst_state);
670
671  // Ensure that all nodes in the branch_map are dominated by the switch
672  // nodes. Returns nodes that are not dominated by the switches but are a
673  // control dependency of a node in the cond, and remove such control
674  // dependencies.
675  StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes(
676      const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
677      const std::vector<Node*>& switches);
678
679  // Validates that the frontier of nodes for the conditional
680  // section are as expected.
681  Status ValidateFrontier(
682      const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
683      const std::unordered_set<Node*>& frontier);
684
685  FunctionLibraryDefinition* library_;
686  Graph* graph_;
687  bool dump_graphs_;
688};
689
690bool IsDeadSwitch(const Node* node) {
691  for (const Edge* e : node->out_edges()) {
692    const Node* dst = e->dst();
693    if (!dst->IsIdentity()) {
694      return false;
695    }
696    for (const Edge* ee : dst->out_edges()) {
697      if (!ee->IsControlEdge() || !ee->dst()->IsSink()) {
698        return false;
699      }
700    }
701  }
702  return true;
703}
704
705string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) {
706  const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = {
707      "else", "then", "both", "neither", "count"};
708  return branch_name[b];
709}
710
711Status FunctionalizeCond::ValidateFrontier(
712    const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
713        branch_map,
714    const std::unordered_set<Node*>& frontier) {
715  std::unordered_set<const Node*> pending[kNumBranchTypes];
716  for (Node* n : frontier) {
717    pending[branch_map.at(n).branch].insert(n);
718  }
719  TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]);
720  for (const Node* n : pending[kBoth]) {
721    TF_RET_CHECK(IsMerge(n)) << n->DebugString();
722    // Merge nodes may be in then or else branch too
723  }
724  int index = (pending[kThenBranch].size() <= pending[kElseBranch].size())
725                  ? kThenBranch
726                  : kElseBranch;
727  int other = 1 - index;
728  for (const Node* n : pending[index]) {
729    if (pending[other].find(n) != pending[other].end()) {
730      return errors::Internal(
731          "Node (", n->DebugString().c_str(),
732          ") in both Else and Then branch should be in Both.");
733    }
734  }
735  // An empty frontier indicates a dead switch. Above we attempt to remove dead
736  // switch nodes, but not all are removed so don't treat it as an error yet.
737  // TODO(jpienaar): Find out why dead switch nodes remain.
738  // if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
739  //     pending[kElseBranch].empty()) {
740  //   return errors::Internal("Unexpected empty frontier for switch nodes");
741  // }
742  return Status::OK();
743}
744
745Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
746                               const Node* dst, ForwardFlowNode* dst_state) {
747  TF_RET_CHECK(dst_state->branch != Branch::kBoth &&
748               dst_state->branch != Branch::kNumBranchTypes)
749      << "Unexpected/Invalid branch type: Merging "
750      << Branch_Name(src_state.branch) << " with "
751      << Branch_Name(dst_state->branch);
752  if (dst_state->branch == Branch::kNeither) {
753    dst_state->branch = src_state.branch;
754  } else if (src_state.branch != dst_state->branch &&
755             src_state.branch != Branch::kNeither) {
756    if (IsMerge(dst)) {
757      dst_state->branch = Branch::kBoth;
758    } else {
759      return errors::Internal("Illegal merge: ", src_state.ToString(), " with ",
760                              dst_state->ToString(), " for ",
761                              dst->DebugString());
762    }
763  }
764  ++dst_state->count;
765  return Status::OK();
766}
767
768StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
769FunctionalizeCond::DeterminePredicateSwitchOrder() {
770  struct Cluster {
771    bool operator==(const Cluster& other) const {
772      return representative == other.representative;
773    }
774    int representative = -1;
775  };
776
777  // Perform a DFS over the graph and
778  // * Determine the reverse topological order of the nodes (there should be no
779  //   cycles at this point so the post-order numbering corresponds to the
780  //   reverse topological sorting);
781  // * Identify dead switches;
782  // * Initialize the cluster's representative;
783  std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
784  std::vector<Node*> dead_switches;
785  std::vector<Node*> switch_order;
786  std::vector<Node*> rev_topo_sorted_nodes;
787  DFS(*graph_, nullptr, [&](Node* n) {
788    clusters[n->id()].Get().representative = n->id();
789    if (IsSwitch(n)) {
790      if (IsDeadSwitch(n)) {
791        dead_switches.push_back(n);
792      } else {
793        rev_topo_sorted_nodes.push_back(n);
794        switch_order.push_back(n);
795      }
796    } else if (n->IsOp()) {
797      // Exclude src and sink nodes from further consideration.
798      rev_topo_sorted_nodes.push_back(n);
799    }
800  });
801
802  std::vector<SwitchCluster> switch_clusters;
803  // Return early if there are no switches in the graph.
804  if (switch_order.empty()) {
805    return switch_clusters;
806  }
807
808  // Remove all dead switch nodes.
809  for (Node* n : dead_switches) {
810    VLOG(2) << "Removing dead switch: " << n->DebugString();
811    graph_->RemoveNode(n);
812  }
813
814  // Identify switch nodes that are part of the same control flow context by
815  // considering the operands of operations: an operation is part of the same
816  // control context as its operands unless the operation is a switch. Control
817  // dependencies are considered part of the same control flow context if the
818  // switch depth is the same (see comment below).
819
820  // entry_cluster records the input cluster to a switch node. This is used when
821  // merging with a merge node where the dst's cluster is merged with the entry
822  // cluster of the merge node's cluster (which corresponds to a switch cluster
823  // and so has an entry cluster).
824  std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
825
826  // Returns the output cluster of a node. Where the output cluster is cluster
827  // where the output of the node is used. For non-merge nodes this is simply
828  // the cluster they are part of, while for merge nodes it is the entry cluster
829  // of the cluster they are part of (this will correspond to the entry node of
830  // a switch node that dominates the merge).
831  auto find_output_cluster = [&](Node* n) {
832    UnionFind<Cluster>* cluster = &clusters[n->id()];
833    if (!IsMerge(n)) return cluster;
834    auto it = entry_cluster.find(clusters[n->id()].Get().representative);
835    // If the cluster is not found in the entry_cluster map then an
836    // instruction not dominated by a switch node has been merged into the
837    // cluster of the merge. This indicates a failure of the clustering.
838    CHECK(it != entry_cluster.end())
839        << "Unable to find entry for n=" << n->id() << " ("
840        << cluster->Get().representative << ")";
841    return it->second;
842  };
843
844  // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
845  std::vector<int> switch_depth(graph_->num_node_ids());
846  for (auto it = rev_topo_sorted_nodes.rbegin();
847       it != rev_topo_sorted_nodes.rend(); ++it) {
848    Node* n = *it;
849
850    // Compute switch depth.
851    int new_switch_depth = 0;
852    for (const Edge* e : n->in_edges()) {
853      Node* src = e->src();
854      new_switch_depth = std::max(
855          new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
856    }
857    switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
858
859    // Only merge the input operands of a switch. The switch's clustering itself
860    // is determined by the interaction of the switch's outputs.
861    if (IsSwitch(n)) {
862      Node* input;
863      TF_CHECK_OK(n->input_node(0, &input));
864      entry_cluster[n->id()] = &clusters[input->id()];
865      UnionFind<Cluster>* cluster = find_output_cluster(input);
866      int cluster_depth = switch_depth[cluster->Get().representative];
867      // Merge the inputs of the switch node with one another. This results in
868      // predicates and control input residing in the same cluster.
869      for (const Edge* e : n->in_edges()) {
870        Node* src = e->src();
871        UnionFind<Cluster>* src_cluster = find_output_cluster(src);
872        int src_cluster_depth = switch_depth[src_cluster->Get().representative];
873        if (cluster_depth != src_cluster_depth) {
874          return errors::InvalidArgument(
875              "Unable to functionalize control flow in graph: Switch ('",
876              n->name(), "') has operands ('", input->name(), "' and '",
877              src->name(), "') that have different switch depths (",
878              cluster_depth, " != ", src_cluster_depth, ")");
879        }
880        cluster->Merge(src_cluster);
881      }
882      continue;
883    }
884
885    for (const Edge* e : n->in_edges()) {
886      Node* src = e->src();
887      if (!src->IsOp()) continue;
888      UnionFind<Cluster>* cluster = find_output_cluster(src);
889      // Merge a node with its data operands and with its control operands if
890      // the src and dst are in the same ControlContext. The ControlContext is
891      // not explicitly available here, and instead the switch depth is used as
892      // a proxy here. Due to the invariant that control edges can only be from
893      // a containing scope to an inner scope or from the inner scope to its
894      // containing scope (for exit nodes), the switch depth will only match if
895      // the src and dst are in the same ControlContext. Control edges between
896      // ControlContexts are handled during the extraction.
897      int src_id = cluster->Get().representative;
898      int src_depth = switch_depth[src_id];
899      if (!e->IsControlEdge() || new_switch_depth == src_depth) {
900        if (src_depth != new_switch_depth) {
901          return errors::InvalidArgument(
902              "Unable to functionalize control flow in graph: Operand ('",
903              src->name(), "') and operator ('", n->name(),
904              "') have different switch depths (", src_depth,
905              " != ", new_switch_depth, ")");
906        }
907        cluster->Merge(&clusters[n->id()]);
908      }
909    }
910  }
911
912  if (dump_graphs_) {
913    // Mark the switch cluster each node is part of.
914    for (Node* n : graph_->nodes()) {
915      n->ClearAttr("_XlaFunctionalizeSwitchGroup");
916      n->AddAttr("_XlaFunctionalizeSwitchGroup",
917                 clusters[n->id()].Get().representative);
918    }
919    LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
920              << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
921                                             library_);
922  }
923
924  // Verify all the nodes of a cluster are at the same depth.
925  std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
926  for (Node* n : graph_->nodes()) {
927    int depth = switch_depth[n->id()];
928    int cluster_rep = clusters[n->id()].Get().representative;
929    auto it = cluster_to_depth_node.find(cluster_rep);
930    if (it == cluster_to_depth_node.end()) {
931      cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
932    } else {
933      if (it->second.first != depth) {
934        return errors::Internal(
935            "Illegal clustering created, mismatch in depths:", "\n\t",
936            n->DebugString(), "(", clusters[n->id()].Get().representative,
937            ") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
938            "(", clusters[n->id()].Get().representative, ") at depth ",
939            it->second.first);
940      }
941    }
942  }
943
944  struct Hash {
945    size_t operator()(const std::pair<Node*, Cluster>& item) const {
946      return Hash64Combine(hash<Node*>()(item.first),
947                           std::hash<int>()(item.second.representative));
948    }
949  };
950
951  // Merge Switch nodes with common predicate.
952  std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
953  // The nodes in switch_order are in reverse topological order, but the
954  // clustered switches need not be (i.e., when considered as a cluster one
955  // element of a cluster may be later in the topological order than another
956  // node whose cluster is later in the topological order of clustered
957  // switches).
958  for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
959    Node* pred;
960    TF_CHECK_OK((*it)->input_node(1, &pred));
961    auto repr = std::make_pair(pred, clusters[(*it)->id()].Get());
962    if (predicate_index.find(repr) == predicate_index.end()) {
963      predicate_index[repr] = switch_clusters.size();
964      switch_clusters.emplace_back(pred);
965      // Generate a name by concatenating with the cluster representative as
966      // there could be multiple switch clusters with the same predicate.
967      switch_clusters[predicate_index[repr]].name =
968          strings::StrCat(pred->name(), "_", repr.second.representative, "_If");
969    }
970    switch_clusters[predicate_index[repr]].switches.push_back(*it);
971  }
972
973  return switch_clusters;
974}
975
976StatusOr<std::vector<Node*>>
977FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes(
978    const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
979    const std::vector<Node*>& switches) {
980  std::vector<Node*> old_control_nodes;
981  for (const auto& kv : branch_map) {
982    if (kv.second.count != kv.first->in_edges().size()) {
983      std::vector<const Edge*> delete_edges;
984      for (const Edge* in : kv.first->in_edges()) {
985        auto it = branch_map.find(in->src());
986        if (it == branch_map.end()) {
987          if (in->IsControlEdge()) {
988            old_control_nodes.push_back(in->src());
989            delete_edges.push_back(in);
990          } else {
991            if (IsSwitch(in->src())) {
992              if (std::find(switches.begin(), switches.end(), in->src()) ==
993                  switches.end()) {
994                return errors::Internal(
995                    "Unexpected switch node found during flow forward: ",
996                    in->src()->DebugString());
997              }
998              continue;
999            }
1000            return errors::InvalidArgument(
1001                "Value ", kv.first->name(), "'s input, ", in->src()->name(),
1002                ", is not dominated by switch nodes ", NodesToString(switches));
1003          }
1004        }
1005      }
1006      // Remove control edges from nodes that are not dominated by the switch
1007      // nodes. New control dependencies will be added between these nodes and
1008      // the XlaIf node inserted.
1009      for (const Edge* e : delete_edges) {
1010        graph_->RemoveEdge(e);
1011      }
1012    }
1013  }
1014  return old_control_nodes;
1015}
1016
1017StatusOr<
1018    std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
1019              std::unordered_set<Node*>>>
1020FunctionalizeCond::DetermineBranchMapAndFrontier(
1021    const SwitchCluster& switch_cluster) {
1022  std::unordered_map<Node*, ForwardFlowNode> branch_map;
1023  std::unordered_set<Node*> frontier;
1024  std::vector<Node*> stack = switch_cluster.switches;
1025  std::vector<bool> visited(graph_->num_node_ids(), false);
1026  while (!stack.empty()) {
1027    Node* n = stack.back();
1028    stack.pop_back();
1029
1030    if (visited[n->id()]) {
1031      continue;
1032    }
1033    visited[n->id()] = true;
1034
1035    // Propagate branch state along each edge of a switch node.
1036    bool sink_only = true;
1037    for (const Edge* e : n->out_edges()) {
1038      Node* out = e->dst();
1039      if (!out->IsOp()) {
1040        continue;
1041      }
1042      sink_only = false;
1043      // Propagate branch information.
1044      ForwardFlowNode& ffn = branch_map[out];
1045      if (IsSwitch(n)) {
1046        int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
1047        TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
1048      } else {
1049        TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
1050      }
1051      if (IsMerge(out)) {
1052        if (out->in_edges().size() == ffn.count) {
1053          frontier.insert(out);
1054        }
1055      } else if (!visited[out->id()]) {
1056        stack.push_back(out);
1057      }
1058    }
1059    if (sink_only) {
1060      if (!IsIdentity(n)) {
1061        VLOG(1) << "Feeding into sink: " << n->DebugString();
1062      }
1063    }
1064  }
1065
1066  if (dump_graphs_) {
1067    for (const auto& kv : branch_map) {
1068      // Append attribute to the graph if running with logging to make the
1069      // changes clearer in the visualization.
1070      kv.first->AddAttr("_XlaFunctionalizeBranch",
1071                        Branch_Name(kv.second.branch));
1072    }
1073  }
1074  return std::make_pair(std::move(branch_map), std::move(frontier));
1075}
1076
1077Status FunctionalizeCond::FunctionalizeInternal() {
1078  TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
1079                      DeterminePredicateSwitchOrder());
1080
1081  // Iterate from innermost set of clustered switches to outermost, replacing
1082  // matching switch->merge subgraphs with single XlaIf nodes.
1083  for (auto it = predicate_switch_order.rbegin();
1084       it != predicate_switch_order.rend(); ++it) {
1085    auto& ps = *it;
1086    VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " ("
1087            << ps.predicate->name() << ")";
1088
1089    std::unordered_map<Node*, ForwardFlowNode> branch_map;
1090    std::unordered_set<Node*> frontier;
1091    TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
1092                        DetermineBranchMapAndFrontier(ps));
1093
1094    if (dump_graphs_)
1095      LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
1096                << dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
1097                                               library_);
1098    TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
1099
1100    // Sort the merge and switch nodes using NodeCmp. The switch-nodes are
1101    // further grouped (post sorting) by input to the switch node as in the
1102    // functionalized form each input will be passed in only once. This grouping
1103    // should retain the sorted order.
1104    CondArgNodes cond_arg_nodes;
1105    std::unordered_map<Node*, int> input_index;
1106    std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp());
1107    for (Node* switch_node : ps.switches) {
1108      Node* in;
1109      TF_RETURN_IF_ERROR(switch_node->input_node(0, &in));
1110      if (input_index.find(in) == input_index.end()) {
1111        input_index[in] = cond_arg_nodes.size();
1112        cond_arg_nodes.emplace_back(in);
1113      }
1114      cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node);
1115    }
1116    std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
1117    std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
1118
1119    TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes,
1120                        EnsureDominanceAndReturnNonDominatedControlNodes(
1121                            branch_map, ps.switches));
1122
1123    TF_ASSIGN_OR_RETURN(Node * if_node,
1124                        ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
1125    for (Node* old : old_control_nodes) {
1126      graph_->AddControlEdge(old, if_node);
1127    }
1128
1129    for (auto& del_kv : branch_map) {
1130      graph_->RemoveNode(del_kv.first);
1131    }
1132    for (auto& kv : cond_arg_nodes) {
1133      for (Node* node : kv.switches) {
1134        graph_->RemoveNode(node);
1135      }
1136    }
1137    if (dump_graphs_)
1138      LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
1139                << dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
1140                                               library_);
1141  }
1142  return Status::OK();
1143}
1144
1145StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
1146    const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
1147    const std::vector<Node*>& merge_nodes) {
1148  VLOG(2) << "Build if op for " << switch_cluster.name;
1149
1150  NodeDef if_def;
1151  // Create a new If node using the name of the merge node.
1152  NodeDefBuilder builder(switch_cluster.name, "XlaIf");
1153  string branch[] = {"else_branch", "then_branch"};
1154  for (int i = 0; i < 2; ++i) {
1155    static std::atomic<int64> sequence_num(0LL);
1156    int64 id = ++sequence_num;
1157
1158    NameAttrList body_name;
1159    body_name.set_name(
1160        strings::StrCat("_functionalize_if_", branch[i], "_", id));
1161    auto body = xla::MakeUnique<Graph>(graph_->op_registry());
1162    TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
1163                                   merge_nodes, i, body.get()));
1164    VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
1165    FunctionDef body_fdef;
1166    TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
1167    TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
1168    builder.Attr(branch[i], body_name);
1169  }
1170
1171  // Build input type.
1172  std::vector<NodeDefBuilder::NodeOut> inputs;
1173  DataTypeVector in_arg_types;
1174  for (auto& kv : cond_arg_nodes) {
1175    bool inserted = false;
1176    for (const Node* arg : kv.switches) {
1177      const Edge* in_edge;
1178      TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
1179      if (in_edge->IsControlEdge()) {
1180        builder.ControlInput(in_edge->src()->name());
1181      } else {
1182        if (!inserted) {
1183          DataType dtype = arg->input_type(0);
1184          inputs.emplace_back(NodeDefBuilder::NodeOut(
1185              in_edge->src()->name(), in_edge->src_output(), dtype));
1186          in_arg_types.push_back(dtype);
1187          inserted = true;
1188        }
1189      }
1190    }
1191  }
1192  builder.Attr("Tin", in_arg_types);
1193
1194  // Build output type.
1195  DataTypeVector out_type;
1196  for (const Node* merge : merge_nodes) {
1197    DataType dtype = merge->output_type(0);
1198    out_type.push_back(dtype);
1199  }
1200  builder.Attr("Tout", out_type);
1201
1202  builder.Attr("Tcond", DT_BOOL);
1203  builder.Device(switch_cluster.predicate->assigned_device_name());
1204  // Conditional should be the first input ...
1205  builder.Input(
1206      NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0,
1207                              switch_cluster.predicate->output_type(0)));
1208  // ... followed by the other inputs.
1209  builder.Input(inputs);
1210
1211  TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
1212  TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_));
1213  return if_node;
1214}
1215
1216Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
1217                                      const std::vector<Node*>& switches,
1218                                      const std::vector<Node*>& merge_nodes,
1219                                      int input_edge, Graph* body) {
1220  VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
1221          << input_edge;
1222  std::vector<bool> squash_src_outputs(graph_->num_node_ids(), false);
1223  std::vector<Node*> node_map(graph_->num_node_ids(), nullptr);
1224  int arg_count = 0;
1225  for (auto& kv : cond_arg_nodes) {
1226    Node* arg_node = nullptr;
1227    for (const auto* arg : kv.switches) {
1228      DataType dtype = arg->input_type(0);
1229      if (arg_node == nullptr) {
1230        TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
1231      }
1232      node_map.at(arg->id()) = arg_node;
1233      squash_src_outputs.at(arg->id()) = true;
1234    }
1235  }
1236
1237  std::vector<Node*> stack;
1238  stack.reserve(merge_nodes.size());
1239  for (int j = 0; j < merge_nodes.size(); ++j) {
1240    Node* node = merge_nodes[j];
1241    TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
1242                        BuildRetvalNode(body, node->output_type(0),
1243                                        /*index=*/j));
1244    const Edge* in_edge;
1245    TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge));
1246    Node* in = in_edge->src();
1247    if (node_map.at(in->id()) == nullptr) {
1248      node_map.at(in->id()) = body->CopyNode(in);
1249    }
1250
1251    if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
1252      body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
1253                    node_map.at(node->id()), 0);
1254    } else {
1255      body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0);
1256      // Don't include input nodes that are already just returned in stack.
1257      continue;
1258    }
1259    stack.push_back(in);
1260  }
1261
1262  return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map,
1263                      body);
1264}
1265
1266Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
1267                                        Node* predicate, Node* if_node) {
1268  VLOG(3) << "AddInputEdges for " << if_node->name();
1269  int index = 0;
1270  graph_->AddEdge(predicate, 0, if_node, index++);
1271  for (auto& kv : cond_arg_nodes) {
1272    bool inserted = false;
1273    for (const Node* arg : kv.switches) {
1274      const Edge* in_edge;
1275      TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
1276      if (in_edge->IsControlEdge()) {
1277        graph_->AddControlEdge(in_edge->src(), if_node);
1278      } else {
1279        if (!inserted) {
1280          graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node,
1281                          index++);
1282          inserted = true;
1283        }
1284      }
1285    }
1286  }
1287  return Status::OK();
1288}
1289
1290Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
1291                                         Node* if_node) {
1292  VLOG(3) << "AddOutputEdges for " << if_node->name();
1293  for (int i = 0; i < outputs.size(); ++i) {
1294    Node* node = outputs[i];
1295    std::vector<const Edge*> edges(node->out_edges().begin(),
1296                                   node->out_edges().end());
1297    for (const Edge* edge : edges) {
1298      Node* dst = edge->dst();
1299      int dst_input = edge->dst_input();
1300
1301      if (edge->src_output() > 0) {
1302        return errors::Unimplemented("Output of index (", edge->src_output(),
1303                                     ") of merge node ", node->name());
1304      }
1305      graph_->RemoveEdge(edge);
1306
1307      int src_output =
1308          dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
1309      graph_->AddEdge(if_node, src_output, dst, dst_input);
1310    }
1311  }
1312  return Status::OK();
1313}
1314
1315StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
1316    const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
1317    const std::vector<Node*>& merge_nodes) {
1318  VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
1319          << NodesToString(merge_nodes);
1320
1321  // Extract bodies and builds a If operator.
1322  TF_ASSIGN_OR_RETURN(
1323      Node * if_node,
1324      BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
1325  TF_RETURN_IF_ERROR(
1326      AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node));
1327  TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
1328
1329  return if_node;
1330}
1331
1332Status FunctionalizeCond::Functionalize(Graph* graph,
1333                                        FunctionLibraryDefinition* library) {
1334  VLOG(1) << "FunctionalizeCond::Functionalize";
1335  FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
1336  return fc.FunctionalizeInternal();
1337}
1338
1339}  // namespace
1340
1341// Transformation that converts TensorFlow's graph control flow constructs into
1342// functional equivalents.
1343Status FunctionalizeControlFlow(Graph* graph,
1344                                FunctionLibraryDefinition* library) {
1345  VLOG(2) << "FunctionalizeControlFlow (initial): "
1346          << dump_graph::DumpGraphToFile("functionalize_initial", *graph,
1347                                         library);
1348  // Note: BuildControlFlowInfo() requires that the graph's source node is
1349  // connected to all source nodes in the graph. Many graphs violate this
1350  // invariant.
1351  std::vector<ControlFlowInfo> cf_info;
1352  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
1353
1354  // Builds Frames, indexed by name.
1355  std::unordered_map<string, Frame> frames;
1356  for (Node* node : graph->op_nodes()) {
1357    const ControlFlowInfo& cf = cf_info[node->id()];
1358
1359    VLOG(2) << "node: " << node->name() << " (" << node->id()
1360            << ") frame_name: " << cf.frame_name
1361            << " frame: " << (cf.frame ? cf.frame->name() : "---")
1362            << " parent_frame: "
1363            << (cf.parent_frame ? cf.parent_frame->name() : "---");
1364    TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
1365
1366    Frame& frame = frames[cf.frame_name];
1367    Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
1368    if (frame.parent == nullptr) {
1369      frame.parent = parent;
1370      frame.name = cf.frame_name;
1371      ++parent->num_children;
1372    } else if (frame.parent != parent) {
1373      return errors::InvalidArgument("Mismatched parent frames for ",
1374                                     cf.frame->id(), ": ", parent->name, " vs ",
1375                                     frame.parent->name);
1376    }
1377
1378    if (IsEnter(node)) {
1379      Arg arg;
1380      arg.enter = node;
1381      TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
1382                                     &arg.is_loop_invariant));
1383      frame.args.push_back(arg);
1384    } else if (IsLoopCond(node)) {
1385      if (frame.loop_cond) {
1386        return errors::InvalidArgument(
1387            "Loop ", cf.frame_name,
1388            " has more than one LoopCond node: ", node->name(), " and ",
1389            frame.loop_cond->name());
1390      }
1391      frame.loop_cond = node;
1392    }
1393    frame.nodes.insert(node);
1394  }
1395
1396  // Adds frames with no children (i.e., the innermost frames) to a worklist.
1397  std::deque<Frame*> worklist;
1398  for (auto& frame : frames) {
1399    if (frame.second.num_children == 0) {
1400      worklist.push_back(&frame.second);
1401    }
1402  }
1403
1404  // Eliminate loops from innermost to outermost.
1405  while (!worklist.empty()) {
1406    Frame* frame = worklist.front();
1407    worklist.pop_front();
1408    if (frame->parent == frame) {
1409      // Skip the root frame.
1410      continue;
1411    }
1412
1413    TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
1414
1415    // If the parent has no remaining children, add it to the worklist.
1416    --frame->parent->num_children;
1417    if (frame->parent->num_children == 0) {
1418      worklist.push_back(frame->parent);
1419    }
1420  }
1421
1422  // FunctionalizeControlFlow is invoked for every function, so the loops's
1423  // bodies and conditionals that were extracted into functions will be handled
1424  // in successive invocations.
1425  TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
1426
1427  VLOG(2) << "FunctionalizeControlFlow (final): "
1428          << dump_graph::DumpGraphToFile("functionalize_final", *graph,
1429                                         library);
1430  return Status::OK();
1431}
1432
1433}  // namespace tensorflow
1434