control_flow_ops_test.py revision a1d2a4ab90bd9df7312408f3971a2236810a1074
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for control_flow_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import graph_pb2 22from tensorflow.core.framework import node_def_pb2 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework.test_util import TensorFlowTestCase 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import embedding_ops 30from tensorflow.python.ops import gradients_impl 31from tensorflow.python.ops import init_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import state_ops 34from tensorflow.python.ops import tensor_array_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.ops import variables 37import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 38from tensorflow.python.platform import googletest 39from tensorflow.python.training import momentum 40from tensorflow.python.util.protobuf import compare 41 42 43class GroupTestCase(TensorFlowTestCase): 44 45 def _StripNode(self, nd): 46 snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) 47 if nd.device: 48 snode.device = nd.device 49 return snode 50 51 def _StripGraph(self, gd): 52 """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" 53 return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) 54 55 def testGroup_NoDevices(self): 56 with ops.Graph().as_default() as g: 57 a = constant_op.constant(0, name="a") 58 b = constant_op.constant(0, name="b") 59 c = constant_op.constant(0, name="c") 60 control_flow_ops.group(a.op, b.op, c.op, name="root") 61 gd = g.as_graph_def() 62 self.assertProtoEquals(""" 63 node { name: "a" op: "Const"} 64 node { name: "b" op: "Const"} 65 node { name: "c" op: "Const"} 66 node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } 67 """, self._StripGraph(gd)) 68 69 def testGroup_OneDevice(self): 70 with ops.Graph().as_default() as g: 71 with g.device("/task:0"): 72 a = constant_op.constant(0, name="a") 73 b = constant_op.constant(0, name="b") 74 control_flow_ops.group(a.op, b.op, name="root") 75 gd = g.as_graph_def() 76 self.assertProtoEquals(""" 77 node { name: "a" op: "Const" device: "/task:0" } 78 node { name: "b" op: "Const" device: "/task:0" } 79 node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } 80 """, self._StripGraph(gd)) 81 82 def testGroup_MultiDevice(self): 83 with ops.Graph().as_default() as g: 84 with g.device("/task:0"): 85 a = constant_op.constant(0, name="a") 86 b = constant_op.constant(0, name="b") 87 with g.device("/task:1"): 88 c = constant_op.constant(0, name="c") 89 d = constant_op.constant(0, name="d") 90 with g.device("/task:2"): 91 control_flow_ops.group(a.op, b.op, c.op, d.op, name="root") 92 gd = g.as_graph_def() 93 self.assertProtoEquals(""" 94 node { name: "a" op: "Const" device: "/task:0"} 95 node { name: "b" op: "Const" device: "/task:0"} 96 node { name: "c" op: "Const" device: "/task:1"} 97 node { name: "d" op: "Const" device: "/task:1"} 98 node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" 99 device: "/task:0" } 100 node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" 101 device: "/task:1" } 102 node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" 103 device: "/task:2" } 104 """, self._StripGraph(gd)) 105 106 107class ShapeTestCase(TensorFlowTestCase): 108 109 def testShape(self): 110 with ops.Graph().as_default(): 111 tensor = constant_op.constant([1.0, 2.0]) 112 self.assertEquals([2], tensor.get_shape()) 113 self.assertEquals([2], 114 control_flow_ops.with_dependencies( 115 [constant_op.constant(1.0)], tensor).get_shape()) 116 117 118class WithDependenciesTestCase(TensorFlowTestCase): 119 120 def testTupleDependencies(self): 121 with ops.Graph().as_default(): 122 counter = variable_scope.get_variable( 123 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 124 increment_counter = state_ops.assign_add(counter, 1) 125 const_with_dep = control_flow_ops.with_dependencies( 126 (increment_counter, constant_op.constant(42)), 127 constant_op.constant(7)) 128 with self.test_session(): 129 variables.global_variables_initializer().run() 130 self.assertEquals(0, counter.eval()) 131 self.assertEquals(7, const_with_dep.eval()) 132 self.assertEquals(1, counter.eval()) 133 134 def testListDependencies(self): 135 with ops.Graph().as_default(): 136 counter = variable_scope.get_variable( 137 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 138 increment_counter = state_ops.assign_add(counter, 1) 139 const_with_dep = control_flow_ops.with_dependencies( 140 [increment_counter, constant_op.constant(42)], 141 constant_op.constant(7)) 142 with self.test_session(): 143 variables.global_variables_initializer().run() 144 self.assertEquals(0, counter.eval()) 145 self.assertEquals(7, const_with_dep.eval()) 146 self.assertEquals(1, counter.eval()) 147 148 149class SwitchTestCase(TensorFlowTestCase): 150 151 def testIndexedSlicesWithDenseShape(self): 152 with self.test_session(): 153 data = ops.IndexedSlices( 154 constant_op.constant([1, 2, 3]), 155 constant_op.constant([0, 1]), 156 dense_shape=constant_op.constant([3])) 157 zero = constant_op.constant(0) 158 one = constant_op.constant(1) 159 less_op = math_ops.less(zero, one) 160 switch_false, switch_true = control_flow_ops.switch(data, less_op) 161 self.assertAllEqual([1, 2, 3], switch_true.values.eval()) 162 self.assertAllEqual([0, 1], switch_true.indices.eval()) 163 164 def testIndexedSlicesGradient(self): 165 with ops.Graph().as_default(): 166 embedding_matrix = variable_scope.get_variable( 167 "embedding_matrix", [5, 5], 168 initializer=init_ops.random_normal_initializer()) 169 170 def Cond(it, _): 171 return it < 5 172 173 def Body(it, cost): 174 embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) 175 cost += math_ops.reduce_sum(embedding) 176 return it + 1, cost 177 178 _, cost = control_flow_ops.while_loop( 179 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 180 optimizer = momentum.MomentumOptimizer(0.1, 0.9) 181 train_op = optimizer.minimize(cost) 182 with self.test_session() as sess: 183 sess.run(variables.global_variables_initializer()) 184 for _ in range(10): 185 sess.run([train_op]) 186 187 def testResourceReadInLoop(self): 188 with ops.Graph().as_default(): 189 embedding_matrix = variable_scope.get_variable( 190 "embedding_matrix", 191 initializer=[[2.0], [3.0]], 192 use_resource=True) 193 194 def Cond(it, _): 195 return it < 5 196 197 def Body(it, cost): 198 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 199 cost += math_ops.reduce_sum(embedding) 200 return it + 1, cost 201 202 _, cost = control_flow_ops.while_loop( 203 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 204 with self.test_session() as sess: 205 sess.run(variables.global_variables_initializer()) 206 self.assertAllEqual(10.0, cost.eval()) 207 208 def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): 209 with ops.Graph().as_default(): 210 embedding_matrix = variable_scope.get_variable( 211 "embedding_matrix", [5, 5], 212 initializer=init_ops.random_normal_initializer(), 213 use_resource=use_resource) 214 215 def Cond(it, _): 216 return it < 5 217 218 def Body(it, cost): 219 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 220 cost = control_flow_ops.cond( 221 math_ops.equal(it, 3), lambda: math_ops.square(cost), 222 lambda: cost + math_ops.reduce_sum(embedding)) 223 return it + 1, cost 224 225 _, cost = control_flow_ops.while_loop( 226 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 227 228 dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] 229 dynamic_grads = math_ops.segment_sum(dynamic_grads.values, 230 dynamic_grads.indices) 231 232 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 233 static = math_ops.square( 234 math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + 235 math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) 236 static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] 237 static_grads = math_ops.segment_sum(static_grads.values, 238 static_grads.indices) 239 240 with self.test_session() as sess: 241 sess.run(variables.global_variables_initializer()) 242 self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) 243 244 def testIndexedSlicesGradientInCondInWhileLoop(self): 245 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False) 246 247 def testIndexedSlicesGradientInCondInWhileLoopResource(self): 248 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True) 249 250 def testIndexedSlicesWithShapeGradientInWhileLoop(self): 251 for dtype in [dtypes.float32, dtypes.float64]: 252 with self.test_session() as sess: 253 num_steps = 9 254 255 inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) 256 initial_outputs = tensor_array_ops.TensorArray( 257 dtype=dtype, size=num_steps) 258 initial_i = constant_op.constant(0, dtype=dtypes.int32) 259 260 def Cond(i, _): 261 return i < num_steps # pylint: disable=cell-var-from-loop 262 263 def Body(i, outputs): 264 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 265 outputs = outputs.write(i, x) 266 return i + 1, outputs 267 268 _, outputs = control_flow_ops.while_loop(Cond, Body, 269 [initial_i, initial_outputs]) 270 271 outputs = math_ops.reduce_sum(outputs.stack()) 272 r = gradients_impl.gradients([outputs], [inputs])[0] 273 grad_wr_inputs = ops.convert_to_tensor(r) 274 o, grad = sess.run([outputs, grad_wr_inputs], 275 feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) 276 self.assertEquals(o, 20) 277 self.assertAllEqual(grad, [1] * num_steps) 278 279 def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): 280 for dtype in [dtypes.float32, dtypes.float64]: 281 with self.test_session() as sess: 282 inputs = array_ops.placeholder(dtype=dtype) 283 initial_outputs = tensor_array_ops.TensorArray( 284 dtype=dtype, dynamic_size=True, size=1) 285 initial_i = constant_op.constant(0, dtype=dtypes.int32) 286 287 def Cond(i, _): 288 return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop 289 290 def Body(i, outputs): 291 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 292 outputs = outputs.write(i, x) 293 return i + 1, outputs 294 295 _, outputs = control_flow_ops.while_loop(Cond, Body, 296 [initial_i, initial_outputs]) 297 298 outputs = math_ops.reduce_sum(outputs.stack()) 299 r = gradients_impl.gradients([outputs], [inputs])[0] 300 grad_wr_inputs = ops.convert_to_tensor(r) 301 o, grad = sess.run([outputs, grad_wr_inputs], 302 feed_dict={inputs: [1, 3, 2]}) 303 self.assertEquals(o, 6) 304 self.assertAllEqual(grad, [1] * 3) 305 306 307class ContextTest(TensorFlowTestCase): 308 309 def testCondContext(self): 310 with self.test_session() as sess: 311 x = constant_op.constant(2) 312 y = constant_op.constant(5) 313 control_flow_ops.cond( 314 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 315 lambda: math_ops.add(y, 23)) 316 for op in sess.graph.get_operations(): 317 c = op._get_control_flow_context() 318 if c: 319 compare.ProtoEq( 320 c.to_proto(), 321 control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) 322 323 def testWhileContext(self): 324 with self.test_session() as sess: 325 i = constant_op.constant(0) 326 c = lambda i: math_ops.less(i, 10) 327 b = lambda i: math_ops.add(i, 1) 328 control_flow_ops.while_loop(c, b, [i]) 329 for op in sess.graph.get_operations(): 330 c = op._get_control_flow_context() 331 if c: 332 compare.ProtoEq( 333 c.to_proto(), 334 control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto()) 335 336 337if __name__ == "__main__": 338 googletest.main() 339