control_flow_ops_test.py revision bcdcb78854e8dfa1b2eda813b9e2910df522abb4
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for control_flow_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.core.framework import graph_pb2 25from tensorflow.core.framework import node_def_pb2 26from tensorflow.python.client import session 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import test_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import embedding_ops 37from tensorflow.python.ops import gradients_impl 38from tensorflow.python.ops import init_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import tensor_array_ops 42from tensorflow.python.ops import variable_scope 43from tensorflow.python.ops import variables 44import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 45from tensorflow.python.platform import googletest 46from tensorflow.python.training import momentum 47from tensorflow.python.util import nest 48 49 50TestTuple = collections.namedtuple("TestTuple", "a b") 51SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a") 52 53 54class GroupTestCase(test_util.TensorFlowTestCase): 55 56 def _StripNode(self, nd): 57 snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) 58 if nd.device: 59 snode.device = nd.device 60 return snode 61 62 def _StripGraph(self, gd): 63 """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" 64 return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) 65 66 def testGroup_NoDevices(self): 67 with ops.Graph().as_default() as g: 68 a = constant_op.constant(0, name="a") 69 b = constant_op.constant(0, name="b") 70 c = constant_op.constant(0, name="c") 71 control_flow_ops.group(a.op, b.op, c.op, name="root") 72 gd = g.as_graph_def() 73 self.assertProtoEquals(""" 74 node { name: "a" op: "Const"} 75 node { name: "b" op: "Const"} 76 node { name: "c" op: "Const"} 77 node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } 78 """, self._StripGraph(gd)) 79 80 def testGroup_OneDevice(self): 81 with ops.Graph().as_default() as g: 82 with g.device("/task:0"): 83 a = constant_op.constant(0, name="a") 84 b = constant_op.constant(0, name="b") 85 control_flow_ops.group(a.op, b.op, name="root") 86 gd = g.as_graph_def() 87 self.assertProtoEquals(""" 88 node { name: "a" op: "Const" device: "/task:0" } 89 node { name: "b" op: "Const" device: "/task:0" } 90 node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } 91 """, self._StripGraph(gd)) 92 93 def testGroup_MultiDevice(self): 94 with ops.Graph().as_default() as g: 95 with g.device("/task:0"): 96 a = constant_op.constant(0, name="a") 97 b = constant_op.constant(0, name="b") 98 with g.device("/task:1"): 99 c = constant_op.constant(0, name="c") 100 d = constant_op.constant(0, name="d") 101 with g.device("/task:2"): 102 control_flow_ops.group(a.op, b.op, c.op, d.op, name="root") 103 gd = g.as_graph_def() 104 self.assertProtoEquals(""" 105 node { name: "a" op: "Const" device: "/task:0"} 106 node { name: "b" op: "Const" device: "/task:0"} 107 node { name: "c" op: "Const" device: "/task:1"} 108 node { name: "d" op: "Const" device: "/task:1"} 109 node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" 110 device: "/task:0" } 111 node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" 112 device: "/task:1" } 113 node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" 114 device: "/task:2" } 115 """, self._StripGraph(gd)) 116 117 def testPassingList(self): 118 with ops.Graph().as_default() as g: 119 a = constant_op.constant(0, name="a") 120 b = constant_op.constant(0, name="b") 121 control_flow_ops.group([a.op, b.op], name="root") 122 gd = g.as_graph_def() 123 self.assertProtoEquals(""" 124 node { name: "a" op: "Const"} 125 node { name: "b" op: "Const"} 126 node { name: "root" op: "NoOp" input: "^a" input: "^b" } 127 """, self._StripGraph(gd)) 128 129 def testPassingNonTensors(self): 130 with ops.Graph().as_default(): 131 with self.assertRaises(TypeError): 132 control_flow_ops.group(1, 2) 133 134 135class ShapeTestCase(test_util.TensorFlowTestCase): 136 137 def testShape(self): 138 with ops.Graph().as_default(): 139 tensor = constant_op.constant([1.0, 2.0]) 140 self.assertEquals([2], tensor.get_shape()) 141 self.assertEquals([2], 142 control_flow_ops.with_dependencies( 143 [constant_op.constant(1.0)], tensor).get_shape()) 144 145 146class WithDependenciesTestCase(test_util.TensorFlowTestCase): 147 148 def testTupleDependencies(self): 149 with ops.Graph().as_default(): 150 counter = variable_scope.get_variable( 151 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 152 increment_counter = state_ops.assign_add(counter, 1) 153 const_with_dep = control_flow_ops.with_dependencies( 154 (increment_counter, constant_op.constant(42)), 155 constant_op.constant(7)) 156 with self.test_session(): 157 variables.global_variables_initializer().run() 158 self.assertEquals(0, counter.eval()) 159 self.assertEquals(7, const_with_dep.eval()) 160 self.assertEquals(1, counter.eval()) 161 162 def testListDependencies(self): 163 with ops.Graph().as_default(): 164 counter = variable_scope.get_variable( 165 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 166 increment_counter = state_ops.assign_add(counter, 1) 167 const_with_dep = control_flow_ops.with_dependencies( 168 [increment_counter, constant_op.constant(42)], 169 constant_op.constant(7)) 170 with self.test_session(): 171 variables.global_variables_initializer().run() 172 self.assertEquals(0, counter.eval()) 173 self.assertEquals(7, const_with_dep.eval()) 174 self.assertEquals(1, counter.eval()) 175 176 177class SwitchTestCase(test_util.TensorFlowTestCase): 178 179 def testIndexedSlicesWithDenseShape(self): 180 with self.test_session(): 181 data = ops.IndexedSlices( 182 constant_op.constant([1, 2, 3]), 183 constant_op.constant([0, 1]), 184 dense_shape=constant_op.constant([3])) 185 zero = constant_op.constant(0) 186 one = constant_op.constant(1) 187 less_op = math_ops.less(zero, one) 188 switch_false, switch_true = control_flow_ops.switch(data, less_op) 189 self.assertAllEqual([1, 2, 3], switch_true.values.eval()) 190 self.assertAllEqual([0, 1], switch_true.indices.eval()) 191 192 def testIndexedSlicesGradient(self): 193 with ops.Graph().as_default(): 194 embedding_matrix = variable_scope.get_variable( 195 "embedding_matrix", [5, 5], 196 initializer=init_ops.random_normal_initializer()) 197 198 def Cond(it, _): 199 return it < 5 200 201 def Body(it, cost): 202 embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) 203 cost += math_ops.reduce_sum(embedding) 204 return it + 1, cost 205 206 _, cost = control_flow_ops.while_loop( 207 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 208 optimizer = momentum.MomentumOptimizer(0.1, 0.9) 209 train_op = optimizer.minimize(cost) 210 with self.test_session() as sess: 211 sess.run(variables.global_variables_initializer()) 212 for _ in range(10): 213 sess.run([train_op]) 214 215 def testResourceReadInLoop(self): 216 with ops.Graph().as_default(): 217 embedding_matrix = variable_scope.get_variable( 218 "embedding_matrix", 219 initializer=[[2.0], [3.0]], 220 use_resource=True) 221 222 def Cond(it, _): 223 return it < 5 224 225 def Body(it, cost): 226 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 227 cost += math_ops.reduce_sum(embedding) 228 return it + 1, cost 229 230 _, cost = control_flow_ops.while_loop( 231 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 232 with self.test_session() as sess: 233 sess.run(variables.global_variables_initializer()) 234 self.assertAllEqual(10.0, cost.eval()) 235 236 def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): 237 with ops.Graph().as_default(): 238 embedding_matrix = variable_scope.get_variable( 239 "embedding_matrix", [5, 5], 240 initializer=init_ops.random_normal_initializer(), 241 use_resource=use_resource) 242 243 def Cond(it, _): 244 return it < 5 245 246 def Body(it, cost): 247 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 248 cost = control_flow_ops.cond( 249 math_ops.equal(it, 3), lambda: math_ops.square(cost), 250 lambda: cost + math_ops.reduce_sum(embedding)) 251 return it + 1, cost 252 253 _, cost = control_flow_ops.while_loop( 254 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 255 256 dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] 257 dynamic_grads = math_ops.segment_sum(dynamic_grads.values, 258 dynamic_grads.indices) 259 260 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 261 static = math_ops.square( 262 math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + 263 math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) 264 static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] 265 static_grads = math_ops.segment_sum(static_grads.values, 266 static_grads.indices) 267 268 with self.test_session() as sess: 269 sess.run(variables.global_variables_initializer()) 270 self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) 271 272 def testIndexedSlicesGradientInCondInWhileLoop(self): 273 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False) 274 275 def testIndexedSlicesGradientInCondInWhileLoopResource(self): 276 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True) 277 278 def testIndexedSlicesWithShapeGradientInWhileLoop(self): 279 for dtype in [dtypes.float32, dtypes.float64]: 280 with self.test_session() as sess: 281 num_steps = 9 282 283 inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) 284 initial_outputs = tensor_array_ops.TensorArray( 285 dtype=dtype, size=num_steps) 286 initial_i = constant_op.constant(0, dtype=dtypes.int32) 287 288 def Cond(i, _): 289 return i < num_steps # pylint: disable=cell-var-from-loop 290 291 def Body(i, outputs): 292 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 293 outputs = outputs.write(i, x) 294 return i + 1, outputs 295 296 _, outputs = control_flow_ops.while_loop(Cond, Body, 297 [initial_i, initial_outputs]) 298 299 outputs = math_ops.reduce_sum(outputs.stack()) 300 r = gradients_impl.gradients([outputs], [inputs])[0] 301 grad_wr_inputs = ops.convert_to_tensor(r) 302 o, grad = sess.run([outputs, grad_wr_inputs], 303 feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) 304 self.assertEquals(o, 20) 305 self.assertAllEqual(grad, [1] * num_steps) 306 307 def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): 308 for dtype in [dtypes.float32, dtypes.float64]: 309 with self.test_session() as sess: 310 inputs = array_ops.placeholder(dtype=dtype) 311 initial_outputs = tensor_array_ops.TensorArray( 312 dtype=dtype, dynamic_size=True, size=1) 313 initial_i = constant_op.constant(0, dtype=dtypes.int32) 314 315 def Cond(i, _): 316 return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop 317 318 def Body(i, outputs): 319 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 320 outputs = outputs.write(i, x) 321 return i + 1, outputs 322 323 _, outputs = control_flow_ops.while_loop(Cond, Body, 324 [initial_i, initial_outputs]) 325 326 outputs = math_ops.reduce_sum(outputs.stack()) 327 r = gradients_impl.gradients([outputs], [inputs])[0] 328 grad_wr_inputs = ops.convert_to_tensor(r) 329 o, grad = sess.run([outputs, grad_wr_inputs], 330 feed_dict={inputs: [1, 3, 2]}) 331 self.assertEquals(o, 6) 332 self.assertAllEqual(grad, [1] * 3) 333 334 def testGradientThroughSingleBranchOutsideOfContext(self): 335 with self.test_session(): 336 x = constant_op.constant(2.) 337 s = constant_op.constant(True) 338 x_false, x_true = control_flow_ops.switch(x, s) 339 grad_x_true = gradients_impl.gradients(x_true, x)[0] 340 grad_x_false = gradients_impl.gradients(x_false, x)[0] 341 self.assertEquals(grad_x_true.eval(), 1.) 342 self.assertEquals(grad_x_false.eval(), 0.) 343 344 345@test_util.with_c_api 346class CondTest(test_util.TensorFlowTestCase): 347 348 def testCondTrue(self): 349 # Create new Graph and Session for each test so we pick up _USE_C_API 350 # correctly. 351 with ops.Graph().as_default(): 352 with session.Session(): 353 x = constant_op.constant(2) 354 y = constant_op.constant(5) 355 z = control_flow_ops.cond( 356 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 357 lambda: math_ops.add(y, 23)) 358 self.assertEquals(z.eval(), 34) 359 360 def testCondFalse(self): 361 with ops.Graph().as_default(): 362 with session.Session(): 363 x = constant_op.constant(2) 364 y = constant_op.constant(1) 365 z = control_flow_ops.cond( 366 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 367 lambda: math_ops.add(y, 23)) 368 self.assertEquals(z.eval(), 24) 369 370 def testCondTrueLegacy(self): 371 with ops.Graph().as_default(): 372 with session.Session(): 373 x = constant_op.constant(2) 374 y = constant_op.constant(5) 375 z = control_flow_ops.cond( 376 math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), 377 fn2=lambda: math_ops.add(y, 23)) 378 self.assertEquals(z.eval(), 34) 379 380 def testCondFalseLegacy(self): 381 with ops.Graph().as_default(): 382 with session.Session(): 383 x = constant_op.constant(2) 384 y = constant_op.constant(1) 385 z = control_flow_ops.cond( 386 math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), 387 fn2=lambda: math_ops.add(y, 23)) 388 self.assertEquals(z.eval(), 24) 389 390 def testCondModifyBoolPred(self): 391 # This test in particular used to fail only when running in GPU, hence 392 # use_gpu=True. 393 with ops.Graph().as_default(): 394 with session.Session() as sess: 395 bool_var = variable_scope.get_variable("bool_var", dtype=dtypes.bool, 396 initializer=True) 397 cond_on_bool_var = control_flow_ops.cond( 398 pred=bool_var, 399 true_fn=lambda: state_ops.assign(bool_var, False), 400 false_fn=lambda: True) 401 sess.run(bool_var.initializer) 402 self.assertEquals(sess.run(cond_on_bool_var), False) 403 self.assertEquals(sess.run(cond_on_bool_var), True) 404 405 def testCondMissingArg1(self): 406 with ops.Graph().as_default(): 407 with session.Session(): 408 x = constant_op.constant(1) 409 with self.assertRaises(TypeError): 410 control_flow_ops.cond(True, false_fn=lambda: x) 411 412 def testCondMissingArg2(self): 413 with ops.Graph().as_default(): 414 with session.Session(): 415 x = constant_op.constant(1) 416 with self.assertRaises(TypeError): 417 control_flow_ops.cond(True, lambda: x) 418 419 def testCondDuplicateArg1(self): 420 with ops.Graph().as_default(): 421 with session.Session(): 422 x = constant_op.constant(1) 423 with self.assertRaises(TypeError): 424 control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) 425 426 def testCondDuplicateArg2(self): 427 with ops.Graph().as_default(): 428 with session.Session(): 429 x = constant_op.constant(1) 430 with self.assertRaises(TypeError): 431 control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) 432 433 434class ContextTest(test_util.TensorFlowTestCase): 435 436 def testCondContext(self): 437 with self.test_session() as sess: 438 x = constant_op.constant(2) 439 y = constant_op.constant(5) 440 control_flow_ops.cond( 441 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 442 lambda: math_ops.add(y, 23)) 443 for op in sess.graph.get_operations(): 444 c = op._get_control_flow_context() 445 if c: 446 self.assertProtoEquals( 447 c.to_proto(), 448 control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) 449 450 def testWhileContext(self): 451 with self.test_session() as sess: 452 i = constant_op.constant(0) 453 c = lambda i: math_ops.less(i, 10) 454 b = lambda i: math_ops.add(i, 1) 455 control_flow_ops.while_loop(c, b, [i]) 456 for op in sess.graph.get_operations(): 457 c = op._get_control_flow_context() 458 if c: 459 self.assertProtoEquals( 460 c.to_proto(), 461 control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto()) 462 463 def testControlContextImportScope(self): 464 with self.test_session(): 465 constant_op.constant(0, name="a") 466 constant_op.constant(2, name="test_scope/a") 467 b1 = constant_op.constant(1, name="b") 468 b2 = constant_op.constant(3, name="test_scope/b") 469 470 c = control_flow_ops.ControlFlowContext() 471 c._values = ["a", "b"] 472 c._external_values = {"a": b1} 473 474 c_with_scope = control_flow_ops.ControlFlowContext._from_proto( 475 c._to_proto(), import_scope="test_scope") 476 477 # _values and _external_values should be have scope prepended. 478 self.assertEquals( 479 c_with_scope._values, set(["test_scope/a", "test_scope/b"])) 480 self.assertEquals( 481 c_with_scope._external_values, {"test_scope/a": b2}) 482 483 # Calling _to_proto() with export_scope should remove "test_scope". 484 self.assertProtoEquals( 485 c._to_proto(), 486 c_with_scope._to_proto(export_scope="test_scope")) 487 488 489def _GetNestedShape(nested): 490 def _GetShape(tensor): 491 if isinstance(tensor, tensor_array_ops.TensorArray): 492 return tensor_array_ops.TensorArray 493 elif isinstance(tensor, ops.IndexedSlices): 494 return tensor.dense_shape 495 else: 496 return tensor.get_shape() 497 498 return nest.map_structure(_GetShape, nested) 499 500 501def _CreateTensorArray(size, shape): 502 ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size, 503 clear_after_read=False) 504 for i in range(size): 505 ta = ta.write(i, array_ops.zeros(shape)) 506 return ta 507 508 509def _RawNestedShape(nested_shape): 510 def _RawShape(shape): 511 if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: 512 return [x.value for x in shape] 513 else: 514 return None 515 return nest.map_structure(_RawShape, nested_shape) 516 517 518# TODO(yori): Add tests for indexed slices. 519class DataTypesTest(test_util.TensorFlowTestCase): 520 521 def assertAllEqualNested(self, a, b): 522 if isinstance(a, (list, tuple)): 523 for entry_a, entry_b in zip(a, b): 524 self.assertAllEqualNested(entry_a, entry_b) 525 else: 526 self.assertAllEqual(a, b) 527 528 def _testShape(self, fn_true, fn_false, expected_shape, 529 strict=False): 530 condition = array_ops.placeholder(dtypes.bool) 531 output_cond = control_flow_ops.cond(condition, fn_true, fn_false, 532 strict=strict) 533 self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)), 534 _RawNestedShape(expected_shape)) 535 536 output_case = control_flow_ops.case([(condition, fn_true)], fn_false, 537 strict=strict) 538 self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)), 539 _RawNestedShape(expected_shape)) 540 541 def _testReturnValues(self, fn_true, fn_false, expected_value_true, 542 expected_value_false, strict=False, 543 check_cond=True, feed_dict=None): 544 if feed_dict is None: feed_dict = {} 545 546 condition = array_ops.placeholder(dtypes.bool) 547 output_cond = control_flow_ops.cond(condition, fn_true, fn_false, 548 strict=strict) 549 output_case = control_flow_ops.case([(condition, fn_true)], fn_false, 550 strict=strict) 551 552 with self.test_session() as sess: 553 variables.global_variables_initializer().run() 554 true_feed_dict = {condition: True} 555 true_feed_dict.update(feed_dict) 556 result_cond, result_case = sess.run([output_cond, output_case], 557 feed_dict=true_feed_dict) 558 self.assertAllEqualNested(result_cond, expected_value_true) 559 if check_cond: 560 self.assertAllEqualNested(result_case, expected_value_true) 561 false_feed_dict = {condition: False} 562 false_feed_dict.update(feed_dict) 563 result_cond, result_case = sess.run([output_cond, output_case], 564 feed_dict=false_feed_dict) 565 self.assertAllEqualNested(result_cond, expected_value_false) 566 if check_cond: 567 self.assertAllEqualNested(result_case, expected_value_false) 568 569 def test_int(self): 570 shape = tensor_shape.TensorShape([]) 571 fn_true = lambda: 1 572 fn_false = lambda: 2 573 self._testShape(fn_true, fn_false, shape) 574 self._testReturnValues(fn_true, fn_false, 1, 2) 575 self._testShape(fn_true, fn_false, shape, strict=True) 576 self._testReturnValues(fn_true, fn_false, 1, 2, strict=True) 577 578 def test_float(self): 579 shape = tensor_shape.TensorShape([]) 580 fn_true = lambda: 1.0 581 fn_false = lambda: 2.0 582 self._testShape(fn_true, fn_false, shape) 583 self._testReturnValues(fn_true, fn_false, 1.0, 2.0) 584 585 def test_noop(self): 586 shape = tensor_shape.TensorShape(None) 587 self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape) 588 self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op, 589 True, False, check_cond=False) 590 591 def test_string(self): 592 shape = tensor_shape.TensorShape([]) 593 fn_true = lambda: "abc" 594 fn_false = lambda: "xyz" 595 self._testShape(fn_true, fn_false, shape) 596 self._testReturnValues(fn_true, fn_false, b"abc", b"xyz") 597 598 def test_variable(self): 599 shape = tensor_shape.TensorShape([]) 600 fn_true = lambda: variables.Variable(3.0) 601 fn_false = lambda: variables.Variable(4.0) 602 self._testShape(fn_true, fn_false, shape) 603 self._testReturnValues(fn_true, fn_false, 3.0, 4.0) 604 605 def test_none(self): 606 fn_none = lambda: None 607 fn_tensor = lambda: constant_op.constant(1) 608 609 with self.assertRaises(ValueError): 610 control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor) 611 612 with self.assertRaises(ValueError): 613 control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none) 614 615 def test_tensors(self): 616 def _BuildTrueBranch(dtype): 617 def _Build(): 618 return (array_ops.zeros([2, 2], dtype=dtype), 619 array_ops.ones([3, 3], dtype=dtype)) 620 return _Build 621 622 def _BuildFalseBranch(dtype): 623 def _Build(): 624 return (array_ops.ones([2, 2], dtype=dtype), 625 array_ops.zeros([3, 3], dtype=dtype)) 626 return _Build 627 628 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 629 shape = (tensor_shape.TensorShape([2, 2]), 630 tensor_shape.TensorShape([3, 3])) 631 fn_true = _BuildTrueBranch(dtype) 632 fn_false = _BuildFalseBranch(dtype) 633 self._testShape(fn_true, fn_false, shape) 634 self._testReturnValues(fn_true, fn_false, 635 (np.zeros([2, 2]), np.ones([3, 3])), 636 (np.ones([2, 2]), np.zeros([3, 3]))) 637 638 def test_tensors_unknown_shape(self): 639 def _BuildTrueBranch(dtype): 640 tensor = array_ops.placeholder(dtype=dtype, shape=None) 641 def _Build(): 642 return tensor 643 return _Build, tensor 644 645 def _BuildFalseBranch(dtype): 646 tensor = array_ops.placeholder(dtype=dtype, shape=None) 647 def _Build(): 648 return tensor 649 return _Build, tensor 650 651 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 652 shape = tensor_shape.TensorShape(None) 653 fn_true, true_tensor = _BuildTrueBranch(dtype) 654 fn_false, false_tensor = _BuildFalseBranch(dtype) 655 self._testShape(fn_true, fn_false, shape) 656 self._testReturnValues(fn_true, fn_false, 657 np.zeros([2, 2]), np.ones([2, 2]), 658 feed_dict={true_tensor: np.zeros([2, 2]), 659 false_tensor: np.ones([2, 2])}) 660 661 def test_sparse_tensors(self): 662 shape = tensor_shape.TensorShape([None, None]) 663 664 def FnTrue(): 665 return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], 666 values=[1, 2], dense_shape=[3, 4])] 667 668 def FnFalse(): 669 return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]], 670 values=[3, 4], dense_shape=[3, 4])] 671 672 value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]], 673 values=[1, 2], dense_shape=[3, 4]) 674 value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]], 675 values=[3, 4], dense_shape=[3, 4]) 676 self._testShape(FnTrue, FnFalse, shape) 677 self._testReturnValues(FnTrue, FnFalse, value1, value2) 678 self._testShape(FnTrue, FnFalse, [shape], strict=True) 679 self._testReturnValues(FnTrue, FnFalse, [value1], [value2], strict=True) 680 681 def test_tensors_with_partially_specified_shapes(self): 682 def _BuildBranch(dtype, shape): 683 a = array_ops.placeholder(dtype=dtype, shape=shape[0]) 684 b = array_ops.placeholder(dtype=dtype, shape=shape[1]) 685 c = array_ops.placeholder(dtype=dtype, shape=shape[2]) 686 def _Build(): 687 return a, b, c 688 return _Build, (a, b, c) 689 690 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 691 shape = (tensor_shape.TensorShape([None, 2]), 692 tensor_shape.TensorShape([None]), 693 tensor_shape.TensorShape([3, None])) 694 fn_true, true_tensors = _BuildBranch(dtype, shape) 695 fn_false, false_tensors = _BuildBranch(dtype, shape) 696 self._testShape(fn_true, fn_false, shape) 697 self._testReturnValues(fn_true, fn_false, 698 (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), 699 (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), 700 feed_dict={true_tensors[0]: np.zeros([2, 2]), 701 false_tensors[0]: np.zeros([2, 2]), 702 true_tensors[1]: np.zeros([5]), 703 false_tensors[1]: np.zeros([5]), 704 true_tensors[2]: np.ones([3, 3]), 705 false_tensors[2]: np.ones([3, 3])}) 706 707 def test_tensor_arrays(self): 708 element_shape = tensor_shape.TensorShape([2]) 709 ta1 = _CreateTensorArray(4, element_shape) 710 ta2 = _CreateTensorArray(4, element_shape) 711 shape = tensor_array_ops.TensorArray 712 fn_true = lambda: ta1 713 fn_false = lambda: ta2 714 self._testShape(fn_true, fn_false, shape) 715 716 def test_tensor_array_reads(self): 717 shape = tensor_shape.TensorShape([2]) 718 ta = _CreateTensorArray(4, shape) 719 fn_true = lambda: ta.read(0) 720 fn_false = lambda: ta.read(1) 721 self._testShape(fn_true, fn_false, shape) 722 723 def test_list(self): 724 shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]), 725 tensor_shape.TensorShape([])] 726 fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)] 727 fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)] 728 self._testShape(fn_true, fn_false, shape) 729 self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0]) 730 731 def test_non_strict(self): 732 shape = tensor_shape.TensorShape([]) 733 fn_tensor = lambda: constant_op.constant(1) 734 fn_list = lambda: [constant_op.constant(2)] 735 fn_tuple = lambda: (constant_op.constant(3),) 736 self._testShape(fn_tensor, fn_list, shape) 737 self._testShape(fn_tensor, fn_tuple, shape) 738 self._testShape(fn_list, fn_tuple, shape) 739 self._testReturnValues(fn_tensor, fn_list, 1, 2) 740 self._testReturnValues(fn_tensor, fn_tuple, 1, 3) 741 self._testReturnValues(fn_list, fn_tuple, 2, 3) 742 743 def test_singleton_strict(self): 744 fn_tensor = lambda: constant_op.constant(1) 745 fn_list = lambda: [constant_op.constant(2)] 746 fn_tuple = lambda: (constant_op.constant(3),) 747 748 with self.assertRaises(ValueError): 749 control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list, 750 strict=True) 751 752 with self.assertRaises(TypeError): 753 control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple, 754 strict=True) 755 756 with self.assertRaises(ValueError): 757 control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list, 758 strict=True) 759 760 with self.assertRaises(TypeError): 761 control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple, 762 strict=True) 763 764 def test_singleton_list(self): 765 shape = tensor_shape.TensorShape([]) 766 fn_true = lambda: [constant_op.constant(1)] 767 fn_false = lambda: [constant_op.constant(3)] 768 self._testShape(fn_true, fn_false, shape) 769 self._testReturnValues(fn_true, fn_false, 1, 3) 770 self._testShape(fn_true, fn_false, [shape], strict=True) 771 self._testReturnValues(fn_true, fn_false, [1], [3], strict=True) 772 773 def test_singleton_tuple(self): 774 shape = tensor_shape.TensorShape([]) 775 fn_true = lambda: (constant_op.constant(1),) 776 fn_false = lambda: (constant_op.constant(3),) 777 self._testShape(fn_true, fn_false, shape) 778 self._testReturnValues(fn_true, fn_false, 1, 3) 779 self._testShape(fn_true, fn_false, (shape,), strict=True) 780 self._testReturnValues(fn_true, fn_false, (1,), (3,), 781 strict=True) 782 783 def test_singleton_namedtuple(self): 784 shape = tensor_shape.TensorShape([]) 785 fn_true = lambda: SingletonTestTuple(constant_op.constant(1)) 786 fn_false = lambda: SingletonTestTuple(constant_op.constant(3)) 787 self._testShape(fn_true, fn_false, shape) 788 self._testReturnValues(fn_true, fn_false, 1, 3) 789 self._testShape(fn_true, fn_false, SingletonTestTuple(shape), 790 strict=True) 791 self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1), 792 SingletonTestTuple(3), strict=True) 793 794 def test_tuple(self): 795 shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) 796 fn_true = lambda: (constant_op.constant(1), 2) 797 fn_false = lambda: (constant_op.constant(3), 4) 798 self._testShape(fn_true, fn_false, shape) 799 self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4)) 800 801 def test_namedtuple(self): 802 shape = TestTuple(tensor_shape.TensorShape([]), 803 tensor_shape.TensorShape([])) 804 fn_true = lambda: TestTuple(constant_op.constant(1), 2) 805 fn_false = lambda: TestTuple(constant_op.constant(3), 4) 806 self._testShape(fn_true, fn_false, shape) 807 self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4)) 808 809 def test_nested(self): 810 shape = [tensor_shape.TensorShape([]), 811 TestTuple(tensor_shape.TensorShape([]), 812 [tensor_shape.TensorShape([]), 813 tensor_shape.TensorShape([])]), 814 tensor_shape.TensorShape([5, 5]), 815 tensor_shape.TensorShape([])] 816 817 def FnTrue(): 818 return [constant_op.constant(1), 819 TestTuple(constant_op.constant(2), [3, 4]), 820 array_ops.zeros([5, 5]), 6] 821 822 def FnFalse(): 823 return [constant_op.constant(11), 824 TestTuple(constant_op.constant(12), [13, 14]), 825 array_ops.ones([5, 5]), 16] 826 827 self._testShape(FnTrue, FnFalse, shape) 828 self._testReturnValues(FnTrue, FnFalse, 829 [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], 830 [11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16]) 831 832 def test_cond_inside_while_loop(self): 833 def Body(i, matrix): 834 result_tuple, unused_matrix = control_flow_ops.cond( 835 constant_op.constant(True), 836 lambda: (TestTuple(matrix * 2, matrix * 4), matrix), 837 lambda: (TestTuple(matrix * 4, matrix * 2), matrix)) 838 return [i+1, result_tuple.a] 839 840 iteration, matrix = control_flow_ops.while_loop( 841 lambda i, matrix: i < 10, 842 Body, 843 loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])]) 844 845 self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) 846 self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) 847 848 849class CaseTest(test_util.TensorFlowTestCase): 850 851 def testCase_withDefault(self): 852 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 853 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 854 (math_ops.equal(x, 2), lambda: constant_op.constant(4))] 855 default = lambda: constant_op.constant(6) 856 output = control_flow_ops.case(conditions, default, exclusive=True) 857 with self.test_session() as sess: 858 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 859 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 860 self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) 861 862 def testCase_multiple_matches_exclusive(self): 863 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 864 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 865 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 866 (math_ops.equal(x, 2), lambda: constant_op.constant(6))] 867 default = lambda: constant_op.constant(8) 868 output = control_flow_ops.case(conditions, default, exclusive=True) 869 with self.test_session() as sess: 870 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 871 self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) 872 with self.assertRaisesRegexp(errors.InvalidArgumentError, 873 "More than one condition evaluated as True"): 874 sess.run(output, feed_dict={x: 2}) 875 876 def testCase_multiple_matches_non_exclusive(self): 877 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 878 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 879 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 880 (math_ops.equal(x, 2), lambda: constant_op.constant(6))] 881 default = lambda: constant_op.constant(8) 882 output = control_flow_ops.case(conditions, default, exclusive=False) 883 with self.test_session() as sess: 884 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 885 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 886 self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) 887 888 def testCase_withoutDefault(self): 889 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 890 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 891 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 892 (math_ops.equal(x, 3), lambda: constant_op.constant(6))] 893 output = control_flow_ops.case(conditions, exclusive=True) 894 with self.test_session() as sess: 895 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 896 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 897 self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) 898 with self.assertRaisesRegexp( 899 errors.InvalidArgumentError, 900 r"\[None of the conditions evaluated as True. " 901 r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] " 902 r"\[0 0 0\]"): 903 sess.run(output, feed_dict={x: 4}) 904 905 def testCase_withoutDefault_oneCondition(self): 906 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 907 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))] 908 output = control_flow_ops.case(conditions, exclusive=True) 909 with self.test_session() as sess: 910 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 911 with self.assertRaisesRegexp( 912 errors.InvalidArgumentError, 913 r"\[None of the conditions evaluated as True. " 914 r"Conditions: \(Equal:0\), Values:\] \[0\]"): 915 sess.run(output, feed_dict={x: 4}) 916 917 918if __name__ == "__main__": 919 googletest.main() 920