1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" 17 18#include "tensorflow/cc/framework/ops.h" 19#include "tensorflow/cc/ops/control_flow_ops_internal.h" 20#include "tensorflow/cc/ops/function_ops.h" 21#include "tensorflow/cc/ops/resource_variable_ops.h" 22#include "tensorflow/cc/ops/standard_ops.h" 23#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" 24#include "tensorflow/compiler/tf2xla/test_util.h" 25#include "tensorflow/compiler/xla/status_macros.h" 26#include "tensorflow/core/common_runtime/function.h" 27#include "tensorflow/core/framework/function.h" 28#include "tensorflow/core/framework/node_def_util.h" 29#include "tensorflow/core/framework/op.h" 30#include "tensorflow/core/graph/graph_constructor.h" 31#include "tensorflow/core/graph/graph_def_builder.h" 32#include "tensorflow/core/lib/core/status_test_util.h" 33#include "tensorflow/core/platform/test.h" 34#include "tensorflow/core/util/equal_graph_def.h" 35 36namespace tensorflow { 37namespace { 38 39// Returns the names of the "then" and "else" functions for the XlaIf node in a 40// graph. 41Status FindIfThenAndElse(const GraphDef& graph, string* op_name, 42 NameAttrList* then_fn, NameAttrList* else_fn) { 43 for (const NodeDef& node : graph.node()) { 44 if (node.op() == "XlaIf") { 45 *op_name = node.name(); 46 const NameAttrList* result; 47 TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); 48 *then_fn = *result; 49 TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); 50 *else_fn = *result; 51 return Status::OK(); 52 } 53 } 54 return errors::NotFound("No XlaIf node found in graph"); 55} 56 57// Graph: 58// x = array_ops.placeholder(dtypes.int32) 59// y = array_ops.placeholder(dtypes.int32) 60// z = control_flow_ops.cond( 61// math_ops.less(y, x), lambda: math_ops.multiply(y, 17), 62// lambda: math_ops.add(x, 23)) 63TEST(FunctionalizeControlFlow, Conditional) { 64 Graph graph(OpRegistry::Global()); 65 { 66 Scope scope = Scope::NewRootScope().ExitOnError(); 67 68 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); 69 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); 70 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); 71 auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less); 72 73 auto identity_t = 74 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); 75 auto seventeen = ops::Const<int32>( 76 scope.WithOpName("cond").WithControlDependencies(identity_t), 17); 77 auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); 78 auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, 79 seventeen); 80 81 auto identity_f = 82 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); 83 auto twenty_three = ops::Const<int32>( 84 scope.WithOpName("cond").WithControlDependencies(identity_f), 23); 85 auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); 86 auto add = ops::Add(scope.WithOpName("cond/false/add"), 87 switch_3.output_false, twenty_three); 88 89 auto merge = ops::Merge(scope.WithOpName("cond/Merge"), 90 std::initializer_list<Input>{add, mul}); 91 92 TF_EXPECT_OK(scope.ToGraph(&graph)); 93 } 94 95 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 96 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); 97 98 GraphDef graph_def; 99 graph.ToGraphDef(&graph_def); 100 string op_name; 101 NameAttrList then_fn; 102 NameAttrList else_fn; 103 TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); 104 InstantiationResultForTest else_result; 105 TF_EXPECT_OK( 106 InstantiateFunctionForTest(else_fn.name(), library, &else_result)); 107 108 // Outer graph 109 { 110 Scope scope = Scope::NewRootScope().ExitOnError(); 111 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); 112 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); 113 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); 114 auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, 115 std::initializer_list<Input>{less, y, x}, then_fn, 116 else_fn, {DT_INT32}); 117 GraphDef expected; 118 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 119 TF_EXPECT_GRAPH_EQ(expected, graph_def); 120 } 121 122 // then body. 123 { 124 Scope scope = Scope::NewRootScope().ExitOnError(); 125 auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); 126 auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 127 auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); 128 auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); 129 auto cond = ops::Const( 130 scope.WithOpName("cond").WithControlDependencies(identity), 17); 131 auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); 132 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); 133 134 GraphDef expected; 135 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 136 137 InstantiationResultForTest result; 138 TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); 139 140 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); 141 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); 142 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 143 } 144 145 // else body. 146 { 147 Scope scope = Scope::NewRootScope().ExitOnError(); 148 auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); 149 auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 150 auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); 151 auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); 152 auto cond_1 = ops::Const( 153 scope.WithOpName("cond_1").WithControlDependencies(identity), 23); 154 auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); 155 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); 156 157 GraphDef expected; 158 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 159 160 InstantiationResultForTest result; 161 TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); 162 163 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); 164 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); 165 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 166 } 167} 168 169// Returns the names of the "cond" and "body" functions for the While node 170// in a graph. 171Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, 172 NameAttrList* body) { 173 for (const NodeDef& node : graph.node()) { 174 if (node.op() == "XlaWhile") { 175 const NameAttrList* result; 176 TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); 177 *cond = *result; 178 TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); 179 *body = *result; 180 return Status::OK(); 181 } 182 } 183 return errors::NotFound("No XlaWhile node found in graph"); 184} 185 186// Graph: 187// x = array_ops.placeholder(dtypes.int32) 188// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) 189TEST(FunctionalizeControlFlow, OneLoopVar) { 190 Graph graph(OpRegistry::Global()); 191 { 192 Scope scope = Scope::NewRootScope().ExitOnError(); 193 194 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); 195 196 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); 197 auto enter = 198 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); 199 // Add an unused Enter node. These should be ignored. 200 auto enter2 = 201 ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); 202 auto merge = ops::Merge(scope.WithOpName("while/Merge"), 203 std::initializer_list<Input>{enter, dummy}); 204 auto ten = ops::Const<int32>( 205 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 206 10); 207 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); 208 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); 209 auto switch_ = 210 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); 211 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), 212 switch_.output_false); 213 auto identity = 214 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); 215 auto one = ops::Const<int32>( 216 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); 217 auto add = ops::Add(scope.WithOpName("while/add"), identity, one); 218 auto next_iteration = 219 ops::NextIteration(scope.WithOpName("while/NextIteration"), add); 220 221 auto sink = ops::Identity(scope.WithOpName("sink"), exit); 222 223 // Remove the dummy node and add the loop backedge. 224 scope.graph()->RemoveNode(dummy.node()); 225 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); 226 227 TF_EXPECT_OK(scope.ToGraph(&graph)); 228 } 229 230 // Regression test: control edges from an Enter node to the graph sink should 231 // be ignored. 232 for (Node* n : graph.nodes()) { 233 if (n->name() == "while/Enter") { 234 graph.AddControlEdge(n, graph.sink_node()); 235 } 236 } 237 238 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 239 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); 240 241 GraphDef graph_def; 242 graph.ToGraphDef(&graph_def); 243 244 NameAttrList cond_fn, body_fn; 245 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); 246 247 // Outer graph 248 { 249 Scope scope = Scope::NewRootScope().ExitOnError(); 250 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); 251 auto while_op = 252 ops::XlaWhile(scope.WithOpName("while/LoopCond"), 253 std::initializer_list<Input>{source}, cond_fn, body_fn); 254 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); 255 GraphDef expected; 256 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 257 TF_EXPECT_GRAPH_EQ(expected, graph_def); 258 } 259 260 // Condition graph 261 { 262 Scope scope = Scope::NewRootScope().ExitOnError(); 263 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 264 auto ten = ops::Const<int32>( 265 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); 266 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); 267 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); 268 269 GraphDef expected; 270 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 271 272 InstantiationResultForTest result; 273 TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); 274 275 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); 276 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); 277 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 278 } 279 280 // Body graph. 281 { 282 Scope scope = Scope::NewRootScope().ExitOnError(); 283 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 284 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); 285 auto one = ops::Const<int32>( 286 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); 287 auto add = ops::Add(scope.WithOpName("while/add"), identity, one); 288 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); 289 290 GraphDef expected; 291 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 292 293 InstantiationResultForTest result; 294 TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); 295 296 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); 297 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); 298 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 299 } 300} 301 302// Tests functionalizing OneLoopVar where the loop value is not used post the 303// loop. 304// Graph: 305// x = array_ops.placeholder(dtypes.int32) 306// control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) 307TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { 308 Graph graph(OpRegistry::Global()); 309 { 310 Scope scope = Scope::NewRootScope().ExitOnError(); 311 312 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); 313 314 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); 315 auto enter = 316 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); 317 auto merge = ops::Merge(scope.WithOpName("while/Merge"), 318 std::initializer_list<Input>{enter, dummy}); 319 auto ten = ops::Const<int32>( 320 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 321 10); 322 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); 323 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); 324 auto switch_ = 325 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); 326 auto identity = 327 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); 328 auto one = ops::Const<int32>( 329 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); 330 auto add = ops::Add(scope.WithOpName("while/add"), identity, one); 331 auto next_iteration = 332 ops::NextIteration(scope.WithOpName("while/NextIteration"), add); 333 334 // Remove the dummy node and add the loop backedge. 335 scope.graph()->RemoveNode(dummy.node()); 336 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); 337 338 TF_EXPECT_OK(scope.ToGraph(&graph)); 339 } 340 341 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 342 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); 343 344 GraphDef graph_def; 345 graph.ToGraphDef(&graph_def); 346 347 NameAttrList cond_fn, body_fn; 348 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); 349 350 // Outer graph 351 { 352 Scope scope = Scope::NewRootScope().ExitOnError(); 353 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); 354 auto while_op = 355 ops::XlaWhile(scope.WithOpName("while/LoopCond"), 356 std::initializer_list<Input>{source}, cond_fn, body_fn); 357 GraphDef expected; 358 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 359 TF_EXPECT_GRAPH_EQ(expected, graph_def); 360 } 361 362 // Condition graph 363 { 364 Scope scope = Scope::NewRootScope().ExitOnError(); 365 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 366 auto ten = ops::Const<int32>( 367 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); 368 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); 369 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); 370 371 GraphDef expected; 372 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 373 374 InstantiationResultForTest result; 375 TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); 376 377 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); 378 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); 379 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 380 } 381 382 // Body graph. 383 { 384 Scope scope = Scope::NewRootScope().ExitOnError(); 385 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 386 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); 387 auto one = ops::Const<int32>( 388 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); 389 auto add = ops::Add(scope.WithOpName("while/add"), identity, one); 390 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); 391 392 GraphDef expected; 393 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 394 395 InstantiationResultForTest result; 396 TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); 397 398 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); 399 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); 400 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 401 } 402} 403 404// Graph: 405// x = array_ops.placeholder(dtypes.int32) 406// y = array_ops.placeholder(dtypes.int32) 407// cond = lambda (i, j): i + 3 < 10 408// body = lambda (i, j): (i < 10, j * 2) 409// z = control_flow_ops.while_loop(cond, body, [x, y]) 410TEST(FunctionalizeControlFlow, TwoLoopVars) { 411 Graph graph(OpRegistry::Global()); 412 { 413 Scope scope = Scope::NewRootScope().ExitOnError(); 414 415 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); 416 417 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); 418 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); 419 auto enter_x = 420 ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop"); 421 auto enter_y = 422 ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop"); 423 auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"), 424 std::initializer_list<Input>{enter_x, dummy}); 425 auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"), 426 std::initializer_list<Input>{enter_y, dummy}); 427 428 // Loop condition 429 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three") 430 .WithControlDependencies(merge_x.output), 431 3); 432 auto cond_add = 433 ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); 434 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten") 435 .WithControlDependencies(merge_x.output), 436 10); 437 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); 438 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); 439 440 auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"), 441 merge_x.output, loop_cond); 442 auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"), 443 merge_y.output, loop_cond); 444 445 auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"), 446 switch_x.output_false); 447 auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"), 448 switch_y.output_false); 449 450 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), 451 switch_x.output_true); 452 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), 453 switch_y.output_true); 454 455 auto one = ops::Const<int32>( 456 scope.WithOpName("while/add/one").WithControlDependencies(identity_x), 457 1); 458 auto two = ops::Const<int32>( 459 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), 460 2); 461 462 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); 463 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); 464 auto next_iteration_x = 465 ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add); 466 auto next_iteration_y = 467 ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul); 468 469 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x); 470 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y); 471 472 // Remove the dummy node and add the loop backedges. 473 scope.graph()->RemoveNode(dummy.node()); 474 scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(), 475 1); 476 scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(), 477 1); 478 479 TF_EXPECT_OK(scope.ToGraph(&graph)); 480 } 481 482 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 483 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); 484 485 GraphDef graph_def; 486 graph.ToGraphDef(&graph_def); 487 488 NameAttrList cond_fn, body_fn; 489 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); 490 491 // Outer graph. 492 { 493 Scope scope = Scope::NewRootScope().ExitOnError(); 494 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); 495 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); 496 auto while_op = 497 ops::XlaWhile(scope.WithOpName("while/LoopCond"), 498 std::initializer_list<Input>{x, y}, cond_fn, body_fn); 499 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); 500 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); 501 GraphDef expected; 502 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 503 TF_EXPECT_GRAPH_EQ(expected, graph_def); 504 } 505 506 // Condition graph. 507 { 508 Scope scope = Scope::NewRootScope().ExitOnError(); 509 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 510 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 511 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three") 512 .WithControlDependencies(arg0.output), 513 3); 514 auto cond_add = 515 ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); 516 auto ten = ops::Const<int32>( 517 scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), 518 10); 519 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); 520 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); 521 522 GraphDef expected; 523 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 524 525 InstantiationResultForTest result; 526 TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); 527 528 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); 529 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); 530 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 531 } 532 533 // Body graph. 534 { 535 Scope scope = Scope::NewRootScope().ExitOnError(); 536 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 537 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 538 539 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); 540 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); 541 542 auto one = ops::Const<int32>( 543 scope.WithOpName("while/add/one").WithControlDependencies(identity_x), 544 1); 545 auto two = ops::Const<int32>( 546 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), 547 2); 548 549 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); 550 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); 551 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); 552 auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); 553 554 GraphDef expected; 555 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 556 557 InstantiationResultForTest result; 558 TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); 559 560 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); 561 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); 562 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 563 } 564} 565 566// Example with nesting, loop-invariant arguments, and resource variables. 567// 568// accum = resource_variable_ops.ResourceVariable(1) 569// x = array_ops.placeholder(2, dtype=dtypes.int32) 570// y = 3 + x 571// 572// def inner_body(j, k): 573// add = state_ops.assign_add(accum, k * j + x) 574// with ops.control_dependencies([add]): 575// return [j + 1, k] 576// 577// def body(i): 578// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, 579// [1, y], name="inner") 580// with ops.control_dependencies(m): 581// return [i + 1] 582// 583// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") 584TEST(FunctionalizeControlFlow, Complex) { 585 Graph graph(OpRegistry::Global()); 586 { 587 Scope scope = Scope::NewRootScope().ExitOnError(); 588 589 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); 590 591 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); 592 auto three = ops::Const<int32>(scope.WithOpName("three"), 3); 593 auto y = ops::Add(scope.WithOpName("y"), x, three); 594 595 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, 596 TensorShape({})); 597 598 // Outer loop 599 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0); 600 auto enter_i = 601 ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); 602 auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), 603 std::initializer_list<Input>{enter_i, dummy}); 604 auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y") 605 .WithControlDependencies(merge_i.output), 606 10); 607 auto less_i = 608 ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); 609 auto outer_loop_cond = 610 ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i); 611 auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"), 612 merge_i.output, outer_loop_cond); 613 auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"), 614 switch_i.output_false); 615 auto identity_i = 616 ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true); 617 618 auto enter_x_outer = 619 ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer", 620 ops::internal::Enter::Attrs().IsConstant(true)); 621 auto enter_k_outer = 622 ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer", 623 ops::internal::Enter::Attrs().IsConstant(true)); 624 auto enter_var_outer = 625 ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer", 626 ops::internal::Enter::Attrs().IsConstant(true)); 627 628 // Inner loop 629 auto one_j = ops::Const<int32>( 630 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); 631 auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), 632 one_j, "inner"); 633 auto enter_k = 634 ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k") 635 .WithControlDependencies(identity_i), 636 enter_k_outer, "inner"); 637 auto enter_x = ops::internal::Enter( 638 scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner", 639 ops::internal::Enter::Attrs().IsConstant(true)); 640 auto enter_var = ops::internal::Enter( 641 scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner", 642 ops::internal::Enter::Attrs().IsConstant(true)); 643 644 auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"), 645 std::initializer_list<Input>{enter_j, dummy}); 646 auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), 647 std::initializer_list<Input>{enter_k, dummy}); 648 649 auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five") 650 .WithControlDependencies(merge_j.output), 651 5); 652 auto less_j = 653 ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); 654 auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j); 655 656 auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"), 657 merge_j.output, loop_cond); 658 auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"), 659 merge_k.output, loop_cond); 660 auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"), 661 switch_j.output_false); 662 auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"), 663 switch_k.output_false); 664 auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"), 665 switch_j.output_true); 666 auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"), 667 switch_k.output_true); 668 669 // Variable update 670 auto mul_jk = 671 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); 672 auto add_jkx = 673 ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x); 674 auto assign = ops::AssignAddVariableOp( 675 scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); 676 677 auto one = 678 ops::Const<int32>(scope.WithOpName("outer/inner/One") 679 .WithControlDependencies( 680 gtl::ArraySlice<Operation>{assign.operation}), 681 1); 682 auto add_j = 683 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); 684 685 auto next_iteration_j = ops::NextIteration( 686 scope.WithOpName("outer/inner/NextIteration_j"), add_j); 687 auto next_iteration_k = ops::NextIteration( 688 scope.WithOpName("outer/inner/NextIteration_k"), identity_k); 689 690 // Body and backedge for outer loop. 691 auto one_outer = ops::Const<int32>( 692 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); 693 auto add_i = 694 ops::Add(scope.WithOpName("outer/add") 695 .WithControlDependencies(gtl::ArraySlice<Operation>{ 696 exit_j.output.op(), exit_k.output.op()}), 697 identity_i, one_outer); 698 auto next_iteration_i = 699 ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i); 700 701 auto sink = ops::Identity(scope.WithOpName("sink"), exit_i); 702 703 // Remove the dummy node and add the loop backedge. 704 scope.graph()->RemoveNode(dummy.node()); 705 scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(), 706 1); 707 scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(), 708 1); 709 scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(), 710 1); 711 712 TF_EXPECT_OK(scope.ToGraph(&graph)); 713 } 714 715 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 716 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); 717 718 GraphDef graph_def; 719 graph.ToGraphDef(&graph_def); 720 721 NameAttrList outer_cond_fn, outer_body_fn; 722 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); 723 724 // Outer graph. 725 { 726 Scope scope = Scope::NewRootScope().ExitOnError(); 727 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); 728 auto three = ops::Const<int32>(scope.WithOpName("three"), 3); 729 auto y = ops::Add(scope.WithOpName("y"), x, three); 730 731 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, 732 TensorShape({})); 733 734 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0); 735 736 auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), 737 std::initializer_list<Input>{zero, y, x, var}, 738 outer_cond_fn, outer_body_fn); 739 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); 740 GraphDef expected; 741 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 742 TF_EXPECT_GRAPH_EQ(expected, graph_def); 743 } 744 745 // Outer condition graph. 746 { 747 Scope scope = Scope::NewRootScope().ExitOnError(); 748 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 749 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 750 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); 751 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); 752 753 auto ten = ops::Const<int32>( 754 scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), 755 10); 756 auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); 757 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); 758 759 GraphDef expected; 760 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 761 762 InstantiationResultForTest result; 763 TF_EXPECT_OK( 764 InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); 765 766 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), 767 result.arg_types); 768 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); 769 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 770 } 771 772 // Outer body graph. 773 NameAttrList inner_cond_fn, inner_body_fn; 774 { 775 InstantiationResultForTest result; 776 TF_EXPECT_OK( 777 InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); 778 779 // Find the inner condition and body names. 780 TF_EXPECT_OK( 781 FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); 782 783 Scope scope = Scope::NewRootScope().ExitOnError(); 784 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 785 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 786 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); 787 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); 788 789 auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); 790 auto one_j = ops::Const<int32>( 791 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); 792 auto while_op = 793 ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), 794 std::initializer_list<Input>{one_j, arg1, arg2, arg3}, 795 inner_cond_fn, inner_body_fn); 796 797 auto one_outer = ops::Const<int32>( 798 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); 799 auto add_i = 800 ops::Add(scope.WithOpName("outer/add") 801 .WithControlDependencies(gtl::ArraySlice<Operation>{ 802 while_op[0].op(), while_op[1].op()}), 803 identity_i, one_outer); 804 805 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); 806 auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); 807 auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); 808 809 GraphDef expected; 810 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 811 812 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), 813 result.arg_types); 814 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); 815 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 816 } 817 818 // Inner condition graph. 819 { 820 Scope scope = Scope::NewRootScope().ExitOnError(); 821 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 822 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 823 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); 824 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); 825 826 auto five = ops::Const<int32>( 827 scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); 828 auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); 829 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); 830 831 GraphDef expected; 832 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 833 834 InstantiationResultForTest result; 835 TF_EXPECT_OK( 836 InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); 837 838 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), 839 result.arg_types); 840 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); 841 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 842 } 843 844 // Inner body graph. 845 { 846 Scope scope = Scope::NewRootScope().ExitOnError(); 847 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); 848 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); 849 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); 850 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); 851 852 auto identity_j = 853 ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); 854 auto identity_k = 855 ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); 856 857 auto mul_jk = 858 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); 859 auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); 860 auto assign = ops::AssignAddVariableOp( 861 scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); 862 863 auto one = 864 ops::Const<int32>(scope.WithOpName("outer/inner/One") 865 .WithControlDependencies( 866 gtl::ArraySlice<Operation>{assign.operation}), 867 1); 868 auto add_j = 869 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); 870 871 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); 872 auto retval1 = 873 ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); 874 auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); 875 876 GraphDef expected; 877 TF_EXPECT_OK(scope.ToGraphDef(&expected)); 878 879 InstantiationResultForTest result; 880 TF_EXPECT_OK( 881 InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); 882 883 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), 884 result.arg_types); 885 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); 886 TF_EXPECT_GRAPH_EQ(expected, result.gdef); 887 } 888} 889 890} // namespace 891} // namespace tensorflow 892