control_flow_ops_test.py revision 0cf9ed3a719c0782695154d5a0bca260001cec15
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for control_flow_ops.py.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import graph_pb2 22from tensorflow.python.framework import ops 23from tensorflow.python.framework.test_util import TensorFlowTestCase 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import embedding_ops 26from tensorflow.python.ops import standard_ops as tf 27from tensorflow.python.platform import googletest 28from tensorflow.python.training import momentum 29 30 31class GroupTestCase(TensorFlowTestCase): 32 33 def _StripNode(self, nd): 34 snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) 35 if nd.device: 36 snode.device = nd.device 37 return snode 38 39 def _StripGraph(self, gd): 40 """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" 41 return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) 42 43 def testGroup_NoDevices(self): 44 with ops.Graph().as_default() as g: 45 a = tf.constant(0, name="a") 46 b = tf.constant(0, name="b") 47 c = tf.constant(0, name="c") 48 tf.group(a.op, b.op, c.op, name="root") 49 gd = g.as_graph_def() 50 self.assertProtoEquals(""" 51 node { name: "a" op: "Const"} 52 node { name: "b" op: "Const"} 53 node { name: "c" op: "Const"} 54 node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } 55 """, self._StripGraph(gd)) 56 57 def testGroup_OneDevice(self): 58 with ops.Graph().as_default() as g: 59 with g.device("/task:0"): 60 a = tf.constant(0, name="a") 61 b = tf.constant(0, name="b") 62 tf.group(a.op, b.op, name="root") 63 gd = g.as_graph_def() 64 self.assertProtoEquals(""" 65 node { name: "a" op: "Const" device: "/task:0" } 66 node { name: "b" op: "Const" device: "/task:0" } 67 node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } 68 """, self._StripGraph(gd)) 69 70 def testGroup_MultiDevice(self): 71 with ops.Graph().as_default() as g: 72 with g.device("/task:0"): 73 a = tf.constant(0, name="a") 74 b = tf.constant(0, name="b") 75 with g.device("/task:1"): 76 c = tf.constant(0, name="c") 77 d = tf.constant(0, name="d") 78 with g.device("/task:2"): 79 tf.group(a.op, b.op, c.op, d.op, name="root") 80 gd = g.as_graph_def() 81 self.assertProtoEquals(""" 82 node { name: "a" op: "Const" device: "/task:0"} 83 node { name: "b" op: "Const" device: "/task:0"} 84 node { name: "c" op: "Const" device: "/task:1"} 85 node { name: "d" op: "Const" device: "/task:1"} 86 node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" 87 device: "/task:0" } 88 node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" 89 device: "/task:1" } 90 node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" 91 device: "/task:2" } 92 """, self._StripGraph(gd)) 93 94 95class ShapeTestCase(TensorFlowTestCase): 96 97 def testShape(self): 98 with ops.Graph().as_default(): 99 tensor = tf.constant([1.0, 2.0]) 100 self.assertEquals([2], tensor.get_shape()) 101 self.assertEquals([2], 102 control_flow_ops.with_dependencies( 103 [tf.constant(1.0)], tensor).get_shape()) 104 105 106class SwitchTestCase(TensorFlowTestCase): 107 108 def testIndexedSlicesWithDenseShape(self): 109 with self.test_session(): 110 data = ops.IndexedSlices(tf.constant([1, 2, 3]), 111 tf.constant([0, 1]), 112 dense_shape=tf.constant([3])) 113 zero = tf.constant(0) 114 one = tf.constant(1) 115 less_op = tf.less(zero, one) 116 switch_false, switch_true = control_flow_ops.switch(data, less_op) 117 self.assertAllEqual([1, 2, 3], switch_true.values.eval()) 118 self.assertAllEqual([0, 1], switch_true.indices.eval()) 119 120 def testIndexedSlicesGradient(self): 121 with ops.Graph().as_default(): 122 embedding_matrix = tf.get_variable( 123 "embedding_matrix", [5, 5], 124 initializer=tf.random_normal_initializer()) 125 def Cond(it, _): 126 return it < 5 127 def Body(it, cost): 128 embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) 129 cost += tf.reduce_sum(embedding) 130 return it + 1, cost 131 _, cost = control_flow_ops.While( 132 Cond, Body, [tf.constant(0), tf.constant(0.0)]) 133 optimizer = momentum.MomentumOptimizer(0.1, 0.9) 134 train_op = optimizer.minimize(cost) 135 with self.test_session() as sess: 136 sess.run(tf.initialize_all_variables()) 137 for _ in range(10): 138 sess.run([train_op]) 139 140 def testIndexedSlicesGradientInCondInWhileLoop(self): 141 with ops.Graph().as_default(): 142 embedding_matrix = tf.get_variable( 143 "embedding_matrix", [5, 5], 144 initializer=tf.random_normal_initializer()) 145 146 def Cond(it, _): 147 return it < 5 148 def Body(it, cost): 149 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 150 cost = tf.cond(tf.equal(it, 3), 151 lambda: tf.square(cost), 152 lambda: cost + tf.reduce_sum(embedding)) 153 return it + 1, cost 154 _, cost = control_flow_ops.While( 155 Cond, Body, [tf.constant(0), tf.constant(0.0)]) 156 157 dynamic_grads = tf.gradients(cost, [embedding_matrix])[0] 158 dynamic_grads = tf.segment_sum(dynamic_grads.values, 159 dynamic_grads.indices) 160 161 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 162 static = tf.square( 163 tf.reduce_sum(embedding) + 164 tf.reduce_sum(embedding) + 165 tf.reduce_sum(embedding)) + tf.reduce_sum(embedding) 166 static_grads = tf.gradients(static, [embedding_matrix])[0] 167 static_grads = tf.segment_sum(static_grads.values, static_grads.indices) 168 169 with self.test_session() as sess: 170 sess.run(tf.initialize_all_variables()) 171 self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) 172 173 174if __name__ == "__main__": 175 googletest.main() 176