control_flow_ops_test.py revision bcdcb78854e8dfa1b2eda813b9e2910df522abb4
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for control_flow_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.framework import node_def_pb2
26from tensorflow.python.client import session
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import test_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import embedding_ops
37from tensorflow.python.ops import gradients_impl
38from tensorflow.python.ops import init_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import tensor_array_ops
42from tensorflow.python.ops import variable_scope
43from tensorflow.python.ops import variables
44import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
45from tensorflow.python.platform import googletest
46from tensorflow.python.training import momentum
47from tensorflow.python.util import nest
48
49
50TestTuple = collections.namedtuple("TestTuple", "a b")
51SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
52
53
54class GroupTestCase(test_util.TensorFlowTestCase):
55
56  def _StripNode(self, nd):
57    snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
58    if nd.device:
59      snode.device = nd.device
60    return snode
61
62  def _StripGraph(self, gd):
63    """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
64    return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
65
66  def testGroup_NoDevices(self):
67    with ops.Graph().as_default() as g:
68      a = constant_op.constant(0, name="a")
69      b = constant_op.constant(0, name="b")
70      c = constant_op.constant(0, name="c")
71      control_flow_ops.group(a.op, b.op, c.op, name="root")
72    gd = g.as_graph_def()
73    self.assertProtoEquals("""
74      node { name: "a" op: "Const"}
75      node { name: "b" op: "Const"}
76      node { name: "c" op: "Const"}
77      node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
78    """, self._StripGraph(gd))
79
80  def testGroup_OneDevice(self):
81    with ops.Graph().as_default() as g:
82      with g.device("/task:0"):
83        a = constant_op.constant(0, name="a")
84        b = constant_op.constant(0, name="b")
85      control_flow_ops.group(a.op, b.op, name="root")
86    gd = g.as_graph_def()
87    self.assertProtoEquals("""
88      node { name: "a" op: "Const" device: "/task:0" }
89      node { name: "b" op: "Const" device: "/task:0" }
90      node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
91    """, self._StripGraph(gd))
92
93  def testGroup_MultiDevice(self):
94    with ops.Graph().as_default() as g:
95      with g.device("/task:0"):
96        a = constant_op.constant(0, name="a")
97        b = constant_op.constant(0, name="b")
98      with g.device("/task:1"):
99        c = constant_op.constant(0, name="c")
100        d = constant_op.constant(0, name="d")
101      with g.device("/task:2"):
102        control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
103    gd = g.as_graph_def()
104    self.assertProtoEquals("""
105      node { name: "a" op: "Const" device: "/task:0"}
106      node { name: "b" op: "Const" device: "/task:0"}
107      node { name: "c" op: "Const" device: "/task:1"}
108      node { name: "d" op: "Const" device: "/task:1"}
109      node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
110             device: "/task:0" }
111      node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
112             device: "/task:1" }
113      node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
114             device: "/task:2" }
115    """, self._StripGraph(gd))
116
117  def testPassingList(self):
118    with ops.Graph().as_default() as g:
119      a = constant_op.constant(0, name="a")
120      b = constant_op.constant(0, name="b")
121      control_flow_ops.group([a.op, b.op], name="root")
122    gd = g.as_graph_def()
123    self.assertProtoEquals("""
124      node { name: "a" op: "Const"}
125      node { name: "b" op: "Const"}
126      node { name: "root" op: "NoOp" input: "^a" input: "^b" }
127    """, self._StripGraph(gd))
128
129  def testPassingNonTensors(self):
130    with ops.Graph().as_default():
131      with self.assertRaises(TypeError):
132        control_flow_ops.group(1, 2)
133
134
135class ShapeTestCase(test_util.TensorFlowTestCase):
136
137  def testShape(self):
138    with ops.Graph().as_default():
139      tensor = constant_op.constant([1.0, 2.0])
140      self.assertEquals([2], tensor.get_shape())
141      self.assertEquals([2],
142                        control_flow_ops.with_dependencies(
143                            [constant_op.constant(1.0)], tensor).get_shape())
144
145
146class WithDependenciesTestCase(test_util.TensorFlowTestCase):
147
148  def testTupleDependencies(self):
149    with ops.Graph().as_default():
150      counter = variable_scope.get_variable(
151          "my_counter", shape=[], initializer=init_ops.zeros_initializer())
152      increment_counter = state_ops.assign_add(counter, 1)
153      const_with_dep = control_flow_ops.with_dependencies(
154          (increment_counter, constant_op.constant(42)),
155          constant_op.constant(7))
156      with self.test_session():
157        variables.global_variables_initializer().run()
158        self.assertEquals(0, counter.eval())
159        self.assertEquals(7, const_with_dep.eval())
160        self.assertEquals(1, counter.eval())
161
162  def testListDependencies(self):
163    with ops.Graph().as_default():
164      counter = variable_scope.get_variable(
165          "my_counter", shape=[], initializer=init_ops.zeros_initializer())
166      increment_counter = state_ops.assign_add(counter, 1)
167      const_with_dep = control_flow_ops.with_dependencies(
168          [increment_counter, constant_op.constant(42)],
169          constant_op.constant(7))
170      with self.test_session():
171        variables.global_variables_initializer().run()
172        self.assertEquals(0, counter.eval())
173        self.assertEquals(7, const_with_dep.eval())
174        self.assertEquals(1, counter.eval())
175
176
177class SwitchTestCase(test_util.TensorFlowTestCase):
178
179  def testIndexedSlicesWithDenseShape(self):
180    with self.test_session():
181      data = ops.IndexedSlices(
182          constant_op.constant([1, 2, 3]),
183          constant_op.constant([0, 1]),
184          dense_shape=constant_op.constant([3]))
185      zero = constant_op.constant(0)
186      one = constant_op.constant(1)
187      less_op = math_ops.less(zero, one)
188      switch_false, switch_true = control_flow_ops.switch(data, less_op)
189      self.assertAllEqual([1, 2, 3], switch_true.values.eval())
190      self.assertAllEqual([0, 1], switch_true.indices.eval())
191
192  def testIndexedSlicesGradient(self):
193    with ops.Graph().as_default():
194      embedding_matrix = variable_scope.get_variable(
195          "embedding_matrix", [5, 5],
196          initializer=init_ops.random_normal_initializer())
197
198      def Cond(it, _):
199        return it < 5
200
201      def Body(it, cost):
202        embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
203        cost += math_ops.reduce_sum(embedding)
204        return it + 1, cost
205
206      _, cost = control_flow_ops.while_loop(
207          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
208      optimizer = momentum.MomentumOptimizer(0.1, 0.9)
209      train_op = optimizer.minimize(cost)
210      with self.test_session() as sess:
211        sess.run(variables.global_variables_initializer())
212        for _ in range(10):
213          sess.run([train_op])
214
215  def testResourceReadInLoop(self):
216    with ops.Graph().as_default():
217      embedding_matrix = variable_scope.get_variable(
218          "embedding_matrix",
219          initializer=[[2.0], [3.0]],
220          use_resource=True)
221
222      def Cond(it, _):
223        return it < 5
224
225      def Body(it, cost):
226        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
227        cost += math_ops.reduce_sum(embedding)
228        return it + 1, cost
229
230      _, cost = control_flow_ops.while_loop(
231          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
232      with self.test_session() as sess:
233        sess.run(variables.global_variables_initializer())
234        self.assertAllEqual(10.0, cost.eval())
235
236  def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
237    with ops.Graph().as_default():
238      embedding_matrix = variable_scope.get_variable(
239          "embedding_matrix", [5, 5],
240          initializer=init_ops.random_normal_initializer(),
241          use_resource=use_resource)
242
243      def Cond(it, _):
244        return it < 5
245
246      def Body(it, cost):
247        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
248        cost = control_flow_ops.cond(
249            math_ops.equal(it, 3), lambda: math_ops.square(cost),
250            lambda: cost + math_ops.reduce_sum(embedding))
251        return it + 1, cost
252
253      _, cost = control_flow_ops.while_loop(
254          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
255
256      dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
257      dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
258                                           dynamic_grads.indices)
259
260      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
261      static = math_ops.square(
262          math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
263          math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
264      static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
265      static_grads = math_ops.segment_sum(static_grads.values,
266                                          static_grads.indices)
267
268      with self.test_session() as sess:
269        sess.run(variables.global_variables_initializer())
270        self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
271
272  def testIndexedSlicesGradientInCondInWhileLoop(self):
273    self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
274
275  def testIndexedSlicesGradientInCondInWhileLoopResource(self):
276    self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True)
277
278  def testIndexedSlicesWithShapeGradientInWhileLoop(self):
279    for dtype in [dtypes.float32, dtypes.float64]:
280      with self.test_session() as sess:
281        num_steps = 9
282
283        inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
284        initial_outputs = tensor_array_ops.TensorArray(
285            dtype=dtype, size=num_steps)
286        initial_i = constant_op.constant(0, dtype=dtypes.int32)
287
288        def Cond(i, _):
289          return i < num_steps  # pylint: disable=cell-var-from-loop
290
291        def Body(i, outputs):
292          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
293          outputs = outputs.write(i, x)
294          return i + 1, outputs
295
296        _, outputs = control_flow_ops.while_loop(Cond, Body,
297                                                 [initial_i, initial_outputs])
298
299        outputs = math_ops.reduce_sum(outputs.stack())
300        r = gradients_impl.gradients([outputs], [inputs])[0]
301        grad_wr_inputs = ops.convert_to_tensor(r)
302        o, grad = sess.run([outputs, grad_wr_inputs],
303                           feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
304        self.assertEquals(o, 20)
305        self.assertAllEqual(grad, [1] * num_steps)
306
307  def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
308    for dtype in [dtypes.float32, dtypes.float64]:
309      with self.test_session() as sess:
310        inputs = array_ops.placeholder(dtype=dtype)
311        initial_outputs = tensor_array_ops.TensorArray(
312            dtype=dtype, dynamic_size=True, size=1)
313        initial_i = constant_op.constant(0, dtype=dtypes.int32)
314
315        def Cond(i, _):
316          return i < array_ops.size(inputs)  # pylint: disable=cell-var-from-loop
317
318        def Body(i, outputs):
319          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
320          outputs = outputs.write(i, x)
321          return i + 1, outputs
322
323        _, outputs = control_flow_ops.while_loop(Cond, Body,
324                                                 [initial_i, initial_outputs])
325
326        outputs = math_ops.reduce_sum(outputs.stack())
327        r = gradients_impl.gradients([outputs], [inputs])[0]
328        grad_wr_inputs = ops.convert_to_tensor(r)
329        o, grad = sess.run([outputs, grad_wr_inputs],
330                           feed_dict={inputs: [1, 3, 2]})
331        self.assertEquals(o, 6)
332        self.assertAllEqual(grad, [1] * 3)
333
334  def testGradientThroughSingleBranchOutsideOfContext(self):
335    with self.test_session():
336      x = constant_op.constant(2.)
337      s = constant_op.constant(True)
338      x_false, x_true = control_flow_ops.switch(x, s)
339      grad_x_true = gradients_impl.gradients(x_true, x)[0]
340      grad_x_false = gradients_impl.gradients(x_false, x)[0]
341      self.assertEquals(grad_x_true.eval(), 1.)
342      self.assertEquals(grad_x_false.eval(), 0.)
343
344
345@test_util.with_c_api
346class CondTest(test_util.TensorFlowTestCase):
347
348  def testCondTrue(self):
349    # Create new Graph and Session for each test so we pick up _USE_C_API
350    # correctly.
351    with ops.Graph().as_default():
352      with session.Session():
353        x = constant_op.constant(2)
354        y = constant_op.constant(5)
355        z = control_flow_ops.cond(
356            math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
357            lambda: math_ops.add(y, 23))
358        self.assertEquals(z.eval(), 34)
359
360  def testCondFalse(self):
361    with ops.Graph().as_default():
362      with session.Session():
363        x = constant_op.constant(2)
364        y = constant_op.constant(1)
365        z = control_flow_ops.cond(
366            math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
367            lambda: math_ops.add(y, 23))
368        self.assertEquals(z.eval(), 24)
369
370  def testCondTrueLegacy(self):
371    with ops.Graph().as_default():
372      with session.Session():
373        x = constant_op.constant(2)
374        y = constant_op.constant(5)
375        z = control_flow_ops.cond(
376            math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
377            fn2=lambda: math_ops.add(y, 23))
378        self.assertEquals(z.eval(), 34)
379
380  def testCondFalseLegacy(self):
381    with ops.Graph().as_default():
382      with session.Session():
383        x = constant_op.constant(2)
384        y = constant_op.constant(1)
385        z = control_flow_ops.cond(
386            math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
387            fn2=lambda: math_ops.add(y, 23))
388        self.assertEquals(z.eval(), 24)
389
390  def testCondModifyBoolPred(self):
391    # This test in particular used to fail only when running in GPU, hence
392    # use_gpu=True.
393    with ops.Graph().as_default():
394      with session.Session() as sess:
395        bool_var = variable_scope.get_variable("bool_var", dtype=dtypes.bool,
396                                               initializer=True)
397        cond_on_bool_var = control_flow_ops.cond(
398            pred=bool_var,
399            true_fn=lambda: state_ops.assign(bool_var, False),
400            false_fn=lambda: True)
401        sess.run(bool_var.initializer)
402        self.assertEquals(sess.run(cond_on_bool_var), False)
403        self.assertEquals(sess.run(cond_on_bool_var), True)
404
405  def testCondMissingArg1(self):
406    with ops.Graph().as_default():
407      with session.Session():
408        x = constant_op.constant(1)
409        with self.assertRaises(TypeError):
410          control_flow_ops.cond(True, false_fn=lambda: x)
411
412  def testCondMissingArg2(self):
413    with ops.Graph().as_default():
414      with session.Session():
415        x = constant_op.constant(1)
416        with self.assertRaises(TypeError):
417          control_flow_ops.cond(True, lambda: x)
418
419  def testCondDuplicateArg1(self):
420    with ops.Graph().as_default():
421      with session.Session():
422        x = constant_op.constant(1)
423        with self.assertRaises(TypeError):
424          control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
425
426  def testCondDuplicateArg2(self):
427    with ops.Graph().as_default():
428      with session.Session():
429        x = constant_op.constant(1)
430        with self.assertRaises(TypeError):
431          control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
432
433
434class ContextTest(test_util.TensorFlowTestCase):
435
436  def testCondContext(self):
437    with self.test_session() as sess:
438      x = constant_op.constant(2)
439      y = constant_op.constant(5)
440      control_flow_ops.cond(
441          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
442          lambda: math_ops.add(y, 23))
443      for op in sess.graph.get_operations():
444        c = op._get_control_flow_context()
445        if c:
446          self.assertProtoEquals(
447              c.to_proto(),
448              control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
449
450  def testWhileContext(self):
451    with self.test_session() as sess:
452      i = constant_op.constant(0)
453      c = lambda i: math_ops.less(i, 10)
454      b = lambda i: math_ops.add(i, 1)
455      control_flow_ops.while_loop(c, b, [i])
456      for op in sess.graph.get_operations():
457        c = op._get_control_flow_context()
458        if c:
459          self.assertProtoEquals(
460              c.to_proto(),
461              control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
462
463  def testControlContextImportScope(self):
464    with self.test_session():
465      constant_op.constant(0, name="a")
466      constant_op.constant(2, name="test_scope/a")
467      b1 = constant_op.constant(1, name="b")
468      b2 = constant_op.constant(3, name="test_scope/b")
469
470      c = control_flow_ops.ControlFlowContext()
471      c._values = ["a", "b"]
472      c._external_values = {"a": b1}
473
474      c_with_scope = control_flow_ops.ControlFlowContext._from_proto(
475          c._to_proto(), import_scope="test_scope")
476
477      # _values and _external_values should be have scope prepended.
478      self.assertEquals(
479          c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
480      self.assertEquals(
481          c_with_scope._external_values, {"test_scope/a": b2})
482
483      # Calling _to_proto() with export_scope should remove "test_scope".
484      self.assertProtoEquals(
485          c._to_proto(),
486          c_with_scope._to_proto(export_scope="test_scope"))
487
488
489def _GetNestedShape(nested):
490  def _GetShape(tensor):
491    if isinstance(tensor, tensor_array_ops.TensorArray):
492      return tensor_array_ops.TensorArray
493    elif isinstance(tensor, ops.IndexedSlices):
494      return tensor.dense_shape
495    else:
496      return tensor.get_shape()
497
498  return nest.map_structure(_GetShape, nested)
499
500
501def _CreateTensorArray(size, shape):
502  ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
503                                    clear_after_read=False)
504  for i in range(size):
505    ta = ta.write(i, array_ops.zeros(shape))
506  return ta
507
508
509def _RawNestedShape(nested_shape):
510  def _RawShape(shape):
511    if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None:
512      return [x.value for x in shape]
513    else:
514      return None
515  return nest.map_structure(_RawShape, nested_shape)
516
517
518# TODO(yori): Add tests for indexed slices.
519class DataTypesTest(test_util.TensorFlowTestCase):
520
521  def assertAllEqualNested(self, a, b):
522    if isinstance(a, (list, tuple)):
523      for entry_a, entry_b in zip(a, b):
524        self.assertAllEqualNested(entry_a, entry_b)
525    else:
526      self.assertAllEqual(a, b)
527
528  def _testShape(self, fn_true, fn_false, expected_shape,
529                 strict=False):
530    condition = array_ops.placeholder(dtypes.bool)
531    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
532                                        strict=strict)
533    self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)),
534                     _RawNestedShape(expected_shape))
535
536    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
537                                        strict=strict)
538    self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)),
539                     _RawNestedShape(expected_shape))
540
541  def _testReturnValues(self, fn_true, fn_false, expected_value_true,
542                        expected_value_false, strict=False,
543                        check_cond=True, feed_dict=None):
544    if feed_dict is None: feed_dict = {}
545
546    condition = array_ops.placeholder(dtypes.bool)
547    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
548                                        strict=strict)
549    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
550                                        strict=strict)
551
552    with self.test_session() as sess:
553      variables.global_variables_initializer().run()
554      true_feed_dict = {condition: True}
555      true_feed_dict.update(feed_dict)
556      result_cond, result_case = sess.run([output_cond, output_case],
557                                          feed_dict=true_feed_dict)
558      self.assertAllEqualNested(result_cond, expected_value_true)
559      if check_cond:
560        self.assertAllEqualNested(result_case, expected_value_true)
561      false_feed_dict = {condition: False}
562      false_feed_dict.update(feed_dict)
563      result_cond, result_case = sess.run([output_cond, output_case],
564                                          feed_dict=false_feed_dict)
565      self.assertAllEqualNested(result_cond, expected_value_false)
566      if check_cond:
567        self.assertAllEqualNested(result_case, expected_value_false)
568
569  def test_int(self):
570    shape = tensor_shape.TensorShape([])
571    fn_true = lambda: 1
572    fn_false = lambda: 2
573    self._testShape(fn_true, fn_false, shape)
574    self._testReturnValues(fn_true, fn_false, 1, 2)
575    self._testShape(fn_true, fn_false, shape, strict=True)
576    self._testReturnValues(fn_true, fn_false, 1, 2, strict=True)
577
578  def test_float(self):
579    shape = tensor_shape.TensorShape([])
580    fn_true = lambda: 1.0
581    fn_false = lambda: 2.0
582    self._testShape(fn_true, fn_false, shape)
583    self._testReturnValues(fn_true, fn_false, 1.0, 2.0)
584
585  def test_noop(self):
586    shape = tensor_shape.TensorShape(None)
587    self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
588    self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
589                           True, False, check_cond=False)
590
591  def test_string(self):
592    shape = tensor_shape.TensorShape([])
593    fn_true = lambda: "abc"
594    fn_false = lambda: "xyz"
595    self._testShape(fn_true, fn_false, shape)
596    self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
597
598  def test_variable(self):
599    shape = tensor_shape.TensorShape([])
600    fn_true = lambda: variables.Variable(3.0)
601    fn_false = lambda: variables.Variable(4.0)
602    self._testShape(fn_true, fn_false, shape)
603    self._testReturnValues(fn_true, fn_false, 3.0, 4.0)
604
605  def test_none(self):
606    fn_none = lambda: None
607    fn_tensor = lambda: constant_op.constant(1)
608
609    with self.assertRaises(ValueError):
610      control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor)
611
612    with self.assertRaises(ValueError):
613      control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none)
614
615  def test_tensors(self):
616    def _BuildTrueBranch(dtype):
617      def _Build():
618        return (array_ops.zeros([2, 2], dtype=dtype),
619                array_ops.ones([3, 3], dtype=dtype))
620      return _Build
621
622    def _BuildFalseBranch(dtype):
623      def _Build():
624        return (array_ops.ones([2, 2], dtype=dtype),
625                array_ops.zeros([3, 3], dtype=dtype))
626      return _Build
627
628    for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
629      shape = (tensor_shape.TensorShape([2, 2]),
630               tensor_shape.TensorShape([3, 3]))
631      fn_true = _BuildTrueBranch(dtype)
632      fn_false = _BuildFalseBranch(dtype)
633      self._testShape(fn_true, fn_false, shape)
634      self._testReturnValues(fn_true, fn_false,
635                             (np.zeros([2, 2]), np.ones([3, 3])),
636                             (np.ones([2, 2]), np.zeros([3, 3])))
637
638  def test_tensors_unknown_shape(self):
639    def _BuildTrueBranch(dtype):
640      tensor = array_ops.placeholder(dtype=dtype, shape=None)
641      def _Build():
642        return tensor
643      return _Build, tensor
644
645    def _BuildFalseBranch(dtype):
646      tensor = array_ops.placeholder(dtype=dtype, shape=None)
647      def _Build():
648        return tensor
649      return _Build, tensor
650
651    for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
652      shape = tensor_shape.TensorShape(None)
653      fn_true, true_tensor = _BuildTrueBranch(dtype)
654      fn_false, false_tensor = _BuildFalseBranch(dtype)
655      self._testShape(fn_true, fn_false, shape)
656      self._testReturnValues(fn_true, fn_false,
657                             np.zeros([2, 2]), np.ones([2, 2]),
658                             feed_dict={true_tensor: np.zeros([2, 2]),
659                                        false_tensor: np.ones([2, 2])})
660
661  def test_sparse_tensors(self):
662    shape = tensor_shape.TensorShape([None, None])
663
664    def FnTrue():
665      return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
666                                         values=[1, 2], dense_shape=[3, 4])]
667
668    def FnFalse():
669      return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
670                                         values=[3, 4], dense_shape=[3, 4])]
671
672    value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
673                                             values=[1, 2], dense_shape=[3, 4])
674    value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
675                                             values=[3, 4], dense_shape=[3, 4])
676    self._testShape(FnTrue, FnFalse, shape)
677    self._testReturnValues(FnTrue, FnFalse, value1, value2)
678    self._testShape(FnTrue, FnFalse, [shape], strict=True)
679    self._testReturnValues(FnTrue, FnFalse, [value1], [value2], strict=True)
680
681  def test_tensors_with_partially_specified_shapes(self):
682    def _BuildBranch(dtype, shape):
683      a = array_ops.placeholder(dtype=dtype, shape=shape[0])
684      b = array_ops.placeholder(dtype=dtype, shape=shape[1])
685      c = array_ops.placeholder(dtype=dtype, shape=shape[2])
686      def _Build():
687        return a, b, c
688      return _Build, (a, b, c)
689
690    for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
691      shape = (tensor_shape.TensorShape([None, 2]),
692               tensor_shape.TensorShape([None]),
693               tensor_shape.TensorShape([3, None]))
694      fn_true, true_tensors = _BuildBranch(dtype, shape)
695      fn_false, false_tensors = _BuildBranch(dtype, shape)
696      self._testShape(fn_true, fn_false, shape)
697      self._testReturnValues(fn_true, fn_false,
698                             (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
699                             (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
700                             feed_dict={true_tensors[0]: np.zeros([2, 2]),
701                                        false_tensors[0]: np.zeros([2, 2]),
702                                        true_tensors[1]: np.zeros([5]),
703                                        false_tensors[1]: np.zeros([5]),
704                                        true_tensors[2]: np.ones([3, 3]),
705                                        false_tensors[2]: np.ones([3, 3])})
706
707  def test_tensor_arrays(self):
708    element_shape = tensor_shape.TensorShape([2])
709    ta1 = _CreateTensorArray(4, element_shape)
710    ta2 = _CreateTensorArray(4, element_shape)
711    shape = tensor_array_ops.TensorArray
712    fn_true = lambda: ta1
713    fn_false = lambda: ta2
714    self._testShape(fn_true, fn_false, shape)
715
716  def test_tensor_array_reads(self):
717    shape = tensor_shape.TensorShape([2])
718    ta = _CreateTensorArray(4, shape)
719    fn_true = lambda: ta.read(0)
720    fn_false = lambda: ta.read(1)
721    self._testShape(fn_true, fn_false, shape)
722
723  def test_list(self):
724    shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
725             tensor_shape.TensorShape([])]
726    fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
727    fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
728    self._testShape(fn_true, fn_false, shape)
729    self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0])
730
731  def test_non_strict(self):
732    shape = tensor_shape.TensorShape([])
733    fn_tensor = lambda: constant_op.constant(1)
734    fn_list = lambda: [constant_op.constant(2)]
735    fn_tuple = lambda: (constant_op.constant(3),)
736    self._testShape(fn_tensor, fn_list, shape)
737    self._testShape(fn_tensor, fn_tuple, shape)
738    self._testShape(fn_list, fn_tuple, shape)
739    self._testReturnValues(fn_tensor, fn_list, 1, 2)
740    self._testReturnValues(fn_tensor, fn_tuple, 1, 3)
741    self._testReturnValues(fn_list, fn_tuple, 2, 3)
742
743  def test_singleton_strict(self):
744    fn_tensor = lambda: constant_op.constant(1)
745    fn_list = lambda: [constant_op.constant(2)]
746    fn_tuple = lambda: (constant_op.constant(3),)
747
748    with self.assertRaises(ValueError):
749      control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
750                            strict=True)
751
752    with self.assertRaises(TypeError):
753      control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
754                            strict=True)
755
756    with self.assertRaises(ValueError):
757      control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
758                            strict=True)
759
760    with self.assertRaises(TypeError):
761      control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
762                            strict=True)
763
764  def test_singleton_list(self):
765    shape = tensor_shape.TensorShape([])
766    fn_true = lambda: [constant_op.constant(1)]
767    fn_false = lambda: [constant_op.constant(3)]
768    self._testShape(fn_true, fn_false, shape)
769    self._testReturnValues(fn_true, fn_false, 1, 3)
770    self._testShape(fn_true, fn_false, [shape], strict=True)
771    self._testReturnValues(fn_true, fn_false, [1], [3], strict=True)
772
773  def test_singleton_tuple(self):
774    shape = tensor_shape.TensorShape([])
775    fn_true = lambda: (constant_op.constant(1),)
776    fn_false = lambda: (constant_op.constant(3),)
777    self._testShape(fn_true, fn_false, shape)
778    self._testReturnValues(fn_true, fn_false, 1, 3)
779    self._testShape(fn_true, fn_false, (shape,), strict=True)
780    self._testReturnValues(fn_true, fn_false, (1,), (3,),
781                           strict=True)
782
783  def test_singleton_namedtuple(self):
784    shape = tensor_shape.TensorShape([])
785    fn_true = lambda: SingletonTestTuple(constant_op.constant(1))
786    fn_false = lambda: SingletonTestTuple(constant_op.constant(3))
787    self._testShape(fn_true, fn_false, shape)
788    self._testReturnValues(fn_true, fn_false, 1, 3)
789    self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
790                    strict=True)
791    self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
792                           SingletonTestTuple(3), strict=True)
793
794  def test_tuple(self):
795    shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
796    fn_true = lambda: (constant_op.constant(1), 2)
797    fn_false = lambda: (constant_op.constant(3), 4)
798    self._testShape(fn_true, fn_false, shape)
799    self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4))
800
801  def test_namedtuple(self):
802    shape = TestTuple(tensor_shape.TensorShape([]),
803                      tensor_shape.TensorShape([]))
804    fn_true = lambda: TestTuple(constant_op.constant(1), 2)
805    fn_false = lambda: TestTuple(constant_op.constant(3), 4)
806    self._testShape(fn_true, fn_false, shape)
807    self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4))
808
809  def test_nested(self):
810    shape = [tensor_shape.TensorShape([]),
811             TestTuple(tensor_shape.TensorShape([]),
812                       [tensor_shape.TensorShape([]),
813                        tensor_shape.TensorShape([])]),
814             tensor_shape.TensorShape([5, 5]),
815             tensor_shape.TensorShape([])]
816
817    def FnTrue():
818      return [constant_op.constant(1),
819              TestTuple(constant_op.constant(2), [3, 4]),
820              array_ops.zeros([5, 5]), 6]
821
822    def FnFalse():
823      return [constant_op.constant(11),
824              TestTuple(constant_op.constant(12), [13, 14]),
825              array_ops.ones([5, 5]), 16]
826
827    self._testShape(FnTrue, FnFalse, shape)
828    self._testReturnValues(FnTrue, FnFalse,
829                           [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
830                           [11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16])
831
832  def test_cond_inside_while_loop(self):
833    def Body(i, matrix):
834      result_tuple, unused_matrix = control_flow_ops.cond(
835          constant_op.constant(True),
836          lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
837          lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
838      return [i+1, result_tuple.a]
839
840    iteration, matrix = control_flow_ops.while_loop(
841        lambda i, matrix: i < 10,
842        Body,
843        loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])])
844
845    self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([]))
846    self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
847
848
849class CaseTest(test_util.TensorFlowTestCase):
850
851  def testCase_withDefault(self):
852    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
853    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
854                  (math_ops.equal(x, 2), lambda: constant_op.constant(4))]
855    default = lambda: constant_op.constant(6)
856    output = control_flow_ops.case(conditions, default, exclusive=True)
857    with self.test_session() as sess:
858      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
859      self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
860      self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
861
862  def testCase_multiple_matches_exclusive(self):
863    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
864    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
865                  (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
866                  (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
867    default = lambda: constant_op.constant(8)
868    output = control_flow_ops.case(conditions, default, exclusive=True)
869    with self.test_session() as sess:
870      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
871      self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
872      with self.assertRaisesRegexp(errors.InvalidArgumentError,
873                                   "More than one condition evaluated as True"):
874        sess.run(output, feed_dict={x: 2})
875
876  def testCase_multiple_matches_non_exclusive(self):
877    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
878    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
879                  (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
880                  (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
881    default = lambda: constant_op.constant(8)
882    output = control_flow_ops.case(conditions, default, exclusive=False)
883    with self.test_session() as sess:
884      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
885      self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
886      self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
887
888  def testCase_withoutDefault(self):
889    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
890    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
891                  (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
892                  (math_ops.equal(x, 3), lambda: constant_op.constant(6))]
893    output = control_flow_ops.case(conditions, exclusive=True)
894    with self.test_session() as sess:
895      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
896      self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
897      self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
898      with self.assertRaisesRegexp(
899          errors.InvalidArgumentError,
900          r"\[None of the conditions evaluated as True. "
901          r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] "
902          r"\[0 0 0\]"):
903        sess.run(output, feed_dict={x: 4})
904
905  def testCase_withoutDefault_oneCondition(self):
906    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
907    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
908    output = control_flow_ops.case(conditions, exclusive=True)
909    with self.test_session() as sess:
910      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
911      with self.assertRaisesRegexp(
912          errors.InvalidArgumentError,
913          r"\[None of the conditions evaluated as True. "
914          r"Conditions: \(Equal:0\), Values:\] \[0\]"):
915        sess.run(output, feed_dict={x: 4})
916
917
918if __name__ == "__main__":
919  googletest.main()
920