control_flow_ops_test.py revision 6f898c6b2cbbc257d0966ee313a3670e88919463
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for control_flow_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.core.framework import graph_pb2 25from tensorflow.core.framework import node_def_pb2 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework.test_util import TensorFlowTestCase 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import embedding_ops 36from tensorflow.python.ops import gradients_impl 37from tensorflow.python.ops import init_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import state_ops 40from tensorflow.python.ops import tensor_array_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables 43import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 44from tensorflow.python.platform import googletest 45from tensorflow.python.training import momentum 46from tensorflow.python.util import nest 47 48 49TestTuple = collections.namedtuple("TestTuple", "a b") 50SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a") 51 52 53class GroupTestCase(TensorFlowTestCase): 54 55 def _StripNode(self, nd): 56 snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) 57 if nd.device: 58 snode.device = nd.device 59 return snode 60 61 def _StripGraph(self, gd): 62 """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" 63 return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) 64 65 def testGroup_NoDevices(self): 66 with ops.Graph().as_default() as g: 67 a = constant_op.constant(0, name="a") 68 b = constant_op.constant(0, name="b") 69 c = constant_op.constant(0, name="c") 70 control_flow_ops.group(a.op, b.op, c.op, name="root") 71 gd = g.as_graph_def() 72 self.assertProtoEquals(""" 73 node { name: "a" op: "Const"} 74 node { name: "b" op: "Const"} 75 node { name: "c" op: "Const"} 76 node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } 77 """, self._StripGraph(gd)) 78 79 def testGroup_OneDevice(self): 80 with ops.Graph().as_default() as g: 81 with g.device("/task:0"): 82 a = constant_op.constant(0, name="a") 83 b = constant_op.constant(0, name="b") 84 control_flow_ops.group(a.op, b.op, name="root") 85 gd = g.as_graph_def() 86 self.assertProtoEquals(""" 87 node { name: "a" op: "Const" device: "/task:0" } 88 node { name: "b" op: "Const" device: "/task:0" } 89 node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } 90 """, self._StripGraph(gd)) 91 92 def testGroup_MultiDevice(self): 93 with ops.Graph().as_default() as g: 94 with g.device("/task:0"): 95 a = constant_op.constant(0, name="a") 96 b = constant_op.constant(0, name="b") 97 with g.device("/task:1"): 98 c = constant_op.constant(0, name="c") 99 d = constant_op.constant(0, name="d") 100 with g.device("/task:2"): 101 control_flow_ops.group(a.op, b.op, c.op, d.op, name="root") 102 gd = g.as_graph_def() 103 self.assertProtoEquals(""" 104 node { name: "a" op: "Const" device: "/task:0"} 105 node { name: "b" op: "Const" device: "/task:0"} 106 node { name: "c" op: "Const" device: "/task:1"} 107 node { name: "d" op: "Const" device: "/task:1"} 108 node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" 109 device: "/task:0" } 110 node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" 111 device: "/task:1" } 112 node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" 113 device: "/task:2" } 114 """, self._StripGraph(gd)) 115 116 117class ShapeTestCase(TensorFlowTestCase): 118 119 def testShape(self): 120 with ops.Graph().as_default(): 121 tensor = constant_op.constant([1.0, 2.0]) 122 self.assertEquals([2], tensor.get_shape()) 123 self.assertEquals([2], 124 control_flow_ops.with_dependencies( 125 [constant_op.constant(1.0)], tensor).get_shape()) 126 127 128class WithDependenciesTestCase(TensorFlowTestCase): 129 130 def testTupleDependencies(self): 131 with ops.Graph().as_default(): 132 counter = variable_scope.get_variable( 133 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 134 increment_counter = state_ops.assign_add(counter, 1) 135 const_with_dep = control_flow_ops.with_dependencies( 136 (increment_counter, constant_op.constant(42)), 137 constant_op.constant(7)) 138 with self.test_session(): 139 variables.global_variables_initializer().run() 140 self.assertEquals(0, counter.eval()) 141 self.assertEquals(7, const_with_dep.eval()) 142 self.assertEquals(1, counter.eval()) 143 144 def testListDependencies(self): 145 with ops.Graph().as_default(): 146 counter = variable_scope.get_variable( 147 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 148 increment_counter = state_ops.assign_add(counter, 1) 149 const_with_dep = control_flow_ops.with_dependencies( 150 [increment_counter, constant_op.constant(42)], 151 constant_op.constant(7)) 152 with self.test_session(): 153 variables.global_variables_initializer().run() 154 self.assertEquals(0, counter.eval()) 155 self.assertEquals(7, const_with_dep.eval()) 156 self.assertEquals(1, counter.eval()) 157 158 159class SwitchTestCase(TensorFlowTestCase): 160 161 def testIndexedSlicesWithDenseShape(self): 162 with self.test_session(): 163 data = ops.IndexedSlices( 164 constant_op.constant([1, 2, 3]), 165 constant_op.constant([0, 1]), 166 dense_shape=constant_op.constant([3])) 167 zero = constant_op.constant(0) 168 one = constant_op.constant(1) 169 less_op = math_ops.less(zero, one) 170 switch_false, switch_true = control_flow_ops.switch(data, less_op) 171 self.assertAllEqual([1, 2, 3], switch_true.values.eval()) 172 self.assertAllEqual([0, 1], switch_true.indices.eval()) 173 174 def testIndexedSlicesGradient(self): 175 with ops.Graph().as_default(): 176 embedding_matrix = variable_scope.get_variable( 177 "embedding_matrix", [5, 5], 178 initializer=init_ops.random_normal_initializer()) 179 180 def Cond(it, _): 181 return it < 5 182 183 def Body(it, cost): 184 embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) 185 cost += math_ops.reduce_sum(embedding) 186 return it + 1, cost 187 188 _, cost = control_flow_ops.while_loop( 189 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 190 optimizer = momentum.MomentumOptimizer(0.1, 0.9) 191 train_op = optimizer.minimize(cost) 192 with self.test_session() as sess: 193 sess.run(variables.global_variables_initializer()) 194 for _ in range(10): 195 sess.run([train_op]) 196 197 def testResourceReadInLoop(self): 198 with ops.Graph().as_default(): 199 embedding_matrix = variable_scope.get_variable( 200 "embedding_matrix", 201 initializer=[[2.0], [3.0]], 202 use_resource=True) 203 204 def Cond(it, _): 205 return it < 5 206 207 def Body(it, cost): 208 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 209 cost += math_ops.reduce_sum(embedding) 210 return it + 1, cost 211 212 _, cost = control_flow_ops.while_loop( 213 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 214 with self.test_session() as sess: 215 sess.run(variables.global_variables_initializer()) 216 self.assertAllEqual(10.0, cost.eval()) 217 218 def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): 219 with ops.Graph().as_default(): 220 embedding_matrix = variable_scope.get_variable( 221 "embedding_matrix", [5, 5], 222 initializer=init_ops.random_normal_initializer(), 223 use_resource=use_resource) 224 225 def Cond(it, _): 226 return it < 5 227 228 def Body(it, cost): 229 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 230 cost = control_flow_ops.cond( 231 math_ops.equal(it, 3), lambda: math_ops.square(cost), 232 lambda: cost + math_ops.reduce_sum(embedding)) 233 return it + 1, cost 234 235 _, cost = control_flow_ops.while_loop( 236 Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) 237 238 dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] 239 dynamic_grads = math_ops.segment_sum(dynamic_grads.values, 240 dynamic_grads.indices) 241 242 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 243 static = math_ops.square( 244 math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + 245 math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) 246 static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] 247 static_grads = math_ops.segment_sum(static_grads.values, 248 static_grads.indices) 249 250 with self.test_session() as sess: 251 sess.run(variables.global_variables_initializer()) 252 self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) 253 254 def testIndexedSlicesGradientInCondInWhileLoop(self): 255 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False) 256 257 def testIndexedSlicesGradientInCondInWhileLoopResource(self): 258 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True) 259 260 def testIndexedSlicesWithShapeGradientInWhileLoop(self): 261 for dtype in [dtypes.float32, dtypes.float64]: 262 with self.test_session() as sess: 263 num_steps = 9 264 265 inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) 266 initial_outputs = tensor_array_ops.TensorArray( 267 dtype=dtype, size=num_steps) 268 initial_i = constant_op.constant(0, dtype=dtypes.int32) 269 270 def Cond(i, _): 271 return i < num_steps # pylint: disable=cell-var-from-loop 272 273 def Body(i, outputs): 274 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 275 outputs = outputs.write(i, x) 276 return i + 1, outputs 277 278 _, outputs = control_flow_ops.while_loop(Cond, Body, 279 [initial_i, initial_outputs]) 280 281 outputs = math_ops.reduce_sum(outputs.stack()) 282 r = gradients_impl.gradients([outputs], [inputs])[0] 283 grad_wr_inputs = ops.convert_to_tensor(r) 284 o, grad = sess.run([outputs, grad_wr_inputs], 285 feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) 286 self.assertEquals(o, 20) 287 self.assertAllEqual(grad, [1] * num_steps) 288 289 def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): 290 for dtype in [dtypes.float32, dtypes.float64]: 291 with self.test_session() as sess: 292 inputs = array_ops.placeholder(dtype=dtype) 293 initial_outputs = tensor_array_ops.TensorArray( 294 dtype=dtype, dynamic_size=True, size=1) 295 initial_i = constant_op.constant(0, dtype=dtypes.int32) 296 297 def Cond(i, _): 298 return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop 299 300 def Body(i, outputs): 301 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 302 outputs = outputs.write(i, x) 303 return i + 1, outputs 304 305 _, outputs = control_flow_ops.while_loop(Cond, Body, 306 [initial_i, initial_outputs]) 307 308 outputs = math_ops.reduce_sum(outputs.stack()) 309 r = gradients_impl.gradients([outputs], [inputs])[0] 310 grad_wr_inputs = ops.convert_to_tensor(r) 311 o, grad = sess.run([outputs, grad_wr_inputs], 312 feed_dict={inputs: [1, 3, 2]}) 313 self.assertEquals(o, 6) 314 self.assertAllEqual(grad, [1] * 3) 315 316 def testGradientThroughSingleBranchOutsideOfContext(self): 317 with self.test_session(): 318 x = constant_op.constant(2.) 319 s = constant_op.constant(True) 320 x_false, x_true = control_flow_ops.switch(x, s) 321 grad_x_true = gradients_impl.gradients(x_true, x)[0] 322 grad_x_false = gradients_impl.gradients(x_false, x)[0] 323 self.assertEquals(grad_x_true.eval(), 1.) 324 self.assertEquals(grad_x_false.eval(), 0.) 325 326 327class CondTest(TensorFlowTestCase): 328 329 def testCondTrue(self): 330 with self.test_session(): 331 x = constant_op.constant(2) 332 y = constant_op.constant(5) 333 z = control_flow_ops.cond( 334 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 335 lambda: math_ops.add(y, 23)) 336 self.assertEquals(z.eval(), 34) 337 338 def testCondFalse(self): 339 with self.test_session(): 340 x = constant_op.constant(2) 341 y = constant_op.constant(1) 342 z = control_flow_ops.cond( 343 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 344 lambda: math_ops.add(y, 23)) 345 self.assertEquals(z.eval(), 24) 346 347 def testCondTrueLegacy(self): 348 with self.test_session(): 349 x = constant_op.constant(2) 350 y = constant_op.constant(5) 351 z = control_flow_ops.cond( 352 math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), 353 fn2=lambda: math_ops.add(y, 23)) 354 self.assertEquals(z.eval(), 34) 355 356 def testCondFalseLegacy(self): 357 with self.test_session(): 358 x = constant_op.constant(2) 359 y = constant_op.constant(1) 360 z = control_flow_ops.cond( 361 math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), 362 fn2=lambda: math_ops.add(y, 23)) 363 self.assertEquals(z.eval(), 24) 364 365 def testCondMissingArg1(self): 366 with self.test_session(): 367 x = constant_op.constant(1) 368 with self.assertRaises(TypeError): 369 control_flow_ops.cond(True, false_fn=lambda: x) 370 371 def testCondMissingArg2(self): 372 with self.test_session(): 373 x = constant_op.constant(1) 374 with self.assertRaises(TypeError): 375 control_flow_ops.cond(True, lambda: x) 376 377 def testCondDuplicateArg1(self): 378 with self.test_session(): 379 x = constant_op.constant(1) 380 with self.assertRaises(TypeError): 381 control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) 382 383 def testCondDuplicateArg2(self): 384 with self.test_session(): 385 x = constant_op.constant(1) 386 with self.assertRaises(TypeError): 387 control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) 388 389 390class ContextTest(TensorFlowTestCase): 391 392 def testCondContext(self): 393 with self.test_session() as sess: 394 x = constant_op.constant(2) 395 y = constant_op.constant(5) 396 control_flow_ops.cond( 397 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 398 lambda: math_ops.add(y, 23)) 399 for op in sess.graph.get_operations(): 400 c = op._get_control_flow_context() 401 if c: 402 self.assertProtoEquals( 403 c.to_proto(), 404 control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) 405 406 def testWhileContext(self): 407 with self.test_session() as sess: 408 i = constant_op.constant(0) 409 c = lambda i: math_ops.less(i, 10) 410 b = lambda i: math_ops.add(i, 1) 411 control_flow_ops.while_loop(c, b, [i]) 412 for op in sess.graph.get_operations(): 413 c = op._get_control_flow_context() 414 if c: 415 self.assertProtoEquals( 416 c.to_proto(), 417 control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto()) 418 419 def testControlContextImportScope(self): 420 with self.test_session(): 421 constant_op.constant(0, name="a") 422 constant_op.constant(2, name="test_scope/a") 423 b1 = constant_op.constant(1, name="b") 424 b2 = constant_op.constant(3, name="test_scope/b") 425 426 c = control_flow_ops.ControlFlowContext() 427 c._values = ["a", "b"] 428 c._external_values = {"a": b1} 429 430 c_with_scope = control_flow_ops.ControlFlowContext._from_proto( 431 c._to_proto(), import_scope="test_scope") 432 433 # _values and _external_values should be have scope prepended. 434 self.assertEquals( 435 c_with_scope._values, set(["test_scope/a", "test_scope/b"])) 436 self.assertEquals( 437 c_with_scope._external_values, {"test_scope/a": b2}) 438 439 # Calling _to_proto() with export_scope should remove "test_scope". 440 self.assertProtoEquals( 441 c._to_proto(), 442 c_with_scope._to_proto(export_scope="test_scope")) 443 444 445def _GetNestedShape(nested): 446 def _GetShape(tensor): 447 if isinstance(tensor, tensor_array_ops.TensorArray): 448 return tensor_array_ops.TensorArray 449 elif isinstance(tensor, ops.IndexedSlices): 450 return tensor.dense_shape 451 else: 452 return tensor.get_shape() 453 454 return nest.map_structure(_GetShape, nested) 455 456 457def _CreateTensorArray(size, shape): 458 ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size, 459 clear_after_read=False) 460 for i in range(size): 461 ta = ta.write(i, array_ops.zeros(shape)) 462 return ta 463 464 465def _RawNestedShape(nested_shape): 466 def _RawShape(shape): 467 if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: 468 return [x.value for x in shape] 469 else: 470 return None 471 return nest.map_structure(_RawShape, nested_shape) 472 473 474# TODO(yori): Add tests for indexed slices. 475class DataTypesTest(TensorFlowTestCase): 476 477 def assertAllEqualNested(self, a, b): 478 if isinstance(a, (list, tuple)): 479 for entry_a, entry_b in zip(a, b): 480 self.assertAllEqualNested(entry_a, entry_b) 481 else: 482 self.assertAllEqual(a, b) 483 484 def _testShape(self, fn_true, fn_false, expected_shape, 485 strict=False): 486 condition = array_ops.placeholder(dtypes.bool) 487 output_cond = control_flow_ops.cond(condition, fn_true, fn_false, 488 strict=strict) 489 self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)), 490 _RawNestedShape(expected_shape)) 491 492 output_case = control_flow_ops.case([(condition, fn_true)], fn_false, 493 strict=strict) 494 self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)), 495 _RawNestedShape(expected_shape)) 496 497 def _testReturnValues(self, fn_true, fn_false, expected_value_true, 498 expected_value_false, strict=False, 499 check_cond=True): 500 condition = array_ops.placeholder(dtypes.bool) 501 output_cond = control_flow_ops.cond(condition, fn_true, fn_false, 502 strict=strict) 503 output_case = control_flow_ops.case([(condition, fn_true)], fn_false, 504 strict=strict) 505 506 with self.test_session() as sess: 507 variables.global_variables_initializer().run() 508 result_cond, result_case = sess.run([output_cond, output_case], 509 feed_dict={condition: True}) 510 self.assertAllEqualNested(result_cond, expected_value_true) 511 if check_cond: 512 self.assertAllEqualNested(result_case, expected_value_true) 513 result_cond, result_case = sess.run([output_cond, output_case], 514 feed_dict={condition: False}) 515 self.assertAllEqualNested(result_cond, expected_value_false) 516 if check_cond: 517 self.assertAllEqualNested(result_case, expected_value_false) 518 519 def test_int(self): 520 shape = tensor_shape.TensorShape([]) 521 fn_true = lambda: 1 522 fn_false = lambda: 2 523 self._testShape(fn_true, fn_false, shape) 524 self._testReturnValues(fn_true, fn_false, 1, 2) 525 self._testShape(fn_true, fn_false, shape, strict=True) 526 self._testReturnValues(fn_true, fn_false, 1, 2, strict=True) 527 528 def test_float(self): 529 shape = tensor_shape.TensorShape([]) 530 fn_true = lambda: 1.0 531 fn_false = lambda: 2.0 532 self._testShape(fn_true, fn_false, shape) 533 self._testReturnValues(fn_true, fn_false, 1.0, 2.0) 534 535 def test_noop(self): 536 shape = tensor_shape.TensorShape(None) 537 self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape) 538 self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op, 539 True, False, check_cond=False) 540 541 def test_string(self): 542 shape = tensor_shape.TensorShape([]) 543 fn_true = lambda: "abc" 544 fn_false = lambda: "xyz" 545 self._testShape(fn_true, fn_false, shape) 546 self._testReturnValues(fn_true, fn_false, b"abc", b"xyz") 547 548 def test_variable(self): 549 shape = tensor_shape.TensorShape([]) 550 fn_true = lambda: variables.Variable(3.0) 551 fn_false = lambda: variables.Variable(4.0) 552 self._testShape(fn_true, fn_false, shape) 553 self._testReturnValues(fn_true, fn_false, 3.0, 4.0) 554 555 def test_none(self): 556 fn_none = lambda: None 557 fn_tensor = lambda: constant_op.constant(1) 558 559 with self.assertRaises(ValueError): 560 control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor) 561 562 with self.assertRaises(ValueError): 563 control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none) 564 565 def test_tensors(self): 566 def _BuildTrueBranch(dtype): 567 def _Build(): 568 return (array_ops.zeros([2, 2], dtype=dtype), 569 array_ops.ones([3, 3], dtype=dtype)) 570 return _Build 571 572 def _BuildFalseBranch(dtype): 573 def _Build(): 574 return (array_ops.ones([2, 2], dtype=dtype), 575 array_ops.zeros([3, 3], dtype=dtype)) 576 return _Build 577 578 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 579 shape = (tensor_shape.TensorShape([2, 2]), 580 tensor_shape.TensorShape([3, 3])) 581 fn_true = _BuildTrueBranch(dtype) 582 fn_false = _BuildFalseBranch(dtype) 583 self._testShape(fn_true, fn_false, shape) 584 self._testReturnValues(fn_true, fn_false, 585 (np.zeros([2, 2]), np.ones([3, 3])), 586 (np.ones([2, 2]), np.zeros([3, 3]))) 587 588 def test_tensors_unknown_shape(self): 589 def _BuildTrueBranch(dtype): 590 def _Build(): 591 tensor = array_ops.zeros([2, 2], dtype=dtype) 592 tensor._shape = tensor_shape.TensorShape(None) 593 return tensor 594 return _Build 595 596 def _BuildFalseBranch(dtype): 597 def _Build(): 598 tensor = array_ops.ones([2, 2], dtype=dtype) 599 tensor._shape = tensor_shape.TensorShape(None) 600 return tensor 601 return _Build 602 603 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 604 shape = tensor_shape.TensorShape(None) 605 fn_true = _BuildTrueBranch(dtype) 606 fn_false = _BuildFalseBranch(dtype) 607 self._testShape(fn_true, fn_false, shape) 608 self._testReturnValues(fn_true, fn_false, 609 np.zeros([2, 2]), np.ones([2, 2])) 610 611 def test_sparse_tensors(self): 612 shape = tensor_shape.TensorShape([None, None]) 613 614 def FnTrue(): 615 return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], 616 values=[1, 2], dense_shape=[3, 4])] 617 618 def FnFalse(): 619 return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]], 620 values=[3, 4], dense_shape=[3, 4])] 621 622 value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]], 623 values=[1, 2], dense_shape=[3, 4]) 624 value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]], 625 values=[3, 4], dense_shape=[3, 4]) 626 self._testShape(FnTrue, FnFalse, shape) 627 self._testReturnValues(FnTrue, FnFalse, value1, value2) 628 self._testShape(FnTrue, FnFalse, [shape], strict=True) 629 self._testReturnValues(FnTrue, FnFalse, [value1], [value2], strict=True) 630 631 def test_tensors_with_partially_specified_shapes(self): 632 def _BuildBranch(dtype, shape): 633 def _Build(): 634 a = array_ops.zeros([2, 2], dtype=dtype) 635 b = array_ops.zeros([5], dtype=dtype) 636 c = array_ops.ones([3, 3], dtype=dtype) 637 a._shape = tensor_shape.TensorShape(shape[0]) 638 b._shape = tensor_shape.TensorShape(shape[1]) 639 c._shape = tensor_shape.TensorShape(shape[2]) 640 return a, b, c 641 return _Build 642 643 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 644 shape = (tensor_shape.TensorShape([None, 2]), 645 tensor_shape.TensorShape([None]), 646 tensor_shape.TensorShape([3, None])) 647 fn_true = _BuildBranch(dtype, shape) 648 fn_false = _BuildBranch(dtype, shape) 649 self._testShape(fn_true, fn_false, shape) 650 self._testReturnValues(fn_true, fn_false, 651 (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), 652 (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3]))) 653 654 def test_tensor_arrays(self): 655 element_shape = tensor_shape.TensorShape([2]) 656 ta1 = _CreateTensorArray(4, element_shape) 657 ta2 = _CreateTensorArray(4, element_shape) 658 shape = tensor_array_ops.TensorArray 659 fn_true = lambda: ta1 660 fn_false = lambda: ta2 661 self._testShape(fn_true, fn_false, shape) 662 663 def test_tensor_array_reads(self): 664 shape = tensor_shape.TensorShape([2]) 665 ta = _CreateTensorArray(4, shape) 666 fn_true = lambda: ta.read(0) 667 fn_false = lambda: ta.read(1) 668 self._testShape(fn_true, fn_false, shape) 669 670 def test_list(self): 671 shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]), 672 tensor_shape.TensorShape([])] 673 fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)] 674 fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)] 675 self._testShape(fn_true, fn_false, shape) 676 self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0]) 677 678 def test_non_strict(self): 679 shape = tensor_shape.TensorShape([]) 680 fn_tensor = lambda: constant_op.constant(1) 681 fn_list = lambda: [constant_op.constant(2)] 682 fn_tuple = lambda: (constant_op.constant(3),) 683 self._testShape(fn_tensor, fn_list, shape) 684 self._testShape(fn_tensor, fn_tuple, shape) 685 self._testShape(fn_list, fn_tuple, shape) 686 self._testReturnValues(fn_tensor, fn_list, 1, 2) 687 self._testReturnValues(fn_tensor, fn_tuple, 1, 3) 688 self._testReturnValues(fn_list, fn_tuple, 2, 3) 689 690 def test_singleton_strict(self): 691 fn_tensor = lambda: constant_op.constant(1) 692 fn_list = lambda: [constant_op.constant(2)] 693 fn_tuple = lambda: (constant_op.constant(3),) 694 695 with self.assertRaises(ValueError): 696 control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list, 697 strict=True) 698 699 with self.assertRaises(TypeError): 700 control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple, 701 strict=True) 702 703 with self.assertRaises(ValueError): 704 control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list, 705 strict=True) 706 707 with self.assertRaises(TypeError): 708 control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple, 709 strict=True) 710 711 def test_singleton_list(self): 712 shape = tensor_shape.TensorShape([]) 713 fn_true = lambda: [constant_op.constant(1)] 714 fn_false = lambda: [constant_op.constant(3)] 715 self._testShape(fn_true, fn_false, shape) 716 self._testReturnValues(fn_true, fn_false, 1, 3) 717 self._testShape(fn_true, fn_false, [shape], strict=True) 718 self._testReturnValues(fn_true, fn_false, [1], [3], strict=True) 719 720 def test_singleton_tuple(self): 721 shape = tensor_shape.TensorShape([]) 722 fn_true = lambda: (constant_op.constant(1),) 723 fn_false = lambda: (constant_op.constant(3),) 724 self._testShape(fn_true, fn_false, shape) 725 self._testReturnValues(fn_true, fn_false, 1, 3) 726 self._testShape(fn_true, fn_false, (shape,), strict=True) 727 self._testReturnValues(fn_true, fn_false, (1,), (3,), 728 strict=True) 729 730 def test_singleton_namedtuple(self): 731 shape = tensor_shape.TensorShape([]) 732 fn_true = lambda: SingletonTestTuple(constant_op.constant(1)) 733 fn_false = lambda: SingletonTestTuple(constant_op.constant(3)) 734 self._testShape(fn_true, fn_false, shape) 735 self._testReturnValues(fn_true, fn_false, 1, 3) 736 self._testShape(fn_true, fn_false, SingletonTestTuple(shape), 737 strict=True) 738 self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1), 739 SingletonTestTuple(3), strict=True) 740 741 def test_tuple(self): 742 shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) 743 fn_true = lambda: (constant_op.constant(1), 2) 744 fn_false = lambda: (constant_op.constant(3), 4) 745 self._testShape(fn_true, fn_false, shape) 746 self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4)) 747 748 def test_namedtuple(self): 749 shape = TestTuple(tensor_shape.TensorShape([]), 750 tensor_shape.TensorShape([])) 751 fn_true = lambda: TestTuple(constant_op.constant(1), 2) 752 fn_false = lambda: TestTuple(constant_op.constant(3), 4) 753 self._testShape(fn_true, fn_false, shape) 754 self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4)) 755 756 def test_nested(self): 757 shape = [tensor_shape.TensorShape([]), 758 TestTuple(tensor_shape.TensorShape([]), 759 [tensor_shape.TensorShape([]), 760 tensor_shape.TensorShape([])]), 761 tensor_shape.TensorShape([5, 5]), 762 tensor_shape.TensorShape([])] 763 764 def FnTrue(): 765 return [constant_op.constant(1), 766 TestTuple(constant_op.constant(2), [3, 4]), 767 array_ops.zeros([5, 5]), 6] 768 769 def FnFalse(): 770 return [constant_op.constant(11), 771 TestTuple(constant_op.constant(12), [13, 14]), 772 array_ops.ones([5, 5]), 16] 773 774 self._testShape(FnTrue, FnFalse, shape) 775 self._testReturnValues(FnTrue, FnFalse, 776 [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], 777 [11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16]) 778 779 def test_cond_inside_while_loop(self): 780 def Body(i, matrix): 781 result_tuple, unused_matrix = control_flow_ops.cond( 782 constant_op.constant(True), 783 lambda: (TestTuple(matrix * 2, matrix * 4), matrix), 784 lambda: (TestTuple(matrix * 4, matrix * 2), matrix)) 785 return [i+1, result_tuple.a] 786 787 iteration, matrix = control_flow_ops.while_loop( 788 lambda i, matrix: i < 10, 789 Body, 790 loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])]) 791 792 self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) 793 self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) 794 795 796class CaseTest(TensorFlowTestCase): 797 798 def testCase_withDefault(self): 799 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 800 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 801 (math_ops.equal(x, 2), lambda: constant_op.constant(4))] 802 default = lambda: constant_op.constant(6) 803 output = control_flow_ops.case(conditions, default, exclusive=True) 804 with self.test_session() as sess: 805 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 806 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 807 self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) 808 809 def testCase_multiple_matches_exclusive(self): 810 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 811 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 812 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 813 (math_ops.equal(x, 2), lambda: constant_op.constant(6))] 814 default = lambda: constant_op.constant(8) 815 output = control_flow_ops.case(conditions, default, exclusive=True) 816 with self.test_session() as sess: 817 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 818 self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) 819 with self.assertRaisesRegexp(errors.InvalidArgumentError, 820 "More than one condition evaluated as True"): 821 sess.run(output, feed_dict={x: 2}) 822 823 def testCase_multiple_matches_non_exclusive(self): 824 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 825 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 826 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 827 (math_ops.equal(x, 2), lambda: constant_op.constant(6))] 828 default = lambda: constant_op.constant(8) 829 output = control_flow_ops.case(conditions, default, exclusive=False) 830 with self.test_session() as sess: 831 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 832 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 833 self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) 834 835 def testCase_withoutDefault(self): 836 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 837 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 838 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 839 (math_ops.equal(x, 3), lambda: constant_op.constant(6))] 840 output = control_flow_ops.case(conditions, exclusive=True) 841 with self.test_session() as sess: 842 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 843 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 844 self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) 845 with self.assertRaisesRegexp( 846 errors.InvalidArgumentError, 847 r"\[None of the conditions evaluated as True. " 848 r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] " 849 r"\[0 0 0\]"): 850 sess.run(output, feed_dict={x: 4}) 851 852 def testCase_withoutDefault_oneCondition(self): 853 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 854 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))] 855 output = control_flow_ops.case(conditions, exclusive=True) 856 with self.test_session() as sess: 857 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 858 with self.assertRaisesRegexp( 859 errors.InvalidArgumentError, 860 r"\[None of the conditions evaluated as True. " 861 r"Conditions: \(Equal:0\), Values:\] \[0\]"): 862 sess.run(output, feed_dict={x: 4}) 863 864 865if __name__ == "__main__": 866 googletest.main() 867