control_flow_ops_test.py revision 6f898c6b2cbbc257d0966ee313a3670e88919463
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for control_flow_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.framework import node_def_pb2
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework.test_util import TensorFlowTestCase
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import embedding_ops
36from tensorflow.python.ops import gradients_impl
37from tensorflow.python.ops import init_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import state_ops
40from tensorflow.python.ops import tensor_array_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables
43import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
44from tensorflow.python.platform import googletest
45from tensorflow.python.training import momentum
46from tensorflow.python.util import nest
47
48
49TestTuple = collections.namedtuple("TestTuple", "a b")
50SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
51
52
53class GroupTestCase(TensorFlowTestCase):
54
55  def _StripNode(self, nd):
56    snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
57    if nd.device:
58      snode.device = nd.device
59    return snode
60
61  def _StripGraph(self, gd):
62    """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
63    return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
64
65  def testGroup_NoDevices(self):
66    with ops.Graph().as_default() as g:
67      a = constant_op.constant(0, name="a")
68      b = constant_op.constant(0, name="b")
69      c = constant_op.constant(0, name="c")
70      control_flow_ops.group(a.op, b.op, c.op, name="root")
71    gd = g.as_graph_def()
72    self.assertProtoEquals("""
73      node { name: "a" op: "Const"}
74      node { name: "b" op: "Const"}
75      node { name: "c" op: "Const"}
76      node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
77    """, self._StripGraph(gd))
78
79  def testGroup_OneDevice(self):
80    with ops.Graph().as_default() as g:
81      with g.device("/task:0"):
82        a = constant_op.constant(0, name="a")
83        b = constant_op.constant(0, name="b")
84      control_flow_ops.group(a.op, b.op, name="root")
85    gd = g.as_graph_def()
86    self.assertProtoEquals("""
87      node { name: "a" op: "Const" device: "/task:0" }
88      node { name: "b" op: "Const" device: "/task:0" }
89      node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
90    """, self._StripGraph(gd))
91
92  def testGroup_MultiDevice(self):
93    with ops.Graph().as_default() as g:
94      with g.device("/task:0"):
95        a = constant_op.constant(0, name="a")
96        b = constant_op.constant(0, name="b")
97      with g.device("/task:1"):
98        c = constant_op.constant(0, name="c")
99        d = constant_op.constant(0, name="d")
100      with g.device("/task:2"):
101        control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
102    gd = g.as_graph_def()
103    self.assertProtoEquals("""
104      node { name: "a" op: "Const" device: "/task:0"}
105      node { name: "b" op: "Const" device: "/task:0"}
106      node { name: "c" op: "Const" device: "/task:1"}
107      node { name: "d" op: "Const" device: "/task:1"}
108      node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
109             device: "/task:0" }
110      node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
111             device: "/task:1" }
112      node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
113             device: "/task:2" }
114    """, self._StripGraph(gd))
115
116
117class ShapeTestCase(TensorFlowTestCase):
118
119  def testShape(self):
120    with ops.Graph().as_default():
121      tensor = constant_op.constant([1.0, 2.0])
122      self.assertEquals([2], tensor.get_shape())
123      self.assertEquals([2],
124                        control_flow_ops.with_dependencies(
125                            [constant_op.constant(1.0)], tensor).get_shape())
126
127
128class WithDependenciesTestCase(TensorFlowTestCase):
129
130  def testTupleDependencies(self):
131    with ops.Graph().as_default():
132      counter = variable_scope.get_variable(
133          "my_counter", shape=[], initializer=init_ops.zeros_initializer())
134      increment_counter = state_ops.assign_add(counter, 1)
135      const_with_dep = control_flow_ops.with_dependencies(
136          (increment_counter, constant_op.constant(42)),
137          constant_op.constant(7))
138      with self.test_session():
139        variables.global_variables_initializer().run()
140        self.assertEquals(0, counter.eval())
141        self.assertEquals(7, const_with_dep.eval())
142        self.assertEquals(1, counter.eval())
143
144  def testListDependencies(self):
145    with ops.Graph().as_default():
146      counter = variable_scope.get_variable(
147          "my_counter", shape=[], initializer=init_ops.zeros_initializer())
148      increment_counter = state_ops.assign_add(counter, 1)
149      const_with_dep = control_flow_ops.with_dependencies(
150          [increment_counter, constant_op.constant(42)],
151          constant_op.constant(7))
152      with self.test_session():
153        variables.global_variables_initializer().run()
154        self.assertEquals(0, counter.eval())
155        self.assertEquals(7, const_with_dep.eval())
156        self.assertEquals(1, counter.eval())
157
158
159class SwitchTestCase(TensorFlowTestCase):
160
161  def testIndexedSlicesWithDenseShape(self):
162    with self.test_session():
163      data = ops.IndexedSlices(
164          constant_op.constant([1, 2, 3]),
165          constant_op.constant([0, 1]),
166          dense_shape=constant_op.constant([3]))
167      zero = constant_op.constant(0)
168      one = constant_op.constant(1)
169      less_op = math_ops.less(zero, one)
170      switch_false, switch_true = control_flow_ops.switch(data, less_op)
171      self.assertAllEqual([1, 2, 3], switch_true.values.eval())
172      self.assertAllEqual([0, 1], switch_true.indices.eval())
173
174  def testIndexedSlicesGradient(self):
175    with ops.Graph().as_default():
176      embedding_matrix = variable_scope.get_variable(
177          "embedding_matrix", [5, 5],
178          initializer=init_ops.random_normal_initializer())
179
180      def Cond(it, _):
181        return it < 5
182
183      def Body(it, cost):
184        embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
185        cost += math_ops.reduce_sum(embedding)
186        return it + 1, cost
187
188      _, cost = control_flow_ops.while_loop(
189          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
190      optimizer = momentum.MomentumOptimizer(0.1, 0.9)
191      train_op = optimizer.minimize(cost)
192      with self.test_session() as sess:
193        sess.run(variables.global_variables_initializer())
194        for _ in range(10):
195          sess.run([train_op])
196
197  def testResourceReadInLoop(self):
198    with ops.Graph().as_default():
199      embedding_matrix = variable_scope.get_variable(
200          "embedding_matrix",
201          initializer=[[2.0], [3.0]],
202          use_resource=True)
203
204      def Cond(it, _):
205        return it < 5
206
207      def Body(it, cost):
208        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
209        cost += math_ops.reduce_sum(embedding)
210        return it + 1, cost
211
212      _, cost = control_flow_ops.while_loop(
213          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
214      with self.test_session() as sess:
215        sess.run(variables.global_variables_initializer())
216        self.assertAllEqual(10.0, cost.eval())
217
218  def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
219    with ops.Graph().as_default():
220      embedding_matrix = variable_scope.get_variable(
221          "embedding_matrix", [5, 5],
222          initializer=init_ops.random_normal_initializer(),
223          use_resource=use_resource)
224
225      def Cond(it, _):
226        return it < 5
227
228      def Body(it, cost):
229        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
230        cost = control_flow_ops.cond(
231            math_ops.equal(it, 3), lambda: math_ops.square(cost),
232            lambda: cost + math_ops.reduce_sum(embedding))
233        return it + 1, cost
234
235      _, cost = control_flow_ops.while_loop(
236          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
237
238      dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
239      dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
240                                           dynamic_grads.indices)
241
242      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
243      static = math_ops.square(
244          math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
245          math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
246      static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
247      static_grads = math_ops.segment_sum(static_grads.values,
248                                          static_grads.indices)
249
250      with self.test_session() as sess:
251        sess.run(variables.global_variables_initializer())
252        self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
253
254  def testIndexedSlicesGradientInCondInWhileLoop(self):
255    self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
256
257  def testIndexedSlicesGradientInCondInWhileLoopResource(self):
258    self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True)
259
260  def testIndexedSlicesWithShapeGradientInWhileLoop(self):
261    for dtype in [dtypes.float32, dtypes.float64]:
262      with self.test_session() as sess:
263        num_steps = 9
264
265        inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
266        initial_outputs = tensor_array_ops.TensorArray(
267            dtype=dtype, size=num_steps)
268        initial_i = constant_op.constant(0, dtype=dtypes.int32)
269
270        def Cond(i, _):
271          return i < num_steps  # pylint: disable=cell-var-from-loop
272
273        def Body(i, outputs):
274          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
275          outputs = outputs.write(i, x)
276          return i + 1, outputs
277
278        _, outputs = control_flow_ops.while_loop(Cond, Body,
279                                                 [initial_i, initial_outputs])
280
281        outputs = math_ops.reduce_sum(outputs.stack())
282        r = gradients_impl.gradients([outputs], [inputs])[0]
283        grad_wr_inputs = ops.convert_to_tensor(r)
284        o, grad = sess.run([outputs, grad_wr_inputs],
285                           feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
286        self.assertEquals(o, 20)
287        self.assertAllEqual(grad, [1] * num_steps)
288
289  def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
290    for dtype in [dtypes.float32, dtypes.float64]:
291      with self.test_session() as sess:
292        inputs = array_ops.placeholder(dtype=dtype)
293        initial_outputs = tensor_array_ops.TensorArray(
294            dtype=dtype, dynamic_size=True, size=1)
295        initial_i = constant_op.constant(0, dtype=dtypes.int32)
296
297        def Cond(i, _):
298          return i < array_ops.size(inputs)  # pylint: disable=cell-var-from-loop
299
300        def Body(i, outputs):
301          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
302          outputs = outputs.write(i, x)
303          return i + 1, outputs
304
305        _, outputs = control_flow_ops.while_loop(Cond, Body,
306                                                 [initial_i, initial_outputs])
307
308        outputs = math_ops.reduce_sum(outputs.stack())
309        r = gradients_impl.gradients([outputs], [inputs])[0]
310        grad_wr_inputs = ops.convert_to_tensor(r)
311        o, grad = sess.run([outputs, grad_wr_inputs],
312                           feed_dict={inputs: [1, 3, 2]})
313        self.assertEquals(o, 6)
314        self.assertAllEqual(grad, [1] * 3)
315
316  def testGradientThroughSingleBranchOutsideOfContext(self):
317    with self.test_session():
318      x = constant_op.constant(2.)
319      s = constant_op.constant(True)
320      x_false, x_true = control_flow_ops.switch(x, s)
321      grad_x_true = gradients_impl.gradients(x_true, x)[0]
322      grad_x_false = gradients_impl.gradients(x_false, x)[0]
323      self.assertEquals(grad_x_true.eval(), 1.)
324      self.assertEquals(grad_x_false.eval(), 0.)
325
326
327class CondTest(TensorFlowTestCase):
328
329  def testCondTrue(self):
330    with self.test_session():
331      x = constant_op.constant(2)
332      y = constant_op.constant(5)
333      z = control_flow_ops.cond(
334          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
335          lambda: math_ops.add(y, 23))
336      self.assertEquals(z.eval(), 34)
337
338  def testCondFalse(self):
339    with self.test_session():
340      x = constant_op.constant(2)
341      y = constant_op.constant(1)
342      z = control_flow_ops.cond(
343          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
344          lambda: math_ops.add(y, 23))
345      self.assertEquals(z.eval(), 24)
346
347  def testCondTrueLegacy(self):
348    with self.test_session():
349      x = constant_op.constant(2)
350      y = constant_op.constant(5)
351      z = control_flow_ops.cond(
352          math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
353          fn2=lambda: math_ops.add(y, 23))
354      self.assertEquals(z.eval(), 34)
355
356  def testCondFalseLegacy(self):
357    with self.test_session():
358      x = constant_op.constant(2)
359      y = constant_op.constant(1)
360      z = control_flow_ops.cond(
361          math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
362          fn2=lambda: math_ops.add(y, 23))
363      self.assertEquals(z.eval(), 24)
364
365  def testCondMissingArg1(self):
366    with self.test_session():
367      x = constant_op.constant(1)
368      with self.assertRaises(TypeError):
369        control_flow_ops.cond(True, false_fn=lambda: x)
370
371  def testCondMissingArg2(self):
372    with self.test_session():
373      x = constant_op.constant(1)
374      with self.assertRaises(TypeError):
375        control_flow_ops.cond(True, lambda: x)
376
377  def testCondDuplicateArg1(self):
378    with self.test_session():
379      x = constant_op.constant(1)
380      with self.assertRaises(TypeError):
381        control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
382
383  def testCondDuplicateArg2(self):
384    with self.test_session():
385      x = constant_op.constant(1)
386      with self.assertRaises(TypeError):
387        control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
388
389
390class ContextTest(TensorFlowTestCase):
391
392  def testCondContext(self):
393    with self.test_session() as sess:
394      x = constant_op.constant(2)
395      y = constant_op.constant(5)
396      control_flow_ops.cond(
397          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
398          lambda: math_ops.add(y, 23))
399      for op in sess.graph.get_operations():
400        c = op._get_control_flow_context()
401        if c:
402          self.assertProtoEquals(
403              c.to_proto(),
404              control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
405
406  def testWhileContext(self):
407    with self.test_session() as sess:
408      i = constant_op.constant(0)
409      c = lambda i: math_ops.less(i, 10)
410      b = lambda i: math_ops.add(i, 1)
411      control_flow_ops.while_loop(c, b, [i])
412      for op in sess.graph.get_operations():
413        c = op._get_control_flow_context()
414        if c:
415          self.assertProtoEquals(
416              c.to_proto(),
417              control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
418
419  def testControlContextImportScope(self):
420    with self.test_session():
421      constant_op.constant(0, name="a")
422      constant_op.constant(2, name="test_scope/a")
423      b1 = constant_op.constant(1, name="b")
424      b2 = constant_op.constant(3, name="test_scope/b")
425
426      c = control_flow_ops.ControlFlowContext()
427      c._values = ["a", "b"]
428      c._external_values = {"a": b1}
429
430      c_with_scope = control_flow_ops.ControlFlowContext._from_proto(
431          c._to_proto(), import_scope="test_scope")
432
433      # _values and _external_values should be have scope prepended.
434      self.assertEquals(
435          c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
436      self.assertEquals(
437          c_with_scope._external_values, {"test_scope/a": b2})
438
439      # Calling _to_proto() with export_scope should remove "test_scope".
440      self.assertProtoEquals(
441          c._to_proto(),
442          c_with_scope._to_proto(export_scope="test_scope"))
443
444
445def _GetNestedShape(nested):
446  def _GetShape(tensor):
447    if isinstance(tensor, tensor_array_ops.TensorArray):
448      return tensor_array_ops.TensorArray
449    elif isinstance(tensor, ops.IndexedSlices):
450      return tensor.dense_shape
451    else:
452      return tensor.get_shape()
453
454  return nest.map_structure(_GetShape, nested)
455
456
457def _CreateTensorArray(size, shape):
458  ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
459                                    clear_after_read=False)
460  for i in range(size):
461    ta = ta.write(i, array_ops.zeros(shape))
462  return ta
463
464
465def _RawNestedShape(nested_shape):
466  def _RawShape(shape):
467    if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None:
468      return [x.value for x in shape]
469    else:
470      return None
471  return nest.map_structure(_RawShape, nested_shape)
472
473
474# TODO(yori): Add tests for indexed slices.
475class DataTypesTest(TensorFlowTestCase):
476
477  def assertAllEqualNested(self, a, b):
478    if isinstance(a, (list, tuple)):
479      for entry_a, entry_b in zip(a, b):
480        self.assertAllEqualNested(entry_a, entry_b)
481    else:
482      self.assertAllEqual(a, b)
483
484  def _testShape(self, fn_true, fn_false, expected_shape,
485                 strict=False):
486    condition = array_ops.placeholder(dtypes.bool)
487    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
488                                        strict=strict)
489    self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)),
490                     _RawNestedShape(expected_shape))
491
492    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
493                                        strict=strict)
494    self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)),
495                     _RawNestedShape(expected_shape))
496
497  def _testReturnValues(self, fn_true, fn_false, expected_value_true,
498                        expected_value_false, strict=False,
499                        check_cond=True):
500    condition = array_ops.placeholder(dtypes.bool)
501    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
502                                        strict=strict)
503    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
504                                        strict=strict)
505
506    with self.test_session() as sess:
507      variables.global_variables_initializer().run()
508      result_cond, result_case = sess.run([output_cond, output_case],
509                                          feed_dict={condition: True})
510      self.assertAllEqualNested(result_cond, expected_value_true)
511      if check_cond:
512        self.assertAllEqualNested(result_case, expected_value_true)
513      result_cond, result_case = sess.run([output_cond, output_case],
514                                          feed_dict={condition: False})
515      self.assertAllEqualNested(result_cond, expected_value_false)
516      if check_cond:
517        self.assertAllEqualNested(result_case, expected_value_false)
518
519  def test_int(self):
520    shape = tensor_shape.TensorShape([])
521    fn_true = lambda: 1
522    fn_false = lambda: 2
523    self._testShape(fn_true, fn_false, shape)
524    self._testReturnValues(fn_true, fn_false, 1, 2)
525    self._testShape(fn_true, fn_false, shape, strict=True)
526    self._testReturnValues(fn_true, fn_false, 1, 2, strict=True)
527
528  def test_float(self):
529    shape = tensor_shape.TensorShape([])
530    fn_true = lambda: 1.0
531    fn_false = lambda: 2.0
532    self._testShape(fn_true, fn_false, shape)
533    self._testReturnValues(fn_true, fn_false, 1.0, 2.0)
534
535  def test_noop(self):
536    shape = tensor_shape.TensorShape(None)
537    self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
538    self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
539                           True, False, check_cond=False)
540
541  def test_string(self):
542    shape = tensor_shape.TensorShape([])
543    fn_true = lambda: "abc"
544    fn_false = lambda: "xyz"
545    self._testShape(fn_true, fn_false, shape)
546    self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
547
548  def test_variable(self):
549    shape = tensor_shape.TensorShape([])
550    fn_true = lambda: variables.Variable(3.0)
551    fn_false = lambda: variables.Variable(4.0)
552    self._testShape(fn_true, fn_false, shape)
553    self._testReturnValues(fn_true, fn_false, 3.0, 4.0)
554
555  def test_none(self):
556    fn_none = lambda: None
557    fn_tensor = lambda: constant_op.constant(1)
558
559    with self.assertRaises(ValueError):
560      control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor)
561
562    with self.assertRaises(ValueError):
563      control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none)
564
565  def test_tensors(self):
566    def _BuildTrueBranch(dtype):
567      def _Build():
568        return (array_ops.zeros([2, 2], dtype=dtype),
569                array_ops.ones([3, 3], dtype=dtype))
570      return _Build
571
572    def _BuildFalseBranch(dtype):
573      def _Build():
574        return (array_ops.ones([2, 2], dtype=dtype),
575                array_ops.zeros([3, 3], dtype=dtype))
576      return _Build
577
578    for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
579      shape = (tensor_shape.TensorShape([2, 2]),
580               tensor_shape.TensorShape([3, 3]))
581      fn_true = _BuildTrueBranch(dtype)
582      fn_false = _BuildFalseBranch(dtype)
583      self._testShape(fn_true, fn_false, shape)
584      self._testReturnValues(fn_true, fn_false,
585                             (np.zeros([2, 2]), np.ones([3, 3])),
586                             (np.ones([2, 2]), np.zeros([3, 3])))
587
588  def test_tensors_unknown_shape(self):
589    def _BuildTrueBranch(dtype):
590      def _Build():
591        tensor = array_ops.zeros([2, 2], dtype=dtype)
592        tensor._shape = tensor_shape.TensorShape(None)
593        return tensor
594      return _Build
595
596    def _BuildFalseBranch(dtype):
597      def _Build():
598        tensor = array_ops.ones([2, 2], dtype=dtype)
599        tensor._shape = tensor_shape.TensorShape(None)
600        return tensor
601      return _Build
602
603    for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
604      shape = tensor_shape.TensorShape(None)
605      fn_true = _BuildTrueBranch(dtype)
606      fn_false = _BuildFalseBranch(dtype)
607      self._testShape(fn_true, fn_false, shape)
608      self._testReturnValues(fn_true, fn_false,
609                             np.zeros([2, 2]), np.ones([2, 2]))
610
611  def test_sparse_tensors(self):
612    shape = tensor_shape.TensorShape([None, None])
613
614    def FnTrue():
615      return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
616                                         values=[1, 2], dense_shape=[3, 4])]
617
618    def FnFalse():
619      return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
620                                         values=[3, 4], dense_shape=[3, 4])]
621
622    value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
623                                             values=[1, 2], dense_shape=[3, 4])
624    value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
625                                             values=[3, 4], dense_shape=[3, 4])
626    self._testShape(FnTrue, FnFalse, shape)
627    self._testReturnValues(FnTrue, FnFalse, value1, value2)
628    self._testShape(FnTrue, FnFalse, [shape], strict=True)
629    self._testReturnValues(FnTrue, FnFalse, [value1], [value2], strict=True)
630
631  def test_tensors_with_partially_specified_shapes(self):
632    def _BuildBranch(dtype, shape):
633      def _Build():
634        a = array_ops.zeros([2, 2], dtype=dtype)
635        b = array_ops.zeros([5], dtype=dtype)
636        c = array_ops.ones([3, 3], dtype=dtype)
637        a._shape = tensor_shape.TensorShape(shape[0])
638        b._shape = tensor_shape.TensorShape(shape[1])
639        c._shape = tensor_shape.TensorShape(shape[2])
640        return a, b, c
641      return _Build
642
643    for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
644      shape = (tensor_shape.TensorShape([None, 2]),
645               tensor_shape.TensorShape([None]),
646               tensor_shape.TensorShape([3, None]))
647      fn_true = _BuildBranch(dtype, shape)
648      fn_false = _BuildBranch(dtype, shape)
649      self._testShape(fn_true, fn_false, shape)
650      self._testReturnValues(fn_true, fn_false,
651                             (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
652                             (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])))
653
654  def test_tensor_arrays(self):
655    element_shape = tensor_shape.TensorShape([2])
656    ta1 = _CreateTensorArray(4, element_shape)
657    ta2 = _CreateTensorArray(4, element_shape)
658    shape = tensor_array_ops.TensorArray
659    fn_true = lambda: ta1
660    fn_false = lambda: ta2
661    self._testShape(fn_true, fn_false, shape)
662
663  def test_tensor_array_reads(self):
664    shape = tensor_shape.TensorShape([2])
665    ta = _CreateTensorArray(4, shape)
666    fn_true = lambda: ta.read(0)
667    fn_false = lambda: ta.read(1)
668    self._testShape(fn_true, fn_false, shape)
669
670  def test_list(self):
671    shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
672             tensor_shape.TensorShape([])]
673    fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
674    fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
675    self._testShape(fn_true, fn_false, shape)
676    self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0])
677
678  def test_non_strict(self):
679    shape = tensor_shape.TensorShape([])
680    fn_tensor = lambda: constant_op.constant(1)
681    fn_list = lambda: [constant_op.constant(2)]
682    fn_tuple = lambda: (constant_op.constant(3),)
683    self._testShape(fn_tensor, fn_list, shape)
684    self._testShape(fn_tensor, fn_tuple, shape)
685    self._testShape(fn_list, fn_tuple, shape)
686    self._testReturnValues(fn_tensor, fn_list, 1, 2)
687    self._testReturnValues(fn_tensor, fn_tuple, 1, 3)
688    self._testReturnValues(fn_list, fn_tuple, 2, 3)
689
690  def test_singleton_strict(self):
691    fn_tensor = lambda: constant_op.constant(1)
692    fn_list = lambda: [constant_op.constant(2)]
693    fn_tuple = lambda: (constant_op.constant(3),)
694
695    with self.assertRaises(ValueError):
696      control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
697                            strict=True)
698
699    with self.assertRaises(TypeError):
700      control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
701                            strict=True)
702
703    with self.assertRaises(ValueError):
704      control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
705                            strict=True)
706
707    with self.assertRaises(TypeError):
708      control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
709                            strict=True)
710
711  def test_singleton_list(self):
712    shape = tensor_shape.TensorShape([])
713    fn_true = lambda: [constant_op.constant(1)]
714    fn_false = lambda: [constant_op.constant(3)]
715    self._testShape(fn_true, fn_false, shape)
716    self._testReturnValues(fn_true, fn_false, 1, 3)
717    self._testShape(fn_true, fn_false, [shape], strict=True)
718    self._testReturnValues(fn_true, fn_false, [1], [3], strict=True)
719
720  def test_singleton_tuple(self):
721    shape = tensor_shape.TensorShape([])
722    fn_true = lambda: (constant_op.constant(1),)
723    fn_false = lambda: (constant_op.constant(3),)
724    self._testShape(fn_true, fn_false, shape)
725    self._testReturnValues(fn_true, fn_false, 1, 3)
726    self._testShape(fn_true, fn_false, (shape,), strict=True)
727    self._testReturnValues(fn_true, fn_false, (1,), (3,),
728                           strict=True)
729
730  def test_singleton_namedtuple(self):
731    shape = tensor_shape.TensorShape([])
732    fn_true = lambda: SingletonTestTuple(constant_op.constant(1))
733    fn_false = lambda: SingletonTestTuple(constant_op.constant(3))
734    self._testShape(fn_true, fn_false, shape)
735    self._testReturnValues(fn_true, fn_false, 1, 3)
736    self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
737                    strict=True)
738    self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
739                           SingletonTestTuple(3), strict=True)
740
741  def test_tuple(self):
742    shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
743    fn_true = lambda: (constant_op.constant(1), 2)
744    fn_false = lambda: (constant_op.constant(3), 4)
745    self._testShape(fn_true, fn_false, shape)
746    self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4))
747
748  def test_namedtuple(self):
749    shape = TestTuple(tensor_shape.TensorShape([]),
750                      tensor_shape.TensorShape([]))
751    fn_true = lambda: TestTuple(constant_op.constant(1), 2)
752    fn_false = lambda: TestTuple(constant_op.constant(3), 4)
753    self._testShape(fn_true, fn_false, shape)
754    self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4))
755
756  def test_nested(self):
757    shape = [tensor_shape.TensorShape([]),
758             TestTuple(tensor_shape.TensorShape([]),
759                       [tensor_shape.TensorShape([]),
760                        tensor_shape.TensorShape([])]),
761             tensor_shape.TensorShape([5, 5]),
762             tensor_shape.TensorShape([])]
763
764    def FnTrue():
765      return [constant_op.constant(1),
766              TestTuple(constant_op.constant(2), [3, 4]),
767              array_ops.zeros([5, 5]), 6]
768
769    def FnFalse():
770      return [constant_op.constant(11),
771              TestTuple(constant_op.constant(12), [13, 14]),
772              array_ops.ones([5, 5]), 16]
773
774    self._testShape(FnTrue, FnFalse, shape)
775    self._testReturnValues(FnTrue, FnFalse,
776                           [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
777                           [11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16])
778
779  def test_cond_inside_while_loop(self):
780    def Body(i, matrix):
781      result_tuple, unused_matrix = control_flow_ops.cond(
782          constant_op.constant(True),
783          lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
784          lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
785      return [i+1, result_tuple.a]
786
787    iteration, matrix = control_flow_ops.while_loop(
788        lambda i, matrix: i < 10,
789        Body,
790        loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])])
791
792    self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([]))
793    self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
794
795
796class CaseTest(TensorFlowTestCase):
797
798  def testCase_withDefault(self):
799    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
800    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
801                  (math_ops.equal(x, 2), lambda: constant_op.constant(4))]
802    default = lambda: constant_op.constant(6)
803    output = control_flow_ops.case(conditions, default, exclusive=True)
804    with self.test_session() as sess:
805      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
806      self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
807      self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
808
809  def testCase_multiple_matches_exclusive(self):
810    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
811    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
812                  (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
813                  (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
814    default = lambda: constant_op.constant(8)
815    output = control_flow_ops.case(conditions, default, exclusive=True)
816    with self.test_session() as sess:
817      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
818      self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
819      with self.assertRaisesRegexp(errors.InvalidArgumentError,
820                                   "More than one condition evaluated as True"):
821        sess.run(output, feed_dict={x: 2})
822
823  def testCase_multiple_matches_non_exclusive(self):
824    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
825    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
826                  (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
827                  (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
828    default = lambda: constant_op.constant(8)
829    output = control_flow_ops.case(conditions, default, exclusive=False)
830    with self.test_session() as sess:
831      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
832      self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
833      self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
834
835  def testCase_withoutDefault(self):
836    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
837    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
838                  (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
839                  (math_ops.equal(x, 3), lambda: constant_op.constant(6))]
840    output = control_flow_ops.case(conditions, exclusive=True)
841    with self.test_session() as sess:
842      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
843      self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
844      self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
845      with self.assertRaisesRegexp(
846          errors.InvalidArgumentError,
847          r"\[None of the conditions evaluated as True. "
848          r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] "
849          r"\[0 0 0\]"):
850        sess.run(output, feed_dict={x: 4})
851
852  def testCase_withoutDefault_oneCondition(self):
853    x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
854    conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
855    output = control_flow_ops.case(conditions, exclusive=True)
856    with self.test_session() as sess:
857      self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
858      with self.assertRaisesRegexp(
859          errors.InvalidArgumentError,
860          r"\[None of the conditions evaluated as True. "
861          r"Conditions: \(Equal:0\), Values:\] \[0\]"):
862        sess.run(output, feed_dict={x: 4})
863
864
865if __name__ == "__main__":
866  googletest.main()
867