1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18#include "tensorflow/cc/framework/ops.h"
19#include "tensorflow/cc/ops/control_flow_ops_internal.h"
20#include "tensorflow/cc/ops/function_ops.h"
21#include "tensorflow/cc/ops/resource_variable_ops.h"
22#include "tensorflow/cc/ops/standard_ops.h"
23#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h"
24#include "tensorflow/compiler/tf2xla/test_util.h"
25#include "tensorflow/compiler/xla/status_macros.h"
26#include "tensorflow/core/common_runtime/function.h"
27#include "tensorflow/core/framework/function.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/op.h"
30#include "tensorflow/core/graph/graph_constructor.h"
31#include "tensorflow/core/graph/graph_def_builder.h"
32#include "tensorflow/core/lib/core/status_test_util.h"
33#include "tensorflow/core/platform/test.h"
34#include "tensorflow/core/util/equal_graph_def.h"
35
36namespace tensorflow {
37namespace {
38
39// Returns the names of the "then" and "else" functions for the XlaIf node in a
40// graph.
41Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
42                         NameAttrList* then_fn, NameAttrList* else_fn) {
43  for (const NodeDef& node : graph.node()) {
44    if (node.op() == "XlaIf") {
45      *op_name = node.name();
46      const NameAttrList* result;
47      TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
48      *then_fn = *result;
49      TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
50      *else_fn = *result;
51      return Status::OK();
52    }
53  }
54  return errors::NotFound("No XlaIf node found in graph");
55}
56
57// Graph:
58// x = array_ops.placeholder(dtypes.int32)
59// y = array_ops.placeholder(dtypes.int32)
60// z = control_flow_ops.cond(
61//     math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
62//     lambda: math_ops.add(x, 23))
63TEST(FunctionalizeControlFlow, Conditional) {
64  Graph graph(OpRegistry::Global());
65  {
66    Scope scope = Scope::NewRootScope().ExitOnError();
67
68    auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
69    auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
70    auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
71    auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
72
73    auto identity_t =
74        ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
75    auto seventeen = ops::Const<int32>(
76        scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
77    auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
78    auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
79                             seventeen);
80
81    auto identity_f =
82        ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
83    auto twenty_three = ops::Const<int32>(
84        scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
85    auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
86    auto add = ops::Add(scope.WithOpName("cond/false/add"),
87                        switch_3.output_false, twenty_three);
88
89    auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
90                            std::initializer_list<Input>{add, mul});
91
92    TF_EXPECT_OK(scope.ToGraph(&graph));
93  }
94
95  FunctionLibraryDefinition library(OpRegistry::Global(), {});
96  TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
97
98  GraphDef graph_def;
99  graph.ToGraphDef(&graph_def);
100  string op_name;
101  NameAttrList then_fn;
102  NameAttrList else_fn;
103  TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
104  InstantiationResultForTest else_result;
105  TF_EXPECT_OK(
106      InstantiateFunctionForTest(else_fn.name(), library, &else_result));
107
108  // Outer graph
109  {
110    Scope scope = Scope::NewRootScope().ExitOnError();
111    auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
112    auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
113    auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
114    auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
115                            std::initializer_list<Input>{less, y, x}, then_fn,
116                            else_fn, {DT_INT32});
117    GraphDef expected;
118    TF_EXPECT_OK(scope.ToGraphDef(&expected));
119    TF_EXPECT_GRAPH_EQ(expected, graph_def);
120  }
121
122  // then body.
123  {
124    Scope scope = Scope::NewRootScope().ExitOnError();
125    auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
126    auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
127    auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
128    auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
129    auto cond = ops::Const(
130        scope.WithOpName("cond").WithControlDependencies(identity), 17);
131    auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
132    auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
133
134    GraphDef expected;
135    TF_EXPECT_OK(scope.ToGraphDef(&expected));
136
137    InstantiationResultForTest result;
138    TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
139
140    EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
141    EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
142    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
143  }
144
145  // else body.
146  {
147    Scope scope = Scope::NewRootScope().ExitOnError();
148    auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
149    auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
150    auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
151    auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
152    auto cond_1 = ops::Const(
153        scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
154    auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
155    auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
156
157    GraphDef expected;
158    TF_EXPECT_OK(scope.ToGraphDef(&expected));
159
160    InstantiationResultForTest result;
161    TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
162
163    EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
164    EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
165    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
166  }
167}
168
169// Returns the names of the "cond" and "body" functions for the While node
170// in a graph.
171Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
172                            NameAttrList* body) {
173  for (const NodeDef& node : graph.node()) {
174    if (node.op() == "XlaWhile") {
175      const NameAttrList* result;
176      TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
177      *cond = *result;
178      TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
179      *body = *result;
180      return Status::OK();
181    }
182  }
183  return errors::NotFound("No XlaWhile node found in graph");
184}
185
186// Graph:
187// x = array_ops.placeholder(dtypes.int32)
188// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
189TEST(FunctionalizeControlFlow, OneLoopVar) {
190  Graph graph(OpRegistry::Global());
191  {
192    Scope scope = Scope::NewRootScope().ExitOnError();
193
194    auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
195
196    auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
197    auto enter =
198        ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
199    // Add an unused Enter node. These should be ignored.
200    auto enter2 =
201        ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
202    auto merge = ops::Merge(scope.WithOpName("while/Merge"),
203                            std::initializer_list<Input>{enter, dummy});
204    auto ten = ops::Const<int32>(
205        scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
206        10);
207    auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
208    auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
209    auto switch_ =
210        ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
211    auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
212                                    switch_.output_false);
213    auto identity =
214        ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
215    auto one = ops::Const<int32>(
216        scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
217    auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
218    auto next_iteration =
219        ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
220
221    auto sink = ops::Identity(scope.WithOpName("sink"), exit);
222
223    // Remove the dummy node and add the loop backedge.
224    scope.graph()->RemoveNode(dummy.node());
225    scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
226
227    TF_EXPECT_OK(scope.ToGraph(&graph));
228  }
229
230  // Regression test: control edges from an Enter node to the graph sink should
231  // be ignored.
232  for (Node* n : graph.nodes()) {
233    if (n->name() == "while/Enter") {
234      graph.AddControlEdge(n, graph.sink_node());
235    }
236  }
237
238  FunctionLibraryDefinition library(OpRegistry::Global(), {});
239  TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
240
241  GraphDef graph_def;
242  graph.ToGraphDef(&graph_def);
243
244  NameAttrList cond_fn, body_fn;
245  TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
246
247  // Outer graph
248  {
249    Scope scope = Scope::NewRootScope().ExitOnError();
250    auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
251    auto while_op =
252        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
253                      std::initializer_list<Input>{source}, cond_fn, body_fn);
254    auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
255    GraphDef expected;
256    TF_EXPECT_OK(scope.ToGraphDef(&expected));
257    TF_EXPECT_GRAPH_EQ(expected, graph_def);
258  }
259
260  // Condition graph
261  {
262    Scope scope = Scope::NewRootScope().ExitOnError();
263    auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
264    auto ten = ops::Const<int32>(
265        scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
266    auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
267    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
268
269    GraphDef expected;
270    TF_EXPECT_OK(scope.ToGraphDef(&expected));
271
272    InstantiationResultForTest result;
273    TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
274
275    EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
276    EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
277    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
278  }
279
280  // Body graph.
281  {
282    Scope scope = Scope::NewRootScope().ExitOnError();
283    auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
284    auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
285    auto one = ops::Const<int32>(
286        scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
287    auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
288    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
289
290    GraphDef expected;
291    TF_EXPECT_OK(scope.ToGraphDef(&expected));
292
293    InstantiationResultForTest result;
294    TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
295
296    EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
297    EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
298    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
299  }
300}
301
302// Tests functionalizing OneLoopVar where the loop value is not used post the
303// loop.
304// Graph:
305// x = array_ops.placeholder(dtypes.int32)
306// control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
307TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
308  Graph graph(OpRegistry::Global());
309  {
310    Scope scope = Scope::NewRootScope().ExitOnError();
311
312    auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
313
314    auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
315    auto enter =
316        ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
317    auto merge = ops::Merge(scope.WithOpName("while/Merge"),
318                            std::initializer_list<Input>{enter, dummy});
319    auto ten = ops::Const<int32>(
320        scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
321        10);
322    auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
323    auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
324    auto switch_ =
325        ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
326    auto identity =
327        ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
328    auto one = ops::Const<int32>(
329        scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
330    auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
331    auto next_iteration =
332        ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
333
334    // Remove the dummy node and add the loop backedge.
335    scope.graph()->RemoveNode(dummy.node());
336    scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
337
338    TF_EXPECT_OK(scope.ToGraph(&graph));
339  }
340
341  FunctionLibraryDefinition library(OpRegistry::Global(), {});
342  TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
343
344  GraphDef graph_def;
345  graph.ToGraphDef(&graph_def);
346
347  NameAttrList cond_fn, body_fn;
348  TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
349
350  // Outer graph
351  {
352    Scope scope = Scope::NewRootScope().ExitOnError();
353    auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
354    auto while_op =
355        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
356                      std::initializer_list<Input>{source}, cond_fn, body_fn);
357    GraphDef expected;
358    TF_EXPECT_OK(scope.ToGraphDef(&expected));
359    TF_EXPECT_GRAPH_EQ(expected, graph_def);
360  }
361
362  // Condition graph
363  {
364    Scope scope = Scope::NewRootScope().ExitOnError();
365    auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
366    auto ten = ops::Const<int32>(
367        scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
368    auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
369    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
370
371    GraphDef expected;
372    TF_EXPECT_OK(scope.ToGraphDef(&expected));
373
374    InstantiationResultForTest result;
375    TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
376
377    EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
378    EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
379    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
380  }
381
382  // Body graph.
383  {
384    Scope scope = Scope::NewRootScope().ExitOnError();
385    auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
386    auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
387    auto one = ops::Const<int32>(
388        scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
389    auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
390    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
391
392    GraphDef expected;
393    TF_EXPECT_OK(scope.ToGraphDef(&expected));
394
395    InstantiationResultForTest result;
396    TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
397
398    EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
399    EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
400    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
401  }
402}
403
404// Graph:
405// x = array_ops.placeholder(dtypes.int32)
406// y = array_ops.placeholder(dtypes.int32)
407// cond = lambda (i, j): i + 3 < 10
408// body = lambda (i, j): (i < 10, j * 2)
409// z = control_flow_ops.while_loop(cond, body, [x, y])
410TEST(FunctionalizeControlFlow, TwoLoopVars) {
411  Graph graph(OpRegistry::Global());
412  {
413    Scope scope = Scope::NewRootScope().ExitOnError();
414
415    auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
416
417    auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
418    auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
419    auto enter_x =
420        ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
421    auto enter_y =
422        ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
423    auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
424                              std::initializer_list<Input>{enter_x, dummy});
425    auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
426                              std::initializer_list<Input>{enter_y, dummy});
427
428    // Loop condition
429    auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
430                                       .WithControlDependencies(merge_x.output),
431                                   3);
432    auto cond_add =
433        ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
434    auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
435                                     .WithControlDependencies(merge_x.output),
436                                 10);
437    auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
438    auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
439
440    auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
441                                merge_x.output, loop_cond);
442    auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
443                                merge_y.output, loop_cond);
444
445    auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
446                                      switch_x.output_false);
447    auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
448                                      switch_y.output_false);
449
450    auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
451                                    switch_x.output_true);
452    auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
453                                    switch_y.output_true);
454
455    auto one = ops::Const<int32>(
456        scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
457        1);
458    auto two = ops::Const<int32>(
459        scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
460        2);
461
462    auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
463    auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
464    auto next_iteration_x =
465        ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
466    auto next_iteration_y =
467        ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
468
469    auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
470    auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
471
472    // Remove the dummy node and add the loop backedges.
473    scope.graph()->RemoveNode(dummy.node());
474    scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
475                           1);
476    scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
477                           1);
478
479    TF_EXPECT_OK(scope.ToGraph(&graph));
480  }
481
482  FunctionLibraryDefinition library(OpRegistry::Global(), {});
483  TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
484
485  GraphDef graph_def;
486  graph.ToGraphDef(&graph_def);
487
488  NameAttrList cond_fn, body_fn;
489  TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
490
491  // Outer graph.
492  {
493    Scope scope = Scope::NewRootScope().ExitOnError();
494    auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
495    auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
496    auto while_op =
497        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
498                      std::initializer_list<Input>{x, y}, cond_fn, body_fn);
499    auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
500    auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
501    GraphDef expected;
502    TF_EXPECT_OK(scope.ToGraphDef(&expected));
503    TF_EXPECT_GRAPH_EQ(expected, graph_def);
504  }
505
506  // Condition graph.
507  {
508    Scope scope = Scope::NewRootScope().ExitOnError();
509    auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
510    auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
511    auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
512                                       .WithControlDependencies(arg0.output),
513                                   3);
514    auto cond_add =
515        ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
516    auto ten = ops::Const<int32>(
517        scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output),
518        10);
519    auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
520    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
521
522    GraphDef expected;
523    TF_EXPECT_OK(scope.ToGraphDef(&expected));
524
525    InstantiationResultForTest result;
526    TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
527
528    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
529    EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
530    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
531  }
532
533  // Body graph.
534  {
535    Scope scope = Scope::NewRootScope().ExitOnError();
536    auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
537    auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
538
539    auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
540    auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
541
542    auto one = ops::Const<int32>(
543        scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
544        1);
545    auto two = ops::Const<int32>(
546        scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
547        2);
548
549    auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
550    auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
551    auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
552    auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
553
554    GraphDef expected;
555    TF_EXPECT_OK(scope.ToGraphDef(&expected));
556
557    InstantiationResultForTest result;
558    TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
559
560    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
561    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
562    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
563  }
564}
565
566// Example with nesting, loop-invariant arguments, and resource variables.
567//
568// accum = resource_variable_ops.ResourceVariable(1)
569// x = array_ops.placeholder(2, dtype=dtypes.int32)
570// y = 3 + x
571//
572// def inner_body(j, k):
573//   add = state_ops.assign_add(accum, k * j + x)
574//   with ops.control_dependencies([add]):
575//     return [j + 1, k]
576//
577// def body(i):
578//   m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
579//                                   [1, y], name="inner")
580//   with ops.control_dependencies(m):
581//     return [i + 1]
582//
583// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
584TEST(FunctionalizeControlFlow, Complex) {
585  Graph graph(OpRegistry::Global());
586  {
587    Scope scope = Scope::NewRootScope().ExitOnError();
588
589    auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
590
591    auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
592    auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
593    auto y = ops::Add(scope.WithOpName("y"), x, three);
594
595    auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
596                                TensorShape({}));
597
598    // Outer loop
599    auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
600    auto enter_i =
601        ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
602    auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
603                              std::initializer_list<Input>{enter_i, dummy});
604    auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
605                                     .WithControlDependencies(merge_i.output),
606                                 10);
607    auto less_i =
608        ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
609    auto outer_loop_cond =
610        ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
611    auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
612                                merge_i.output, outer_loop_cond);
613    auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
614                                      switch_i.output_false);
615    auto identity_i =
616        ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
617
618    auto enter_x_outer =
619        ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
620                             ops::internal::Enter::Attrs().IsConstant(true));
621    auto enter_k_outer =
622        ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
623                             ops::internal::Enter::Attrs().IsConstant(true));
624    auto enter_var_outer =
625        ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
626                             ops::internal::Enter::Attrs().IsConstant(true));
627
628    // Inner loop
629    auto one_j = ops::Const<int32>(
630        scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
631    auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
632                                        one_j, "inner");
633    auto enter_k =
634        ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
635                                 .WithControlDependencies(identity_i),
636                             enter_k_outer, "inner");
637    auto enter_x = ops::internal::Enter(
638        scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
639        ops::internal::Enter::Attrs().IsConstant(true));
640    auto enter_var = ops::internal::Enter(
641        scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
642        ops::internal::Enter::Attrs().IsConstant(true));
643
644    auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
645                              std::initializer_list<Input>{enter_j, dummy});
646    auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
647                              std::initializer_list<Input>{enter_k, dummy});
648
649    auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
650                                      .WithControlDependencies(merge_j.output),
651                                  5);
652    auto less_j =
653        ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
654    auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
655
656    auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
657                                merge_j.output, loop_cond);
658    auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
659                                merge_k.output, loop_cond);
660    auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
661                                      switch_j.output_false);
662    auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
663                                      switch_k.output_false);
664    auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
665                                    switch_j.output_true);
666    auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
667                                    switch_k.output_true);
668
669    // Variable update
670    auto mul_jk =
671        ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
672    auto add_jkx =
673        ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
674    auto assign = ops::AssignAddVariableOp(
675        scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
676
677    auto one =
678        ops::Const<int32>(scope.WithOpName("outer/inner/One")
679                              .WithControlDependencies(
680                                  gtl::ArraySlice<Operation>{assign.operation}),
681                          1);
682    auto add_j =
683        ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
684
685    auto next_iteration_j = ops::NextIteration(
686        scope.WithOpName("outer/inner/NextIteration_j"), add_j);
687    auto next_iteration_k = ops::NextIteration(
688        scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
689
690    // Body and backedge for outer loop.
691    auto one_outer = ops::Const<int32>(
692        scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
693    auto add_i =
694        ops::Add(scope.WithOpName("outer/add")
695                     .WithControlDependencies(gtl::ArraySlice<Operation>{
696                         exit_j.output.op(), exit_k.output.op()}),
697                 identity_i, one_outer);
698    auto next_iteration_i =
699        ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
700
701    auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
702
703    // Remove the dummy node and add the loop backedge.
704    scope.graph()->RemoveNode(dummy.node());
705    scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
706                           1);
707    scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
708                           1);
709    scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
710                           1);
711
712    TF_EXPECT_OK(scope.ToGraph(&graph));
713  }
714
715  FunctionLibraryDefinition library(OpRegistry::Global(), {});
716  TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
717
718  GraphDef graph_def;
719  graph.ToGraphDef(&graph_def);
720
721  NameAttrList outer_cond_fn, outer_body_fn;
722  TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
723
724  // Outer graph.
725  {
726    Scope scope = Scope::NewRootScope().ExitOnError();
727    auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
728    auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
729    auto y = ops::Add(scope.WithOpName("y"), x, three);
730
731    auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
732                                TensorShape({}));
733
734    auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
735
736    auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
737                                  std::initializer_list<Input>{zero, y, x, var},
738                                  outer_cond_fn, outer_body_fn);
739    auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
740    GraphDef expected;
741    TF_EXPECT_OK(scope.ToGraphDef(&expected));
742    TF_EXPECT_GRAPH_EQ(expected, graph_def);
743  }
744
745  // Outer condition graph.
746  {
747    Scope scope = Scope::NewRootScope().ExitOnError();
748    auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
749    auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
750    auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
751    auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
752
753    auto ten = ops::Const<int32>(
754        scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
755        10);
756    auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
757    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
758
759    GraphDef expected;
760    TF_EXPECT_OK(scope.ToGraphDef(&expected));
761
762    InstantiationResultForTest result;
763    TF_EXPECT_OK(
764        InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
765
766    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
767              result.arg_types);
768    EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
769    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
770  }
771
772  // Outer body graph.
773  NameAttrList inner_cond_fn, inner_body_fn;
774  {
775    InstantiationResultForTest result;
776    TF_EXPECT_OK(
777        InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
778
779    // Find the inner condition and body names.
780    TF_EXPECT_OK(
781        FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
782
783    Scope scope = Scope::NewRootScope().ExitOnError();
784    auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
785    auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
786    auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
787    auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
788
789    auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
790    auto one_j = ops::Const<int32>(
791        scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
792    auto while_op =
793        ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
794                      std::initializer_list<Input>{one_j, arg1, arg2, arg3},
795                      inner_cond_fn, inner_body_fn);
796
797    auto one_outer = ops::Const<int32>(
798        scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
799    auto add_i =
800        ops::Add(scope.WithOpName("outer/add")
801                     .WithControlDependencies(gtl::ArraySlice<Operation>{
802                         while_op[0].op(), while_op[1].op()}),
803                 identity_i, one_outer);
804
805    auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
806    auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
807    auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
808
809    GraphDef expected;
810    TF_EXPECT_OK(scope.ToGraphDef(&expected));
811
812    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
813              result.arg_types);
814    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
815    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
816  }
817
818  // Inner condition graph.
819  {
820    Scope scope = Scope::NewRootScope().ExitOnError();
821    auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
822    auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
823    auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
824    auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
825
826    auto five = ops::Const<int32>(
827        scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
828    auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
829    auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
830
831    GraphDef expected;
832    TF_EXPECT_OK(scope.ToGraphDef(&expected));
833
834    InstantiationResultForTest result;
835    TF_EXPECT_OK(
836        InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
837
838    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
839              result.arg_types);
840    EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
841    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
842  }
843
844  // Inner body graph.
845  {
846    Scope scope = Scope::NewRootScope().ExitOnError();
847    auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
848    auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
849    auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
850    auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
851
852    auto identity_j =
853        ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
854    auto identity_k =
855        ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
856
857    auto mul_jk =
858        ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
859    auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
860    auto assign = ops::AssignAddVariableOp(
861        scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
862
863    auto one =
864        ops::Const<int32>(scope.WithOpName("outer/inner/One")
865                              .WithControlDependencies(
866                                  gtl::ArraySlice<Operation>{assign.operation}),
867                          1);
868    auto add_j =
869        ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
870
871    auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
872    auto retval1 =
873        ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
874    auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
875
876    GraphDef expected;
877    TF_EXPECT_OK(scope.ToGraphDef(&expected));
878
879    InstantiationResultForTest result;
880    TF_EXPECT_OK(
881        InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
882
883    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
884              result.arg_types);
885    EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
886    TF_EXPECT_GRAPH_EQ(expected, result.gdef);
887  }
888}
889
890}  // namespace
891}  // namespace tensorflow
892