1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" 17 18#include <algorithm> 19#include <deque> 20#include <stack> 21#include <unordered_set> 22#include <vector> 23 24#include "tensorflow/compiler/jit/graph_to_functiondef.h" 25#include "tensorflow/compiler/jit/union_find.h" 26#include "tensorflow/compiler/tf2xla/dump_graph.h" 27#include "tensorflow/compiler/tf2xla/tf2xla_util.h" 28#include "tensorflow/compiler/xla/ptr_util.h" 29#include "tensorflow/compiler/xla/status_macros.h" 30#include "tensorflow/core/common_runtime/function.h" 31#include "tensorflow/core/framework/node_def_builder.h" 32#include "tensorflow/core/graph/algorithm.h" 33#include "tensorflow/core/graph/control_flow.h" 34#include "tensorflow/core/lib/gtl/optional.h" 35 36namespace tensorflow { 37 38namespace { 39 40using xla::StatusOr; 41 42const char* const kArgOp = "_Arg"; 43const char* const kRetValOp = "_Retval"; 44 45// Information about a loop argument. 46struct Arg { 47 // Every loop argument has an Enter node. 48 Node* enter; 49 50 // Is the loop argument a loop-invariant value? Taken from the `is_constant` 51 // attribute on the Enter node. 52 bool is_loop_invariant; 53 54 // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant 55 // arguments must have all of the following nodes: 56 Node* merge = nullptr; 57 Node* switch_node = nullptr; 58 Node* next_iteration = nullptr; 59 Node* exit = nullptr; 60}; 61 62// Information about a loop frame. 63struct Frame { 64 string name; 65 66 // Pointer to the parent frame. The root frame has a pointer to itself. 67 Frame* parent = nullptr; 68 int num_children = 0; 69 70 // Arguments to this loop. 71 std::vector<Arg> args; 72 73 // The loop condition of the loop. There should be exactly one loop condition 74 // in every loop. 75 Node* loop_cond = nullptr; 76 77 // Set of nodes that belong to the loop frame. 78 std::unordered_set<Node*> nodes; 79}; 80 81// Comparison function used for sorting nodes consistently. 82// a) resource variables are last, and 83// b) sort lexicographically by name (for deterministic output). 84struct NodeCmp { 85 bool operator()(const Node* lhs, const Node* rhs) const { 86 bool lhs_is_resource = 87 lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; 88 bool rhs_is_resource = 89 rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; 90 return std::tie(lhs_is_resource, lhs->name()) < 91 std::tie(rhs_is_resource, rhs->name()); 92 } 93}; 94 95// Returns a textual representation of the names of the nodes in the input. 96template <typename T> 97string NodesToString(const T& nodes) { 98 return strings::StrCat("{", 99 str_util::Join(nodes, ",", 100 [](string* output, const Node* node) { 101 strings::StrAppend(output, 102 node->name()); 103 }), 104 "}"); 105} 106 107// Copies a subgraph from `graph` to `output` by performing a reverse DFS 108// starting at nodes in vector `stack`. 109// `node_map` is a vector indexed by source node ID to dest nodes. 110// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` 111// before the traversal clients can cut the graph. If a frame is provided (frame 112// != nullptr), then this functions will return an error if the 113// traversal leaves 'frame'; the client must add enough nodes to `node_map` to 114// cut the graph and prevent the traversal from escaping. 115// 116// `squash_src_outputs` contains a bool for each source node ID. If true, then 117// the source output on that node will be replaced by zero when copied. This is 118// used when replacing a Switch node with an _Arg node. The output we are 119// taking from the Switch node was not necessarily the first output, but _Arg 120// nodes only have one output. By adding the Switch node to `squash_src_outputs` 121// we rewrite the src_output of the corresponding edge to be 0. 122Status CopySubgraph(const Graph& graph, const Frame* frame, 123 std::vector<Node*> stack, 124 const std::vector<bool>& squash_src_outputs, 125 std::vector<Node*>* node_map, Graph* output) { 126 VLOG(3) << "Stack: " << NodesToString(stack); 127 std::vector<bool> visited(graph.num_node_ids(), false); 128 while (!stack.empty()) { 129 Node* n = stack.back(); 130 stack.pop_back(); 131 132 VLOG(5) << "Copying node " << n->name(); 133 134 if (visited[n->id()]) continue; 135 visited[n->id()] = true; 136 137 for (const Edge* e : n->in_edges()) { 138 Node* src = e->src(); 139 if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { 140 // We traversed out of the loop frame, without encountering a cut node. 141 return errors::Internal("Graph traversal of loop frame ", frame->name, 142 " escaped frame at ", src->name(), 143 " without encountering an argument node."); 144 } 145 if ((*node_map)[src->id()] == nullptr) { 146 (*node_map)[src->id()] = output->CopyNode(src); 147 stack.push_back(src); 148 } 149 Node* src_copy = (*node_map)[e->src()->id()]; 150 int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() 151 ? 0 152 : e->src_output(); 153 Node* dst_copy = (*node_map)[e->dst()->id()]; 154 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); 155 } 156 } 157 return Status::OK(); 158} 159 160StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) { 161 Status status; 162 Node* inserted_node = graph->AddNode(node_def, &status); 163 if (!status.ok()) { 164 return status; 165 } 166 return inserted_node; 167} 168 169StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) { 170 NodeDef arg_def; 171 NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); 172 builder.Attr("T", type); 173 builder.Attr("index", index); 174 TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); 175 return AddNode(arg_def, graph); 176} 177 178StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) { 179 NodeDef ret_def; 180 ret_def.set_op(kRetValOp); 181 ret_def.set_name(strings::StrCat(kRetValOp, index)); 182 AddNodeAttr("T", type, &ret_def); 183 AddNodeAttr("index", index, &ret_def); 184 return AddNode(ret_def, graph); 185} 186 187// Builds a graph for the loop condition. 188Status BuildLoopCondition(const Graph& graph, Frame* frame, 189 std::unique_ptr<Graph>* cond_output) { 190 VLOG(2) << "Building loop condition for " << frame->name; 191 *cond_output = xla::MakeUnique<Graph>(graph.op_registry()); 192 Graph* output = cond_output->get(); 193 194 // Map from nodes in the original graph to the condition graph. 195 std::vector<Node*> node_map(graph.num_node_ids(), nullptr); 196 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false); 197 198 // Build one _Arg node for each Enter node. 199 for (int i = 0; i < frame->args.size(); ++i) { 200 const Arg& arg = frame->args[i]; 201 202 TF_ASSIGN_OR_RETURN(Node * arg_node, 203 BuildArgNode(output, arg.enter->input_type(0), i)); 204 if (arg.is_loop_invariant) { 205 node_map[arg.enter->id()] = arg_node; 206 } else { 207 node_map[arg.merge->id()] = arg_node; 208 } 209 } 210 211 // Build a Retval node for the loop condition. The LoopCond nodes are always 212 // boolean because of the type constraints on the LoopCond op. 213 TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], 214 BuildRetvalNode(output, DT_BOOL, 0)); 215 216 // Performs a reverse DFS, copying nodes and edges to the output graph. 217 // The _Arg and _Retval nodes were added unconditionally above, so we are 218 // guaranteed to get the correct function signature. 219 return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, 220 &node_map, output); 221} 222 223// Builds a graph for the loop body. 224Status BuildLoopBody(const Graph& graph, Frame* frame, 225 DataTypeVector* arg_types, 226 std::unique_ptr<Graph>* body_output) { 227 VLOG(2) << "Building loop body for " << frame->name; 228 *body_output = xla::MakeUnique<Graph>(graph.op_registry()); 229 Graph* output = body_output->get(); 230 231 // Map from nodes in the original graph to the condition graph. 232 std::vector<Node*> node_map(graph.num_node_ids(), nullptr); 233 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false); 234 235 // Build one _Arg node for each Enter node. 236 std::vector<Node*> next_iterations; 237 next_iterations.reserve(frame->args.size()); 238 arg_types->reserve(frame->args.size()); 239 for (int i = 0; i < frame->args.size(); ++i) { 240 const Arg& arg = frame->args[i]; 241 242 DataType dtype = arg.enter->input_type(0); 243 arg_types->push_back(dtype); 244 245 TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); 246 247 if (dtype == DT_RESOURCE) { 248 // The convention of the XLA bridge is that resource variable arguments 249 // are only inputs to the loop body and have no corresponding output. 250 // TODO(b/37741920): change the convention so that DT_RESOURCE variables 251 // are both inputs and outputs, and then remove this case. 252 TF_RET_CHECK(arg.is_loop_invariant); 253 node_map[arg.enter->id()] = arg_node; 254 } else { 255 TF_ASSIGN_OR_RETURN(Node * retval_node, 256 BuildRetvalNode(output, dtype, i)); 257 258 if (arg.is_loop_invariant) { 259 // Argument is loop-invariant. Forward it from the Arg to the Retval. 260 node_map[arg.enter->id()] = arg_node; 261 output->AddEdge(arg_node, 0, retval_node, 0); 262 } else { 263 // Argument is loop-varying. 264 node_map[arg.switch_node->id()] = arg_node; 265 // The Switch node has two outputs, but _Arg only has one. This tells 266 // the CopySubgraph function to rewrite the output number of edges from 267 // the _Arg node to be 0 rather than copying the output number from the 268 // Switch node. 269 squash_src_outputs[arg.switch_node->id()] = true; 270 node_map[arg.next_iteration->id()] = retval_node; 271 next_iterations.push_back(arg.next_iteration); 272 } 273 } 274 } 275 276 // Performs a reverse DFS, copying nodes and edges to the output graph. 277 // The _Arg and _Retval nodes were added unconditionally above, so we are 278 // guaranteed to get the correct function signature. 279 TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), 280 squash_src_outputs, &node_map, output)); 281 282 return Status::OK(); 283} 284 285Status FunctionalizeLoop(Graph* graph, Frame* frame, 286 FunctionLibraryDefinition* library) { 287 VLOG(2) << "Frame " << frame->name << " before: " 288 << dump_graph::DumpGraphToFile("functionalize_before", *graph, 289 library); 290 291 // Split loop-varying Enter nodes with multiple successors. If the same 292 // Tensor is fed as input to multiple loop arguments, we may end up with a 293 // shared Enter node. We clone Enter nodes with multiple successors to 294 // maintain the invariant of a unique Enter node per argument of the final 295 // loop. 296 std::vector<Arg> args; 297 for (const Arg& arg : frame->args) { 298 if (arg.is_loop_invariant) { 299 args.push_back(arg); 300 } else { 301 std::vector<const Edge*> edges(arg.enter->out_edges().begin(), 302 arg.enter->out_edges().end()); 303 for (int i = 0; i < edges.size(); ++i) { 304 if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { 305 continue; 306 } 307 TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); 308 Arg new_arg; 309 new_arg.is_loop_invariant = false; 310 if (i == 0) { 311 new_arg.enter = arg.enter; 312 } else { 313 new_arg.enter = graph->CopyNode(arg.enter); 314 frame->nodes.insert(new_arg.enter); 315 for (Edge const* e : arg.enter->in_edges()) { 316 graph->AddEdge(e->src(), e->src_output(), new_arg.enter, 317 e->IsControlEdge() ? Graph::kControlSlot : 0); 318 } 319 Node* dst = edges[i]->dst(); 320 int dst_input = edges[i]->dst_input(); 321 graph->RemoveEdge(edges[i]); 322 graph->AddEdge(new_arg.enter, 0, dst, dst_input); 323 } 324 args.push_back(new_arg); 325 } 326 } 327 } 328 frame->args = std::move(args); 329 330 std::sort( 331 frame->args.begin(), frame->args.end(), 332 [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); }); 333 334 if (frame->loop_cond == nullptr) { 335 return errors::InvalidArgument("Loop ", frame->name, 336 " has no LoopCond node"); 337 } 338 339 // Find the set of Switch nodes that are successors of the LoopCond. 340 std::unordered_set<Node*> switches; 341 for (const Edge* edge : frame->loop_cond->out_edges()) { 342 if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && 343 edge->dst_input() == 1) { 344 switches.insert(edge->dst()); 345 } 346 } 347 348 // For each non-constant argument, looks for the following pattern of nodes: 349 // Enter ----> Merge --------> Switch --> Exit 350 // ^ ^ 351 // | | 352 // NextIteration LoopCond 353 // ^ ^ 354 // | | 355 // ... ... 356 for (Arg& arg : frame->args) { 357 if (!arg.is_loop_invariant) { 358 // Follow the edge from the Enter to Merge. 359 const Edge* enter_merge = nullptr; 360 for (const Edge* e : arg.enter->out_edges()) { 361 // Ignore control-edges to the sink node. These are allowed by the 362 // graph invariants, although probably they should have been stripped 363 // off earlier. 364 if (e->IsControlEdge() && e->dst()->IsSink()) { 365 continue; 366 } 367 if (enter_merge != nullptr) { 368 return errors::Internal( 369 "Enter node for loop-varying argument ", arg.enter->name(), 370 " has multiple successors: ", enter_merge->dst()->name(), " and ", 371 e->dst()->name()); 372 } 373 enter_merge = e; 374 } 375 if (enter_merge == nullptr) { 376 return errors::Internal("Enter node for loop-varying argument ", 377 arg.enter->name(), " has zero successors"); 378 } 379 arg.merge = enter_merge->dst(); 380 if (!IsMerge(arg.merge)) { 381 return errors::InvalidArgument( 382 "Successor of Enter node for loop-varying argument ", 383 arg.merge->name(), 384 " is not a Merge node; got: ", arg.merge->type_string()); 385 } 386 387 // Find the NextIteration from the merge. There should be two inputs to 388 // the Merge and the NextIteration should be the other input. 389 if (arg.merge->input_types().size() != 2) { 390 return errors::InvalidArgument( 391 "Unexpected number of inputs to Merge node for loop-varying " 392 "argument ", 393 arg.merge->name(), "; expected 2, got ", 394 arg.merge->input_types().size()); 395 } 396 TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), 397 &arg.next_iteration)); 398 if (!IsNextIteration(arg.next_iteration)) { 399 return errors::InvalidArgument( 400 "Expected NextIteration node as input to Merge node; got node ", 401 arg.next_iteration->name(), " with kind ", 402 arg.next_iteration->type_string()); 403 } 404 405 // Find the Switch successor of the Merge. There should be exactly one 406 // Switch node that is a successor of both the Merge and the LoopCond. 407 for (const Edge* edge : arg.merge->out_edges()) { 408 if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && 409 switches.find(edge->dst()) != switches.end()) { 410 if (arg.switch_node != nullptr) { 411 return errors::InvalidArgument("Duplicate Switch successors to ", 412 arg.merge->name()); 413 } 414 arg.switch_node = edge->dst(); 415 } 416 } 417 if (arg.switch_node == nullptr) { 418 return errors::InvalidArgument("Missing Switch successor to ", 419 arg.merge->name()); 420 } 421 422 // Update the device on the Identity outputs of the switch to match their 423 // target. These Identity outputs do not 424 425 // Loop over the switch node's output to: 426 // - Find the Exit successor. 427 // - Set the sharding on all Identity outputs of the switch. These 428 // identity nodes are values used by the loop body or condition. 429 // The Identity node may have the wrong device so copy the device from 430 // one of its outputs instead. 431 std::deque<const Edge*> possible_exit; 432 for (const Edge* edge : arg.switch_node->out_edges()) { 433 if (edge->src_output() == 0) { 434 possible_exit.push_back(edge); 435 } 436 if (IsIdentity(edge->dst())) { 437 TF_RETURN_IF_ERROR( 438 SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); 439 } 440 } 441 // TODO(b/67425339): Allow general graph between switch and exit. 442 while (!possible_exit.empty()) { 443 const Edge* edge = possible_exit.front(); 444 possible_exit.pop_front(); 445 if (IsExit(edge->dst())) { 446 if (arg.exit != nullptr) { 447 return errors::InvalidArgument("Duplicate Exit successors to ", 448 arg.switch_node->name()); 449 } 450 arg.exit = edge->dst(); 451 } else { 452 if (!IsIdentity(edge->dst())) { 453 return errors::Unimplemented("General graph between switch (", 454 arg.switch_node->name(), 455 ") and exit node of frame ", 456 frame->name, " not supported yet."); 457 } 458 for (const Edge* out : edge->dst()->out_edges()) { 459 possible_exit.push_back(out); 460 } 461 } 462 } 463 } 464 } 465 466 // Builds the condition and body functions. 467 std::unique_ptr<Graph> cond_graph; 468 TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); 469 DataTypeVector arg_types; 470 std::unique_ptr<Graph> body_graph; 471 TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); 472 473 VLOG(2) << "Frame " << frame->name << " condition: " 474 << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) 475 << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); 476 477 static std::atomic<int64> sequence_num(0LL); 478 int64 id = ++sequence_num; 479 NameAttrList cond_name; 480 cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); 481 NameAttrList body_name; 482 body_name.set_name(strings::StrCat("_functionalize_body_", id)); 483 FunctionDef cond_fdef; 484 TF_RETURN_IF_ERROR( 485 GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); 486 FunctionDef body_fdef; 487 TF_RETURN_IF_ERROR( 488 GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); 489 490 TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); 491 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); 492 493 // Builds a While operator. 494 NodeDef while_def; 495 NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); 496 builder.Attr("T", arg_types); 497 builder.Attr("cond", cond_name); 498 builder.Attr("body", body_name); 499 std::vector<NodeDefBuilder::NodeOut> inputs; 500 for (int i = 0; i < frame->args.size(); ++i) { 501 const Arg& arg = frame->args[i]; 502 const Edge* in_edge; 503 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); 504 if (in_edge->IsControlEdge()) { 505 builder.ControlInput(in_edge->src()->name()); 506 } else { 507 inputs.push_back(NodeDefBuilder::NodeOut( 508 in_edge->src()->name(), in_edge->src_output(), arg_types[i])); 509 } 510 } 511 builder.Input(inputs); 512 TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); 513 TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); 514 515 // Copies edges to the Enter nodes and from the Exit nodes onto the While. 516 for (int i = 0; i < frame->args.size(); ++i) { 517 const Arg& arg = frame->args[i]; 518 const Edge* in_edge; 519 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); 520 if (in_edge->IsControlEdge()) { 521 graph->AddControlEdge(in_edge->src(), while_node); 522 } else { 523 graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); 524 } 525 526 if (!arg.is_loop_invariant) { 527 // Add output edges if the output of the loop is consumed. 528 if (arg.exit != nullptr) { 529 std::vector<const Edge*> edges(arg.exit->out_edges().begin(), 530 arg.exit->out_edges().end()); 531 for (const Edge* edge : edges) { 532 Node* dst = edge->dst(); 533 int dst_input = edge->dst_input(); 534 graph->RemoveEdge(edge); 535 536 if (dst_input == Graph::kControlSlot) { 537 graph->AddControlEdge(while_node, dst); 538 } else { 539 graph->AddEdge(while_node, i, dst, dst_input); 540 } 541 } 542 } 543 } 544 } 545 546 // Remove the old nodes from the graph, and add the while node to the parent 547 // frame. 548 for (Node* node : frame->nodes) { 549 graph->RemoveNode(node); 550 } 551 frame->nodes.clear(); 552 frame->parent->nodes.insert(while_node); 553 554 VLOG(2) << "Frame " << frame->name << " after: " 555 << dump_graph::DumpGraphToFile("functionalize_after", *graph, 556 library); 557 558 return Status::OK(); 559} 560 561class FunctionalizeCond { 562 public: 563 // All nodes are assumed to be either in no branch, then branch, else branch, 564 // or both branches (such as merge nodes). 565 enum Branch { 566 kElseBranch = 0, 567 kThenBranch = 1, 568 kBoth = 2, 569 kNeither = 3, 570 kNumBranchTypes = 4 571 }; 572 573 // Returns a textual representation of the Branch b. 574 static string Branch_Name(FunctionalizeCond::Branch b); 575 576 // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf 577 // nodes. That is, attempt to transform every remaining switch and merge nodes 578 // in the graph into XlaIf nodes. 579 // Precondition: All while loops have been removed from graph. 580 static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); 581 582 private: 583 // CondArgNode represents a input to the conditional and its corresponding 584 // switch nodes. 585 struct CondArgNode { 586 explicit CondArgNode(Node* input) : input(input) {} 587 string ToString() const { 588 return strings::StrCat("input=", input->name(), 589 " switches=", NodesToString(switches)); 590 } 591 592 Node* input; 593 std::vector<Node*> switches; 594 }; 595 using CondArgNodes = std::vector<CondArgNode>; 596 597 struct ForwardFlowNode { 598 explicit ForwardFlowNode(Branch branch = Branch::kNeither) 599 : branch(branch), count(0) {} 600 string ToString() const { 601 return strings::StrCat("branch=", Branch_Name(branch), " count=", count); 602 } 603 Branch branch; 604 int count; 605 }; 606 607 // Group of switch nodes that will be part of the same XlaIf. 608 struct SwitchCluster { 609 explicit SwitchCluster(Node* predicate) : predicate(predicate) {} 610 string ToString() const { 611 return strings::StrCat(name, " predicate=", predicate->name(), 612 " switches=", NodesToString(switches)); 613 } 614 615 string name; 616 Node* predicate; 617 std::vector<Node*> switches; 618 }; 619 620 FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, 621 bool dump_graphs) 622 : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} 623 624 // Perform the actual cond functionalization. Iterate over groups of switch 625 // nodes (linked by common predicate), from innermost to outermost, and 626 // extract into XlaIf nodes. 627 Status FunctionalizeInternal(); 628 629 // Determines the branch_map (mapping from node to branch of cond) and 630 // frontier (the nodes where the cond ends). 631 StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>, 632 std::unordered_set<Node*>>> 633 DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); 634 635 // Returns XlaIf node created from subgraph of merge and switch nodes. This 636 // encapsulates the process of extracting the bodies needed for the then and 637 // else branch, creates a XlaIf node, removing the nodes of the branches from 638 // the graph and replacing the merge node with a XlaIf. 639 StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, 640 const SwitchCluster& switch_cluster, 641 const std::vector<Node*>& switches); 642 643 // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. 644 StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, 645 const SwitchCluster& switch_cluster, 646 const std::vector<Node*>& merge_nodes); 647 648 // Extracts a function body corresponding to the given input edge of the merge 649 // node. 650 Status ExtractBody(const CondArgNodes& cond_arg_nodes, 651 const std::vector<Node*>& switches, 652 const std::vector<Node*>& merge_nodes, int input_edge, 653 Graph* body); 654 655 // Adds all the input edges to `if_node` corresponding to the arguments. 656 Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate, 657 Node* if_node); 658 659 // Adds all output edges from the `if_node`. 660 Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node); 661 662 // Returns the switch clusters of graph_ in postorder. Dead switch nodes are 663 // skipped and removed from the graph. 664 StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder(); 665 666 // Update the state for destination based on the state of source and the node 667 // being updated. 668 Status Join(const ForwardFlowNode& src_state, const Node* dst, 669 ForwardFlowNode* dst_state); 670 671 // Ensure that all nodes in the branch_map are dominated by the switch 672 // nodes. Returns nodes that are not dominated by the switches but are a 673 // control dependency of a node in the cond, and remove such control 674 // dependencies. 675 StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes( 676 const std::unordered_map<Node*, ForwardFlowNode>& branch_map, 677 const std::vector<Node*>& switches); 678 679 // Validates that the frontier of nodes for the conditional 680 // section are as expected. 681 Status ValidateFrontier( 682 const std::unordered_map<Node*, ForwardFlowNode>& branch_map, 683 const std::unordered_set<Node*>& frontier); 684 685 FunctionLibraryDefinition* library_; 686 Graph* graph_; 687 bool dump_graphs_; 688}; 689 690bool IsDeadSwitch(const Node* node) { 691 for (const Edge* e : node->out_edges()) { 692 const Node* dst = e->dst(); 693 if (!dst->IsIdentity()) { 694 return false; 695 } 696 for (const Edge* ee : dst->out_edges()) { 697 if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { 698 return false; 699 } 700 } 701 } 702 return true; 703} 704 705string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { 706 const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { 707 "else", "then", "both", "neither", "count"}; 708 return branch_name[b]; 709} 710 711Status FunctionalizeCond::ValidateFrontier( 712 const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>& 713 branch_map, 714 const std::unordered_set<Node*>& frontier) { 715 std::unordered_set<const Node*> pending[kNumBranchTypes]; 716 for (Node* n : frontier) { 717 pending[branch_map.at(n).branch].insert(n); 718 } 719 TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); 720 for (const Node* n : pending[kBoth]) { 721 TF_RET_CHECK(IsMerge(n)) << n->DebugString(); 722 // Merge nodes may be in then or else branch too 723 } 724 int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) 725 ? kThenBranch 726 : kElseBranch; 727 int other = 1 - index; 728 for (const Node* n : pending[index]) { 729 if (pending[other].find(n) != pending[other].end()) { 730 return errors::Internal( 731 "Node (", n->DebugString().c_str(), 732 ") in both Else and Then branch should be in Both."); 733 } 734 } 735 // An empty frontier indicates a dead switch. Above we attempt to remove dead 736 // switch nodes, but not all are removed so don't treat it as an error yet. 737 // TODO(jpienaar): Find out why dead switch nodes remain. 738 // if (pending[kBoth].empty() && pending[kThenBranch].empty() && 739 // pending[kElseBranch].empty()) { 740 // return errors::Internal("Unexpected empty frontier for switch nodes"); 741 // } 742 return Status::OK(); 743} 744 745Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, 746 const Node* dst, ForwardFlowNode* dst_state) { 747 TF_RET_CHECK(dst_state->branch != Branch::kBoth && 748 dst_state->branch != Branch::kNumBranchTypes) 749 << "Unexpected/Invalid branch type: Merging " 750 << Branch_Name(src_state.branch) << " with " 751 << Branch_Name(dst_state->branch); 752 if (dst_state->branch == Branch::kNeither) { 753 dst_state->branch = src_state.branch; 754 } else if (src_state.branch != dst_state->branch && 755 src_state.branch != Branch::kNeither) { 756 if (IsMerge(dst)) { 757 dst_state->branch = Branch::kBoth; 758 } else { 759 return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", 760 dst_state->ToString(), " for ", 761 dst->DebugString()); 762 } 763 } 764 ++dst_state->count; 765 return Status::OK(); 766} 767 768StatusOr<std::vector<FunctionalizeCond::SwitchCluster>> 769FunctionalizeCond::DeterminePredicateSwitchOrder() { 770 struct Cluster { 771 bool operator==(const Cluster& other) const { 772 return representative == other.representative; 773 } 774 int representative = -1; 775 }; 776 777 // Perform a DFS over the graph and 778 // * Determine the reverse topological order of the nodes (there should be no 779 // cycles at this point so the post-order numbering corresponds to the 780 // reverse topological sorting); 781 // * Identify dead switches; 782 // * Initialize the cluster's representative; 783 std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids()); 784 std::vector<Node*> dead_switches; 785 std::vector<Node*> switch_order; 786 std::vector<Node*> rev_topo_sorted_nodes; 787 DFS(*graph_, nullptr, [&](Node* n) { 788 clusters[n->id()].Get().representative = n->id(); 789 if (IsSwitch(n)) { 790 if (IsDeadSwitch(n)) { 791 dead_switches.push_back(n); 792 } else { 793 rev_topo_sorted_nodes.push_back(n); 794 switch_order.push_back(n); 795 } 796 } else if (n->IsOp()) { 797 // Exclude src and sink nodes from further consideration. 798 rev_topo_sorted_nodes.push_back(n); 799 } 800 }); 801 802 std::vector<SwitchCluster> switch_clusters; 803 // Return early if there are no switches in the graph. 804 if (switch_order.empty()) { 805 return switch_clusters; 806 } 807 808 // Remove all dead switch nodes. 809 for (Node* n : dead_switches) { 810 VLOG(2) << "Removing dead switch: " << n->DebugString(); 811 graph_->RemoveNode(n); 812 } 813 814 // Identify switch nodes that are part of the same control flow context by 815 // considering the operands of operations: an operation is part of the same 816 // control context as its operands unless the operation is a switch. Control 817 // dependencies are considered part of the same control flow context if the 818 // switch depth is the same (see comment below). 819 820 // entry_cluster records the input cluster to a switch node. This is used when 821 // merging with a merge node where the dst's cluster is merged with the entry 822 // cluster of the merge node's cluster (which corresponds to a switch cluster 823 // and so has an entry cluster). 824 std::unordered_map<int, UnionFind<Cluster>*> entry_cluster; 825 826 // Returns the output cluster of a node. Where the output cluster is cluster 827 // where the output of the node is used. For non-merge nodes this is simply 828 // the cluster they are part of, while for merge nodes it is the entry cluster 829 // of the cluster they are part of (this will correspond to the entry node of 830 // a switch node that dominates the merge). 831 auto find_output_cluster = [&](Node* n) { 832 UnionFind<Cluster>* cluster = &clusters[n->id()]; 833 if (!IsMerge(n)) return cluster; 834 auto it = entry_cluster.find(clusters[n->id()].Get().representative); 835 // If the cluster is not found in the entry_cluster map then an 836 // instruction not dominated by a switch node has been merged into the 837 // cluster of the merge. This indicates a failure of the clustering. 838 CHECK(it != entry_cluster.end()) 839 << "Unable to find entry for n=" << n->id() << " (" 840 << cluster->Get().representative << ")"; 841 return it->second; 842 }; 843 844 // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. 845 std::vector<int> switch_depth(graph_->num_node_ids()); 846 for (auto it = rev_topo_sorted_nodes.rbegin(); 847 it != rev_topo_sorted_nodes.rend(); ++it) { 848 Node* n = *it; 849 850 // Compute switch depth. 851 int new_switch_depth = 0; 852 for (const Edge* e : n->in_edges()) { 853 Node* src = e->src(); 854 new_switch_depth = std::max( 855 new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); 856 } 857 switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); 858 859 // Only merge the input operands of a switch. The switch's clustering itself 860 // is determined by the interaction of the switch's outputs. 861 if (IsSwitch(n)) { 862 Node* input; 863 TF_CHECK_OK(n->input_node(0, &input)); 864 entry_cluster[n->id()] = &clusters[input->id()]; 865 UnionFind<Cluster>* cluster = find_output_cluster(input); 866 int cluster_depth = switch_depth[cluster->Get().representative]; 867 // Merge the inputs of the switch node with one another. This results in 868 // predicates and control input residing in the same cluster. 869 for (const Edge* e : n->in_edges()) { 870 Node* src = e->src(); 871 UnionFind<Cluster>* src_cluster = find_output_cluster(src); 872 int src_cluster_depth = switch_depth[src_cluster->Get().representative]; 873 if (cluster_depth != src_cluster_depth) { 874 return errors::InvalidArgument( 875 "Unable to functionalize control flow in graph: Switch ('", 876 n->name(), "') has operands ('", input->name(), "' and '", 877 src->name(), "') that have different switch depths (", 878 cluster_depth, " != ", src_cluster_depth, ")"); 879 } 880 cluster->Merge(src_cluster); 881 } 882 continue; 883 } 884 885 for (const Edge* e : n->in_edges()) { 886 Node* src = e->src(); 887 if (!src->IsOp()) continue; 888 UnionFind<Cluster>* cluster = find_output_cluster(src); 889 // Merge a node with its data operands and with its control operands if 890 // the src and dst are in the same ControlContext. The ControlContext is 891 // not explicitly available here, and instead the switch depth is used as 892 // a proxy here. Due to the invariant that control edges can only be from 893 // a containing scope to an inner scope or from the inner scope to its 894 // containing scope (for exit nodes), the switch depth will only match if 895 // the src and dst are in the same ControlContext. Control edges between 896 // ControlContexts are handled during the extraction. 897 int src_id = cluster->Get().representative; 898 int src_depth = switch_depth[src_id]; 899 if (!e->IsControlEdge() || new_switch_depth == src_depth) { 900 if (src_depth != new_switch_depth) { 901 return errors::InvalidArgument( 902 "Unable to functionalize control flow in graph: Operand ('", 903 src->name(), "') and operator ('", n->name(), 904 "') have different switch depths (", src_depth, 905 " != ", new_switch_depth, ")"); 906 } 907 cluster->Merge(&clusters[n->id()]); 908 } 909 } 910 } 911 912 if (dump_graphs_) { 913 // Mark the switch cluster each node is part of. 914 for (Node* n : graph_->nodes()) { 915 n->ClearAttr("_XlaFunctionalizeSwitchGroup"); 916 n->AddAttr("_XlaFunctionalizeSwitchGroup", 917 clusters[n->id()].Get().representative); 918 } 919 LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " 920 << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, 921 library_); 922 } 923 924 // Verify all the nodes of a cluster are at the same depth. 925 std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node; 926 for (Node* n : graph_->nodes()) { 927 int depth = switch_depth[n->id()]; 928 int cluster_rep = clusters[n->id()].Get().representative; 929 auto it = cluster_to_depth_node.find(cluster_rep); 930 if (it == cluster_to_depth_node.end()) { 931 cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); 932 } else { 933 if (it->second.first != depth) { 934 return errors::Internal( 935 "Illegal clustering created, mismatch in depths:", "\n\t", 936 n->DebugString(), "(", clusters[n->id()].Get().representative, 937 ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), 938 "(", clusters[n->id()].Get().representative, ") at depth ", 939 it->second.first); 940 } 941 } 942 } 943 944 struct Hash { 945 size_t operator()(const std::pair<Node*, Cluster>& item) const { 946 return Hash64Combine(hash<Node*>()(item.first), 947 std::hash<int>()(item.second.representative)); 948 } 949 }; 950 951 // Merge Switch nodes with common predicate. 952 std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index; 953 // The nodes in switch_order are in reverse topological order, but the 954 // clustered switches need not be (i.e., when considered as a cluster one 955 // element of a cluster may be later in the topological order than another 956 // node whose cluster is later in the topological order of clustered 957 // switches). 958 for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { 959 Node* pred; 960 TF_CHECK_OK((*it)->input_node(1, &pred)); 961 auto repr = std::make_pair(pred, clusters[(*it)->id()].Get()); 962 if (predicate_index.find(repr) == predicate_index.end()) { 963 predicate_index[repr] = switch_clusters.size(); 964 switch_clusters.emplace_back(pred); 965 // Generate a name by concatenating with the cluster representative as 966 // there could be multiple switch clusters with the same predicate. 967 switch_clusters[predicate_index[repr]].name = 968 strings::StrCat(pred->name(), "_", repr.second.representative, "_If"); 969 } 970 switch_clusters[predicate_index[repr]].switches.push_back(*it); 971 } 972 973 return switch_clusters; 974} 975 976StatusOr<std::vector<Node*>> 977FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes( 978 const std::unordered_map<Node*, ForwardFlowNode>& branch_map, 979 const std::vector<Node*>& switches) { 980 std::vector<Node*> old_control_nodes; 981 for (const auto& kv : branch_map) { 982 if (kv.second.count != kv.first->in_edges().size()) { 983 std::vector<const Edge*> delete_edges; 984 for (const Edge* in : kv.first->in_edges()) { 985 auto it = branch_map.find(in->src()); 986 if (it == branch_map.end()) { 987 if (in->IsControlEdge()) { 988 old_control_nodes.push_back(in->src()); 989 delete_edges.push_back(in); 990 } else { 991 if (IsSwitch(in->src())) { 992 if (std::find(switches.begin(), switches.end(), in->src()) == 993 switches.end()) { 994 return errors::Internal( 995 "Unexpected switch node found during flow forward: ", 996 in->src()->DebugString()); 997 } 998 continue; 999 } 1000 return errors::InvalidArgument( 1001 "Value ", kv.first->name(), "'s input, ", in->src()->name(), 1002 ", is not dominated by switch nodes ", NodesToString(switches)); 1003 } 1004 } 1005 } 1006 // Remove control edges from nodes that are not dominated by the switch 1007 // nodes. New control dependencies will be added between these nodes and 1008 // the XlaIf node inserted. 1009 for (const Edge* e : delete_edges) { 1010 graph_->RemoveEdge(e); 1011 } 1012 } 1013 } 1014 return old_control_nodes; 1015} 1016 1017StatusOr< 1018 std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>, 1019 std::unordered_set<Node*>>> 1020FunctionalizeCond::DetermineBranchMapAndFrontier( 1021 const SwitchCluster& switch_cluster) { 1022 std::unordered_map<Node*, ForwardFlowNode> branch_map; 1023 std::unordered_set<Node*> frontier; 1024 std::vector<Node*> stack = switch_cluster.switches; 1025 std::vector<bool> visited(graph_->num_node_ids(), false); 1026 while (!stack.empty()) { 1027 Node* n = stack.back(); 1028 stack.pop_back(); 1029 1030 if (visited[n->id()]) { 1031 continue; 1032 } 1033 visited[n->id()] = true; 1034 1035 // Propagate branch state along each edge of a switch node. 1036 bool sink_only = true; 1037 for (const Edge* e : n->out_edges()) { 1038 Node* out = e->dst(); 1039 if (!out->IsOp()) { 1040 continue; 1041 } 1042 sink_only = false; 1043 // Propagate branch information. 1044 ForwardFlowNode& ffn = branch_map[out]; 1045 if (IsSwitch(n)) { 1046 int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); 1047 TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); 1048 } else { 1049 TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); 1050 } 1051 if (IsMerge(out)) { 1052 if (out->in_edges().size() == ffn.count) { 1053 frontier.insert(out); 1054 } 1055 } else if (!visited[out->id()]) { 1056 stack.push_back(out); 1057 } 1058 } 1059 if (sink_only) { 1060 if (!IsIdentity(n)) { 1061 VLOG(1) << "Feeding into sink: " << n->DebugString(); 1062 } 1063 } 1064 } 1065 1066 if (dump_graphs_) { 1067 for (const auto& kv : branch_map) { 1068 // Append attribute to the graph if running with logging to make the 1069 // changes clearer in the visualization. 1070 kv.first->AddAttr("_XlaFunctionalizeBranch", 1071 Branch_Name(kv.second.branch)); 1072 } 1073 } 1074 return std::make_pair(std::move(branch_map), std::move(frontier)); 1075} 1076 1077Status FunctionalizeCond::FunctionalizeInternal() { 1078 TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order, 1079 DeterminePredicateSwitchOrder()); 1080 1081 // Iterate from innermost set of clustered switches to outermost, replacing 1082 // matching switch->merge subgraphs with single XlaIf nodes. 1083 for (auto it = predicate_switch_order.rbegin(); 1084 it != predicate_switch_order.rend(); ++it) { 1085 auto& ps = *it; 1086 VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " (" 1087 << ps.predicate->name() << ")"; 1088 1089 std::unordered_map<Node*, ForwardFlowNode> branch_map; 1090 std::unordered_set<Node*> frontier; 1091 TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), 1092 DetermineBranchMapAndFrontier(ps)); 1093 1094 if (dump_graphs_) 1095 LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " 1096 << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, 1097 library_); 1098 TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); 1099 1100 // Sort the merge and switch nodes using NodeCmp. The switch-nodes are 1101 // further grouped (post sorting) by input to the switch node as in the 1102 // functionalized form each input will be passed in only once. This grouping 1103 // should retain the sorted order. 1104 CondArgNodes cond_arg_nodes; 1105 std::unordered_map<Node*, int> input_index; 1106 std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); 1107 for (Node* switch_node : ps.switches) { 1108 Node* in; 1109 TF_RETURN_IF_ERROR(switch_node->input_node(0, &in)); 1110 if (input_index.find(in) == input_index.end()) { 1111 input_index[in] = cond_arg_nodes.size(); 1112 cond_arg_nodes.emplace_back(in); 1113 } 1114 cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node); 1115 } 1116 std::vector<Node*> merge_nodes(frontier.begin(), frontier.end()); 1117 std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); 1118 1119 TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes, 1120 EnsureDominanceAndReturnNonDominatedControlNodes( 1121 branch_map, ps.switches)); 1122 1123 TF_ASSIGN_OR_RETURN(Node * if_node, 1124 ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); 1125 for (Node* old : old_control_nodes) { 1126 graph_->AddControlEdge(old, if_node); 1127 } 1128 1129 for (auto& del_kv : branch_map) { 1130 graph_->RemoveNode(del_kv.first); 1131 } 1132 for (auto& kv : cond_arg_nodes) { 1133 for (Node* node : kv.switches) { 1134 graph_->RemoveNode(node); 1135 } 1136 } 1137 if (dump_graphs_) 1138 LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " 1139 << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, 1140 library_); 1141 } 1142 return Status::OK(); 1143} 1144 1145StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp( 1146 const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, 1147 const std::vector<Node*>& merge_nodes) { 1148 VLOG(2) << "Build if op for " << switch_cluster.name; 1149 1150 NodeDef if_def; 1151 // Create a new If node using the name of the merge node. 1152 NodeDefBuilder builder(switch_cluster.name, "XlaIf"); 1153 string branch[] = {"else_branch", "then_branch"}; 1154 for (int i = 0; i < 2; ++i) { 1155 static std::atomic<int64> sequence_num(0LL); 1156 int64 id = ++sequence_num; 1157 1158 NameAttrList body_name; 1159 body_name.set_name( 1160 strings::StrCat("_functionalize_if_", branch[i], "_", id)); 1161 auto body = xla::MakeUnique<Graph>(graph_->op_registry()); 1162 TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, 1163 merge_nodes, i, body.get())); 1164 VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); 1165 FunctionDef body_fdef; 1166 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); 1167 TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); 1168 builder.Attr(branch[i], body_name); 1169 } 1170 1171 // Build input type. 1172 std::vector<NodeDefBuilder::NodeOut> inputs; 1173 DataTypeVector in_arg_types; 1174 for (auto& kv : cond_arg_nodes) { 1175 bool inserted = false; 1176 for (const Node* arg : kv.switches) { 1177 const Edge* in_edge; 1178 TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); 1179 if (in_edge->IsControlEdge()) { 1180 builder.ControlInput(in_edge->src()->name()); 1181 } else { 1182 if (!inserted) { 1183 DataType dtype = arg->input_type(0); 1184 inputs.emplace_back(NodeDefBuilder::NodeOut( 1185 in_edge->src()->name(), in_edge->src_output(), dtype)); 1186 in_arg_types.push_back(dtype); 1187 inserted = true; 1188 } 1189 } 1190 } 1191 } 1192 builder.Attr("Tin", in_arg_types); 1193 1194 // Build output type. 1195 DataTypeVector out_type; 1196 for (const Node* merge : merge_nodes) { 1197 DataType dtype = merge->output_type(0); 1198 out_type.push_back(dtype); 1199 } 1200 builder.Attr("Tout", out_type); 1201 1202 builder.Attr("Tcond", DT_BOOL); 1203 builder.Device(switch_cluster.predicate->assigned_device_name()); 1204 // Conditional should be the first input ... 1205 builder.Input( 1206 NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0, 1207 switch_cluster.predicate->output_type(0))); 1208 // ... followed by the other inputs. 1209 builder.Input(inputs); 1210 1211 TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); 1212 TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); 1213 return if_node; 1214} 1215 1216Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, 1217 const std::vector<Node*>& switches, 1218 const std::vector<Node*>& merge_nodes, 1219 int input_edge, Graph* body) { 1220 VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " 1221 << input_edge; 1222 std::vector<bool> squash_src_outputs(graph_->num_node_ids(), false); 1223 std::vector<Node*> node_map(graph_->num_node_ids(), nullptr); 1224 int arg_count = 0; 1225 for (auto& kv : cond_arg_nodes) { 1226 Node* arg_node = nullptr; 1227 for (const auto* arg : kv.switches) { 1228 DataType dtype = arg->input_type(0); 1229 if (arg_node == nullptr) { 1230 TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); 1231 } 1232 node_map.at(arg->id()) = arg_node; 1233 squash_src_outputs.at(arg->id()) = true; 1234 } 1235 } 1236 1237 std::vector<Node*> stack; 1238 stack.reserve(merge_nodes.size()); 1239 for (int j = 0; j < merge_nodes.size(); ++j) { 1240 Node* node = merge_nodes[j]; 1241 TF_ASSIGN_OR_RETURN(node_map.at(node->id()), 1242 BuildRetvalNode(body, node->output_type(0), 1243 /*index=*/j)); 1244 const Edge* in_edge; 1245 TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); 1246 Node* in = in_edge->src(); 1247 if (node_map.at(in->id()) == nullptr) { 1248 node_map.at(in->id()) = body->CopyNode(in); 1249 } 1250 1251 if (std::find(switches.begin(), switches.end(), in) == switches.end()) { 1252 body->AddEdge(node_map.at(in->id()), in_edge->src_output(), 1253 node_map.at(node->id()), 0); 1254 } else { 1255 body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); 1256 // Don't include input nodes that are already just returned in stack. 1257 continue; 1258 } 1259 stack.push_back(in); 1260 } 1261 1262 return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, 1263 body); 1264} 1265 1266Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, 1267 Node* predicate, Node* if_node) { 1268 VLOG(3) << "AddInputEdges for " << if_node->name(); 1269 int index = 0; 1270 graph_->AddEdge(predicate, 0, if_node, index++); 1271 for (auto& kv : cond_arg_nodes) { 1272 bool inserted = false; 1273 for (const Node* arg : kv.switches) { 1274 const Edge* in_edge; 1275 TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); 1276 if (in_edge->IsControlEdge()) { 1277 graph_->AddControlEdge(in_edge->src(), if_node); 1278 } else { 1279 if (!inserted) { 1280 graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, 1281 index++); 1282 inserted = true; 1283 } 1284 } 1285 } 1286 } 1287 return Status::OK(); 1288} 1289 1290Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs, 1291 Node* if_node) { 1292 VLOG(3) << "AddOutputEdges for " << if_node->name(); 1293 for (int i = 0; i < outputs.size(); ++i) { 1294 Node* node = outputs[i]; 1295 std::vector<const Edge*> edges(node->out_edges().begin(), 1296 node->out_edges().end()); 1297 for (const Edge* edge : edges) { 1298 Node* dst = edge->dst(); 1299 int dst_input = edge->dst_input(); 1300 1301 if (edge->src_output() > 0) { 1302 return errors::Unimplemented("Output of index (", edge->src_output(), 1303 ") of merge node ", node->name()); 1304 } 1305 graph_->RemoveEdge(edge); 1306 1307 int src_output = 1308 dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; 1309 graph_->AddEdge(if_node, src_output, dst, dst_input); 1310 } 1311 } 1312 return Status::OK(); 1313} 1314 1315StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf( 1316 const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, 1317 const std::vector<Node*>& merge_nodes) { 1318 VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " 1319 << NodesToString(merge_nodes); 1320 1321 // Extract bodies and builds a If operator. 1322 TF_ASSIGN_OR_RETURN( 1323 Node * if_node, 1324 BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); 1325 TF_RETURN_IF_ERROR( 1326 AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node)); 1327 TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); 1328 1329 return if_node; 1330} 1331 1332Status FunctionalizeCond::Functionalize(Graph* graph, 1333 FunctionLibraryDefinition* library) { 1334 VLOG(1) << "FunctionalizeCond::Functionalize"; 1335 FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); 1336 return fc.FunctionalizeInternal(); 1337} 1338 1339} // namespace 1340 1341// Transformation that converts TensorFlow's graph control flow constructs into 1342// functional equivalents. 1343Status FunctionalizeControlFlow(Graph* graph, 1344 FunctionLibraryDefinition* library) { 1345 VLOG(2) << "FunctionalizeControlFlow (initial): " 1346 << dump_graph::DumpGraphToFile("functionalize_initial", *graph, 1347 library); 1348 // Note: BuildControlFlowInfo() requires that the graph's source node is 1349 // connected to all source nodes in the graph. Many graphs violate this 1350 // invariant. 1351 std::vector<ControlFlowInfo> cf_info; 1352 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); 1353 1354 // Builds Frames, indexed by name. 1355 std::unordered_map<string, Frame> frames; 1356 for (Node* node : graph->op_nodes()) { 1357 const ControlFlowInfo& cf = cf_info[node->id()]; 1358 1359 VLOG(2) << "node: " << node->name() << " (" << node->id() 1360 << ") frame_name: " << cf.frame_name 1361 << " frame: " << (cf.frame ? cf.frame->name() : "---") 1362 << " parent_frame: " 1363 << (cf.parent_frame ? cf.parent_frame->name() : "---"); 1364 TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); 1365 1366 Frame& frame = frames[cf.frame_name]; 1367 Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; 1368 if (frame.parent == nullptr) { 1369 frame.parent = parent; 1370 frame.name = cf.frame_name; 1371 ++parent->num_children; 1372 } else if (frame.parent != parent) { 1373 return errors::InvalidArgument("Mismatched parent frames for ", 1374 cf.frame->id(), ": ", parent->name, " vs ", 1375 frame.parent->name); 1376 } 1377 1378 if (IsEnter(node)) { 1379 Arg arg; 1380 arg.enter = node; 1381 TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", 1382 &arg.is_loop_invariant)); 1383 frame.args.push_back(arg); 1384 } else if (IsLoopCond(node)) { 1385 if (frame.loop_cond) { 1386 return errors::InvalidArgument( 1387 "Loop ", cf.frame_name, 1388 " has more than one LoopCond node: ", node->name(), " and ", 1389 frame.loop_cond->name()); 1390 } 1391 frame.loop_cond = node; 1392 } 1393 frame.nodes.insert(node); 1394 } 1395 1396 // Adds frames with no children (i.e., the innermost frames) to a worklist. 1397 std::deque<Frame*> worklist; 1398 for (auto& frame : frames) { 1399 if (frame.second.num_children == 0) { 1400 worklist.push_back(&frame.second); 1401 } 1402 } 1403 1404 // Eliminate loops from innermost to outermost. 1405 while (!worklist.empty()) { 1406 Frame* frame = worklist.front(); 1407 worklist.pop_front(); 1408 if (frame->parent == frame) { 1409 // Skip the root frame. 1410 continue; 1411 } 1412 1413 TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); 1414 1415 // If the parent has no remaining children, add it to the worklist. 1416 --frame->parent->num_children; 1417 if (frame->parent->num_children == 0) { 1418 worklist.push_back(frame->parent); 1419 } 1420 } 1421 1422 // FunctionalizeControlFlow is invoked for every function, so the loops's 1423 // bodies and conditionals that were extracted into functions will be handled 1424 // in successive invocations. 1425 TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); 1426 1427 VLOG(2) << "FunctionalizeControlFlow (final): " 1428 << dump_graph::DumpGraphToFile("functionalize_final", *graph, 1429 library); 1430 return Status::OK(); 1431} 1432 1433} // namespace tensorflow 1434