1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=g-long-lambda
17"""Tests for tensorflow.ops.control_flow_ops."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import collections
24import math
25import time
26
27import numpy as np
28from six.moves import xrange  # pylint: disable=redefined-builtin
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.python.client import device_lib
32from tensorflow.python.client import session
33from tensorflow.python.eager import context
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors_impl
37from tensorflow.python.framework import function
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.framework import test_util
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import data_flow_ops
45from tensorflow.python.ops import functional_ops
46from tensorflow.python.ops import gen_array_ops
47from tensorflow.python.ops import gen_control_flow_ops
48from tensorflow.python.ops import gen_data_flow_ops
49from tensorflow.python.ops import gen_logging_ops
50from tensorflow.python.ops import gen_state_ops
51from tensorflow.python.ops import gradients_impl
52from tensorflow.python.ops import init_ops
53from tensorflow.python.ops import linalg_ops
54from tensorflow.python.ops import logging_ops
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
57from tensorflow.python.ops import nn_ops
58from tensorflow.python.ops import random_ops
59from tensorflow.python.ops import resource_variable_ops
60from tensorflow.python.ops import script_ops
61from tensorflow.python.ops import state_ops
62from tensorflow.python.ops import variable_scope
63from tensorflow.python.ops import variables
64# pylint: disable=unused-import
65import tensorflow.python.ops.tensor_array_grad
66# pylint: enable=unused-import
67from tensorflow.python.platform import test
68from tensorflow.python.training import adam
69from tensorflow.python.training import gradient_descent
70from tensorflow.python.util import nest
71
72
73def check_consumers(graph):
74  """Sanity check on the consumer list of the tensors."""
75
76  consumer_count = {}
77  for op in graph.get_operations():
78    for v in op.inputs:
79      cnt = consumer_count.get(v, 0)
80      consumer_count[v] = cnt + 1
81  for k, v in consumer_count.items():
82    if len(k.consumers()) != v:
83      return False
84  return True
85
86
87def all_fetchables():
88  tensor_names = []
89  graph = ops.get_default_graph()
90  for op in graph.get_operations():
91    for t in op.outputs:
92      if graph.is_fetchable(t):
93        tensor_names.append(t.name)
94  return tensor_names
95
96
97def all_feedables():
98  feedable_tensors = []
99  graph = ops.get_default_graph()
100  for op in graph.get_operations():
101    for t in op.inputs:
102      if graph.is_feedable(t):
103        feedable_tensors.append(t)
104  return feedable_tensors
105
106
107def opt_cfg():
108  return config_pb2.ConfigProto(
109      allow_soft_placement=True,
110      graph_options=config_pb2.GraphOptions(
111          optimizer_options=config_pb2.OptimizerOptions(
112              opt_level=config_pb2.OptimizerOptions.L1,
113              do_function_inlining=True,
114              do_constant_folding=True)))
115
116
117def isum(s, maximum_iterations=None):
118  i = constant_op.constant(0, name="i")
119  c = lambda i, s: math_ops.less(i, 10)
120  b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)]
121  _, r_s = control_flow_ops.while_loop(
122      c, b, [i, s], maximum_iterations=maximum_iterations)
123  return r_s
124
125
126@test_util.with_c_api
127class ControlFlowTest(test.TestCase):
128
129  def testRefIdentity(self):
130    with self.test_session():
131      v = variables.Variable(7)
132
133      v = control_flow_ops._Identity(v)
134      op = state_ops.assign(v, 9)
135      v2 = control_flow_ops.with_dependencies([op], v)
136
137      self.assertTrue(isinstance(v2, ops.Tensor))
138      variables.global_variables_initializer().run()
139      self.assertEqual(9, v2.eval())
140
141  def testRefEnter(self):
142    with self.test_session():
143      v = variables.Variable(7)
144
145      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
146      nine = constant_op.constant(9)
147      enter_nine = gen_control_flow_ops._enter(nine, "foo_1")
148      op = state_ops.assign(enter_v, enter_nine)
149      v2 = control_flow_ops.with_dependencies([op], enter_v)
150      v3 = control_flow_ops.exit(v2)
151      variables.global_variables_initializer().run()
152      self.assertEqual(9, v3.eval())
153
154  def testRefSwitch(self):
155    with self.test_session():
156      v = variables.Variable(7)
157
158      p = constant_op.constant(True)
159      v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p)  # pylint: disable=protected-access
160      v2 = state_ops.assign(v1[1], 9)
161      variables.global_variables_initializer().run()
162      self.assertEqual(9, v2.eval())
163
164  def testEnterMulExit(self):
165    with self.test_session():
166      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
167      enter_data = gen_control_flow_ops._enter(data, "foo_1", False)
168      five = constant_op.constant(5)
169      enter_five = gen_control_flow_ops._enter(five, "foo_1", False)
170      mul_op = math_ops.multiply(enter_data, enter_five)
171      exit_op = control_flow_ops.exit(mul_op)
172
173      result = exit_op.eval()
174    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
175
176  def testEnterShapePropagation(self):
177    with self.test_session():
178      v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
179
180      # If is_constant=True, the shape information should be propagated.
181      enter_v_constant = gen_control_flow_ops._enter(
182          v, "frame1", is_constant=True)
183      self.assertEqual(enter_v_constant.shape, [2])
184
185      # Otherwise, the shape should be unknown.
186      enter_v_non_constant = gen_control_flow_ops._enter(
187          v, "frame2", is_constant=False)
188      self.assertEqual(enter_v_non_constant.shape, None)
189
190  def testSwitchMergeIndexedSlices(self):
191    with self.test_session():
192      values = constant_op.constant([1, 2, 3, 4, 5, 6])
193      indices = constant_op.constant([0, 2, 4, 6, 8, 10])
194      data = ops.IndexedSlices(values, indices)
195      pred = ops.convert_to_tensor(True)
196      switch_op = control_flow_ops.switch(data, pred)
197      merge_op = control_flow_ops.merge(switch_op)[0]
198
199      val = merge_op.values.eval()
200      ind = merge_op.indices.eval()
201    self.assertAllEqual(np.arange(1, 7), val)
202    self.assertAllEqual(np.arange(0, 12, 2), ind)
203
204  def testSwitchDeadBranch(self):
205    with self.test_session():
206      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
207      ports = ops.convert_to_tensor(True, name="ports")
208      switch_op = control_flow_ops.switch(data, ports)
209      dead_branch = array_ops.identity(switch_op[0])
210
211      with self.assertRaisesWithPredicateMatch(
212          errors_impl.InvalidArgumentError,
213          lambda e: "Retval[0] does not have value" in str(e)):
214        dead_branch.eval()
215
216  def testSwitchMergeLess(self):
217    with self.test_session():
218      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
219      zero = ops.convert_to_tensor(0)
220      one = ops.convert_to_tensor(1)
221      less_op = math_ops.less(zero, one)
222      switch_op = control_flow_ops.switch(data, less_op)
223      merge_op = control_flow_ops.merge(switch_op)[0]
224
225      result = merge_op.eval()
226    self.assertAllEqual(np.arange(1, 7), result)
227
228  def testSwitchMergeAddIdentity(self):
229    with self.test_session():
230      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
231      ports = ops.convert_to_tensor(False, name="ports")
232      switch_op = control_flow_ops.switch(data, ports)
233      one = constant_op.constant(1)
234      add_op = math_ops.add(switch_op[0], one)
235      id_op = array_ops.identity(switch_op[1])
236      merge_op = control_flow_ops.merge([add_op, id_op])[0]
237
238      result = merge_op.eval()
239    self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
240
241  def testSwitchMergeAddMul(self):
242    with self.test_session():
243      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
244      ports = ops.convert_to_tensor(True, name="ports")
245      switch_op = control_flow_ops.switch(data, ports)
246      one = constant_op.constant(1)
247      add_op = math_ops.add(switch_op[0], one)
248      five = constant_op.constant(5)
249      mul_op = math_ops.multiply(switch_op[1], five)
250      merge_op = control_flow_ops.merge([add_op, mul_op])[0]
251
252      result = merge_op.eval()
253    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
254
255  def testLoop_false(self):
256    with self.test_session():
257      false = ops.convert_to_tensor(False)
258      n = constant_op.constant(10)
259
260      enter_false = gen_control_flow_ops._enter(false, "foo_1", False)
261      enter_n = gen_control_flow_ops._enter(n, "foo_1", False)
262
263      merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
264      switch_n = control_flow_ops.switch(merge_n, enter_false)
265      exit_n = control_flow_ops.exit(switch_n[0])
266      next_n = control_flow_ops.next_iteration(switch_n[0])
267      merge_n.op._update_input(1, next_n)
268
269      result = exit_n.eval()
270    self.assertAllEqual(10, result)
271
272  def testLoop_1(self):
273    with self.test_session():
274      zero = constant_op.constant(0)
275      one = constant_op.constant(1)
276      n = constant_op.constant(10)
277
278      enter_i = gen_control_flow_ops._enter(zero, "foo", False)
279      enter_one = gen_control_flow_ops._enter(one, "foo", True)
280      enter_n = gen_control_flow_ops._enter(n, "foo", True)
281
282      with ops.device(test.gpu_device_name()):
283        merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
284
285      less_op = math_ops.less(merge_i, enter_n)
286      cond_op = control_flow_ops.loop_cond(less_op)
287      switch_i = control_flow_ops.switch(merge_i, cond_op)
288
289      add_i = math_ops.add(switch_i[1], enter_one)
290
291      next_i = control_flow_ops.next_iteration(add_i)
292      merge_i.op._update_input(1, next_i)
293
294      exit_i = control_flow_ops.exit(switch_i[0])
295      result = exit_i.eval()
296    self.assertAllEqual(10, result)
297
298  def testLoop_2(self):
299    with self.test_session():
300      zero = constant_op.constant(0)
301      one = constant_op.constant(1)
302      n = constant_op.constant(10)
303
304      enter_i = gen_control_flow_ops._enter(zero, "foo", False)
305      enter_one = gen_control_flow_ops._enter(one, "foo", True)
306      enter_n = gen_control_flow_ops._enter(n, "foo", True)
307
308      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
309
310      less_op = math_ops.less(merge_i, enter_n)
311      cond_op = control_flow_ops.loop_cond(less_op)
312      switch_i = control_flow_ops.switch(merge_i, cond_op)
313
314      add_i = math_ops.add(switch_i[1], enter_one)
315
316      with ops.device(test.gpu_device_name()):
317        next_i = control_flow_ops.next_iteration(add_i)
318      merge_i.op._update_input(1, next_i)
319
320      exit_i = control_flow_ops.exit(switch_i[0])
321      result = exit_i.eval()
322    self.assertAllEqual(10, result)
323
324  def testDifferentFrame(self):
325    with self.test_session():
326      data = array_ops.placeholder(dtypes.float32, shape=[])
327      enter_1 = gen_control_flow_ops._enter(data, "foo_1", False)
328      enter_2 = gen_control_flow_ops._enter(data, "foo_2", False)
329      res = math_ops.add(enter_1, enter_2)
330      with self.assertRaisesOpError("has inputs from different frames"):
331        res.eval(feed_dict={data: 1.0})
332
333  def testCondBool(self):
334    values = constant_op.constant(10)
335    fn1 = lambda: math_ops.add(values, 1)
336    fn2 = lambda: math_ops.subtract(values, 1)
337    with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
338      _ = control_flow_ops.cond(False, fn1, fn2)
339
340  def testCondInt(self):
341    p = array_ops.placeholder(dtypes.bool, shape=[])
342    v = constant_op.constant(10)
343    fn1 = lambda: math_ops.add(v, 1)
344    fn2 = lambda: math_ops.subtract(v, 1)
345    y = control_flow_ops.cond(p, fn1, fn2)
346    grad = gradients_impl.gradients(y, [v])
347    self.assertAllEqual([None], grad)
348
349  def testFetchable(self):
350    with self.test_session() as sess:
351      x = array_ops.placeholder(dtypes.float32)
352      control_flow_ops.cond(
353          constant_op.constant(True), lambda: x + 2, lambda: x + 0)
354      graph = ops.get_default_graph()
355      for op in graph.get_operations():
356        for t in op.inputs:
357          if graph.is_fetchable(t.op):
358            sess.run(t, feed_dict={x: 3})
359          else:
360            with self.assertRaisesRegexp(ValueError,
361                                         "has been marked as not fetchable"):
362              sess.run(t, feed_dict={x: 3})
363
364  def testFeedable(self):
365    with self.test_session() as sess:
366      c = constant_op.constant(2)
367      i0 = constant_op.constant(0)
368      r = control_flow_ops.while_loop(lambda i: i < 1000,
369                                      lambda i: math_ops.square(c) + i, [i0])
370      self.assertEqual(1000, r.eval(feed_dict={i0: 0}))
371      feedable_tensors = all_feedables()
372      for t in feedable_tensors:
373        sess.run(r, feed_dict={t: 3})
374      graph = ops.get_default_graph()
375      for op in graph.get_operations():
376        for t in op.inputs:
377          if t not in feedable_tensors and t.dtype is dtypes.int32:
378            with self.assertRaisesRegexp(ValueError, "may not be fed"):
379              sess.run(r, feed_dict={t: 3})
380
381  def testCondIndexedSlices(self):
382    with self.test_session():
383      values = constant_op.constant(10)
384      indices = constant_op.constant(0)
385      x = ops.IndexedSlices(values, indices)
386      pred = math_ops.less(1, 2)
387      fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices)
388      fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices)
389      r = control_flow_ops.cond(pred, fn1, fn2)
390
391      val = r.values.eval()
392      ind = r.indices.eval()
393    self.assertAllEqual(11, val)
394    self.assertAllEqual(0, ind)
395
396  def testCondSparseTensor(self):
397    with self.test_session():
398      values = constant_op.constant([2.0, 4.0], name="values")
399      indices = constant_op.constant(
400          [[0], [3]], dtype=dtypes.int64, name="indices")
401      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
402      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
403      pred = math_ops.less(1, 2)
404      fn1 = lambda: sparse_tensor.SparseTensor(
405          indices + 1, x.values + 1, dense_shape=shape)
406      fn2 = lambda: sparse_tensor.SparseTensor(
407          indices, x.values - 1, dense_shape=shape)
408      r = control_flow_ops.cond(pred, fn1, fn2)
409      self.assertAllEqual([3.0, 5.0], r.values.eval())
410      self.assertAllEqual([[1], [4]], r.indices.eval())
411      self.assertAllEqual(r.values.get_shape(), (2,))
412
413  def testCondResource(self):
414    with self.test_session():
415      rv = resource_variable_ops.ResourceVariable(True)
416      variables.global_variables_initializer().run()
417      t = ops.convert_to_tensor(1.0)
418
419      def case():
420        assign = resource_variable_ops.assign_variable_op(rv.handle, False)
421        with ops.control_dependencies([assign]):
422          return array_ops.identity(t)
423
424      self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
425
426  def testCondIndexedSlicesDifferentTypes(self):
427    with self.test_session():
428      values = constant_op.constant(10)
429      i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
430      i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
431      x = ops.IndexedSlices(values, i_32)
432      pred = math_ops.less(1, 2)
433      fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32)
434      fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64)
435      r = control_flow_ops.cond(pred, fn1, fn2)
436
437      val = r.values.eval()
438      ind = r.indices.eval()
439    self.assertAllEqual(11, val)
440    self.assertAllEqual(0, ind)
441    self.assertTrue(ind.dtype == np.int64)
442
443  def testCondColocation(self):
444    with self.test_session(use_gpu=True):
445      with ops.device("/cpu:0"):
446        v = variables.Variable(7.0)
447
448      x = constant_op.constant(10.0)
449      pred = math_ops.less(1.0, 2.0)
450      fn1 = lambda: math_ops.add(v, 1.0)
451      fn2 = lambda: math_ops.subtract(x, 1.0)
452      r = control_flow_ops.cond(pred, fn1, fn2)
453
454      for op in x.graph.get_operations():
455        if op.name == "cond/Add/Switch":
456          self.assertDeviceEqual(op.device, "/cpu:0")
457
458  def _testCond_1(self, use_gpu):
459    with self.test_session(use_gpu=use_gpu):
460      x = constant_op.constant(10)
461      pred = math_ops.less(1, 2)
462      fn1 = lambda: math_ops.add(x, 1)
463      fn2 = lambda: math_ops.subtract(x, 1)
464      r = control_flow_ops.cond(pred, fn1, fn2)
465
466      result = r.eval()
467    self.assertAllEqual(11, result)
468
469  def testCond_1(self):
470    self._testCond_1(use_gpu=False)
471    self._testCond_1(use_gpu=True)
472
473  def testCond_2(self):
474    with self.test_session():
475      x = constant_op.constant(10)
476      r = control_flow_ops.cond(
477          math_ops.less(1, 0), lambda: math_ops.add(x, 1),
478          lambda: math_ops.subtract(x, 1))
479      result = r.eval()
480    self.assertAllEqual(9, result)
481
482  def testCond_3(self):
483    with self.test_session():
484      x = constant_op.constant(10)
485      pred = math_ops.less(1, 2)
486      fn1 = lambda: math_ops.add(x, 1)
487      fn2 = lambda: math_ops.subtract(x, 1)
488      fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1)
489      r = control_flow_ops.cond(pred, fn3, fn2)
490
491      result = r.eval()
492    self.assertAllEqual(12, result)
493
494  def testCond_4(self):
495    with self.test_session():
496      v1 = variables.Variable(7)
497      v2 = variables.Variable(7)
498      v3 = variables.Variable(7)
499
500      age = constant_op.constant(3)
501      max_age = constant_op.constant(2)
502      pred = math_ops.greater(age, max_age)
503      fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op]
504      fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op]
505      r = control_flow_ops.cond(pred, fn1, fn2)
506
507      variables.global_variables_initializer().run()
508      self.assertEqual(len(r), 2)
509      result = r[1].eval()
510      self.assertAllEqual(True, result)
511      self.assertAllEqual(7, v1.eval())
512      self.assertAllEqual(2, v2.eval())
513      self.assertAllEqual(7, v3.eval())
514
515  def testCond_5(self):
516    with self.test_session():
517      alive = constant_op.constant(True, name="alive")
518      count = constant_op.constant(0, name="count")
519
520      def body(i):
521        return control_flow_ops.cond(
522            alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)],
523            lambda: [alive, count])
524
525      for i in range(10):
526        alive, count = body(i)
527      self.assertAllEqual(4, count.eval())
528
529  def testCond_6(self):
530    with self.test_session():
531      v1 = variables.Variable([7])
532
533      age = constant_op.constant(3)
534      pred = math_ops.greater(age, 4)
535      fn1 = lambda: age
536      fn2 = lambda: v1
537      r = control_flow_ops.cond(pred, fn1, fn2)
538
539      variables.global_variables_initializer().run()
540      result = r.eval()
541      self.assertAllEqual(np.array([7]), result)
542
543  def testCond_7(self):
544    with self.test_session() as sess:
545      x = constant_op.constant(10)
546      y = constant_op.constant(200)
547      pred = math_ops.less(1, 2)
548      fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)]
549      fn2 = lambda: [y, y]
550      r = control_flow_ops.cond(pred, fn1, fn2)
551      self.assertAllEqual([11, 12], sess.run(r))
552
553  def testCondRef(self):
554    with self.test_session():
555      x = gen_state_ops._variable(
556          shape=[1],
557          dtype=dtypes.float32,
558          name="x",
559          container="",
560          shared_name="")
561      true_fn = lambda: x
562      false_fn = lambda: constant_op.constant([2.0])
563      r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
564      self.assertAllEqual([2.0], r.eval())
565
566  def testCondWithControl(self):
567    with self.test_session() as sess:
568      control_holder = array_ops.placeholder(dtypes.float32, shape=())
569      a = constant_op.constant(3)
570
571      def true_branch():
572        with ops.control_dependencies([control_holder]):
573          _ = a + 1
574        return a + 2
575
576      r = control_flow_ops.cond(
577          constant_op.constant(True), true_branch,
578          lambda: constant_op.constant(1))
579      self.assertEqual(5, r.eval())
580
581  def testUninitializedRefIdentity(self):
582    with self.test_session() as sess:
583      v = gen_state_ops._variable(
584          shape=[1],
585          dtype=dtypes.float32,
586          name="v",
587          container="",
588          shared_name="")
589      inited = state_ops.is_variable_initialized(v)
590      v_f, v_t = control_flow_ops.ref_switch(v, inited)
591      # Both v_f and v_t are uninitialized references. However, an actual use
592      # of the reference in the 'true' branch in the 'tf.identity' op will
593      # not 'fire' when v is uninitialized, so this is a valid construction.
594      # This test tests that _ref_identity allows uninitialized ref as input
595      # so that this construction is allowed.
596      v_f_op = gen_array_ops._ref_identity(v_f)
597      v_t_op = gen_array_ops._ref_identity(v_t)
598      with ops.control_dependencies([v_f_op]):
599        assign_v = state_ops.assign(v, [1.0])
600      with ops.control_dependencies([v_t_op]):
601        orig_v = array_ops.identity(v)
602      merged_op = control_flow_ops.merge([assign_v, orig_v])
603      self.assertAllEqual([1.0], sess.run(merged_op.output))
604
605  def testCondSwitchIdentity(self):
606    # Make sure the recv identity is not removed by optimization.
607    with session.Session(config=opt_cfg()) as sess:
608      pred = constant_op.constant(True)
609
610      def fn1():
611        return control_flow_ops.no_op()
612
613      def fn2():
614        return control_flow_ops.Assert(False, ["Wrong branch!!!"])
615
616      r = control_flow_ops.cond(pred, fn1, fn2)
617      sess.run(r)
618
619  def testCondRecvIdentity(self):
620    # Make sure the switch identity is not removed by optimization.
621    with session.Session(config=opt_cfg()) as sess:
622      with ops.device(test.gpu_device_name()):
623        pred = constant_op.constant(True)
624
625      def fn1():
626        return control_flow_ops.no_op()
627
628      def fn2():
629        with ops.device("/cpu:0"):
630          return control_flow_ops.Assert(False, ["Wrong branch!!!"])
631
632      r = control_flow_ops.cond(pred, fn1, fn2)
633      sess.run(r)
634
635  def testCondGrad_1(self):
636    with self.test_session():
637      x = constant_op.constant(10.0, name="x")
638      pred = math_ops.less(1, 2)
639      fn1 = lambda: array_ops.identity(x)
640      fn2 = lambda: array_ops.identity(x)
641      r = control_flow_ops.cond(pred, fn1, fn2)
642
643      grad = gradients_impl.gradients(r, [x])[0]
644      result = grad.eval()
645    self.assertAllEqual(1.0, result)
646
647  def testCondGrad_2(self):
648    with self.test_session():
649      c = array_ops.placeholder(dtypes.int32, shape=[])
650      x = constant_op.constant(10.0)
651      pred = math_ops.less(c, 2)
652      fn1 = lambda: math_ops.multiply(x, 42.0)
653      fn2 = lambda: math_ops.multiply(x, 3.0)
654      r = control_flow_ops.cond(pred, fn1, fn2)
655
656      grad = gradients_impl.gradients(r, [x])[0]
657      self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
658      self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
659
660  def testNestedCond_Simple(self):
661    with self.test_session():
662      x = constant_op.constant(0., name="X")
663      y = control_flow_ops.cond(
664          constant_op.constant(True), lambda: x,
665          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
666      result = gradients_impl.gradients(y, x)[0]
667      self.assertEqual(1.0, result.eval())
668
669      z = control_flow_ops.cond(
670          constant_op.constant(False), lambda: x,
671          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
672      result = gradients_impl.gradients(z, x)[0]
673      self.assertEqual(1.0, result.eval())
674
675  def testCondGrad_Gather(self):
676    with self.test_session() as sess:
677      v1 = variables.Variable([1.0, 42.0])
678      c = array_ops.placeholder(dtypes.int32, shape=[])
679      pred = math_ops.less(c, 2)
680      fn1 = lambda: array_ops.identity(v1)
681      fn2 = lambda: array_ops.gather(v1, [1, 1])
682      r = control_flow_ops.cond(pred, fn1, fn2)
683      grad = gradients_impl.gradients(r, [v1])[0]
684      variables.global_variables_initializer().run()
685      # Should just be [1, 1], but possibly a sparse representation
686      gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1})
687      dense_gv = [
688          sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2)
689      ]
690      self.assertAllEqual(dense_gv, [1.0, 1.0])
691      # Should be [0, 2], as the else forwards v1[1] twice
692      gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3})
693      dense_gv = [
694          sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2)
695      ]
696      self.assertAllEqual(dense_gv, [0.0, 2.0])
697
698  # Microbenchmark: 256,000 iterations/s.
699  def testWhile_1(self):
700    with self.test_session():
701      n = constant_op.constant(0)
702      c = lambda x: math_ops.less(x, 10000)
703      b = lambda x: math_ops.add(x, 1)
704      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
705      self.assertEqual(10000, r.eval())
706
707  def testWhileExternalControlDependencies(self):
708    with self.test_session():
709      v = variables.Variable(0.0)
710      v.initializer.run()
711      increment = v.assign_add(1.0)
712
713      def body_fn(i):
714        with ops.control_dependencies([increment]):
715          return i + i
716
717      result = control_flow_ops.while_loop(cond=lambda i: i < 1,
718                                           body=body_fn, loop_vars=[1])
719      result.eval()
720      self.assertAllEqual(v.eval(), 1.0)
721
722  def testWhileExternalControlDependenciesNoInput(self):
723    with self.test_session():
724      v = variables.Variable(0.0)
725      v.initializer.run()
726      increment = v.assign_add(1.0)
727
728      def body_fn(unused_i):
729        with ops.control_dependencies([increment]):
730          return constant_op.constant(5, name="five")
731
732      result = control_flow_ops.while_loop(cond=lambda i: i < 5,
733                                           body=body_fn, loop_vars=[0])
734      result.eval()
735      self.assertAllEqual(v.eval(), 1.0)
736
737  def testWhileWithRefs_1(self):
738    with self.test_session() as sess:
739      x = variables.Variable(0)._ref()  # pylint: disable=protected-access
740      i = constant_op.constant(0)
741      c = lambda i, x: math_ops.less(i, 100)
742
743      self.assertEqual(x.dtype, dtypes.int32_ref)
744
745      def b(i, x):
746        self.assertEqual(x.dtype, dtypes.int32_ref)
747        return (i + 1, gen_array_ops._ref_identity(x))
748
749      r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5)
750
751      variables.global_variables_initializer().run()
752
753      self.assertEqual(r[0].dtype, dtypes.int32)
754      self.assertEqual(r[1].dtype, dtypes.int32_ref)
755
756      value_i, value_x = sess.run(r)
757
758    self.assertEqual(100, value_i)
759    self.assertEqual(0, value_x)
760
761  def testWhile_2(self):
762    with self.test_session():
763      s = constant_op.constant(0)
764      r = isum(s)
765      self.assertAllEqual(45, r.eval())
766
767  def testWhileWithMaximumIterations(self):
768    with self.test_session():
769      s = constant_op.constant([1, 2, 3, 4, 5])
770      r = isum(s, maximum_iterations=3)
771      self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
772
773  def testWhileWithMaximumIterationsAndSingleArgument(self):
774    with self.test_session():
775      r = control_flow_ops.while_loop(
776          lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
777      self.assertEqual(1, r.eval())
778
779  def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
780    v = constant_op.constant(1.0)
781
782    def training_loop_with_gradient(i):
783      out = control_flow_ops.while_loop(
784          lambda i_, _: i_ < 3,
785          lambda i_, j: [i_ + 1, j * v], [0, 1.0],
786          maximum_iterations=i)
787      g = gradients_impl.gradients(out, v)
788      with ops.control_dependencies(g):
789        return i + 1
790
791    xla_context = control_flow_ops.XLAControlFlowContext()
792    xla_context.Enter()
793    # Create training loop, ensure we can call gradient() of
794    # while_loop inside the training loop.
795    loop = control_flow_ops.while_loop(lambda i: i < 3,
796                                       training_loop_with_gradient, [0])
797    xla_context.Exit()
798
799    loop_execute = array_ops.identity(loop)  # Because loop is not fetchable.
800
801    # Should execute without issue.
802    self.assertEqual(3, self.evaluate(loop_execute))
803
804  def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
805    v = constant_op.constant(1.0)
806
807    def inner_body(i, x):
808      out = control_flow_ops.while_loop(
809          lambda i, _: i < 3,
810          lambda i, j: [i + 1, j * v], [0, x],
811          maximum_iterations=i)
812      return out
813
814    def create_while_loop(maximum_iterations=None):
815      return control_flow_ops.while_loop(
816          lambda i, _: i < 3,
817          inner_body, [0, 1.0],
818          maximum_iterations=maximum_iterations)
819
820    loop_no_xla = create_while_loop(maximum_iterations=5)
821    # maximum_iterations is fine outside of an XLA scope
822    gs = gradients_impl.gradients(loop_no_xla, v)
823    self.evaluate(gs)  # This should execute without error.
824
825    xla_context = control_flow_ops.XLAControlFlowContext()
826    xla_context.Enter()
827    loop_no_maxiter = create_while_loop()
828    loop_with_maxiter = create_while_loop(maximum_iterations=2)
829    xla_context.Exit()
830
831    with self.assertRaisesRegexp(
832        ValueError,
833        r"Cannot create a gradient accumulator for tensor '.+' inside "
834        r"XLA while_loop because maximum_iterations was not passed to "
835        r"the tf.while_loop call \('.+'\)."):
836      _ = gradients_impl.gradients(loop_no_maxiter, v)
837
838    with self.assertRaisesRegexp(
839        ValueError,
840        r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
841        r"while_loop. maximum_iterations tensor '.+' for while_loop context "
842        r"'.+' must be statically known \(e.g. a constant value or known "
843        r"shape dimension\), or be defined at or outside the while loop "
844        r"context '.*' \(currently defined in '.*'\)"):
845      _ = gradients_impl.gradients(loop_with_maxiter, v)
846
847  def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
848    v = constant_op.constant(1.0)
849
850    def create_while_loop():
851      max_iter_holder = []
852
853      def create_mi():
854        max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
855        return 1.0
856
857      _ = control_flow_ops.cond(
858          constant_op.constant(True), create_mi, create_mi)
859
860      return control_flow_ops.while_loop(
861          lambda i, _: i < 3,
862          lambda i, x: (i + 1, v * x), (0, 1.0),
863          maximum_iterations=max_iter_holder[0])
864
865    xla_context = control_flow_ops.XLAControlFlowContext()
866    xla_context.Enter()
867    loop = create_while_loop()
868    xla_context.Exit()
869
870    with self.assertRaisesRegexp(
871        ValueError,
872        r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
873        r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
874        r"while_loop context '.+' must be statically known \(e.g. a constant "
875        r"value or known shape dimension\), or be defined at or outside the "
876        r"while loop context '' \(currently defined in 'cond/.+'\)"):
877      _ = gradients_impl.gradients(loop, v)
878
879  def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
880    v = constant_op.constant(1.0)
881
882    p = array_ops.placeholder(dtype=dtypes.int32)
883
884    def mid_body_builder(iterations):
885
886      def mid_body(i, x):
887        r = control_flow_ops.while_loop(
888            lambda *_: True,
889            lambda i, x: (i + 1, v * x), (0, x),
890            maximum_iterations=iterations,
891            name="inner")
892        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
893
894      return mid_body
895
896    def outer_body(i, x):
897      iterations = array_ops.size(p, name="iterations")
898      return (i + 1, x + control_flow_ops.while_loop(
899          lambda *_: True,
900          mid_body_builder(iterations), (0, x),
901          maximum_iterations=iterations,
902          name="mid")[1])
903
904    def create_while_loop():
905      with ops.device("/cpu:0"):
906        r = control_flow_ops.while_loop(
907            lambda *_: True,
908            outer_body, (0, 1.0),
909            maximum_iterations=5,
910            name="outer")
911        return array_ops.identity(r[1])
912
913    xla_context = control_flow_ops.XLAControlFlowContext()
914    xla_context.Enter()
915    final_with_xla_context = create_while_loop()
916    xla_context.Exit()
917
918    final_without_xla_context = create_while_loop()
919
920    with self.test_session(use_gpu=False) as sess:
921      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
922      run_metadata = config_pb2.RunMetadata()
923
924      final_value_without_xla_context = sess.run(
925          final_without_xla_context, feed_dict={
926              p: [0, 0, 0]
927          })
928
929      final_value_with_xla_context = sess.run(
930          final_with_xla_context,
931          feed_dict={p: [0, 0, 0]},
932          options=opts,
933          run_metadata=run_metadata)
934
935      node_stats = run_metadata.step_stats.dev_stats[0].node_stats
936      stack_push_count = len(
937          [x for x in node_stats if x.node_name.endswith("StackPushV2")])
938      # Pushes to the stack = product of maximum_iterations values;
939      # the last two "3"s comes from size(p), when p == [0, 0, 0].
940      self.assertEqual(stack_push_count, 5 * 3 * 3)
941
942      self.assertAllClose(final_value_with_xla_context,
943                          final_value_without_xla_context)
944
945  # Have more than 10 parallel iterations and hence exercise k-bound
946  # most of the time.
947  def testWhile_3(self):
948    with self.test_session():
949
950      def compute(i, m, c, o):
951        m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
952        o = math_ops.add(o, m)
953        o = math_ops.add(o, c)
954        i = math_ops.add(i, 1)
955        return [i, m, c, o]
956
957      i = ops.convert_to_tensor(0)
958      m = ops.convert_to_tensor(0)
959      c = ops.convert_to_tensor(0)
960      o = ops.convert_to_tensor(0)
961      d = ops.convert_to_tensor(100)
962      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d),
963                                      compute, [i, m, c, o])
964      result = r[3].eval()
965    self.assertAllEqual(10100, result)
966
967  def testWhile_4(self):
968    with self.test_session():
969
970      def compute(i, m, c, o):
971        m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
972        o = math_ops.add(o, m)
973        o = math_ops.add(o, c)
974        i = math_ops.add(i, 1)
975        return [i, m, c, o]
976
977      i = ops.convert_to_tensor(0)
978      m = ops.convert_to_tensor(0)
979      c = ops.convert_to_tensor(0)
980      o = ops.convert_to_tensor(0)
981      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
982      s = array_ops.size(x)
983      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s),
984                                      compute, [i, m, c, o])
985      result = r[3].eval()
986    self.assertAllEqual(42, result)
987
988  def testWhile_5(self):
989    with self.test_session():
990
991      def compute(i, c, o):
992        c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
993                                    [1] + array_ops.expand_dims(i, 0))
994        o = array_ops.concat([o, c], 0)
995        i = math_ops.add(i, 1)
996        return [i, c, o]
997
998      i = ops.convert_to_tensor(0)
999      c = ops.convert_to_tensor([0])
1000      o = ops.convert_to_tensor([0])
1001      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1002      s = array_ops.size(x)
1003      r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
1004                                      compute, [i, c, o], [
1005                                          i.get_shape(),
1006                                          tensor_shape.unknown_shape(),
1007                                          tensor_shape.unknown_shape()
1008                                      ])
1009      result = r[2].eval()
1010    self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
1011
1012  def testBufferForwarding(self):
1013    run_options = config_pb2.RunOptions(
1014        trace_level=config_pb2.RunOptions.FULL_TRACE)
1015    run_metadata = config_pb2.RunMetadata()
1016
1017    with self.test_session() as sess:
1018      with ops.device("/cpu:0"):
1019        c = constant_op.constant(2)
1020        i0 = constant_op.constant(0)
1021        r = control_flow_ops.while_loop(lambda i: i < 1000,
1022                                        lambda i: math_ops.square(c) + i, [i0])
1023      r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
1024      self.assertEqual(1000, r_val)
1025      self.assertTrue(run_metadata.HasField("step_stats"))
1026      unique_allocs = set()
1027      for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
1028        for output in node_stat.output:
1029          unique_allocs.add(
1030              output.tensor_description.allocation_description.ptr)
1031      # Prior to cl/147536680, the number of unique allocations was about 1005.
1032      self.assertLess(len(unique_allocs), 756)
1033
1034  def _testWhile_Gpu_1(self, use_gpu):
1035    with self.test_session(use_gpu=use_gpu):
1036      n = constant_op.constant(1.0)
1037      c = lambda x: math_ops.less(x, 10.0)
1038      b = lambda x: math_ops.add(x, 1.0)
1039      r = control_flow_ops.while_loop(c, b, [n])
1040      self.assertAllClose(10.0, r.eval())
1041
1042  def testWhile_Gpu_1(self):
1043    self._testWhile_Gpu_1(use_gpu=False)
1044    self._testWhile_Gpu_1(use_gpu=True)
1045
1046  def _testWhile_Gpu_2(self, use_gpu):
1047    with self.test_session(use_gpu=use_gpu):
1048      n = constant_op.constant(1.0)
1049      c = lambda x: math_ops.less(x, 10.0)
1050
1051      def b(x):
1052        with ops.device("/cpu:0"):
1053          return math_ops.add(x, 1.0)
1054
1055      r = control_flow_ops.while_loop(c, b, [n])
1056      self.assertAllClose(10.0, r.eval())
1057
1058  def testWhile_Gpu_2(self):
1059    self._testWhile_Gpu_1(use_gpu=False)
1060    self._testWhile_Gpu_1(use_gpu=True)
1061
1062  def testWhileShape(self):
1063    with self.test_session():
1064      i = constant_op.constant(0)
1065      m = array_ops.ones([2, 2])
1066      c = lambda i, j: math_ops.less(i, 2)
1067
1068      def _b(i, j):
1069        new_i = math_ops.add(i, 1)
1070        new_j = array_ops.tile(j, [2, 2])
1071        return [new_i, new_j]
1072
1073      r = control_flow_ops.while_loop(
1074          c, _b, [i, m],
1075          [i.get_shape(), tensor_shape.unknown_shape()])
1076      r = r[1] * array_ops.ones([8, 8])
1077      self.assertAllEqual(np.ones((8, 8)), r.eval())
1078
1079  def testWhileWithNonTensorInput_Scalar(self):
1080    with self.test_session():
1081      n = 0
1082      c = lambda x: x < 10000
1083      b = lambda x: x + 1
1084      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1085      self.assertEqual(10000, r.eval())
1086
1087  def testWhileWithNonTensorInput_Vector(self):
1088    with self.test_session():
1089      n = np.array([0])  # Note, [0] would not work here; that is a list
1090      c = lambda x: x[0] < 10000
1091      b = lambda x: array_ops.stack([x[0] + 1])
1092      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1093      self.assertEqual([10000], r.eval())
1094
1095  def testWhileShapeInference(self):
1096    with self.test_session():
1097      i = constant_op.constant(0)
1098      m = array_ops.ones([2, 2])
1099      c = lambda i, j: math_ops.less(i, 2)
1100
1101      def b(i, j):
1102        new_i = math_ops.add(i, 1)
1103        new_j = array_ops.concat([j, j], 0)
1104        return [new_i, new_j]
1105
1106      r = control_flow_ops.while_loop(
1107          c, b, [i, m],
1108          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
1109      self.assertTrue(r[1].get_shape()[0].value is None)
1110      self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
1111
1112      with self.assertRaisesRegexp(
1113          ValueError,
1114          r"The shape for while_1/Merge_1:0 is not an invariant for the loop. "
1115          r"It enters the loop with shape \(2, 2\), but has shape \(4, 2\) "
1116          r"after one iteration. Provide shape invariants using either the "
1117          r"`shape_invariants` argument of tf.while_loop or set_shape\(\) on "
1118          r"the loop variables."):
1119        r = control_flow_ops.while_loop(c, b, [i, m])
1120
1121  def testWhileShapeInferenceSparseTensor(self):
1122    with self.test_session():
1123      values = constant_op.constant([2.0, 4.0], name="values")
1124      indices = constant_op.constant(
1125          [[0], [3]], dtype=dtypes.int64, name="indices")
1126      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
1127      i = constant_op.constant(0)
1128      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
1129
1130      def c(i, _):
1131        return i < 10
1132
1133      def b(i, x):
1134        return [
1135            i + 1,
1136            sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
1137        ]
1138
1139      _, r = control_flow_ops.while_loop(c, b, [i, x])
1140      self.assertEqual(r.dense_shape.get_shape()[0].value, 1)
1141
1142      _, r = control_flow_ops.while_loop(
1143          c, b, [i, x],
1144          [i.get_shape(), tensor_shape.TensorShape([None])])
1145      self.assertTrue(r.dense_shape.get_shape()[0].value is None)
1146
1147      with self.assertRaisesRegexp(ValueError, "is not compatible with"):
1148        _, r = control_flow_ops.while_loop(
1149            c, b, [i, x],
1150            [i.get_shape(), tensor_shape.TensorShape([5])])
1151
1152  def testWhileShapeInferenceIndexedSlices(self):
1153    with self.test_session():
1154      values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
1155      indices = constant_op.constant([0, 3], name="indices")
1156      shape = constant_op.constant([10, 2], name="dense_shape")
1157      i = constant_op.constant(0)
1158      x = ops.IndexedSlices(values, indices, dense_shape=shape)
1159
1160      def c(i, _):
1161        return i < 10
1162
1163      def b(i, x):
1164        return [
1165            i + 1,
1166            ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
1167        ]
1168
1169      _, r = control_flow_ops.while_loop(c, b, [i, x])
1170      self.assertEqual(r.dense_shape.get_shape()[0].value, 2)
1171      self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
1172
1173      _, r = control_flow_ops.while_loop(
1174          c, b, [i, x],
1175          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
1176      self.assertEqual(r.dense_shape.get_shape()[0].value, 2)
1177      self.assertTrue(r.values.get_shape()[0].value is None)
1178      self.assertEqual(r.values.get_shape()[1].value, 2)
1179
1180      with self.assertRaisesRegexp(ValueError, "is not compatible with"):
1181        _, r = control_flow_ops.while_loop(
1182            c, b, [i, x],
1183            [i.get_shape(), tensor_shape.TensorShape([None, 5])])
1184
1185  def _testNestedWhile_1(self, use_gpu):
1186    with self.test_session(use_gpu=use_gpu):
1187      n = constant_op.constant(0)
1188
1189      def cpu_sum(s):
1190        c = lambda i, s: math_ops.less(i, 10)
1191
1192        def b(i, s):
1193          i1 = math_ops.add(i, 1)
1194          with ops.device("/cpu:0"):
1195            s1 = math_ops.add(i, s)
1196          return i1, s1
1197
1198        _, r_s = control_flow_ops.while_loop(c, b, [n, s])
1199        return r_s
1200
1201      c = lambda x: math_ops.less(x, 200)
1202      b = lambda x: math_ops.add(x, cpu_sum(n))
1203      r = control_flow_ops.while_loop(c, b, [n])
1204      self.assertEqual(225, r.eval())
1205
1206  def testNestedWhile_1(self):
1207    self._testNestedWhile_1(use_gpu=False)
1208    self._testNestedWhile_1(use_gpu=True)
1209
1210  def _testNestedWhile_2(self, use_gpu):
1211    # Test the cases that A -> Enter and Exit -> A are partitioned.
1212    with self.test_session(use_gpu=use_gpu):
1213      s0 = constant_op.constant(2.0)
1214
1215      def inner_loop(s):
1216        c = lambda s: math_ops.less(s, 20.0)
1217
1218        def b(s):
1219          s1 = math_ops.add(s, s)
1220          return s1
1221
1222        r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1)
1223        return r_s
1224
1225      outer_c = lambda x: math_ops.less(x, 3000.0)
1226
1227      def outer_b(x):
1228        x = logging_ops.Print(x, [x])  # Edge "Print -> Enter" is partitioned
1229        x = inner_loop(x)
1230        with ops.device("/cpu:0"):
1231          x = math_ops.square(x)  # Edge "Exit -> Square" is partitioned
1232        return x
1233
1234      r = control_flow_ops.while_loop(
1235          outer_c, outer_b, [s0], parallel_iterations=1)
1236      self.assertEqual(1048576.0, r.eval())
1237
1238  def testNestedWhile_2(self):
1239    self._testNestedWhile_2(use_gpu=False)
1240    self._testNestedWhile_2(use_gpu=True)
1241
1242  def testWhileWithControl_1(self):
1243    with self.test_session():
1244      n = constant_op.constant(0)
1245      r = constant_op.constant(0)
1246      condition = lambda n_, r_: math_ops.less(n_, 10)
1247
1248      def body(n_, r_):
1249        n_ = math_ops.add(n_, 1)
1250        with r_.graph.control_dependencies([r_]):
1251          r_ = constant_op.constant(12)
1252        return [n_, r_]
1253
1254      res = control_flow_ops.while_loop(
1255          condition, body, [n, r], parallel_iterations=1)
1256      self.assertAllEqual(12, res[1].eval())
1257
1258  def testWhileWithControl_2(self):
1259    with self.test_session():
1260      r = constant_op.constant(0)
1261      condition = lambda r_: math_ops.less(r_, 10)
1262
1263      def body(r_):
1264        with r_.graph.control_dependencies([r_]):
1265          r_ = constant_op.constant(12)
1266        return [r_]
1267
1268      res = control_flow_ops.while_loop(
1269          condition, body, [r], parallel_iterations=1)
1270      self.assertAllEqual(12, res.eval())
1271
1272  def testWhileWithControl_3(self):
1273    with self.test_session() as sess:
1274      b = array_ops.placeholder(dtypes.bool)
1275      c = constant_op.constant(1)
1276      x0 = constant_op.constant(0)
1277      with ops.control_dependencies([b]):
1278        r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0])
1279      self.assertEqual(10, sess.run(r, {b: True}))
1280
1281  def testWhileWithControl_4(self):
1282    with self.test_session() as sess:
1283      b = array_ops.placeholder(dtypes.bool)
1284      c = constant_op.constant(1)
1285      x0 = constant_op.constant(0)
1286      with ops.control_dependencies([b]):
1287        r = control_flow_ops.while_loop(
1288            lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
1289      self.assertEqual(10, sess.run(r, {b: True}))
1290
1291  def testWhileWithControl_5(self):
1292    with self.test_session() as sess:
1293      b = array_ops.placeholder(dtypes.bool)
1294      c = constant_op.constant(1)
1295      x0 = constant_op.constant(0)
1296
1297      def body(x):
1298        with ops.control_dependencies([b]):
1299          return x + c
1300
1301      r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
1302      self.assertEqual(10, sess.run(r, {b: True}))
1303
1304  def testWhileCondWithControl(self):
1305    # Ensure that no control edges by an outer control dependency context are
1306    # added to nodes inside cond/while contexts.
1307    with self.test_session() as sess:
1308      const_true = lambda: constant_op.constant(True)
1309      const_false = lambda: constant_op.constant(False)
1310      cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
1311      body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i)
1312
1313      with ops.control_dependencies([control_flow_ops.no_op()]):
1314        loop = control_flow_ops.while_loop(cond, body,
1315                                           (constant_op.constant(5),))
1316      self.assertEqual(0, sess.run(loop))
1317
1318  def testWhileCondWithControl_1(self):
1319    with self.test_session():
1320      v = variable_scope.get_variable(
1321          "v", [], initializer=init_ops.constant_initializer(2))
1322      i0 = constant_op.constant(0)
1323      with ops.control_dependencies([i0]):
1324
1325        def loop_condition(i):
1326          return i < 4
1327
1328        def loop_body(i):
1329          some_cond = control_flow_ops.cond(
1330              constant_op.constant(True),
1331              lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
1332          with ops.control_dependencies([some_cond]):
1333            return i + 1
1334
1335      r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
1336      variables.global_variables_initializer().run()
1337      self.assertEqual(4, r.eval())
1338      self.assertAllClose(65536.0, v.eval())
1339
1340  def testWhileCondExitControl(self):
1341    with self.test_session():
1342      v = variables.Variable(1)
1343
1344      def false_branch():
1345        cond = lambda i: i < 100
1346
1347        def body(i):
1348          x = state_ops.assign(v, i)
1349          return x + 1
1350
1351        loop = control_flow_ops.while_loop(cond, body, [0])
1352        # Make sure to handle correctly control edge from Exit to a node.
1353        with ops.control_dependencies([loop]):
1354          return constant_op.constant(6.0)
1355
1356      r = control_flow_ops.cond(
1357          constant_op.constant(False), lambda: constant_op.constant(1.0),
1358          false_branch)
1359      variables.global_variables_initializer().run()
1360      self.assertEqual(6.0, r.eval())
1361      self.assertEqual(99, v.eval())
1362
1363  def testCondWhile_1(self):
1364    with self.test_session():
1365      n = ops.convert_to_tensor(0, name="n")
1366      c = lambda x: math_ops.less(x, 10)
1367      b = lambda x: math_ops.add(x, 1)
1368      r = control_flow_ops.cond(
1369          math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]),
1370          lambda: n)
1371      self.assertAllEqual(10, r.eval())
1372
1373  def testCondWhile_2(self):
1374    with self.test_session():
1375      n = ops.convert_to_tensor(0)
1376      c = lambda x: math_ops.less(x, 10)
1377      b = lambda x: math_ops.add(x, 1)
1378      r = control_flow_ops.cond(
1379          math_ops.less(1, 0), lambda: math_ops.add(n, 1),
1380          lambda: control_flow_ops.while_loop(c, b, [n]))
1381      self.assertAllEqual(10, r.eval())
1382
1383  def _testCondWhile_3(self, use_gpu):
1384    with self.test_session(use_gpu=use_gpu) as sess:
1385      p = array_ops.placeholder(dtypes.bool)
1386      n = constant_op.constant(0.0)
1387
1388      def c(x):
1389        return math_ops.less(x, 10.0)
1390
1391      def b(x):
1392        with ops.device("/cpu:0"):
1393          x1 = math_ops.add(x, 1.0)
1394        return x1
1395
1396      r = control_flow_ops.cond(p,
1397                                lambda: control_flow_ops.while_loop(c, b, [n]),
1398                                lambda: math_ops.multiply(n, 2.0))
1399      r1 = gradients_impl.gradients(r, [n])
1400      self.assertEqual(10, sess.run(r, {p: True}))
1401      self.assertEqual([1.0], sess.run(r1, {p: True}))
1402      self.assertEqual(0.0, sess.run(r, {p: False}))
1403      self.assertEqual([2.0], sess.run(r1, {p: False}))
1404
1405  def testCondWhile_3(self):
1406    self._testCondWhile_3(use_gpu=False)
1407    self._testCondWhile_3(use_gpu=True)
1408
1409  def testWhileCond_1(self):
1410    with self.test_session():
1411      i = ops.convert_to_tensor(0, name="i")
1412      n = ops.convert_to_tensor(10, name="n")
1413      one = ops.convert_to_tensor(1, name="one")
1414      c = lambda x: math_ops.less(x, n)
1415      # pylint: disable=undefined-variable
1416      # for OSS build
1417      b = lambda x: control_flow_ops.cond(
1418          constant_op.constant(True),
1419          lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one))
1420      # pylint: enable=undefined-variable
1421      r = control_flow_ops.while_loop(c, b, [i])
1422      self.assertAllEqual(10, r.eval())
1423
1424  def testWhileCond_2(self):
1425    with self.test_session():
1426      n = ops.convert_to_tensor(0, name="n")
1427      c = lambda x: math_ops.less(x, 10)
1428      b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
1429      r = control_flow_ops.while_loop(c, b, [n])
1430      self.assertAllEqual(10, r.eval())
1431
1432  def testWhileCond_3(self):
1433    with self.test_session():
1434      n = ops.convert_to_tensor(0)
1435      c = lambda x: math_ops.less(x, 10)
1436      # pylint: disable=undefined-variable
1437      # for OSS build
1438      b = lambda x: control_flow_ops.cond(math_ops.less(0, 1),
1439                                          lambda: math_ops.add(x, 1),
1440                                          lambda: math_ops.subtract(x, 1))
1441      # pylint: enable=undefined-variable
1442      r = control_flow_ops.while_loop(c, b, [n])
1443      self.assertAllEqual(10, r.eval())
1444
1445  # NOTE: It is ok to have parallel_iterations > 1
1446  def testWhileUpdateVariable_1(self):
1447    with self.test_session():
1448      select = variables.Variable([3.0, 4.0, 5.0])
1449      n = constant_op.constant(0)
1450
1451      def loop_iterator(j):
1452        return math_ops.less(j, 3)
1453
1454      def loop_body(j):
1455        ns = state_ops.scatter_update(select, j, 10.0)
1456        nj = math_ops.add(j, 1)
1457        op = control_flow_ops.group(ns)
1458        nj = control_flow_ops.with_dependencies([op], nj)
1459        return [nj]
1460
1461      r = control_flow_ops.while_loop(
1462          loop_iterator, loop_body, [n], parallel_iterations=1)
1463      variables.global_variables_initializer().run()
1464      self.assertEqual(3, r.eval())
1465      result = select.eval()
1466      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
1467
1468  def testWhileUpdateVariable_2(self):
1469    with self.test_session():
1470      select1 = variables.Variable([3.0, 4.0, 5.0])
1471      select2 = variables.Variable([3.0, 4.0, 5.0])
1472      n = constant_op.constant(0)
1473
1474      def loop_iterator(j):
1475        return math_ops.less(j, 3)
1476
1477      def loop_body(j):
1478        ns1 = state_ops.scatter_update(select1, j, 10.0)
1479        ns2 = state_ops.scatter_update(select2, j, 10.0)
1480        nj = math_ops.add(j, 1)
1481        op = control_flow_ops.group(ns1, ns2)
1482        nj = control_flow_ops.with_dependencies([op], nj)
1483        return [nj]
1484
1485      r = control_flow_ops.while_loop(
1486          loop_iterator, loop_body, [n], parallel_iterations=1)
1487      variables.global_variables_initializer().run()
1488      self.assertEqual(3, r.eval())
1489      result1 = select1.eval()
1490      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
1491      result2 = select2.eval()
1492      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
1493
1494  def testWhileUpdateVariable_3(self):
1495    with self.test_session():
1496      select = variables.Variable([3.0, 4.0, 5.0])
1497      n = constant_op.constant(0)
1498
1499      def loop_iterator(j, _):
1500        return math_ops.less(j, 3)
1501
1502      def loop_body(j, _):
1503        ns = state_ops.scatter_update(select, j, 10.0)
1504        nj = math_ops.add(j, 1)
1505        return [nj, ns]
1506
1507      r = control_flow_ops.while_loop(
1508          loop_iterator,
1509          loop_body, [n, array_ops.identity(select)],
1510          parallel_iterations=1)
1511      variables.global_variables_initializer().run()
1512      result = r[1].eval()
1513    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
1514
1515  # b/24814703
1516  def testWhileUpdateVariable_4(self):
1517    with self.test_session():
1518      var_a = variables.Variable(0, name="a")
1519      var_b = variables.Variable(0, name="b")
1520      variables.global_variables_initializer().run()
1521
1522      c = constant_op.constant(0, name="c")
1523      asn1 = state_ops.assign_add(var_a, 1, name="a_add")
1524
1525      # Loop condition
1526      def pred(i):
1527        return math_ops.less(i, 10)
1528
1529      # Loop body
1530      def loop_body(i):
1531        asn2 = state_ops.assign_add(var_b, asn1, name="b_add")
1532        with ops.control_dependencies([asn2]):
1533          ni = math_ops.add(i, 1, name="i_add")
1534        return ni
1535
1536      lpa = control_flow_ops.while_loop(
1537          pred, loop_body, [c], parallel_iterations=1)
1538
1539      self.assertEqual(0, var_b.eval())
1540      lpa.eval()  # Run the loop
1541      self.assertEqual(10, var_b.eval())
1542
1543  # b/24736492
1544  def testWhileUpdateVariable_5(self):
1545    with self.test_session():
1546      # Create some variables.
1547      var_a = variables.Variable(0, name="a")
1548      var_b = variables.Variable(0, name="b")
1549      variables.global_variables_initializer().run()
1550
1551      # Change condition to check var_b
1552      def pred(_):
1553        return math_ops.less(var_b, 10)
1554
1555      # Change body to increment var_b
1556      def loop_body(i):
1557        asn1 = state_ops.assign_add(
1558            var_a, constant_op.constant(1), name="a_add")
1559        asn2 = state_ops.assign_add(
1560            var_b, constant_op.constant(1), name="b_add")
1561        with ops.control_dependencies([asn1, asn2]):
1562          inc_b = array_ops.identity(var_b)
1563        return inc_b
1564
1565      lpa = control_flow_ops.while_loop(
1566          pred, loop_body, [var_b], parallel_iterations=1, name="loop")
1567
1568      self.assertEqual(0, var_b.eval())
1569      lpa.eval()  # Run the loop
1570      self.assertEqual(10, var_a.eval())
1571      self.assertEqual(10, var_b.eval())
1572
1573  # b/24814668
1574  def testWhileUpdateVariable_6(self):
1575    with self.test_session():
1576      # Create some variables.
1577      var_a = variables.Variable(0, name="a")
1578      var_b = variables.Variable(0, name="b")
1579      c = constant_op.constant(0)
1580      variables.global_variables_initializer().run()
1581
1582      # Loop condition
1583      def pred(i):
1584        return math_ops.less(i, 10)
1585
1586      # Loop body
1587      def loop_body(i):
1588        asn1 = state_ops.assign_add(var_a, 1, name="a_add")
1589        with ops.control_dependencies([asn1]):
1590          asn2 = state_ops.assign_add(var_b, var_a, name="b_add")
1591        with ops.control_dependencies([asn2]):
1592          ni = math_ops.add(i, 1, name="i_add")
1593          return ni
1594
1595      lpa = control_flow_ops.while_loop(
1596          pred, loop_body, [c], parallel_iterations=1, name="loop")
1597
1598      self.assertEqual(0, var_b.eval())
1599      lpa.eval()  # Run the loop
1600      self.assertEqual(55, var_b.eval())
1601      self.assertEqual(10, var_a.eval())
1602
1603  def testWhileQueue_1(self):
1604    with self.test_session():
1605      q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
1606      i = constant_op.constant(0)
1607
1608      def c(i):
1609        return math_ops.less(i, 10)
1610
1611      def b(i):
1612        ni = math_ops.add(i, 1)
1613        ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
1614        return ni
1615
1616      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
1617      self.assertEqual([10], r.eval())
1618      for i in xrange(10):
1619        self.assertEqual([i], q.dequeue().eval())
1620
1621  def testWhileStack_1(self):
1622    with self.test_session():
1623      s = gen_data_flow_ops._stack_v2(-1, dtypes.int32, stack_name="foo")
1624      i = constant_op.constant(0)
1625
1626      def c(i):
1627        return math_ops.less(i, 10)
1628
1629      def b(i):
1630        ni = math_ops.add(i, 1)
1631        ni = control_flow_ops.with_dependencies(
1632            [gen_data_flow_ops._stack_push_v2(s, i)], ni)
1633        return ni
1634
1635      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
1636
1637      x = constant_op.constant(0)
1638
1639      def c1(i, _):
1640        return math_ops.greater(i, 0)
1641
1642      def b1(i, x):
1643        ni = math_ops.subtract(i, 1)
1644        nx = x + gen_data_flow_ops._stack_pop_v2(s, dtypes.int32)
1645        return [ni, nx]
1646
1647      _, rx = control_flow_ops.while_loop(
1648          c1,
1649          b1, [r, x],
1650          [r.get_shape(), tensor_shape.unknown_shape()],
1651          parallel_iterations=1)
1652      self.assertEqual(45, rx.eval())
1653
1654  def _testWhileGrad_ColocateGradients(self, colocate):
1655    gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
1656    ) else "/device:GPU:0"
1657
1658    graph = ops.Graph()
1659    with graph.as_default():
1660      v = constant_op.constant(2.0, name="v")
1661      c = lambda v: math_ops.less(v, 100.0)
1662
1663      def b(x):
1664        with ops.device(gpu_dev_name):
1665          return math_ops.square(x)
1666
1667      loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
1668      r = gradients_impl.gradients(
1669          loop, v, colocate_gradients_with_ops=colocate)[0]
1670
1671    r_ops = graph.get_operations()
1672    r_devices = [(op.name, op.device) for op in r_ops]
1673
1674    self.assertTrue(any("Square" in op.name for op in r_ops))
1675
1676    for (name, dev) in r_devices:
1677      if not colocate and name.endswith("Square"):
1678        # Only forward graph contain gpu in Square device
1679        self.assertTrue(gpu_dev_name in dev)
1680      elif colocate and "Square" in name:
1681        # Forward and backward graphs contain gpu in Square/Square_grad devices
1682        self.assertTrue(gpu_dev_name in dev)
1683      else:
1684        self.assertFalse(gpu_dev_name in dev)
1685
1686    with self.test_session(graph=graph) as sess:
1687      self.assertAllClose(1024.0, sess.run(r))
1688
1689  def testWhileGrad_ColocateGradients(self):
1690    self._testWhileGrad_ColocateGradients(colocate=False)
1691    self._testWhileGrad_ColocateGradients(colocate=True)
1692
1693  def testWhileGrad_Square(self):
1694    with self.test_session():
1695      v = constant_op.constant(2.0, name="v")
1696      c = lambda v: math_ops.less(v, 100.0)
1697      b = math_ops.square
1698      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
1699      r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v)
1700
1701      r = gradients_impl.gradients(r, v)[0]
1702      self.assertAllClose(1024.0, r.eval())
1703
1704  def testWhileGrad_Shape(self):
1705    with self.test_session():
1706      x = array_ops.placeholder(dtypes.float32, shape=[None])
1707      v = constant_op.constant([2.0], name="v")
1708      n = constant_op.constant(0, name="n")
1709      c = lambda i, v: math_ops.less(i, 5)
1710      b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
1711      r = control_flow_ops.while_loop(
1712          c,
1713          b, [n, v],
1714          [n.get_shape(), tensor_shape.unknown_shape()],
1715          parallel_iterations=1)
1716
1717      r = gradients_impl.gradients(r[1], x)[0]
1718      self.assertEqual([None], r.get_shape().as_list())
1719      self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
1720
1721  def testWhileGrad_BaseShape(self):
1722    with self.test_session() as sess:
1723      x = array_ops.placeholder(dtypes.float32, [None])
1724      v0 = constant_op.constant([2.0, 2.0], name="v")
1725      c = lambda v: constant_op.constant(False)
1726      b = lambda v: math_ops.multiply(v, x)
1727      r = control_flow_ops.while_loop(c, b, [v0])
1728      y = math_ops.square(x)
1729
1730      r = gradients_impl.gradients([r, y], x)[0]
1731      self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
1732
1733  def testWhileGrad_MultipleUses(self):
1734    with self.test_session():
1735      v = constant_op.constant(2.0, name="v")
1736      c = lambda v: math_ops.less(v, 100.0)
1737      b = math_ops.square
1738      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
1739      r = math_ops.multiply(r, r)
1740
1741      r = gradients_impl.gradients(r, v)[0]
1742      self.assertEqual(524288.0, r.eval())
1743
1744  def testWhileGrad_LoopAdd(self):
1745    with self.test_session():
1746      v = constant_op.constant(2.0, name="v")
1747      c = lambda v: math_ops.less(v, 100.0)
1748      b = math_ops.square
1749      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
1750      r = math_ops.add(r, r)
1751
1752      r = gradients_impl.gradients(r, v)[0]
1753      self.assertAllClose(2048.0, r.eval())
1754
1755  def _testWhileGrad_Mul(self, use_gpu, p_iters):
1756    with self.test_session(use_gpu=use_gpu) as sess:
1757      a = constant_op.constant(3.0, name="a")
1758      v = constant_op.constant(2.0, name="v")
1759      c = lambda v: math_ops.less(v, 100.0)
1760      b = lambda v: math_ops.multiply(v, a)
1761      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters)
1762
1763      grad_a, grad_v = gradients_impl.gradients(r, [a, v])
1764      grad_a_val, grad_v_val = sess.run([grad_a, grad_v])
1765      self.assertAllClose(216.0, grad_a_val)
1766      self.assertAllClose(81.0, grad_v_val)
1767
1768  def testWhileGrad_Mul(self):
1769    self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
1770    self._testWhileGrad_Mul(use_gpu=False, p_iters=10)
1771    self._testWhileGrad_Mul(use_gpu=True, p_iters=1)
1772    self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
1773
1774  def _testNestedWhileCondWhileGrad(self, use_gpu):
1775    with self.test_session(use_gpu=use_gpu):
1776      v = constant_op.constant(1.0)
1777
1778      def inner_loop(s):
1779        z = constant_op.constant(0)
1780        c = lambda i, x: math_ops.less(i, 4)
1781        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
1782        return control_flow_ops.while_loop(c, b, [z, s])
1783
1784      c = lambda x: math_ops.less(x, 128.0)
1785
1786      def b(x):
1787        return control_flow_ops.cond(
1788            constant_op.constant(True),
1789            lambda: math_ops.square(inner_loop(x)[1]),
1790            lambda: math_ops.multiply(x, 2.0))
1791
1792      r = control_flow_ops.while_loop(c, b, [v])
1793      r = gradients_impl.gradients(r, v)[0]
1794      self.assertAllClose(512.0, r.eval())
1795
1796  def testNestedWhileCondWhileGrad(self):
1797    self._testNestedWhileCondWhileGrad(use_gpu=False)
1798    self._testNestedWhileCondWhileGrad(use_gpu=True)
1799
1800  def testWhileGrad_Variable(self):
1801    with self.test_session():
1802      a = variables.Variable(3.0)
1803      v = constant_op.constant(2.0, name="v")
1804      c = lambda v: math_ops.less(v, 100.0)
1805      b = lambda v: math_ops.multiply(v, a)
1806      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
1807
1808      r = gradients_impl.gradients(r, a)
1809      variables.global_variables_initializer().run()
1810      self.assertAllClose(216.0, r[0].eval())
1811
1812  def testWhileGradInCond(self):
1813    with self.test_session():
1814      n = ops.convert_to_tensor(1.0, name="n")
1815      x = array_ops.placeholder(dtypes.float32, shape=None)
1816      c = lambda n: math_ops.less(n, 10.0)
1817      b = lambda n: math_ops.add(n, x)
1818
1819      def fn1():
1820        r = control_flow_ops.while_loop(c, b, [n],
1821                                        [tensor_shape.unknown_shape()])
1822        return gradients_impl.gradients(r, x)
1823
1824      r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
1825      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
1826
1827  def testWhileGradInWhile(self):
1828    with self.test_session():
1829      n = ops.convert_to_tensor(1.0, name="n")
1830      x = array_ops.placeholder(dtypes.float32, shape=None)
1831      c = lambda n: math_ops.less(n, 10.0)
1832      b = lambda n: math_ops.add(n, x)
1833
1834      def b1(n):
1835        r = control_flow_ops.while_loop(c, b, [n],
1836                                        [tensor_shape.unknown_shape()])
1837        return gradients_impl.gradients(r, x)
1838
1839      r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n],
1840                                      [tensor_shape.unknown_shape()])
1841      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
1842
1843  def testWhile_NestedInput(self):
1844    with self.test_session() as sess:
1845      named = collections.namedtuple("named", ("a", "b"))
1846      loop_vars = [
1847          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
1848          (constant_op.constant(2.0), constant_op.constant(3.0)),
1849          constant_op.constant(4.0)
1850      ]
1851      c = lambda lv0, _1, _2: lv0.a < 100.0
1852
1853      def b(lv0, lv1, lv2):
1854        lv0 = named(a=lv0.a + 1, b=lv0.b)
1855        lv1 = (lv1[0] + 1, lv1[1])
1856        lv2 += 2
1857        return [lv0, lv1, lv2]
1858
1859      r = control_flow_ops.while_loop(c, b, loop_vars)
1860
1861      self.assertTrue(isinstance(r, list))
1862      self.assertTrue(isinstance(r[0], named))
1863      self.assertTrue(isinstance(r[1], tuple))
1864      self.assertTrue(isinstance(r[2], ops.Tensor))
1865
1866      r_flattened = nest.flatten(r)
1867      self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
1868                       sess.run(r_flattened))
1869
1870  def testWhile_NestedBadArityFails(self):
1871    with self.test_session():
1872      named = collections.namedtuple("named", ("a", "b"))
1873      loop_vars = [
1874          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
1875          (constant_op.constant(2.0), constant_op.constant(3.0)),
1876          constant_op.constant(4.0)
1877      ]
1878      c = lambda lv0, _1, _2: lv0.a < 100.0
1879
1880      def b(lv0, lv1, _):
1881        return [lv0, lv1]
1882
1883      with self.assertRaisesRegexp(ValueError, "the same number of elements"):
1884        control_flow_ops.while_loop(c, b, loop_vars)
1885
1886  def testWhileGrad_ys_xs(self):
1887    with self.test_session():
1888      x = constant_op.constant(3.0, name="x")
1889      y = constant_op.constant(2.0, name="y")
1890
1891      c = lambda x, y: math_ops.less(x, 100.0)
1892
1893      def b(x, y):
1894        y1 = math_ops.add(x, y)
1895        x1 = math_ops.multiply(x, y1)
1896        return x1, y1
1897
1898      rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1)
1899
1900      r = gradients_impl.gradients([rx, ry], x)
1901      self.assertAllClose(304.0, r[0].eval())
1902      r = gradients_impl.gradients([rx, ry], y)
1903      self.assertAllClose(124.0, r[0].eval())
1904      r = gradients_impl.gradients([rx], x)
1905      self.assertAllClose(295.0, r[0].eval())
1906      r = gradients_impl.gradients([rx], y)
1907      self.assertAllClose(120.0, r[0].eval())
1908
1909  def testWhileGrad_Dependency(self):
1910    with self.test_session():
1911      i = constant_op.constant(0, name="i")
1912      x = constant_op.constant(2.0, name="x")
1913
1914      c = lambda i, x: math_ops.less(i, 10)
1915
1916      def b(i, x):
1917        x = math_ops.multiply(x, 2.0)
1918        i = math_ops.add(i, 1)
1919        return i, x
1920
1921      ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
1922
1923      r = gradients_impl.gradients([ri, rx], x)
1924      self.assertAllClose(1024.0, r[0].eval())
1925      r = gradients_impl.gradients([rx], x)
1926      self.assertAllClose(1024.0, r[0].eval())
1927
1928  def testWhileGrad_NoGradient(self):
1929    with self.test_session():
1930      v = constant_op.constant(2.0, name="v")
1931      c = lambda v: math_ops.less(v, 100.0)
1932      b = math_ops.square
1933      r = control_flow_ops.while_loop(c, b, [v], back_prop=False)
1934      r = math_ops.add(r, v)
1935      r = gradients_impl.gradients(r, v)
1936      self.assertAllClose(1.0, r[0].eval())
1937
1938  def testWhileGrad_NoDependency(self):
1939    with self.test_session() as sess:
1940      variable = variables.Variable(array_ops.ones([2, 3]))
1941      duration = array_ops.zeros([], dtype=dtypes.int32)
1942
1943      def cond(duration, tensor, _):
1944        del tensor
1945        return duration < 10
1946
1947      def body(duration, tensor, _):
1948        return (duration + 1, tensor, tensor)
1949
1950      loop_vars = [duration, variable, variable]
1951      tensors = control_flow_ops.while_loop(
1952          cond=cond, body=body, loop_vars=loop_vars)
1953      cost = math_ops.reduce_sum(tensors[2])
1954      grad = gradients_impl.gradients(cost, [variable])
1955      variables.global_variables_initializer().run()
1956      self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
1957
1958  def testWhileGrad_Const(self):
1959    with self.test_session() as sess:
1960      c0 = constant_op.constant(0.0, name="c0")
1961      c1 = constant_op.constant(1.0, name="c1")
1962      duration = constant_op.constant(0, name="t")
1963
1964      def cond(duration, _):
1965        return duration < 1
1966
1967      def body(duration, _):
1968        return duration + 1, c1
1969
1970      loop_vars = [duration, c0]
1971      tensors = control_flow_ops.while_loop(
1972          cond=cond, body=body, loop_vars=loop_vars)
1973      cost = math_ops.reduce_sum(tensors[1])
1974      grad = gradients_impl.gradients(cost, [c0])
1975      self.assertAllClose(0.0, sess.run(grad[0]))
1976
1977  def testWhileGrad_SerialTwoLoops(self):
1978    with self.test_session():
1979      i = constant_op.constant(0, name="i")
1980      x = constant_op.constant(2.0, name="x")
1981
1982      c = lambda i, x: math_ops.less(i, 5)
1983
1984      def b(i, x):
1985        x = math_ops.multiply(x, 2.0)
1986        i = math_ops.add(i, 1)
1987        return i, x
1988
1989      _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
1990      _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1)
1991
1992      r = gradients_impl.gradients([rx], x)
1993      self.assertAllClose(1024.0, r[0].eval())
1994
1995  def testWhileGrad_ParallelTwoLoops(self):
1996    with self.test_session():
1997      i = constant_op.constant(0, name="i")
1998      x = constant_op.constant(2.0, name="x")
1999
2000      c = lambda i, x: math_ops.less(i, 5)
2001
2002      def b(i, x):
2003        x = math_ops.multiply(x, 2.0)
2004        i = math_ops.add(i, 1)
2005        return i, x
2006
2007      _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
2008      _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
2009      rx = math_ops.add(r1, r2)
2010
2011      r = gradients_impl.gradients([rx], x)
2012      self.assertAllClose(64.0, r[0].eval())
2013
2014  def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
2015    with self.test_session():
2016      i = constant_op.constant(0, name="i")
2017      x = constant_op.constant(1.0, name="x")
2018      y = constant_op.constant(1.0, name="y")
2019      c = lambda i, *_: math_ops.less(i, 1, name="cond_less")
2020
2021      def b(i, xi, yi):
2022        # return (i + 1, xi, xi + yi)
2023        return (math_ops.add(i, 1, name="inc"), array_ops.identity(
2024            xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi"))
2025
2026      _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y])
2027      with ops.control_dependencies([x_f]):
2028        y_f_d = array_ops.identity(y_f, name="y_f_d")
2029
2030      self.assertAllClose(2.0, y_f_d.eval())  # y_f_d = 1.0 + 1.0
2031      g = gradients_impl.gradients([y_f_d], [x])[0]
2032      self.assertTrue(g is not None)
2033      self.assertAllClose(1.0, g.eval())  # y_f_d = x + 1.0, dy_f_d/dx = 1.0
2034
2035  def _testNestedWhileGrad_Simple(self, use_gpu):
2036    with self.test_session(use_gpu=use_gpu):
2037      v = constant_op.constant(1.0)
2038
2039      def inner_loop(s):
2040        c = lambda x: math_ops.less(x, 4.0)
2041        b = lambda x: math_ops.multiply(x, 2.0)
2042        return control_flow_ops.while_loop(c, b, [s])
2043
2044      c = lambda x: math_ops.less(x, 2.0)
2045      b = lambda x: math_ops.multiply(inner_loop(x), 2.0)
2046      r = control_flow_ops.while_loop(c, b, [v])
2047
2048      r = gradients_impl.gradients(r, v)[0]
2049      self.assertAllClose(8.0, r.eval())
2050
2051  def testNestedWhileGrad_Simple(self):
2052    self._testNestedWhileGrad_Simple(use_gpu=False)
2053    self._testNestedWhileGrad_Simple(use_gpu=True)
2054
2055  def testNestedWhileGrad_SerialInner(self):
2056    with self.test_session():
2057      v = constant_op.constant(1.0)
2058
2059      def inner_loop1(s):
2060        z = constant_op.constant(0)
2061        c = lambda i, x: math_ops.less(i, 4)
2062        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
2063        return control_flow_ops.while_loop(c, b, [z, s])
2064
2065      def inner_loop2(s):
2066        z = constant_op.constant(0)
2067        c = lambda i, x: math_ops.less(i, 4)
2068        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
2069        return control_flow_ops.while_loop(c, b, [z, s])
2070
2071      c = lambda x: math_ops.less(x, 128.0)
2072      b = lambda x: inner_loop2(inner_loop1(x)[1])[1]
2073      r = control_flow_ops.while_loop(c, b, [v])
2074
2075      r = gradients_impl.gradients(r, v)[0]
2076      self.assertAllClose(256.0, r.eval())
2077
2078  def testNestedWhileGrad_ParallelInner(self):
2079    with self.test_session():
2080      v = constant_op.constant(1.0)
2081
2082      def inner_loop1(s):
2083        z = constant_op.constant(0)
2084        c = lambda i, x: math_ops.less(i, 4)
2085        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
2086        return control_flow_ops.while_loop(c, b, [z, s])
2087
2088      def inner_loop2(s):
2089        z = constant_op.constant(0)
2090        c = lambda i, x: math_ops.less(i, 4)
2091        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
2092        return control_flow_ops.while_loop(c, b, [z, s])
2093
2094      c = lambda x: math_ops.less(x, 128.0)
2095      b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1])
2096      r = control_flow_ops.while_loop(c, b, [v])
2097
2098      r = gradients_impl.gradients(r, v)[0]
2099      self.assertAllClose(512.0, r.eval())
2100
2101  def testNestedWhileGrad_ParallelIterations(self):
2102    # Make sure the stack pushes and pops of an inner loop are executed in
2103    # the sequential order of the iterations of its outer loop.
2104    with self.test_session() as sess:
2105
2106      def inner_loop(t):
2107        fn = lambda n: n + math_ops.square(var)
2108        return functional_ops.map_fn(fn=fn, elems=t, parallel_iterations=10)
2109
2110      def outer_loop(inp):
2111        return functional_ops.map_fn(
2112            fn=inner_loop, elems=inp, parallel_iterations=10)
2113
2114      var = variables.Variable(constant_op.constant(3.0))
2115      inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
2116      res = outer_loop(inp)
2117      optimizer = adam.AdamOptimizer(learning_rate=0.001)
2118      train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res)))
2119      sess.run(variables.global_variables_initializer())
2120      sess.run(train_op)
2121      self.assertAllClose(2.999, var.eval())
2122
2123  def _testWhileCondGrad_Simple(self, use_gpu):
2124    with self.test_session(use_gpu=use_gpu):
2125      v = ops.convert_to_tensor(2.0, name="v")
2126      n = ops.convert_to_tensor(100.0, name="n")
2127      one = ops.convert_to_tensor(1.0, name="one")
2128      c = lambda x: math_ops.less(x, n)
2129      # pylint: disable=undefined-variable
2130      # for OSS build
2131      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
2132                                          lambda: math_ops.square(x),
2133                                          lambda: math_ops.subtract(x, one))
2134      # pylint: enable=undefined-variable
2135      r = control_flow_ops.while_loop(c, b, [v])
2136      r = gradients_impl.gradients(r, v)[0]
2137      self.assertAllClose(1024.0, r.eval())
2138
2139  def testWhileCondGrad_Simple(self):
2140    self._testWhileCondGrad_Simple(use_gpu=False)
2141    self._testWhileCondGrad_Simple(use_gpu=True)
2142
2143  def testWhileCondGrad_UnknownShape(self):
2144    with self.test_session() as sess:
2145      v = array_ops.placeholder(dtypes.float32)
2146      n = ops.convert_to_tensor(100.0, name="n")
2147      one = ops.convert_to_tensor(1.0, name="one")
2148      c = lambda x: math_ops.less(x, n)
2149      # pylint: disable=undefined-variable
2150      # for OSS build
2151      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
2152                                          lambda: math_ops.square(x),
2153                                          lambda: math_ops.subtract(x, one))
2154      # pylint: enable=undefined-variable
2155      r = control_flow_ops.while_loop(c, b, [v])
2156      r = gradients_impl.gradients(r, v)[0]
2157      r = sess.run(r, feed_dict={v: 2.0})
2158      self.assertAllClose(1024.0, r)
2159
2160  def testWhileGrad_Concat(self):
2161    with self.test_session() as sess:
2162      x = variable_scope.get_variable("x", initializer=[[1., 2.]])
2163      i0 = constant_op.constant(0)
2164      h0 = array_ops.zeros([0, 2])
2165
2166      def condition(i, _):
2167        return i < 2
2168
2169      def body(i, h):
2170        return i + 1, array_ops.concat([h, x], 0)
2171
2172      _, h = control_flow_ops.while_loop(
2173          condition, body, [i0, h0],
2174          [i0.get_shape(), tensor_shape.TensorShape([None, 2])])
2175      s = math_ops.reduce_sum(h)
2176
2177      sess.run(variables.global_variables_initializer())
2178      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
2179      op = optimizer.minimize(s)
2180      sess.run(op)
2181      self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
2182
2183  def testWhileWithRefsWithGradients_1(self):
2184    with self.test_session() as sess:
2185      x = variables.Variable(0)._ref()  # pylint: disable=protected-access
2186      i = constant_op.constant(0)
2187      c = lambda i, x: math_ops.less(i, 10)
2188
2189      self.assertEqual(x.dtype, dtypes.int32_ref)
2190
2191      # pylint: disable=protected-access
2192      def body(i, x):
2193        self.assertEqual(x.dtype, dtypes.int32_ref)
2194        return [i + 1, gen_array_ops._ref_identity(x)]
2195
2196      # pylint: enable=protected-access
2197
2198      r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
2199
2200      grad_ys = [variables.Variable(73)._ref()]  # pylint: disable=protected-access
2201      grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
2202
2203      variables.global_variables_initializer().run()
2204
2205      self.assertEqual(r[0].dtype, dtypes.int32)
2206      self.assertEqual(r[1].dtype, dtypes.int32_ref)
2207
2208      value_i, value_x, value_x_grad = sess.run(r + grad)
2209
2210    self.assertEqual(10, value_i)
2211    self.assertEqual(0, value_x)
2212    self.assertEqual(73, value_x_grad)
2213
2214  def testWhileGrad_IndexedSlices(self):
2215    with self.test_session():
2216      values = constant_op.constant([2.0, 4.0], name="values")
2217      indices = constant_op.constant([0, 3], name="indices")
2218      shape = constant_op.constant([10], name="dense_shape")
2219      i = constant_op.constant(0)
2220      x = ops.IndexedSlices(values, indices, dense_shape=shape)
2221
2222      def c(i, _):
2223        return i < 10
2224
2225      def b(i, x):
2226        return [
2227            i + 1,
2228            ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
2229        ]
2230
2231      _, r = control_flow_ops.while_loop(c, b, [i, x])
2232      r = gradients_impl.gradients(r.values, values)[0]
2233      self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
2234
2235  def testWhileGrad_SparseTensor(self):
2236    with self.test_session():
2237      values = constant_op.constant([2.0, 4.0], name="values")
2238      indices = constant_op.constant(
2239          [[0], [3]], dtype=dtypes.int64, name="indices")
2240      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2241      i = constant_op.constant(0)
2242      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2243
2244      def c(i, _):
2245        return i < 10
2246
2247      def b(i, x):
2248        return [
2249            i + 1,
2250            sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
2251        ]
2252
2253      _, r = control_flow_ops.while_loop(c, b, [i, x])
2254      r = gradients_impl.gradients(r.values, values)[0]
2255      self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
2256
2257  def testCallGradInLoop(self):
2258    with self.test_session() as sess:
2259      i0 = constant_op.constant(0)
2260      params = constant_op.constant(5.0)
2261      params_1 = math_ops.square(params)
2262
2263      def c(i, _):
2264        return i < 10
2265
2266      def b(i, x):
2267        data = constant_op.constant([1.0, 2.0, 3.0])
2268        data = math_ops.multiply(data, params_1)
2269        x1 = x + gradients_impl.gradients(data, params)[0]
2270        return i + 1, x1
2271
2272      output_grad = control_flow_ops.while_loop(
2273          c, b, [i0, constant_op.constant(0.0)])
2274      self.assertAllClose(600.0, sess.run(output_grad)[1])
2275
2276  def testWhileAndTensorArray(self):
2277    with self.test_session() as sess:
2278      param = constant_op.constant(2.0)
2279      n0 = constant_op.constant(0)
2280      y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
2281
2282      def c(i, _):
2283        return i < 10
2284
2285      def b(i, y):
2286        return [
2287            i + 1,
2288            functional_ops.map_fn(lambda x: math_ops.multiply(x, param), y)
2289        ]
2290
2291      r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1)
2292      r = gradients_impl.gradients(r, param)[0]
2293      self.assertAllClose(107520.0, sess.run(r))
2294
2295  def testWhileGrad_StopGrad(self):
2296    with self.test_session():
2297      x = constant_op.constant(3.0, name="x")
2298      y = constant_op.constant(2.0, name="y")
2299
2300      c = lambda x, y: math_ops.less(x, 100.0)
2301
2302      def b(x, y):
2303        y1 = math_ops.square(y)
2304        x1 = math_ops.add(math_ops.square(x), y1)
2305        return x1, y1
2306
2307      rx, ry = control_flow_ops.while_loop(c, b, [x, y])
2308
2309      r = gradients_impl.gradients(rx, y)[0]
2310      self.assertEqual(136.0, r.eval())
2311      r = gradients_impl.gradients(ry, y)[0]
2312      self.assertEqual(32.0, r.eval())
2313
2314      r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0]
2315      self.assertEqual(r, None)
2316      r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0]
2317      self.assertEqual(r, None)
2318
2319      r = gradients_impl.gradients(
2320          array_ops.stop_gradient(math_ops.square(rx)), y)[0]
2321      self.assertEqual(r, None)
2322      r = gradients_impl.gradients(
2323          array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0]
2324      self.assertEqual(r, None)
2325      r = gradients_impl.gradients(
2326          array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0]
2327      self.assertEqual(r, None)
2328
2329      r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0]
2330      self.assertEqual(168.0, r.eval())
2331      r = gradients_impl.gradients(
2332          math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0]
2333      self.assertEqual(136.0, r.eval())
2334      r = gradients_impl.gradients(
2335          math_ops.add(array_ops.stop_gradient(rx), ry), y)[0]
2336      self.assertEqual(32.0, r.eval())
2337
2338  def testWhileGrad_StopGradInside(self):
2339    with self.test_session():
2340      x = constant_op.constant(3.0, name="x")
2341      y = constant_op.constant(2.0, name="y")
2342
2343      c = lambda x, y: math_ops.less(x, 100.0)
2344
2345      def b(x, y):
2346        y1 = array_ops.stop_gradient(math_ops.square(y))
2347        x1 = math_ops.add(math_ops.square(x), y1)
2348        return x1, y1
2349
2350      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
2351
2352      r = gradients_impl.gradients(rx, y)[0]
2353      self.assertAllClose(0.0, r.eval())
2354      r = gradients_impl.gradients(rx, x)[0]
2355      self.assertAllClose(156.0, r.eval())
2356
2357  def testWhileGrad_StopGradInsideNoShape(self):
2358    with self.test_session() as sess:
2359      x = array_ops.placeholder(dtypes.float32)
2360      y = array_ops.placeholder(dtypes.float32)
2361
2362      c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0)
2363
2364      def b(x, y):
2365        y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped"))
2366        x1 = math_ops.add(math_ops.square(x), y1)
2367        return x1, y1
2368
2369      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
2370
2371      r = gradients_impl.gradients(rx, y)[0]
2372      feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]}
2373      self.assertAllClose([0.0, 0.0], sess.run(r, feed_dict=feed_dict))
2374      r = gradients_impl.gradients(rx, x)[0]
2375      self.assertAllClose([156.0, 400.0], sess.run(r, feed_dict=feed_dict))
2376      name = "gradients/while/stopped_grad"
2377      all_ops = x.graph.get_operations()
2378      self.assertFalse(any([name in op.name for op in all_ops]))
2379
2380  def testWhileGradGradFail(self):
2381    theta = variables.Variable(initial_value=1.)
2382
2383    def fn(prev, x):
2384      return prev + x * theta
2385
2386    result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
2387    grad_theta = gradients_impl.gradients(result, theta)
2388    with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
2389      gradients_impl.gradients(grad_theta, theta)
2390    grad_theta_stopped = array_ops.stop_gradient(grad_theta)
2391    gradients_impl.gradients(grad_theta_stopped, theta)
2392
2393  def testStopGradOnWhileGrad(self):
2394    with self.test_session():
2395      x = constant_op.constant(2.0, name="x")
2396      y = constant_op.constant(2.0, name="y")
2397
2398      c = lambda x: math_ops.less(x, 100.0)
2399      b = lambda x: math_ops.multiply(x, y)
2400      rx = control_flow_ops.while_loop(c, b, [x])
2401
2402      rg = gradients_impl.gradients(rx, y)[0]
2403      rg = array_ops.stop_gradient(rg)
2404      r = math_ops.add(math_ops.square(y), rx)
2405      r = math_ops.add(r, rg)
2406      r = gradients_impl.gradients(r, y)[0]
2407      self.assertEqual(388.0, r.eval())
2408
2409  def testStopGradMultiFlows(self):
2410    with self.test_session():
2411
2412      def body(i, y, r):
2413        x = variable_scope.get_variable(
2414            "x",
2415            shape=(),
2416            dtype=dtypes.float32,
2417            initializer=init_ops.ones_initializer())
2418        y *= x
2419        return [i + 1, y, r + math_ops.reduce_sum(y)]
2420
2421      i0 = constant_op.constant(0)
2422      y0 = array_ops.ones(5)
2423      r0 = constant_op.constant(0.0)
2424      cond = lambda i, y, r: i < 1
2425      _, _, r = control_flow_ops.while_loop(
2426          cond, body, [i0, y0, r0], back_prop=True)
2427
2428      vars_ = variables.global_variables()
2429      grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0])
2430      z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads)))
2431      result = gradients_impl.gradients(z, vars_)[0]
2432      variables.global_variables_initializer().run()
2433      self.assertEqual(5.0, result.eval())
2434
2435  def testOneValueCond(self):
2436    with self.test_session():
2437      c = array_ops.placeholder(dtypes.int32, shape=[])
2438      one = ops.convert_to_tensor(1, name="one")
2439      two = ops.convert_to_tensor(2, name="two")
2440      p = math_ops.greater_equal(c, 1)
2441      i = control_flow_ops.cond(p, lambda: one, lambda: two)
2442      self.assertTrue(isinstance(i, ops.Tensor))
2443
2444      # True case: c = 2 is >= 1
2445      self.assertEqual([1], i.eval(feed_dict={c: 2}))
2446
2447      # False case: c = 0 is not >= 1
2448      self.assertEqual([2], i.eval(feed_dict={c: 0}))
2449
2450  def testExampleCond(self):
2451    with self.test_session():
2452      x = ops.convert_to_tensor([-2.0, 2.0], name="x")
2453      d = array_ops.placeholder(dtypes.int32, shape=[])
2454
2455      def l2():
2456        return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x)))
2457
2458      def l1():
2459        return math_ops.reduce_sum(math_ops.abs(x))
2460
2461      i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1)
2462      self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
2463      self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
2464
2465  def testCase(self):
2466    with self.test_session():
2467      x = constant_op.constant(1)
2468      y = constant_op.constant(2)
2469      z = constant_op.constant(3)
2470      f1 = lambda: constant_op.constant(17)
2471      f2 = lambda: constant_op.constant(23)
2472      f3 = lambda: constant_op.constant(-1)
2473
2474      r1 = control_flow_ops.case(
2475          {
2476              x < y: f1,
2477              x > z: f2
2478          }, default=f3, exclusive=True)
2479      self.assertAllEqual(r1.eval(), 17)
2480
2481      r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3)
2482      self.assertAllEqual(r2.eval(), 23)
2483
2484      # Duplicate events can happen, first one is selected
2485      r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3)
2486      self.assertAllEqual(r3.eval(), 17)
2487
2488      # Duplicate events cause an error if exclusive = True
2489      r4 = control_flow_ops.case(
2490          [(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
2491      with self.assertRaisesOpError("Input error:"):
2492        r4.eval()
2493
2494      # Check that the default is called if none of the others are
2495      r5 = control_flow_ops.case({x > y: f1}, default=f3)
2496      self.assertAllEqual(r5.eval(), -1)
2497
2498      ran_once = [False, False, False]
2499
2500      def break_run_twice(ix):
2501
2502        def _break():
2503          ran_once[ix] = True
2504          return constant_op.constant(ix)
2505
2506        return _break
2507
2508      # Should not fail - each conditional gets called exactly once
2509      # except default.  Default gets called twice: once to create an
2510      # empty output and once for the actual cond switch.
2511      r6 = control_flow_ops.case(
2512          [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))],
2513          default=lambda: constant_op.constant(2))
2514
2515      self.assertAllEqual(r6.eval(), 0)
2516
2517  def testCaseSideEffects(self):
2518    with self.test_session() as sess:
2519      v0 = variables.Variable(-1)
2520      v1 = variables.Variable(-1)
2521      v2 = variables.Variable(-1)
2522
2523      a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0)
2524      b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1)
2525      c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2)
2526
2527      x = constant_op.constant(1)
2528      y = constant_op.constant(2)
2529
2530      r0 = control_flow_ops.case(
2531          ((x < y, a), (x > y, b)), default=c, exclusive=True)
2532      r1 = control_flow_ops.case(
2533          ((x > y, a), (x < y, b)), default=c, exclusive=True)
2534      r2 = control_flow_ops.case(
2535          ((x > y, a), (x > y, b)), default=c, exclusive=True)
2536
2537      variables.global_variables_initializer().run()
2538      self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3)
2539      self.assertEqual(2, r2.eval())
2540      self.assertAllEqual(sess.run([v0, v1, v2]), [-1, -1, 2])
2541
2542      variables.global_variables_initializer().run()
2543      self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3)
2544      self.assertEqual(1, r1.eval())
2545      self.assertAllEqual(sess.run([v0, v1, v2]), [-1, 1, -1])
2546
2547      variables.global_variables_initializer().run()
2548      self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3)
2549      self.assertEqual(0, r0.eval())
2550      self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
2551
2552  def testOneOpCond(self):
2553    with self.test_session():
2554      v = variables.Variable(0)
2555      c = ops.convert_to_tensor(0)
2556      one = ops.convert_to_tensor(1)
2557      two = ops.convert_to_tensor(2)
2558      p = math_ops.greater_equal(c, 1)
2559
2560      def a():
2561        return state_ops.assign(v, one)
2562
2563      def b():
2564        return state_ops.assign(v, two)
2565
2566      i = control_flow_ops.cond(p, a, b)
2567      self.assertTrue(isinstance(i, ops.Tensor))
2568      variables.global_variables_initializer().run()
2569
2570      self.assertEqual(0, v.eval())
2571
2572      # True case: c = 2 is >= 1, v is set to 1.
2573      self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
2574      self.assertEqual(1, v.eval())
2575
2576      # False case: c = 0 is not >= 1, v is set to 2.
2577      self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
2578      self.assertEqual(2, v.eval())
2579
2580  def testWithOpsDependencies(self):
2581    with self.test_session() as sess:
2582      v = variables.Variable(0.0)
2583      c = constant_op.constant(10)
2584
2585      # Fetching v directly will result in an uninitialized error
2586      with self.assertRaisesOpError("Attempting to use uninitialized value"):
2587        sess.run([c, v])
2588
2589      # Use a control dependency to ensure init_variable is run
2590      # while asking for c
2591      real_v = control_flow_ops.with_dependencies(
2592          name="real_tensor",
2593          output_tensor=v._ref(),  # pylint: disable=protected-access
2594          dependencies=[v.initializer])
2595      c_val, real_v_val = sess.run([c, real_v])
2596
2597    # Ensure the result of 'real_c' is the same as 'c'
2598    self.assertAllEqual(10, c_val)
2599
2600    # Ensure that 'v' is initialized
2601    self.assertAllClose(0.0, real_v_val)
2602
2603  def testWithTensorDependencies(self):
2604    with self.test_session():
2605      v = variables.Variable(0.0)
2606      c1 = constant_op.constant(10)
2607      c2 = constant_op.constant(20)
2608
2609      # c1_with_init_v depends on the init op for v
2610      c1_with_init_v = control_flow_ops.with_dependencies(
2611          name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer])
2612      # c2_with_c1 depends on the value of c1_with_init_v
2613      c2_with_c1_dep = control_flow_ops.with_dependencies(
2614          name="c2_with_c1_dep",
2615          output_tensor=c2,
2616          dependencies=[c1_with_init_v])
2617
2618      # Fetching v directly will result in an uninitialized error
2619      with self.assertRaisesOpError("Attempting to use uninitialized value"):
2620        v.eval()
2621
2622      # Get the value of 'c2_with_c1_dep', which should cause 'v'
2623      # to be initialized.
2624      self.assertAllEqual(20, c2_with_c1_dep.eval())
2625
2626      # Ensure that 'v' is initialized
2627      self.assertAllClose(0.0, v.eval())
2628
2629  def testWithIndexedSlicesDependencies(self):
2630    with self.test_session():
2631      v = variables.Variable(
2632          np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
2633      v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
2634      gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
2635      v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
2636                                                             v_at_1)
2637      gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values,
2638                                                  v_at_1_after_init.indices)
2639
2640      # Fetching gather_v_at_1 will result in an uninitialized error
2641      with self.assertRaisesOpError("Attempting to use uninitialized value"):
2642        gather_v_at_1.eval()
2643
2644      # Getting gather_v_at_1_after_init will work, and initialize v.
2645      self.assertAllEqual([[10.0, 11.0]], gather_v_at_1_after_init.eval())
2646
2647      # Double check that 'v' is initialized
2648      self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], v.eval())
2649
2650  def testDependenciesDevice(self):
2651    with ops.Graph().as_default():
2652      # device set on tensor => same device on dep.
2653      with ops.device("/job:ps"):
2654        vd = variables.Variable([0.0])
2655      with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
2656      self.assertTrue("/job:ps" in with_vd_dep.device)
2657
2658      # No device set on tensor => no device on dep.
2659      vnod = variables.Variable([0.0])
2660      with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
2661                                                         vnod)
2662      self.assertDeviceEqual(None, with_vnod_dep.device)
2663
2664      # device set on tensor, default device on graph => default device on dep.
2665      vdef = variables.Variable([0.0], name="vdef")
2666      with ops.device("/job:worker/device:GPU:1"):
2667        with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
2668                                                           vdef)
2669        # The device is empty, but the colocation constraint is set.
2670        self.assertDeviceEqual("", with_vdef_dep.device)
2671        self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
2672
2673  def testGroup(self):
2674    with self.test_session() as sess:
2675      v1 = variables.Variable([0.0])
2676      v2 = variables.Variable([1.0])
2677
2678      # Group init1 and init2 and run.
2679      init = control_flow_ops.group(v1.initializer, v2.initializer)
2680      # Fetching v1 directly will result in an uninitialized error
2681      with self.assertRaisesOpError("Attempting to use uninitialized value"):
2682        v1.eval()
2683
2684      # Runs "init" before fetching v1 and v2.
2685      init.run()
2686      v1_val, v2_val = sess.run([v1, v2])
2687
2688    # Ensure that v1 and v2 are initialized
2689    self.assertAllClose([0.0], v1_val)
2690    self.assertAllClose([1.0], v2_val)
2691
2692  def testGroupEmpty(self):
2693    op = control_flow_ops.group()
2694    self.assertEqual(op.type, "NoOp")
2695    self.assertEqual(op.control_inputs, [])
2696
2697  def testMergeShapes(self):
2698    # All inputs unknown.
2699    p1 = array_ops.placeholder(dtypes.float32)
2700    p2 = array_ops.placeholder(dtypes.float32)
2701    p3 = array_ops.placeholder(dtypes.float32)
2702    m, index = control_flow_ops.merge([p1, p2, p3])
2703    self.assertIs(None, m.get_shape().ndims)
2704    self.assertEqual([], index.get_shape())
2705
2706    # All inputs known with different ranks.
2707    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
2708    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3])
2709    m, index = control_flow_ops.merge([p1, p2])
2710    self.assertIs(None, m.get_shape().ndims)
2711    self.assertEqual([], index.get_shape())
2712
2713    # All inputs known with some dimensions different.
2714    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
2715    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1])
2716    m, index = control_flow_ops.merge([p1, p2])
2717    self.assertEqual([None, None], m.get_shape().as_list())
2718    self.assertEqual([], index.get_shape())
2719
2720    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
2721    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
2722    m, index = control_flow_ops.merge([p1, p2])
2723    self.assertEqual([None, 2], m.get_shape().as_list())
2724    self.assertEqual([], index.get_shape())
2725
2726    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
2727    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2])
2728    m, index = control_flow_ops.merge([p1, p2])
2729    self.assertEqual([None, 2], m.get_shape().as_list())
2730    self.assertEqual([], index.get_shape())
2731
2732    # All inputs known with same dimensions.
2733    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
2734    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
2735    m, index = control_flow_ops.merge([p1, p2])
2736    self.assertEqual([1, 2], m.get_shape().as_list())
2737    self.assertEqual([], index.get_shape())
2738
2739    p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
2740    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
2741    m, index = control_flow_ops.merge([p1, p2])
2742    self.assertEqual([None, 2], m.get_shape().as_list())
2743    self.assertEqual([], index.get_shape())
2744
2745    p1 = array_ops.placeholder(dtypes.float32, shape=[None, None])
2746    p2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
2747    m, index = control_flow_ops.merge([p1, p2])
2748    self.assertEqual([None, None], m.get_shape().as_list())
2749    self.assertEqual([], index.get_shape())
2750
2751  def testRefSelect(self):
2752    index = array_ops.placeholder(dtypes.int32)
2753
2754    # All inputs unknown.
2755    p1 = array_ops.placeholder(dtypes.float32)
2756    p2 = array_ops.placeholder(dtypes.float32)
2757    p3 = array_ops.placeholder(dtypes.float32)
2758    v1 = variables.Variable(p1, validate_shape=False)
2759    v2 = variables.Variable(p2, validate_shape=False)
2760    v3 = variables.Variable(p3, validate_shape=False)
2761    self.assertIs(None, v1.get_shape().ndims)
2762    s = control_flow_ops.ref_select(index, [v1, v2, v3])
2763    self.assertIs(None, s.get_shape().ndims)
2764
2765    # All inputs known but different.
2766    v1 = variables.Variable([[1, 2]])
2767    v2 = variables.Variable([[2], [1]])
2768    s = control_flow_ops.ref_select(index, [v1, v2])
2769    self.assertIs(None, s.get_shape().ndims)
2770
2771    # All inputs known and same.
2772    v1 = variables.Variable([[1, 2]])
2773    v2 = variables.Variable([[1, 2]])
2774    s = control_flow_ops.ref_select(index, [v1, v2])
2775    self.assertEqual([1, 2], s.get_shape())
2776
2777    # Possibly the same but not guaranteed.
2778    v1 = variables.Variable([[1., 2.]])
2779    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
2780    v2 = variables.Variable(p2, validate_shape=False)
2781    s = control_flow_ops.ref_select(index, [v1, v2])
2782    self.assertEqual(None, s.get_shape())
2783
2784  def testRunLoopTensor(self):
2785    with self.test_session() as sess:
2786      tensor_list = []
2787
2788      def condition(t):
2789        return t < constant_op.constant(5)
2790
2791      def body(_):
2792        tensor_list.append(constant_op.constant(5))
2793        return constant_op.constant(10)
2794
2795      result = control_flow_ops.while_loop(condition, body,
2796                                           [constant_op.constant(4)])
2797      self.assertEqual(10, sess.run(result))
2798
2799      # Ensure that we cannot run a tensor that escapes the loop body
2800      # accidentally.
2801      with self.assertRaises(ValueError):
2802        sess.run(tensor_list[0])
2803
2804  def testWhilePyFuncBasic(self):
2805
2806    def func(x):
2807      return np.square(x)
2808
2809    with self.test_session():
2810      r = control_flow_ops.while_loop(
2811          lambda i, v: i < 4,
2812          lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
2813          [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)],
2814          [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
2815      self.assertEqual(r[1].eval(), 65536.0)
2816
2817  def testWhileFuncBasic(self):
2818
2819    @function.Defun(dtypes.float32)
2820    def func(x):
2821      return math_ops.square(math_ops.square(x))
2822
2823    with self.test_session():
2824      x = constant_op.constant(2.0, dtypes.float32)
2825      r = control_flow_ops.while_loop(
2826          lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
2827          [constant_op.constant(0), x],
2828          [tensor_shape.unknown_shape(),
2829           tensor_shape.unknown_shape()])
2830      self.assertEqual(r[1].eval(), 65536.0)
2831
2832      r = gradients_impl.gradients(r, x)[0]
2833      self.assertEqual(r.eval(), 524288.0)
2834      self.assertEqual(
2835          len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
2836          1)
2837
2838
2839@test_util.with_c_api
2840class ControlFlowContextCheckTest(test.TestCase):
2841
2842  def _getWhileTensor(self):
2843    """Creates and returns a tensor from a while context."""
2844    tensor = []
2845
2846    def body(i):
2847      if not tensor:
2848        tensor.append(constant_op.constant(1))
2849      return i + tensor[0]
2850
2851    control_flow_ops.while_loop(lambda i: i < 10, body, [0])
2852    return tensor[0]
2853
2854  def _getCondTensor(self):
2855    cond_tensor = []
2856
2857    def true_fn():
2858      if not cond_tensor:
2859        cond_tensor.append(constant_op.constant(1))
2860      return cond_tensor[0]
2861
2862    control_flow_ops.cond(
2863        math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
2864    return cond_tensor[0]
2865
2866  def testInvalidContext(self):
2867    # Accessing a while loop tensor outside of control flow is illegal.
2868    while_tensor = self._getWhileTensor()
2869    with self.assertRaisesRegexp(
2870        ValueError,
2871        "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
2872        "is in a while loop. See info log for more details."):
2873      math_ops.add(1, while_tensor)
2874
2875  def testInvalidContextInCond(self):
2876    # Accessing a while loop tensor in cond is illegal.
2877    while_tensor = self._getWhileTensor()
2878    with self.assertRaisesRegexp(
2879        ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
2880        "'while/Const_1' is in a while loop. See info log for more details."):
2881      # TODO(skyewm): this passes if we return while_tensor directly instead
2882      # of using it as input to another op.
2883      control_flow_ops.cond(
2884          math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
2885          lambda: constant_op.constant(0))
2886
2887  def testInvalidContextInWhile(self):
2888    # Accessing a while loop tensor in a different while loop is illegal.
2889    while_tensor = self._getWhileTensor()
2890    with self.assertRaisesRegexp(
2891        ValueError,
2892        "Cannot use 'while_1/Add' as input to 'while/Const_1' because they are "
2893        "in different while loops. See info log for more details."):
2894      control_flow_ops.while_loop(lambda i: i < 10,
2895                                  lambda x: math_ops.add(1, while_tensor), [0])
2896
2897    with self.assertRaisesRegexp(
2898        ValueError,
2899        "Cannot use 'while_2/NextIteration' as input to 'while/Const_1' "
2900        "because they are in different while loops. See info log for more "
2901        "details."):
2902      control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])
2903
2904  def testValidCondContext(self):
2905    # Accessing a tensor from a cond context is OK (although dangerous).
2906    cond_tensor = self._getCondTensor()
2907    math_ops.add(1, cond_tensor)
2908
2909  def testValidCondContextBranches(self):
2910    # Accessing a tensor from a cond context from the other branch's cond
2911    # context is OK (although dangerous).
2912    cond_tensor = []
2913
2914    def branch_fn():
2915      if not cond_tensor:
2916        cond_tensor.append(constant_op.constant(1))
2917      return cond_tensor[0]
2918
2919    control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn)
2920
2921  def testValidWhileContext(self):
2922    # Accessing a tensor in a nested while is OK.
2923    def body(_):
2924      c = constant_op.constant(1)
2925      return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0])
2926
2927    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
2928
2929  def testValidNestedContexts(self):
2930    # Accessing a tensor from a cond context in a while context, all inside an
2931    # outer while context, is OK.
2932    def body(_):
2933      cond_tensor = self._getCondTensor()
2934      # Create another cond containing the while loop for good measure
2935      return control_flow_ops.cond(
2936          math_ops.less(1, 2),
2937          lambda: control_flow_ops.while_loop(lambda i: i < 3,
2938                                              lambda i: i + cond_tensor, [0]),
2939          lambda: constant_op.constant(0))
2940
2941    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
2942
2943  def testInvalidNestedContexts(self):
2944    # Accessing a tensor from a while context in a different while context, all
2945    # inside a cond context, is illegal.
2946    def true_fn():
2947      while_tensor = self._getWhileTensor()
2948      return control_flow_ops.while_loop(lambda i: i < 3,
2949                                         lambda i: i + while_tensor, [0])
2950
2951    with self.assertRaisesRegexp(
2952        ValueError,
2953        "Cannot use 'cond/while_1/add' as input to 'cond/while/Const_1' because"
2954        " they are in different while loops. See info log for more details."):
2955      control_flow_ops.cond(
2956          math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
2957
2958
2959@test_util.with_c_api
2960class TupleTest(test.TestCase):
2961
2962  def testTensors(self):
2963    for v1_first in [True, False]:
2964      with self.test_session():
2965        v1 = variables.Variable([1.0])
2966        add1 = math_ops.add(
2967            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
2968            2.0)
2969        v2 = variables.Variable([10.0])
2970        add2 = math_ops.add(
2971            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
2972            20.0)
2973        t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
2974
2975        # v1 is not initialized.
2976        with self.assertRaisesOpError("Attempting to use uninitialized value"):
2977          v1.eval()
2978
2979        # v2 is not initialized.
2980        with self.assertRaisesOpError("Attempting to use uninitialized value"):
2981          v2.eval()
2982
2983        if v1_first:
2984          # Getting t1 initializes v2.
2985          self.assertAllClose([3.0], t1.eval())
2986          self.assertAllClose([10.0], v2.eval())
2987        else:
2988          # Getting t2 initializes v1.
2989          self.assertAllClose([30.0], t2.eval())
2990          self.assertAllClose([1.0], v1.eval())
2991
2992  def testIndexedSlices(self):
2993    for v1_first in [True, False]:
2994      with self.test_session():
2995        v1 = variables.Variable(
2996            np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
2997                np.float32))
2998        v1_at_1 = ops.IndexedSlices(
2999            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
3000            constant_op.constant([1]))
3001
3002        v2 = variables.Variable(
3003            np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
3004                np.float32))
3005        v2_at_1 = ops.IndexedSlices(
3006            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
3007            constant_op.constant([1]))
3008
3009        st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
3010        g1 = array_ops.gather(st1.values, st1.indices)
3011        g2 = array_ops.gather(st2.values, st2.indices)
3012
3013        # v1 is not initialized.
3014        with self.assertRaisesOpError("Attempting to use uninitialized value"):
3015          v1.eval()
3016
3017        # v2 is not initialized.
3018        with self.assertRaisesOpError("Attempting to use uninitialized value"):
3019          v2.eval()
3020
3021        if v1_first:
3022          # Getting g1 initializes v2.
3023          self.assertAllClose([[10.0, 11.0]], g1.eval())
3024          self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
3025                              v2.eval())
3026        else:
3027          # Getting g2 initializes v1.
3028          self.assertAllClose([[10.1, 11.1]], g2.eval())
3029          self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
3030                              v1.eval())
3031
3032  def testAcceptTensorsAsControlInputs(self):
3033    with self.test_session():
3034      var = variables.Variable(0)
3035      assign = state_ops.assign(var, 1)
3036      t, = control_flow_ops.tuple(
3037          [constant_op.constant(0)], control_inputs=[assign])
3038
3039      # Should trigger the assign.
3040      t.eval()
3041
3042      self.assertEquals(1, var.eval())
3043
3044
3045@test_util.with_c_api
3046class AssertTest(test.TestCase):
3047
3048  def testGuardedAssertDoesNotCopyWhenTrue(self):
3049    with self.test_session(use_gpu=True) as sess:
3050      with ops.device(test.gpu_device_name()):
3051        value = constant_op.constant(1.0)
3052      with ops.device("/cpu:0"):
3053        true = constant_op.constant(True)
3054        guarded_assert = control_flow_ops.Assert(true, [value], name="guarded")
3055        unguarded_assert = gen_logging_ops._assert(
3056            true, [value], name="unguarded")
3057      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
3058      guarded_metadata = config_pb2.RunMetadata()
3059      sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata)
3060      unguarded_metadata = config_pb2.RunMetadata()
3061      sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
3062      guarded_nodestat_names = [
3063          n.node_name
3064          for d in guarded_metadata.step_stats.dev_stats
3065          for n in d.node_stats
3066      ]
3067      unguarded_nodestat_names = [
3068          n.node_name
3069          for d in unguarded_metadata.step_stats.dev_stats
3070          for n in d.node_stats
3071      ]
3072      guarded_memcpy_nodestat_names = [
3073          n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
3074      ]
3075      unguarded_memcpy_nodestat_names = [
3076          n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n
3077      ]
3078      if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
3079        # A copy was performed for the unguarded assert
3080        self.assertLess(0, len(unguarded_memcpy_nodestat_names))
3081      # No copy was performed for the guarded assert
3082      self.assertEqual([], guarded_memcpy_nodestat_names)
3083
3084
3085@test_util.with_c_api
3086class WhileOpBenchmark(test.Benchmark):
3087  """Evaluate the performance of while_loop op."""
3088
3089  def _getInitVariables(self):
3090    batch_size = 10
3091    image_size = 256
3092    kernel_size = 3
3093    depth = 16
3094
3095    init_step = constant_op.constant(-1)
3096    image = variable_scope.get_variable(
3097        "image",
3098        initializer=random_ops.random_normal(
3099            [batch_size, image_size, image_size, depth],
3100            dtype=dtypes.float32,
3101            stddev=1e-1))
3102    kernel = variable_scope.get_variable(
3103        "weights",
3104        initializer=random_ops.truncated_normal(
3105            [kernel_size, kernel_size, depth, depth],
3106            dtype=dtypes.float32,
3107            stddev=1e-1))
3108    return init_step, image, kernel
3109
3110  def _runOneBenchmark(self,
3111                       default_device,
3112                       num_iters=10,
3113                       static_unroll=False,
3114                       steps=10):
3115    """Evaluate the while loop performance.
3116
3117    Args:
3118      default_device: The default device to run all ops except the loop_body.
3119        loop_body is always run on GPU.
3120      num_iters: Number of iterations to run.
3121      static_unroll: If true, run unrolled version; otherwise, run while_loop.
3122      steps: Total number of repeated steps to run the loop.
3123
3124    Returns:
3125      The duration of the run in seconds.
3126    """
3127
3128    def loop_body(i, x):
3129      with ops.device("/gpu:0"):
3130        # Always put loop body on GPU.
3131        nx = nn_ops.conv2d(
3132            input=x,
3133            filter=kernel,
3134            strides=[1, 1, 1, 1],
3135            padding="SAME",
3136            data_format="NHWC",
3137            name="conv2d")
3138        ni = math_ops.add(i, 1)
3139        return ni, nx
3140
3141    ops.reset_default_graph()
3142    with session.Session() as sess, ops.device(default_device):
3143      # Get the initial id i, input x, and kernel.
3144      i, x, kernel = self._getInitVariables()
3145      sess.run(variables.global_variables_initializer())
3146
3147      if static_unroll:
3148        for _ in xrange(steps):
3149          i, x = loop_body(i, x)
3150      else:
3151        i, x = control_flow_ops.while_loop(
3152            lambda i, _: i < steps,
3153            loop_body, [i, x],
3154            parallel_iterations=steps,
3155            swap_memory=True)
3156
3157      r = math_ops.reduce_sum(x)
3158      dx, dk = gradients_impl.gradients(r, [x, kernel])
3159      # Use group to avoid fetching back results.
3160      r = control_flow_ops.group(dx, dk)
3161
3162      for _ in xrange(3):
3163        # exclude warm up time
3164        sess.run(r)
3165
3166      start_time = time.time()
3167      for _ in xrange(num_iters):
3168        sess.run(r)
3169      return (time.time() - start_time) / num_iters
3170
3171  def benchmarkWhileOpCrossDevicePlacement(self):
3172    iters = 10
3173    # Run loop body on GPU, but other ops on CPU.
3174    duration = self._runOneBenchmark("cpu", iters, static_unroll=False)
3175    self.report_benchmark(
3176        name="while_op_cross_device", iters=iters, wall_time=duration)
3177
3178  def benchmarkWhileOpSameDevicePlacement(self):
3179    iters = 10
3180    # Run all ops on the same GPU device.
3181    duration = self._runOneBenchmark("gpu", iters, static_unroll=False)
3182    self.report_benchmark(
3183        name="while_op_same_device", iters=iters, wall_time=duration)
3184
3185  def benchmarkWhileOpUnrollCrossDevicePlacement(self):
3186    iters = 10
3187    # Run loop body on GPU, but other ops on CPU.
3188    duration = self._runOneBenchmark("cpu", iters, static_unroll=True)
3189    self.report_benchmark(
3190        name="unroll_cross_device_cpu", iters=iters, wall_time=duration)
3191
3192  def benchmarkWhileOpUnrollSameDevicePlacement(self):
3193    iters = 10
3194    # Run all ops on GPU.
3195    duration = self._runOneBenchmark("gpu", iters, static_unroll=True)
3196    self.report_benchmark(
3197        name="unroll_same_device", iters=iters, wall_time=duration)
3198
3199
3200@test_util.with_c_api
3201class EagerTest(test.TestCase):
3202
3203  def testCond(self):
3204    with context.eager_mode():
3205      pred = math_ops.less(1, 2)
3206      fn1 = lambda: [constant_op.constant(10)]
3207      fn2 = lambda: [constant_op.constant(20)]
3208      r = control_flow_ops.cond(pred, fn1, fn2)
3209
3210      self.assertAllEqual(r.numpy(), 10)
3211      self.assertFalse(isinstance(r, list))
3212
3213  def testWhileLoop(self):
3214    with context.eager_mode():
3215      tensor = constant_op.constant([1, 2, 3, 4, 5])
3216      self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])
3217
3218  def testWhileLoopWithMaxIterations(self):
3219    with context.eager_mode():
3220      tensor = constant_op.constant([1, 2, 3, 4, 5])
3221      self.assertAllEqual(
3222          isum(tensor, maximum_iterations=3).numpy(),
3223          [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])
3224
3225  def testWhileWithMaximumIterationsAndSingleArgument(self):
3226    with context.eager_mode():
3227      tensor = constant_op.constant(0)
3228      r = control_flow_ops.while_loop(
3229          lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
3230      self.assertEqual(1, r.numpy())
3231
3232  def testWithDependencies(self):
3233    with context.eager_mode():
3234      t1 = constant_op.constant(1)
3235      t2 = constant_op.constant(2)
3236      t3 = control_flow_ops.with_dependencies(t1, t2)
3237      self.assertAllEqual(t2.numpy(), t3.numpy())
3238
3239  def testTuple(self):
3240    with context.eager_mode():
3241      t1 = constant_op.constant(1)
3242      t2 = constant_op.constant(2)
3243      tup1, tup2 = control_flow_ops.tuple([t1, t2])
3244      self.assertAllEqual(t1.numpy(), tup1.numpy())
3245      self.assertAllEqual(t2.numpy(), tup2.numpy())
3246
3247  def testCase(self):
3248    with context.eager_mode():
3249      x = constant_op.constant(1)
3250      y = constant_op.constant(2)
3251      z = constant_op.constant(3)
3252      f1 = lambda: constant_op.constant(17)
3253      f2 = lambda: constant_op.constant(23)
3254      f3 = lambda: constant_op.constant(-1)
3255
3256      r1 = control_flow_ops.case(
3257          [(x < y, f1), (x > z, f2)], default=f3, exclusive=True)
3258      self.assertAllEqual(r1.numpy(), 17)
3259
3260
3261if __name__ == "__main__":
3262  test.main()
3263