control_flow_ops_test.py revision a1d2a4ab90bd9df7312408f3971a2236810a1074
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for control_flow_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import node_def_pb2
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework.test_util import TensorFlowTestCase
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import embedding_ops
30from tensorflow.python.ops import gradients_impl
31from tensorflow.python.ops import init_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.ops import tensor_array_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.ops import variables
37import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
38from tensorflow.python.platform import googletest
39from tensorflow.python.training import momentum
40from tensorflow.python.util.protobuf import compare
41
42
43class GroupTestCase(TensorFlowTestCase):
44
45  def _StripNode(self, nd):
46    snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
47    if nd.device:
48      snode.device = nd.device
49    return snode
50
51  def _StripGraph(self, gd):
52    """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
53    return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
54
55  def testGroup_NoDevices(self):
56    with ops.Graph().as_default() as g:
57      a = constant_op.constant(0, name="a")
58      b = constant_op.constant(0, name="b")
59      c = constant_op.constant(0, name="c")
60      control_flow_ops.group(a.op, b.op, c.op, name="root")
61    gd = g.as_graph_def()
62    self.assertProtoEquals("""
63      node { name: "a" op: "Const"}
64      node { name: "b" op: "Const"}
65      node { name: "c" op: "Const"}
66      node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
67    """, self._StripGraph(gd))
68
69  def testGroup_OneDevice(self):
70    with ops.Graph().as_default() as g:
71      with g.device("/task:0"):
72        a = constant_op.constant(0, name="a")
73        b = constant_op.constant(0, name="b")
74      control_flow_ops.group(a.op, b.op, name="root")
75    gd = g.as_graph_def()
76    self.assertProtoEquals("""
77      node { name: "a" op: "Const" device: "/task:0" }
78      node { name: "b" op: "Const" device: "/task:0" }
79      node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
80    """, self._StripGraph(gd))
81
82  def testGroup_MultiDevice(self):
83    with ops.Graph().as_default() as g:
84      with g.device("/task:0"):
85        a = constant_op.constant(0, name="a")
86        b = constant_op.constant(0, name="b")
87      with g.device("/task:1"):
88        c = constant_op.constant(0, name="c")
89        d = constant_op.constant(0, name="d")
90      with g.device("/task:2"):
91        control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
92    gd = g.as_graph_def()
93    self.assertProtoEquals("""
94      node { name: "a" op: "Const" device: "/task:0"}
95      node { name: "b" op: "Const" device: "/task:0"}
96      node { name: "c" op: "Const" device: "/task:1"}
97      node { name: "d" op: "Const" device: "/task:1"}
98      node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
99             device: "/task:0" }
100      node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
101             device: "/task:1" }
102      node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
103             device: "/task:2" }
104    """, self._StripGraph(gd))
105
106
107class ShapeTestCase(TensorFlowTestCase):
108
109  def testShape(self):
110    with ops.Graph().as_default():
111      tensor = constant_op.constant([1.0, 2.0])
112      self.assertEquals([2], tensor.get_shape())
113      self.assertEquals([2],
114                        control_flow_ops.with_dependencies(
115                            [constant_op.constant(1.0)], tensor).get_shape())
116
117
118class WithDependenciesTestCase(TensorFlowTestCase):
119
120  def testTupleDependencies(self):
121    with ops.Graph().as_default():
122      counter = variable_scope.get_variable(
123          "my_counter", shape=[], initializer=init_ops.zeros_initializer())
124      increment_counter = state_ops.assign_add(counter, 1)
125      const_with_dep = control_flow_ops.with_dependencies(
126          (increment_counter, constant_op.constant(42)),
127          constant_op.constant(7))
128      with self.test_session():
129        variables.global_variables_initializer().run()
130        self.assertEquals(0, counter.eval())
131        self.assertEquals(7, const_with_dep.eval())
132        self.assertEquals(1, counter.eval())
133
134  def testListDependencies(self):
135    with ops.Graph().as_default():
136      counter = variable_scope.get_variable(
137          "my_counter", shape=[], initializer=init_ops.zeros_initializer())
138      increment_counter = state_ops.assign_add(counter, 1)
139      const_with_dep = control_flow_ops.with_dependencies(
140          [increment_counter, constant_op.constant(42)],
141          constant_op.constant(7))
142      with self.test_session():
143        variables.global_variables_initializer().run()
144        self.assertEquals(0, counter.eval())
145        self.assertEquals(7, const_with_dep.eval())
146        self.assertEquals(1, counter.eval())
147
148
149class SwitchTestCase(TensorFlowTestCase):
150
151  def testIndexedSlicesWithDenseShape(self):
152    with self.test_session():
153      data = ops.IndexedSlices(
154          constant_op.constant([1, 2, 3]),
155          constant_op.constant([0, 1]),
156          dense_shape=constant_op.constant([3]))
157      zero = constant_op.constant(0)
158      one = constant_op.constant(1)
159      less_op = math_ops.less(zero, one)
160      switch_false, switch_true = control_flow_ops.switch(data, less_op)
161      self.assertAllEqual([1, 2, 3], switch_true.values.eval())
162      self.assertAllEqual([0, 1], switch_true.indices.eval())
163
164  def testIndexedSlicesGradient(self):
165    with ops.Graph().as_default():
166      embedding_matrix = variable_scope.get_variable(
167          "embedding_matrix", [5, 5],
168          initializer=init_ops.random_normal_initializer())
169
170      def Cond(it, _):
171        return it < 5
172
173      def Body(it, cost):
174        embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
175        cost += math_ops.reduce_sum(embedding)
176        return it + 1, cost
177
178      _, cost = control_flow_ops.while_loop(
179          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
180      optimizer = momentum.MomentumOptimizer(0.1, 0.9)
181      train_op = optimizer.minimize(cost)
182      with self.test_session() as sess:
183        sess.run(variables.global_variables_initializer())
184        for _ in range(10):
185          sess.run([train_op])
186
187  def testResourceReadInLoop(self):
188    with ops.Graph().as_default():
189      embedding_matrix = variable_scope.get_variable(
190          "embedding_matrix",
191          initializer=[[2.0], [3.0]],
192          use_resource=True)
193
194      def Cond(it, _):
195        return it < 5
196
197      def Body(it, cost):
198        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
199        cost += math_ops.reduce_sum(embedding)
200        return it + 1, cost
201
202      _, cost = control_flow_ops.while_loop(
203          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
204      with self.test_session() as sess:
205        sess.run(variables.global_variables_initializer())
206        self.assertAllEqual(10.0, cost.eval())
207
208  def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
209    with ops.Graph().as_default():
210      embedding_matrix = variable_scope.get_variable(
211          "embedding_matrix", [5, 5],
212          initializer=init_ops.random_normal_initializer(),
213          use_resource=use_resource)
214
215      def Cond(it, _):
216        return it < 5
217
218      def Body(it, cost):
219        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
220        cost = control_flow_ops.cond(
221            math_ops.equal(it, 3), lambda: math_ops.square(cost),
222            lambda: cost + math_ops.reduce_sum(embedding))
223        return it + 1, cost
224
225      _, cost = control_flow_ops.while_loop(
226          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])
227
228      dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
229      dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
230                                           dynamic_grads.indices)
231
232      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
233      static = math_ops.square(
234          math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
235          math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
236      static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
237      static_grads = math_ops.segment_sum(static_grads.values,
238                                          static_grads.indices)
239
240      with self.test_session() as sess:
241        sess.run(variables.global_variables_initializer())
242        self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
243
244  def testIndexedSlicesGradientInCondInWhileLoop(self):
245    self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
246
247  def testIndexedSlicesGradientInCondInWhileLoopResource(self):
248    self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True)
249
250  def testIndexedSlicesWithShapeGradientInWhileLoop(self):
251    for dtype in [dtypes.float32, dtypes.float64]:
252      with self.test_session() as sess:
253        num_steps = 9
254
255        inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
256        initial_outputs = tensor_array_ops.TensorArray(
257            dtype=dtype, size=num_steps)
258        initial_i = constant_op.constant(0, dtype=dtypes.int32)
259
260        def Cond(i, _):
261          return i < num_steps  # pylint: disable=cell-var-from-loop
262
263        def Body(i, outputs):
264          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
265          outputs = outputs.write(i, x)
266          return i + 1, outputs
267
268        _, outputs = control_flow_ops.while_loop(Cond, Body,
269                                                 [initial_i, initial_outputs])
270
271        outputs = math_ops.reduce_sum(outputs.stack())
272        r = gradients_impl.gradients([outputs], [inputs])[0]
273        grad_wr_inputs = ops.convert_to_tensor(r)
274        o, grad = sess.run([outputs, grad_wr_inputs],
275                           feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
276        self.assertEquals(o, 20)
277        self.assertAllEqual(grad, [1] * num_steps)
278
279  def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
280    for dtype in [dtypes.float32, dtypes.float64]:
281      with self.test_session() as sess:
282        inputs = array_ops.placeholder(dtype=dtype)
283        initial_outputs = tensor_array_ops.TensorArray(
284            dtype=dtype, dynamic_size=True, size=1)
285        initial_i = constant_op.constant(0, dtype=dtypes.int32)
286
287        def Cond(i, _):
288          return i < array_ops.size(inputs)  # pylint: disable=cell-var-from-loop
289
290        def Body(i, outputs):
291          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
292          outputs = outputs.write(i, x)
293          return i + 1, outputs
294
295        _, outputs = control_flow_ops.while_loop(Cond, Body,
296                                                 [initial_i, initial_outputs])
297
298        outputs = math_ops.reduce_sum(outputs.stack())
299        r = gradients_impl.gradients([outputs], [inputs])[0]
300        grad_wr_inputs = ops.convert_to_tensor(r)
301        o, grad = sess.run([outputs, grad_wr_inputs],
302                           feed_dict={inputs: [1, 3, 2]})
303        self.assertEquals(o, 6)
304        self.assertAllEqual(grad, [1] * 3)
305
306
307class ContextTest(TensorFlowTestCase):
308
309  def testCondContext(self):
310    with self.test_session() as sess:
311      x = constant_op.constant(2)
312      y = constant_op.constant(5)
313      control_flow_ops.cond(
314          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
315          lambda: math_ops.add(y, 23))
316      for op in sess.graph.get_operations():
317        c = op._get_control_flow_context()
318        if c:
319          compare.ProtoEq(
320              c.to_proto(),
321              control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
322
323  def testWhileContext(self):
324    with self.test_session() as sess:
325      i = constant_op.constant(0)
326      c = lambda i: math_ops.less(i, 10)
327      b = lambda i: math_ops.add(i, 1)
328      control_flow_ops.while_loop(c, b, [i])
329      for op in sess.graph.get_operations():
330        c = op._get_control_flow_context()
331        if c:
332          compare.ProtoEq(
333              c.to_proto(),
334              control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
335
336
337if __name__ == "__main__":
338  googletest.main()
339