control_flow_ops_test.py revision 0cf9ed3a719c0782695154d5a0bca260001cec15
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for control_flow_ops.py."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import graph_pb2
22from tensorflow.python.framework import ops
23from tensorflow.python.framework.test_util import TensorFlowTestCase
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import embedding_ops
26from tensorflow.python.ops import standard_ops as tf
27from tensorflow.python.platform import googletest
28from tensorflow.python.training import momentum
29
30
31class GroupTestCase(TensorFlowTestCase):
32
33  def _StripNode(self, nd):
34    snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
35    if nd.device:
36      snode.device = nd.device
37    return snode
38
39  def _StripGraph(self, gd):
40    """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
41    return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
42
43  def testGroup_NoDevices(self):
44    with ops.Graph().as_default() as g:
45      a = tf.constant(0, name="a")
46      b = tf.constant(0, name="b")
47      c = tf.constant(0, name="c")
48      tf.group(a.op, b.op, c.op, name="root")
49    gd = g.as_graph_def()
50    self.assertProtoEquals("""
51      node { name: "a" op: "Const"}
52      node { name: "b" op: "Const"}
53      node { name: "c" op: "Const"}
54      node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
55    """, self._StripGraph(gd))
56
57  def testGroup_OneDevice(self):
58    with ops.Graph().as_default() as g:
59      with g.device("/task:0"):
60        a = tf.constant(0, name="a")
61        b = tf.constant(0, name="b")
62      tf.group(a.op, b.op, name="root")
63    gd = g.as_graph_def()
64    self.assertProtoEquals("""
65      node { name: "a" op: "Const" device: "/task:0" }
66      node { name: "b" op: "Const" device: "/task:0" }
67      node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
68    """, self._StripGraph(gd))
69
70  def testGroup_MultiDevice(self):
71    with ops.Graph().as_default() as g:
72      with g.device("/task:0"):
73        a = tf.constant(0, name="a")
74        b = tf.constant(0, name="b")
75      with g.device("/task:1"):
76        c = tf.constant(0, name="c")
77        d = tf.constant(0, name="d")
78      with g.device("/task:2"):
79        tf.group(a.op, b.op, c.op, d.op, name="root")
80    gd = g.as_graph_def()
81    self.assertProtoEquals("""
82      node { name: "a" op: "Const" device: "/task:0"}
83      node { name: "b" op: "Const" device: "/task:0"}
84      node { name: "c" op: "Const" device: "/task:1"}
85      node { name: "d" op: "Const" device: "/task:1"}
86      node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
87             device: "/task:0" }
88      node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
89             device: "/task:1" }
90      node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
91             device: "/task:2" }
92    """, self._StripGraph(gd))
93
94
95class ShapeTestCase(TensorFlowTestCase):
96
97  def testShape(self):
98    with ops.Graph().as_default():
99      tensor = tf.constant([1.0, 2.0])
100      self.assertEquals([2], tensor.get_shape())
101      self.assertEquals([2],
102                        control_flow_ops.with_dependencies(
103                            [tf.constant(1.0)], tensor).get_shape())
104
105
106class SwitchTestCase(TensorFlowTestCase):
107
108  def testIndexedSlicesWithDenseShape(self):
109    with self.test_session():
110      data = ops.IndexedSlices(tf.constant([1, 2, 3]),
111                               tf.constant([0, 1]),
112                               dense_shape=tf.constant([3]))
113      zero = tf.constant(0)
114      one = tf.constant(1)
115      less_op = tf.less(zero, one)
116      switch_false, switch_true = control_flow_ops.switch(data, less_op)
117      self.assertAllEqual([1, 2, 3], switch_true.values.eval())
118      self.assertAllEqual([0, 1], switch_true.indices.eval())
119
120  def testIndexedSlicesGradient(self):
121    with ops.Graph().as_default():
122      embedding_matrix = tf.get_variable(
123          "embedding_matrix", [5, 5],
124          initializer=tf.random_normal_initializer())
125      def Cond(it, _):
126        return it < 5
127      def Body(it, cost):
128        embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
129        cost += tf.reduce_sum(embedding)
130        return it + 1, cost
131      _, cost = control_flow_ops.While(
132          Cond, Body, [tf.constant(0), tf.constant(0.0)])
133      optimizer = momentum.MomentumOptimizer(0.1, 0.9)
134      train_op = optimizer.minimize(cost)
135      with self.test_session() as sess:
136        sess.run(tf.initialize_all_variables())
137        for _ in range(10):
138          sess.run([train_op])
139
140  def testIndexedSlicesGradientInCondInWhileLoop(self):
141    with ops.Graph().as_default():
142      embedding_matrix = tf.get_variable(
143          "embedding_matrix", [5, 5],
144          initializer=tf.random_normal_initializer())
145
146      def Cond(it, _):
147        return it < 5
148      def Body(it, cost):
149        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
150        cost = tf.cond(tf.equal(it, 3),
151                       lambda: tf.square(cost),
152                       lambda: cost + tf.reduce_sum(embedding))
153        return it + 1, cost
154      _, cost = control_flow_ops.While(
155          Cond, Body, [tf.constant(0), tf.constant(0.0)])
156
157      dynamic_grads = tf.gradients(cost, [embedding_matrix])[0]
158      dynamic_grads = tf.segment_sum(dynamic_grads.values,
159                                     dynamic_grads.indices)
160
161      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
162      static = tf.square(
163          tf.reduce_sum(embedding) +
164          tf.reduce_sum(embedding) +
165          tf.reduce_sum(embedding)) + tf.reduce_sum(embedding)
166      static_grads = tf.gradients(static, [embedding_matrix])[0]
167      static_grads = tf.segment_sum(static_grads.values, static_grads.indices)
168
169      with self.test_session() as sess:
170        sess.run(tf.initialize_all_variables())
171        self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
172
173
174if __name__ == "__main__":
175  googletest.main()
176