1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=g-long-lambda 17"""Tests for tensorflow.ops.control_flow_ops.""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import collections 24import math 25import time 26 27import numpy as np 28from six.moves import xrange # pylint: disable=redefined-builtin 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.python.client import device_lib 32from tensorflow.python.client import session 33from tensorflow.python.eager import context 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors_impl 37from tensorflow.python.framework import function 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.framework import test_util 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import data_flow_ops 45from tensorflow.python.ops import functional_ops 46from tensorflow.python.ops import gen_array_ops 47from tensorflow.python.ops import gen_control_flow_ops 48from tensorflow.python.ops import gen_data_flow_ops 49from tensorflow.python.ops import gen_logging_ops 50from tensorflow.python.ops import gen_state_ops 51from tensorflow.python.ops import gradients_impl 52from tensorflow.python.ops import init_ops 53from tensorflow.python.ops import linalg_ops 54from tensorflow.python.ops import logging_ops 55from tensorflow.python.ops import math_ops 56from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 57from tensorflow.python.ops import nn_ops 58from tensorflow.python.ops import random_ops 59from tensorflow.python.ops import resource_variable_ops 60from tensorflow.python.ops import script_ops 61from tensorflow.python.ops import state_ops 62from tensorflow.python.ops import variable_scope 63from tensorflow.python.ops import variables 64# pylint: disable=unused-import 65import tensorflow.python.ops.tensor_array_grad 66# pylint: enable=unused-import 67from tensorflow.python.platform import test 68from tensorflow.python.training import adam 69from tensorflow.python.training import gradient_descent 70from tensorflow.python.util import nest 71 72 73def check_consumers(graph): 74 """Sanity check on the consumer list of the tensors.""" 75 76 consumer_count = {} 77 for op in graph.get_operations(): 78 for v in op.inputs: 79 cnt = consumer_count.get(v, 0) 80 consumer_count[v] = cnt + 1 81 for k, v in consumer_count.items(): 82 if len(k.consumers()) != v: 83 return False 84 return True 85 86 87def all_fetchables(): 88 tensor_names = [] 89 graph = ops.get_default_graph() 90 for op in graph.get_operations(): 91 for t in op.outputs: 92 if graph.is_fetchable(t): 93 tensor_names.append(t.name) 94 return tensor_names 95 96 97def all_feedables(): 98 feedable_tensors = [] 99 graph = ops.get_default_graph() 100 for op in graph.get_operations(): 101 for t in op.inputs: 102 if graph.is_feedable(t): 103 feedable_tensors.append(t) 104 return feedable_tensors 105 106 107def opt_cfg(): 108 return config_pb2.ConfigProto( 109 allow_soft_placement=True, 110 graph_options=config_pb2.GraphOptions( 111 optimizer_options=config_pb2.OptimizerOptions( 112 opt_level=config_pb2.OptimizerOptions.L1, 113 do_function_inlining=True, 114 do_constant_folding=True))) 115 116 117def isum(s, maximum_iterations=None): 118 i = constant_op.constant(0, name="i") 119 c = lambda i, s: math_ops.less(i, 10) 120 b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)] 121 _, r_s = control_flow_ops.while_loop( 122 c, b, [i, s], maximum_iterations=maximum_iterations) 123 return r_s 124 125 126@test_util.with_c_api 127class ControlFlowTest(test.TestCase): 128 129 def testRefIdentity(self): 130 with self.test_session(): 131 v = variables.Variable(7) 132 133 v = control_flow_ops._Identity(v) 134 op = state_ops.assign(v, 9) 135 v2 = control_flow_ops.with_dependencies([op], v) 136 137 self.assertTrue(isinstance(v2, ops.Tensor)) 138 variables.global_variables_initializer().run() 139 self.assertEqual(9, v2.eval()) 140 141 def testRefEnter(self): 142 with self.test_session(): 143 v = variables.Variable(7) 144 145 enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) 146 nine = constant_op.constant(9) 147 enter_nine = gen_control_flow_ops._enter(nine, "foo_1") 148 op = state_ops.assign(enter_v, enter_nine) 149 v2 = control_flow_ops.with_dependencies([op], enter_v) 150 v3 = control_flow_ops.exit(v2) 151 variables.global_variables_initializer().run() 152 self.assertEqual(9, v3.eval()) 153 154 def testRefSwitch(self): 155 with self.test_session(): 156 v = variables.Variable(7) 157 158 p = constant_op.constant(True) 159 v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access 160 v2 = state_ops.assign(v1[1], 9) 161 variables.global_variables_initializer().run() 162 self.assertEqual(9, v2.eval()) 163 164 def testEnterMulExit(self): 165 with self.test_session(): 166 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 167 enter_data = gen_control_flow_ops._enter(data, "foo_1", False) 168 five = constant_op.constant(5) 169 enter_five = gen_control_flow_ops._enter(five, "foo_1", False) 170 mul_op = math_ops.multiply(enter_data, enter_five) 171 exit_op = control_flow_ops.exit(mul_op) 172 173 result = exit_op.eval() 174 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 175 176 def testEnterShapePropagation(self): 177 with self.test_session(): 178 v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) 179 180 # If is_constant=True, the shape information should be propagated. 181 enter_v_constant = gen_control_flow_ops._enter( 182 v, "frame1", is_constant=True) 183 self.assertEqual(enter_v_constant.shape, [2]) 184 185 # Otherwise, the shape should be unknown. 186 enter_v_non_constant = gen_control_flow_ops._enter( 187 v, "frame2", is_constant=False) 188 self.assertEqual(enter_v_non_constant.shape, None) 189 190 def testSwitchMergeIndexedSlices(self): 191 with self.test_session(): 192 values = constant_op.constant([1, 2, 3, 4, 5, 6]) 193 indices = constant_op.constant([0, 2, 4, 6, 8, 10]) 194 data = ops.IndexedSlices(values, indices) 195 pred = ops.convert_to_tensor(True) 196 switch_op = control_flow_ops.switch(data, pred) 197 merge_op = control_flow_ops.merge(switch_op)[0] 198 199 val = merge_op.values.eval() 200 ind = merge_op.indices.eval() 201 self.assertAllEqual(np.arange(1, 7), val) 202 self.assertAllEqual(np.arange(0, 12, 2), ind) 203 204 def testSwitchDeadBranch(self): 205 with self.test_session(): 206 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 207 ports = ops.convert_to_tensor(True, name="ports") 208 switch_op = control_flow_ops.switch(data, ports) 209 dead_branch = array_ops.identity(switch_op[0]) 210 211 with self.assertRaisesWithPredicateMatch( 212 errors_impl.InvalidArgumentError, 213 lambda e: "Retval[0] does not have value" in str(e)): 214 dead_branch.eval() 215 216 def testSwitchMergeLess(self): 217 with self.test_session(): 218 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 219 zero = ops.convert_to_tensor(0) 220 one = ops.convert_to_tensor(1) 221 less_op = math_ops.less(zero, one) 222 switch_op = control_flow_ops.switch(data, less_op) 223 merge_op = control_flow_ops.merge(switch_op)[0] 224 225 result = merge_op.eval() 226 self.assertAllEqual(np.arange(1, 7), result) 227 228 def testSwitchMergeAddIdentity(self): 229 with self.test_session(): 230 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 231 ports = ops.convert_to_tensor(False, name="ports") 232 switch_op = control_flow_ops.switch(data, ports) 233 one = constant_op.constant(1) 234 add_op = math_ops.add(switch_op[0], one) 235 id_op = array_ops.identity(switch_op[1]) 236 merge_op = control_flow_ops.merge([add_op, id_op])[0] 237 238 result = merge_op.eval() 239 self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result) 240 241 def testSwitchMergeAddMul(self): 242 with self.test_session(): 243 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 244 ports = ops.convert_to_tensor(True, name="ports") 245 switch_op = control_flow_ops.switch(data, ports) 246 one = constant_op.constant(1) 247 add_op = math_ops.add(switch_op[0], one) 248 five = constant_op.constant(5) 249 mul_op = math_ops.multiply(switch_op[1], five) 250 merge_op = control_flow_ops.merge([add_op, mul_op])[0] 251 252 result = merge_op.eval() 253 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 254 255 def testLoop_false(self): 256 with self.test_session(): 257 false = ops.convert_to_tensor(False) 258 n = constant_op.constant(10) 259 260 enter_false = gen_control_flow_ops._enter(false, "foo_1", False) 261 enter_n = gen_control_flow_ops._enter(n, "foo_1", False) 262 263 merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] 264 switch_n = control_flow_ops.switch(merge_n, enter_false) 265 exit_n = control_flow_ops.exit(switch_n[0]) 266 next_n = control_flow_ops.next_iteration(switch_n[0]) 267 merge_n.op._update_input(1, next_n) 268 269 result = exit_n.eval() 270 self.assertAllEqual(10, result) 271 272 def testLoop_1(self): 273 with self.test_session(): 274 zero = constant_op.constant(0) 275 one = constant_op.constant(1) 276 n = constant_op.constant(10) 277 278 enter_i = gen_control_flow_ops._enter(zero, "foo", False) 279 enter_one = gen_control_flow_ops._enter(one, "foo", True) 280 enter_n = gen_control_flow_ops._enter(n, "foo", True) 281 282 with ops.device(test.gpu_device_name()): 283 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 284 285 less_op = math_ops.less(merge_i, enter_n) 286 cond_op = control_flow_ops.loop_cond(less_op) 287 switch_i = control_flow_ops.switch(merge_i, cond_op) 288 289 add_i = math_ops.add(switch_i[1], enter_one) 290 291 next_i = control_flow_ops.next_iteration(add_i) 292 merge_i.op._update_input(1, next_i) 293 294 exit_i = control_flow_ops.exit(switch_i[0]) 295 result = exit_i.eval() 296 self.assertAllEqual(10, result) 297 298 def testLoop_2(self): 299 with self.test_session(): 300 zero = constant_op.constant(0) 301 one = constant_op.constant(1) 302 n = constant_op.constant(10) 303 304 enter_i = gen_control_flow_ops._enter(zero, "foo", False) 305 enter_one = gen_control_flow_ops._enter(one, "foo", True) 306 enter_n = gen_control_flow_ops._enter(n, "foo", True) 307 308 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 309 310 less_op = math_ops.less(merge_i, enter_n) 311 cond_op = control_flow_ops.loop_cond(less_op) 312 switch_i = control_flow_ops.switch(merge_i, cond_op) 313 314 add_i = math_ops.add(switch_i[1], enter_one) 315 316 with ops.device(test.gpu_device_name()): 317 next_i = control_flow_ops.next_iteration(add_i) 318 merge_i.op._update_input(1, next_i) 319 320 exit_i = control_flow_ops.exit(switch_i[0]) 321 result = exit_i.eval() 322 self.assertAllEqual(10, result) 323 324 def testDifferentFrame(self): 325 with self.test_session(): 326 data = array_ops.placeholder(dtypes.float32, shape=[]) 327 enter_1 = gen_control_flow_ops._enter(data, "foo_1", False) 328 enter_2 = gen_control_flow_ops._enter(data, "foo_2", False) 329 res = math_ops.add(enter_1, enter_2) 330 with self.assertRaisesOpError("has inputs from different frames"): 331 res.eval(feed_dict={data: 1.0}) 332 333 def testCondBool(self): 334 values = constant_op.constant(10) 335 fn1 = lambda: math_ops.add(values, 1) 336 fn2 = lambda: math_ops.subtract(values, 1) 337 with self.assertRaisesRegexp(TypeError, "must not be a Python bool"): 338 _ = control_flow_ops.cond(False, fn1, fn2) 339 340 def testCondInt(self): 341 p = array_ops.placeholder(dtypes.bool, shape=[]) 342 v = constant_op.constant(10) 343 fn1 = lambda: math_ops.add(v, 1) 344 fn2 = lambda: math_ops.subtract(v, 1) 345 y = control_flow_ops.cond(p, fn1, fn2) 346 grad = gradients_impl.gradients(y, [v]) 347 self.assertAllEqual([None], grad) 348 349 def testFetchable(self): 350 with self.test_session() as sess: 351 x = array_ops.placeholder(dtypes.float32) 352 control_flow_ops.cond( 353 constant_op.constant(True), lambda: x + 2, lambda: x + 0) 354 graph = ops.get_default_graph() 355 for op in graph.get_operations(): 356 for t in op.inputs: 357 if graph.is_fetchable(t.op): 358 sess.run(t, feed_dict={x: 3}) 359 else: 360 with self.assertRaisesRegexp(ValueError, 361 "has been marked as not fetchable"): 362 sess.run(t, feed_dict={x: 3}) 363 364 def testFeedable(self): 365 with self.test_session() as sess: 366 c = constant_op.constant(2) 367 i0 = constant_op.constant(0) 368 r = control_flow_ops.while_loop(lambda i: i < 1000, 369 lambda i: math_ops.square(c) + i, [i0]) 370 self.assertEqual(1000, r.eval(feed_dict={i0: 0})) 371 feedable_tensors = all_feedables() 372 for t in feedable_tensors: 373 sess.run(r, feed_dict={t: 3}) 374 graph = ops.get_default_graph() 375 for op in graph.get_operations(): 376 for t in op.inputs: 377 if t not in feedable_tensors and t.dtype is dtypes.int32: 378 with self.assertRaisesRegexp(ValueError, "may not be fed"): 379 sess.run(r, feed_dict={t: 3}) 380 381 def testCondIndexedSlices(self): 382 with self.test_session(): 383 values = constant_op.constant(10) 384 indices = constant_op.constant(0) 385 x = ops.IndexedSlices(values, indices) 386 pred = math_ops.less(1, 2) 387 fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices) 388 fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices) 389 r = control_flow_ops.cond(pred, fn1, fn2) 390 391 val = r.values.eval() 392 ind = r.indices.eval() 393 self.assertAllEqual(11, val) 394 self.assertAllEqual(0, ind) 395 396 def testCondSparseTensor(self): 397 with self.test_session(): 398 values = constant_op.constant([2.0, 4.0], name="values") 399 indices = constant_op.constant( 400 [[0], [3]], dtype=dtypes.int64, name="indices") 401 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 402 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 403 pred = math_ops.less(1, 2) 404 fn1 = lambda: sparse_tensor.SparseTensor( 405 indices + 1, x.values + 1, dense_shape=shape) 406 fn2 = lambda: sparse_tensor.SparseTensor( 407 indices, x.values - 1, dense_shape=shape) 408 r = control_flow_ops.cond(pred, fn1, fn2) 409 self.assertAllEqual([3.0, 5.0], r.values.eval()) 410 self.assertAllEqual([[1], [4]], r.indices.eval()) 411 self.assertAllEqual(r.values.get_shape(), (2,)) 412 413 def testCondResource(self): 414 with self.test_session(): 415 rv = resource_variable_ops.ResourceVariable(True) 416 variables.global_variables_initializer().run() 417 t = ops.convert_to_tensor(1.0) 418 419 def case(): 420 assign = resource_variable_ops.assign_variable_op(rv.handle, False) 421 with ops.control_dependencies([assign]): 422 return array_ops.identity(t) 423 424 self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval()) 425 426 def testCondIndexedSlicesDifferentTypes(self): 427 with self.test_session(): 428 values = constant_op.constant(10) 429 i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32) 430 i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64) 431 x = ops.IndexedSlices(values, i_32) 432 pred = math_ops.less(1, 2) 433 fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32) 434 fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64) 435 r = control_flow_ops.cond(pred, fn1, fn2) 436 437 val = r.values.eval() 438 ind = r.indices.eval() 439 self.assertAllEqual(11, val) 440 self.assertAllEqual(0, ind) 441 self.assertTrue(ind.dtype == np.int64) 442 443 def testCondColocation(self): 444 with self.test_session(use_gpu=True): 445 with ops.device("/cpu:0"): 446 v = variables.Variable(7.0) 447 448 x = constant_op.constant(10.0) 449 pred = math_ops.less(1.0, 2.0) 450 fn1 = lambda: math_ops.add(v, 1.0) 451 fn2 = lambda: math_ops.subtract(x, 1.0) 452 r = control_flow_ops.cond(pred, fn1, fn2) 453 454 for op in x.graph.get_operations(): 455 if op.name == "cond/Add/Switch": 456 self.assertDeviceEqual(op.device, "/cpu:0") 457 458 def _testCond_1(self, use_gpu): 459 with self.test_session(use_gpu=use_gpu): 460 x = constant_op.constant(10) 461 pred = math_ops.less(1, 2) 462 fn1 = lambda: math_ops.add(x, 1) 463 fn2 = lambda: math_ops.subtract(x, 1) 464 r = control_flow_ops.cond(pred, fn1, fn2) 465 466 result = r.eval() 467 self.assertAllEqual(11, result) 468 469 def testCond_1(self): 470 self._testCond_1(use_gpu=False) 471 self._testCond_1(use_gpu=True) 472 473 def testCond_2(self): 474 with self.test_session(): 475 x = constant_op.constant(10) 476 r = control_flow_ops.cond( 477 math_ops.less(1, 0), lambda: math_ops.add(x, 1), 478 lambda: math_ops.subtract(x, 1)) 479 result = r.eval() 480 self.assertAllEqual(9, result) 481 482 def testCond_3(self): 483 with self.test_session(): 484 x = constant_op.constant(10) 485 pred = math_ops.less(1, 2) 486 fn1 = lambda: math_ops.add(x, 1) 487 fn2 = lambda: math_ops.subtract(x, 1) 488 fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1) 489 r = control_flow_ops.cond(pred, fn3, fn2) 490 491 result = r.eval() 492 self.assertAllEqual(12, result) 493 494 def testCond_4(self): 495 with self.test_session(): 496 v1 = variables.Variable(7) 497 v2 = variables.Variable(7) 498 v3 = variables.Variable(7) 499 500 age = constant_op.constant(3) 501 max_age = constant_op.constant(2) 502 pred = math_ops.greater(age, max_age) 503 fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op] 504 fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op] 505 r = control_flow_ops.cond(pred, fn1, fn2) 506 507 variables.global_variables_initializer().run() 508 self.assertEqual(len(r), 2) 509 result = r[1].eval() 510 self.assertAllEqual(True, result) 511 self.assertAllEqual(7, v1.eval()) 512 self.assertAllEqual(2, v2.eval()) 513 self.assertAllEqual(7, v3.eval()) 514 515 def testCond_5(self): 516 with self.test_session(): 517 alive = constant_op.constant(True, name="alive") 518 count = constant_op.constant(0, name="count") 519 520 def body(i): 521 return control_flow_ops.cond( 522 alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)], 523 lambda: [alive, count]) 524 525 for i in range(10): 526 alive, count = body(i) 527 self.assertAllEqual(4, count.eval()) 528 529 def testCond_6(self): 530 with self.test_session(): 531 v1 = variables.Variable([7]) 532 533 age = constant_op.constant(3) 534 pred = math_ops.greater(age, 4) 535 fn1 = lambda: age 536 fn2 = lambda: v1 537 r = control_flow_ops.cond(pred, fn1, fn2) 538 539 variables.global_variables_initializer().run() 540 result = r.eval() 541 self.assertAllEqual(np.array([7]), result) 542 543 def testCond_7(self): 544 with self.test_session() as sess: 545 x = constant_op.constant(10) 546 y = constant_op.constant(200) 547 pred = math_ops.less(1, 2) 548 fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)] 549 fn2 = lambda: [y, y] 550 r = control_flow_ops.cond(pred, fn1, fn2) 551 self.assertAllEqual([11, 12], sess.run(r)) 552 553 def testCondRef(self): 554 with self.test_session(): 555 x = gen_state_ops._variable( 556 shape=[1], 557 dtype=dtypes.float32, 558 name="x", 559 container="", 560 shared_name="") 561 true_fn = lambda: x 562 false_fn = lambda: constant_op.constant([2.0]) 563 r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn) 564 self.assertAllEqual([2.0], r.eval()) 565 566 def testCondWithControl(self): 567 with self.test_session() as sess: 568 control_holder = array_ops.placeholder(dtypes.float32, shape=()) 569 a = constant_op.constant(3) 570 571 def true_branch(): 572 with ops.control_dependencies([control_holder]): 573 _ = a + 1 574 return a + 2 575 576 r = control_flow_ops.cond( 577 constant_op.constant(True), true_branch, 578 lambda: constant_op.constant(1)) 579 self.assertEqual(5, r.eval()) 580 581 def testUninitializedRefIdentity(self): 582 with self.test_session() as sess: 583 v = gen_state_ops._variable( 584 shape=[1], 585 dtype=dtypes.float32, 586 name="v", 587 container="", 588 shared_name="") 589 inited = state_ops.is_variable_initialized(v) 590 v_f, v_t = control_flow_ops.ref_switch(v, inited) 591 # Both v_f and v_t are uninitialized references. However, an actual use 592 # of the reference in the 'true' branch in the 'tf.identity' op will 593 # not 'fire' when v is uninitialized, so this is a valid construction. 594 # This test tests that _ref_identity allows uninitialized ref as input 595 # so that this construction is allowed. 596 v_f_op = gen_array_ops._ref_identity(v_f) 597 v_t_op = gen_array_ops._ref_identity(v_t) 598 with ops.control_dependencies([v_f_op]): 599 assign_v = state_ops.assign(v, [1.0]) 600 with ops.control_dependencies([v_t_op]): 601 orig_v = array_ops.identity(v) 602 merged_op = control_flow_ops.merge([assign_v, orig_v]) 603 self.assertAllEqual([1.0], sess.run(merged_op.output)) 604 605 def testCondSwitchIdentity(self): 606 # Make sure the recv identity is not removed by optimization. 607 with session.Session(config=opt_cfg()) as sess: 608 pred = constant_op.constant(True) 609 610 def fn1(): 611 return control_flow_ops.no_op() 612 613 def fn2(): 614 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 615 616 r = control_flow_ops.cond(pred, fn1, fn2) 617 sess.run(r) 618 619 def testCondRecvIdentity(self): 620 # Make sure the switch identity is not removed by optimization. 621 with session.Session(config=opt_cfg()) as sess: 622 with ops.device(test.gpu_device_name()): 623 pred = constant_op.constant(True) 624 625 def fn1(): 626 return control_flow_ops.no_op() 627 628 def fn2(): 629 with ops.device("/cpu:0"): 630 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 631 632 r = control_flow_ops.cond(pred, fn1, fn2) 633 sess.run(r) 634 635 def testCondGrad_1(self): 636 with self.test_session(): 637 x = constant_op.constant(10.0, name="x") 638 pred = math_ops.less(1, 2) 639 fn1 = lambda: array_ops.identity(x) 640 fn2 = lambda: array_ops.identity(x) 641 r = control_flow_ops.cond(pred, fn1, fn2) 642 643 grad = gradients_impl.gradients(r, [x])[0] 644 result = grad.eval() 645 self.assertAllEqual(1.0, result) 646 647 def testCondGrad_2(self): 648 with self.test_session(): 649 c = array_ops.placeholder(dtypes.int32, shape=[]) 650 x = constant_op.constant(10.0) 651 pred = math_ops.less(c, 2) 652 fn1 = lambda: math_ops.multiply(x, 42.0) 653 fn2 = lambda: math_ops.multiply(x, 3.0) 654 r = control_flow_ops.cond(pred, fn1, fn2) 655 656 grad = gradients_impl.gradients(r, [x])[0] 657 self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1})) 658 self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) 659 660 def testNestedCond_Simple(self): 661 with self.test_session(): 662 x = constant_op.constant(0., name="X") 663 y = control_flow_ops.cond( 664 constant_op.constant(True), lambda: x, 665 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 666 result = gradients_impl.gradients(y, x)[0] 667 self.assertEqual(1.0, result.eval()) 668 669 z = control_flow_ops.cond( 670 constant_op.constant(False), lambda: x, 671 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 672 result = gradients_impl.gradients(z, x)[0] 673 self.assertEqual(1.0, result.eval()) 674 675 def testCondGrad_Gather(self): 676 with self.test_session() as sess: 677 v1 = variables.Variable([1.0, 42.0]) 678 c = array_ops.placeholder(dtypes.int32, shape=[]) 679 pred = math_ops.less(c, 2) 680 fn1 = lambda: array_ops.identity(v1) 681 fn2 = lambda: array_ops.gather(v1, [1, 1]) 682 r = control_flow_ops.cond(pred, fn1, fn2) 683 grad = gradients_impl.gradients(r, [v1])[0] 684 variables.global_variables_initializer().run() 685 # Should just be [1, 1], but possibly a sparse representation 686 gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1}) 687 dense_gv = [ 688 sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2) 689 ] 690 self.assertAllEqual(dense_gv, [1.0, 1.0]) 691 # Should be [0, 2], as the else forwards v1[1] twice 692 gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3}) 693 dense_gv = [ 694 sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2) 695 ] 696 self.assertAllEqual(dense_gv, [0.0, 2.0]) 697 698 # Microbenchmark: 256,000 iterations/s. 699 def testWhile_1(self): 700 with self.test_session(): 701 n = constant_op.constant(0) 702 c = lambda x: math_ops.less(x, 10000) 703 b = lambda x: math_ops.add(x, 1) 704 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 705 self.assertEqual(10000, r.eval()) 706 707 def testWhileExternalControlDependencies(self): 708 with self.test_session(): 709 v = variables.Variable(0.0) 710 v.initializer.run() 711 increment = v.assign_add(1.0) 712 713 def body_fn(i): 714 with ops.control_dependencies([increment]): 715 return i + i 716 717 result = control_flow_ops.while_loop(cond=lambda i: i < 1, 718 body=body_fn, loop_vars=[1]) 719 result.eval() 720 self.assertAllEqual(v.eval(), 1.0) 721 722 def testWhileExternalControlDependenciesNoInput(self): 723 with self.test_session(): 724 v = variables.Variable(0.0) 725 v.initializer.run() 726 increment = v.assign_add(1.0) 727 728 def body_fn(unused_i): 729 with ops.control_dependencies([increment]): 730 return constant_op.constant(5, name="five") 731 732 result = control_flow_ops.while_loop(cond=lambda i: i < 5, 733 body=body_fn, loop_vars=[0]) 734 result.eval() 735 self.assertAllEqual(v.eval(), 1.0) 736 737 def testWhileWithRefs_1(self): 738 with self.test_session() as sess: 739 x = variables.Variable(0)._ref() # pylint: disable=protected-access 740 i = constant_op.constant(0) 741 c = lambda i, x: math_ops.less(i, 100) 742 743 self.assertEqual(x.dtype, dtypes.int32_ref) 744 745 def b(i, x): 746 self.assertEqual(x.dtype, dtypes.int32_ref) 747 return (i + 1, gen_array_ops._ref_identity(x)) 748 749 r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5) 750 751 variables.global_variables_initializer().run() 752 753 self.assertEqual(r[0].dtype, dtypes.int32) 754 self.assertEqual(r[1].dtype, dtypes.int32_ref) 755 756 value_i, value_x = sess.run(r) 757 758 self.assertEqual(100, value_i) 759 self.assertEqual(0, value_x) 760 761 def testWhile_2(self): 762 with self.test_session(): 763 s = constant_op.constant(0) 764 r = isum(s) 765 self.assertAllEqual(45, r.eval()) 766 767 def testWhileWithMaximumIterations(self): 768 with self.test_session(): 769 s = constant_op.constant([1, 2, 3, 4, 5]) 770 r = isum(s, maximum_iterations=3) 771 self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval()) 772 773 def testWhileWithMaximumIterationsAndSingleArgument(self): 774 with self.test_session(): 775 r = control_flow_ops.while_loop( 776 lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) 777 self.assertEqual(1, r.eval()) 778 779 def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): 780 v = constant_op.constant(1.0) 781 782 def training_loop_with_gradient(i): 783 out = control_flow_ops.while_loop( 784 lambda i_, _: i_ < 3, 785 lambda i_, j: [i_ + 1, j * v], [0, 1.0], 786 maximum_iterations=i) 787 g = gradients_impl.gradients(out, v) 788 with ops.control_dependencies(g): 789 return i + 1 790 791 xla_context = control_flow_ops.XLAControlFlowContext() 792 xla_context.Enter() 793 # Create training loop, ensure we can call gradient() of 794 # while_loop inside the training loop. 795 loop = control_flow_ops.while_loop(lambda i: i < 3, 796 training_loop_with_gradient, [0]) 797 xla_context.Exit() 798 799 loop_execute = array_ops.identity(loop) # Because loop is not fetchable. 800 801 # Should execute without issue. 802 self.assertEqual(3, self.evaluate(loop_execute)) 803 804 def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): 805 v = constant_op.constant(1.0) 806 807 def inner_body(i, x): 808 out = control_flow_ops.while_loop( 809 lambda i, _: i < 3, 810 lambda i, j: [i + 1, j * v], [0, x], 811 maximum_iterations=i) 812 return out 813 814 def create_while_loop(maximum_iterations=None): 815 return control_flow_ops.while_loop( 816 lambda i, _: i < 3, 817 inner_body, [0, 1.0], 818 maximum_iterations=maximum_iterations) 819 820 loop_no_xla = create_while_loop(maximum_iterations=5) 821 # maximum_iterations is fine outside of an XLA scope 822 gs = gradients_impl.gradients(loop_no_xla, v) 823 self.evaluate(gs) # This should execute without error. 824 825 xla_context = control_flow_ops.XLAControlFlowContext() 826 xla_context.Enter() 827 loop_no_maxiter = create_while_loop() 828 loop_with_maxiter = create_while_loop(maximum_iterations=2) 829 xla_context.Exit() 830 831 with self.assertRaisesRegexp( 832 ValueError, 833 r"Cannot create a gradient accumulator for tensor '.+' inside " 834 r"XLA while_loop because maximum_iterations was not passed to " 835 r"the tf.while_loop call \('.+'\)."): 836 _ = gradients_impl.gradients(loop_no_maxiter, v) 837 838 with self.assertRaisesRegexp( 839 ValueError, 840 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 841 r"while_loop. maximum_iterations tensor '.+' for while_loop context " 842 r"'.+' must be statically known \(e.g. a constant value or known " 843 r"shape dimension\), or be defined at or outside the while loop " 844 r"context '.*' \(currently defined in '.*'\)"): 845 _ = gradients_impl.gradients(loop_with_maxiter, v) 846 847 def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): 848 v = constant_op.constant(1.0) 849 850 def create_while_loop(): 851 max_iter_holder = [] 852 853 def create_mi(): 854 max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) 855 return 1.0 856 857 _ = control_flow_ops.cond( 858 constant_op.constant(True), create_mi, create_mi) 859 860 return control_flow_ops.while_loop( 861 lambda i, _: i < 3, 862 lambda i, x: (i + 1, v * x), (0, 1.0), 863 maximum_iterations=max_iter_holder[0]) 864 865 xla_context = control_flow_ops.XLAControlFlowContext() 866 xla_context.Enter() 867 loop = create_while_loop() 868 xla_context.Exit() 869 870 with self.assertRaisesRegexp( 871 ValueError, 872 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 873 r"while_loop. maximum_iterations tensor '.*Placeholder:0' for " 874 r"while_loop context '.+' must be statically known \(e.g. a constant " 875 r"value or known shape dimension\), or be defined at or outside the " 876 r"while loop context '' \(currently defined in 'cond/.+'\)"): 877 _ = gradients_impl.gradients(loop, v) 878 879 def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): 880 v = constant_op.constant(1.0) 881 882 p = array_ops.placeholder(dtype=dtypes.int32) 883 884 def mid_body_builder(iterations): 885 886 def mid_body(i, x): 887 r = control_flow_ops.while_loop( 888 lambda *_: True, 889 lambda i, x: (i + 1, v * x), (0, x), 890 maximum_iterations=iterations, 891 name="inner") 892 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 893 894 return mid_body 895 896 def outer_body(i, x): 897 iterations = array_ops.size(p, name="iterations") 898 return (i + 1, x + control_flow_ops.while_loop( 899 lambda *_: True, 900 mid_body_builder(iterations), (0, x), 901 maximum_iterations=iterations, 902 name="mid")[1]) 903 904 def create_while_loop(): 905 with ops.device("/cpu:0"): 906 r = control_flow_ops.while_loop( 907 lambda *_: True, 908 outer_body, (0, 1.0), 909 maximum_iterations=5, 910 name="outer") 911 return array_ops.identity(r[1]) 912 913 xla_context = control_flow_ops.XLAControlFlowContext() 914 xla_context.Enter() 915 final_with_xla_context = create_while_loop() 916 xla_context.Exit() 917 918 final_without_xla_context = create_while_loop() 919 920 with self.test_session(use_gpu=False) as sess: 921 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 922 run_metadata = config_pb2.RunMetadata() 923 924 final_value_without_xla_context = sess.run( 925 final_without_xla_context, feed_dict={ 926 p: [0, 0, 0] 927 }) 928 929 final_value_with_xla_context = sess.run( 930 final_with_xla_context, 931 feed_dict={p: [0, 0, 0]}, 932 options=opts, 933 run_metadata=run_metadata) 934 935 node_stats = run_metadata.step_stats.dev_stats[0].node_stats 936 stack_push_count = len( 937 [x for x in node_stats if x.node_name.endswith("StackPushV2")]) 938 # Pushes to the stack = product of maximum_iterations values; 939 # the last two "3"s comes from size(p), when p == [0, 0, 0]. 940 self.assertEqual(stack_push_count, 5 * 3 * 3) 941 942 self.assertAllClose(final_value_with_xla_context, 943 final_value_without_xla_context) 944 945 # Have more than 10 parallel iterations and hence exercise k-bound 946 # most of the time. 947 def testWhile_3(self): 948 with self.test_session(): 949 950 def compute(i, m, c, o): 951 m, c = [math_ops.add(m, 1), math_ops.add(c, 1)] 952 o = math_ops.add(o, m) 953 o = math_ops.add(o, c) 954 i = math_ops.add(i, 1) 955 return [i, m, c, o] 956 957 i = ops.convert_to_tensor(0) 958 m = ops.convert_to_tensor(0) 959 c = ops.convert_to_tensor(0) 960 o = ops.convert_to_tensor(0) 961 d = ops.convert_to_tensor(100) 962 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d), 963 compute, [i, m, c, o]) 964 result = r[3].eval() 965 self.assertAllEqual(10100, result) 966 967 def testWhile_4(self): 968 with self.test_session(): 969 970 def compute(i, m, c, o): 971 m, c = [array_ops.gather(x, i), array_ops.gather(x, i)] 972 o = math_ops.add(o, m) 973 o = math_ops.add(o, c) 974 i = math_ops.add(i, 1) 975 return [i, m, c, o] 976 977 i = ops.convert_to_tensor(0) 978 m = ops.convert_to_tensor(0) 979 c = ops.convert_to_tensor(0) 980 o = ops.convert_to_tensor(0) 981 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 982 s = array_ops.size(x) 983 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s), 984 compute, [i, m, c, o]) 985 result = r[3].eval() 986 self.assertAllEqual(42, result) 987 988 def testWhile_5(self): 989 with self.test_session(): 990 991 def compute(i, c, o): 992 c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0), 993 [1] + array_ops.expand_dims(i, 0)) 994 o = array_ops.concat([o, c], 0) 995 i = math_ops.add(i, 1) 996 return [i, c, o] 997 998 i = ops.convert_to_tensor(0) 999 c = ops.convert_to_tensor([0]) 1000 o = ops.convert_to_tensor([0]) 1001 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1002 s = array_ops.size(x) 1003 r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s), 1004 compute, [i, c, o], [ 1005 i.get_shape(), 1006 tensor_shape.unknown_shape(), 1007 tensor_shape.unknown_shape() 1008 ]) 1009 result = r[2].eval() 1010 self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result) 1011 1012 def testBufferForwarding(self): 1013 run_options = config_pb2.RunOptions( 1014 trace_level=config_pb2.RunOptions.FULL_TRACE) 1015 run_metadata = config_pb2.RunMetadata() 1016 1017 with self.test_session() as sess: 1018 with ops.device("/cpu:0"): 1019 c = constant_op.constant(2) 1020 i0 = constant_op.constant(0) 1021 r = control_flow_ops.while_loop(lambda i: i < 1000, 1022 lambda i: math_ops.square(c) + i, [i0]) 1023 r_val = sess.run(r, options=run_options, run_metadata=run_metadata) 1024 self.assertEqual(1000, r_val) 1025 self.assertTrue(run_metadata.HasField("step_stats")) 1026 unique_allocs = set() 1027 for node_stat in run_metadata.step_stats.dev_stats[0].node_stats: 1028 for output in node_stat.output: 1029 unique_allocs.add( 1030 output.tensor_description.allocation_description.ptr) 1031 # Prior to cl/147536680, the number of unique allocations was about 1005. 1032 self.assertLess(len(unique_allocs), 756) 1033 1034 def _testWhile_Gpu_1(self, use_gpu): 1035 with self.test_session(use_gpu=use_gpu): 1036 n = constant_op.constant(1.0) 1037 c = lambda x: math_ops.less(x, 10.0) 1038 b = lambda x: math_ops.add(x, 1.0) 1039 r = control_flow_ops.while_loop(c, b, [n]) 1040 self.assertAllClose(10.0, r.eval()) 1041 1042 def testWhile_Gpu_1(self): 1043 self._testWhile_Gpu_1(use_gpu=False) 1044 self._testWhile_Gpu_1(use_gpu=True) 1045 1046 def _testWhile_Gpu_2(self, use_gpu): 1047 with self.test_session(use_gpu=use_gpu): 1048 n = constant_op.constant(1.0) 1049 c = lambda x: math_ops.less(x, 10.0) 1050 1051 def b(x): 1052 with ops.device("/cpu:0"): 1053 return math_ops.add(x, 1.0) 1054 1055 r = control_flow_ops.while_loop(c, b, [n]) 1056 self.assertAllClose(10.0, r.eval()) 1057 1058 def testWhile_Gpu_2(self): 1059 self._testWhile_Gpu_1(use_gpu=False) 1060 self._testWhile_Gpu_1(use_gpu=True) 1061 1062 def testWhileShape(self): 1063 with self.test_session(): 1064 i = constant_op.constant(0) 1065 m = array_ops.ones([2, 2]) 1066 c = lambda i, j: math_ops.less(i, 2) 1067 1068 def _b(i, j): 1069 new_i = math_ops.add(i, 1) 1070 new_j = array_ops.tile(j, [2, 2]) 1071 return [new_i, new_j] 1072 1073 r = control_flow_ops.while_loop( 1074 c, _b, [i, m], 1075 [i.get_shape(), tensor_shape.unknown_shape()]) 1076 r = r[1] * array_ops.ones([8, 8]) 1077 self.assertAllEqual(np.ones((8, 8)), r.eval()) 1078 1079 def testWhileWithNonTensorInput_Scalar(self): 1080 with self.test_session(): 1081 n = 0 1082 c = lambda x: x < 10000 1083 b = lambda x: x + 1 1084 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1085 self.assertEqual(10000, r.eval()) 1086 1087 def testWhileWithNonTensorInput_Vector(self): 1088 with self.test_session(): 1089 n = np.array([0]) # Note, [0] would not work here; that is a list 1090 c = lambda x: x[0] < 10000 1091 b = lambda x: array_ops.stack([x[0] + 1]) 1092 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1093 self.assertEqual([10000], r.eval()) 1094 1095 def testWhileShapeInference(self): 1096 with self.test_session(): 1097 i = constant_op.constant(0) 1098 m = array_ops.ones([2, 2]) 1099 c = lambda i, j: math_ops.less(i, 2) 1100 1101 def b(i, j): 1102 new_i = math_ops.add(i, 1) 1103 new_j = array_ops.concat([j, j], 0) 1104 return [new_i, new_j] 1105 1106 r = control_flow_ops.while_loop( 1107 c, b, [i, m], 1108 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 1109 self.assertTrue(r[1].get_shape()[0].value is None) 1110 self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2)) 1111 1112 with self.assertRaisesRegexp( 1113 ValueError, 1114 r"The shape for while_1/Merge_1:0 is not an invariant for the loop. " 1115 r"It enters the loop with shape \(2, 2\), but has shape \(4, 2\) " 1116 r"after one iteration. Provide shape invariants using either the " 1117 r"`shape_invariants` argument of tf.while_loop or set_shape\(\) on " 1118 r"the loop variables."): 1119 r = control_flow_ops.while_loop(c, b, [i, m]) 1120 1121 def testWhileShapeInferenceSparseTensor(self): 1122 with self.test_session(): 1123 values = constant_op.constant([2.0, 4.0], name="values") 1124 indices = constant_op.constant( 1125 [[0], [3]], dtype=dtypes.int64, name="indices") 1126 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 1127 i = constant_op.constant(0) 1128 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 1129 1130 def c(i, _): 1131 return i < 10 1132 1133 def b(i, x): 1134 return [ 1135 i + 1, 1136 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 1137 ] 1138 1139 _, r = control_flow_ops.while_loop(c, b, [i, x]) 1140 self.assertEqual(r.dense_shape.get_shape()[0].value, 1) 1141 1142 _, r = control_flow_ops.while_loop( 1143 c, b, [i, x], 1144 [i.get_shape(), tensor_shape.TensorShape([None])]) 1145 self.assertTrue(r.dense_shape.get_shape()[0].value is None) 1146 1147 with self.assertRaisesRegexp(ValueError, "is not compatible with"): 1148 _, r = control_flow_ops.while_loop( 1149 c, b, [i, x], 1150 [i.get_shape(), tensor_shape.TensorShape([5])]) 1151 1152 def testWhileShapeInferenceIndexedSlices(self): 1153 with self.test_session(): 1154 values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values") 1155 indices = constant_op.constant([0, 3], name="indices") 1156 shape = constant_op.constant([10, 2], name="dense_shape") 1157 i = constant_op.constant(0) 1158 x = ops.IndexedSlices(values, indices, dense_shape=shape) 1159 1160 def c(i, _): 1161 return i < 10 1162 1163 def b(i, x): 1164 return [ 1165 i + 1, 1166 ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) 1167 ] 1168 1169 _, r = control_flow_ops.while_loop(c, b, [i, x]) 1170 self.assertEqual(r.dense_shape.get_shape()[0].value, 2) 1171 self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2])) 1172 1173 _, r = control_flow_ops.while_loop( 1174 c, b, [i, x], 1175 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 1176 self.assertEqual(r.dense_shape.get_shape()[0].value, 2) 1177 self.assertTrue(r.values.get_shape()[0].value is None) 1178 self.assertEqual(r.values.get_shape()[1].value, 2) 1179 1180 with self.assertRaisesRegexp(ValueError, "is not compatible with"): 1181 _, r = control_flow_ops.while_loop( 1182 c, b, [i, x], 1183 [i.get_shape(), tensor_shape.TensorShape([None, 5])]) 1184 1185 def _testNestedWhile_1(self, use_gpu): 1186 with self.test_session(use_gpu=use_gpu): 1187 n = constant_op.constant(0) 1188 1189 def cpu_sum(s): 1190 c = lambda i, s: math_ops.less(i, 10) 1191 1192 def b(i, s): 1193 i1 = math_ops.add(i, 1) 1194 with ops.device("/cpu:0"): 1195 s1 = math_ops.add(i, s) 1196 return i1, s1 1197 1198 _, r_s = control_flow_ops.while_loop(c, b, [n, s]) 1199 return r_s 1200 1201 c = lambda x: math_ops.less(x, 200) 1202 b = lambda x: math_ops.add(x, cpu_sum(n)) 1203 r = control_flow_ops.while_loop(c, b, [n]) 1204 self.assertEqual(225, r.eval()) 1205 1206 def testNestedWhile_1(self): 1207 self._testNestedWhile_1(use_gpu=False) 1208 self._testNestedWhile_1(use_gpu=True) 1209 1210 def _testNestedWhile_2(self, use_gpu): 1211 # Test the cases that A -> Enter and Exit -> A are partitioned. 1212 with self.test_session(use_gpu=use_gpu): 1213 s0 = constant_op.constant(2.0) 1214 1215 def inner_loop(s): 1216 c = lambda s: math_ops.less(s, 20.0) 1217 1218 def b(s): 1219 s1 = math_ops.add(s, s) 1220 return s1 1221 1222 r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1) 1223 return r_s 1224 1225 outer_c = lambda x: math_ops.less(x, 3000.0) 1226 1227 def outer_b(x): 1228 x = logging_ops.Print(x, [x]) # Edge "Print -> Enter" is partitioned 1229 x = inner_loop(x) 1230 with ops.device("/cpu:0"): 1231 x = math_ops.square(x) # Edge "Exit -> Square" is partitioned 1232 return x 1233 1234 r = control_flow_ops.while_loop( 1235 outer_c, outer_b, [s0], parallel_iterations=1) 1236 self.assertEqual(1048576.0, r.eval()) 1237 1238 def testNestedWhile_2(self): 1239 self._testNestedWhile_2(use_gpu=False) 1240 self._testNestedWhile_2(use_gpu=True) 1241 1242 def testWhileWithControl_1(self): 1243 with self.test_session(): 1244 n = constant_op.constant(0) 1245 r = constant_op.constant(0) 1246 condition = lambda n_, r_: math_ops.less(n_, 10) 1247 1248 def body(n_, r_): 1249 n_ = math_ops.add(n_, 1) 1250 with r_.graph.control_dependencies([r_]): 1251 r_ = constant_op.constant(12) 1252 return [n_, r_] 1253 1254 res = control_flow_ops.while_loop( 1255 condition, body, [n, r], parallel_iterations=1) 1256 self.assertAllEqual(12, res[1].eval()) 1257 1258 def testWhileWithControl_2(self): 1259 with self.test_session(): 1260 r = constant_op.constant(0) 1261 condition = lambda r_: math_ops.less(r_, 10) 1262 1263 def body(r_): 1264 with r_.graph.control_dependencies([r_]): 1265 r_ = constant_op.constant(12) 1266 return [r_] 1267 1268 res = control_flow_ops.while_loop( 1269 condition, body, [r], parallel_iterations=1) 1270 self.assertAllEqual(12, res.eval()) 1271 1272 def testWhileWithControl_3(self): 1273 with self.test_session() as sess: 1274 b = array_ops.placeholder(dtypes.bool) 1275 c = constant_op.constant(1) 1276 x0 = constant_op.constant(0) 1277 with ops.control_dependencies([b]): 1278 r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0]) 1279 self.assertEqual(10, sess.run(r, {b: True})) 1280 1281 def testWhileWithControl_4(self): 1282 with self.test_session() as sess: 1283 b = array_ops.placeholder(dtypes.bool) 1284 c = constant_op.constant(1) 1285 x0 = constant_op.constant(0) 1286 with ops.control_dependencies([b]): 1287 r = control_flow_ops.while_loop( 1288 lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0]) 1289 self.assertEqual(10, sess.run(r, {b: True})) 1290 1291 def testWhileWithControl_5(self): 1292 with self.test_session() as sess: 1293 b = array_ops.placeholder(dtypes.bool) 1294 c = constant_op.constant(1) 1295 x0 = constant_op.constant(0) 1296 1297 def body(x): 1298 with ops.control_dependencies([b]): 1299 return x + c 1300 1301 r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0]) 1302 self.assertEqual(10, sess.run(r, {b: True})) 1303 1304 def testWhileCondWithControl(self): 1305 # Ensure that no control edges by an outer control dependency context are 1306 # added to nodes inside cond/while contexts. 1307 with self.test_session() as sess: 1308 const_true = lambda: constant_op.constant(True) 1309 const_false = lambda: constant_op.constant(False) 1310 cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false) 1311 body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i) 1312 1313 with ops.control_dependencies([control_flow_ops.no_op()]): 1314 loop = control_flow_ops.while_loop(cond, body, 1315 (constant_op.constant(5),)) 1316 self.assertEqual(0, sess.run(loop)) 1317 1318 def testWhileCondWithControl_1(self): 1319 with self.test_session(): 1320 v = variable_scope.get_variable( 1321 "v", [], initializer=init_ops.constant_initializer(2)) 1322 i0 = constant_op.constant(0) 1323 with ops.control_dependencies([i0]): 1324 1325 def loop_condition(i): 1326 return i < 4 1327 1328 def loop_body(i): 1329 some_cond = control_flow_ops.cond( 1330 constant_op.constant(True), 1331 lambda: state_ops.assign(v, math_ops.square(v)), lambda: v) 1332 with ops.control_dependencies([some_cond]): 1333 return i + 1 1334 1335 r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,)) 1336 variables.global_variables_initializer().run() 1337 self.assertEqual(4, r.eval()) 1338 self.assertAllClose(65536.0, v.eval()) 1339 1340 def testWhileCondExitControl(self): 1341 with self.test_session(): 1342 v = variables.Variable(1) 1343 1344 def false_branch(): 1345 cond = lambda i: i < 100 1346 1347 def body(i): 1348 x = state_ops.assign(v, i) 1349 return x + 1 1350 1351 loop = control_flow_ops.while_loop(cond, body, [0]) 1352 # Make sure to handle correctly control edge from Exit to a node. 1353 with ops.control_dependencies([loop]): 1354 return constant_op.constant(6.0) 1355 1356 r = control_flow_ops.cond( 1357 constant_op.constant(False), lambda: constant_op.constant(1.0), 1358 false_branch) 1359 variables.global_variables_initializer().run() 1360 self.assertEqual(6.0, r.eval()) 1361 self.assertEqual(99, v.eval()) 1362 1363 def testCondWhile_1(self): 1364 with self.test_session(): 1365 n = ops.convert_to_tensor(0, name="n") 1366 c = lambda x: math_ops.less(x, 10) 1367 b = lambda x: math_ops.add(x, 1) 1368 r = control_flow_ops.cond( 1369 math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]), 1370 lambda: n) 1371 self.assertAllEqual(10, r.eval()) 1372 1373 def testCondWhile_2(self): 1374 with self.test_session(): 1375 n = ops.convert_to_tensor(0) 1376 c = lambda x: math_ops.less(x, 10) 1377 b = lambda x: math_ops.add(x, 1) 1378 r = control_flow_ops.cond( 1379 math_ops.less(1, 0), lambda: math_ops.add(n, 1), 1380 lambda: control_flow_ops.while_loop(c, b, [n])) 1381 self.assertAllEqual(10, r.eval()) 1382 1383 def _testCondWhile_3(self, use_gpu): 1384 with self.test_session(use_gpu=use_gpu) as sess: 1385 p = array_ops.placeholder(dtypes.bool) 1386 n = constant_op.constant(0.0) 1387 1388 def c(x): 1389 return math_ops.less(x, 10.0) 1390 1391 def b(x): 1392 with ops.device("/cpu:0"): 1393 x1 = math_ops.add(x, 1.0) 1394 return x1 1395 1396 r = control_flow_ops.cond(p, 1397 lambda: control_flow_ops.while_loop(c, b, [n]), 1398 lambda: math_ops.multiply(n, 2.0)) 1399 r1 = gradients_impl.gradients(r, [n]) 1400 self.assertEqual(10, sess.run(r, {p: True})) 1401 self.assertEqual([1.0], sess.run(r1, {p: True})) 1402 self.assertEqual(0.0, sess.run(r, {p: False})) 1403 self.assertEqual([2.0], sess.run(r1, {p: False})) 1404 1405 def testCondWhile_3(self): 1406 self._testCondWhile_3(use_gpu=False) 1407 self._testCondWhile_3(use_gpu=True) 1408 1409 def testWhileCond_1(self): 1410 with self.test_session(): 1411 i = ops.convert_to_tensor(0, name="i") 1412 n = ops.convert_to_tensor(10, name="n") 1413 one = ops.convert_to_tensor(1, name="one") 1414 c = lambda x: math_ops.less(x, n) 1415 # pylint: disable=undefined-variable 1416 # for OSS build 1417 b = lambda x: control_flow_ops.cond( 1418 constant_op.constant(True), 1419 lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one)) 1420 # pylint: enable=undefined-variable 1421 r = control_flow_ops.while_loop(c, b, [i]) 1422 self.assertAllEqual(10, r.eval()) 1423 1424 def testWhileCond_2(self): 1425 with self.test_session(): 1426 n = ops.convert_to_tensor(0, name="n") 1427 c = lambda x: math_ops.less(x, 10) 1428 b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n) 1429 r = control_flow_ops.while_loop(c, b, [n]) 1430 self.assertAllEqual(10, r.eval()) 1431 1432 def testWhileCond_3(self): 1433 with self.test_session(): 1434 n = ops.convert_to_tensor(0) 1435 c = lambda x: math_ops.less(x, 10) 1436 # pylint: disable=undefined-variable 1437 # for OSS build 1438 b = lambda x: control_flow_ops.cond(math_ops.less(0, 1), 1439 lambda: math_ops.add(x, 1), 1440 lambda: math_ops.subtract(x, 1)) 1441 # pylint: enable=undefined-variable 1442 r = control_flow_ops.while_loop(c, b, [n]) 1443 self.assertAllEqual(10, r.eval()) 1444 1445 # NOTE: It is ok to have parallel_iterations > 1 1446 def testWhileUpdateVariable_1(self): 1447 with self.test_session(): 1448 select = variables.Variable([3.0, 4.0, 5.0]) 1449 n = constant_op.constant(0) 1450 1451 def loop_iterator(j): 1452 return math_ops.less(j, 3) 1453 1454 def loop_body(j): 1455 ns = state_ops.scatter_update(select, j, 10.0) 1456 nj = math_ops.add(j, 1) 1457 op = control_flow_ops.group(ns) 1458 nj = control_flow_ops.with_dependencies([op], nj) 1459 return [nj] 1460 1461 r = control_flow_ops.while_loop( 1462 loop_iterator, loop_body, [n], parallel_iterations=1) 1463 variables.global_variables_initializer().run() 1464 self.assertEqual(3, r.eval()) 1465 result = select.eval() 1466 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 1467 1468 def testWhileUpdateVariable_2(self): 1469 with self.test_session(): 1470 select1 = variables.Variable([3.0, 4.0, 5.0]) 1471 select2 = variables.Variable([3.0, 4.0, 5.0]) 1472 n = constant_op.constant(0) 1473 1474 def loop_iterator(j): 1475 return math_ops.less(j, 3) 1476 1477 def loop_body(j): 1478 ns1 = state_ops.scatter_update(select1, j, 10.0) 1479 ns2 = state_ops.scatter_update(select2, j, 10.0) 1480 nj = math_ops.add(j, 1) 1481 op = control_flow_ops.group(ns1, ns2) 1482 nj = control_flow_ops.with_dependencies([op], nj) 1483 return [nj] 1484 1485 r = control_flow_ops.while_loop( 1486 loop_iterator, loop_body, [n], parallel_iterations=1) 1487 variables.global_variables_initializer().run() 1488 self.assertEqual(3, r.eval()) 1489 result1 = select1.eval() 1490 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1) 1491 result2 = select2.eval() 1492 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2) 1493 1494 def testWhileUpdateVariable_3(self): 1495 with self.test_session(): 1496 select = variables.Variable([3.0, 4.0, 5.0]) 1497 n = constant_op.constant(0) 1498 1499 def loop_iterator(j, _): 1500 return math_ops.less(j, 3) 1501 1502 def loop_body(j, _): 1503 ns = state_ops.scatter_update(select, j, 10.0) 1504 nj = math_ops.add(j, 1) 1505 return [nj, ns] 1506 1507 r = control_flow_ops.while_loop( 1508 loop_iterator, 1509 loop_body, [n, array_ops.identity(select)], 1510 parallel_iterations=1) 1511 variables.global_variables_initializer().run() 1512 result = r[1].eval() 1513 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 1514 1515 # b/24814703 1516 def testWhileUpdateVariable_4(self): 1517 with self.test_session(): 1518 var_a = variables.Variable(0, name="a") 1519 var_b = variables.Variable(0, name="b") 1520 variables.global_variables_initializer().run() 1521 1522 c = constant_op.constant(0, name="c") 1523 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 1524 1525 # Loop condition 1526 def pred(i): 1527 return math_ops.less(i, 10) 1528 1529 # Loop body 1530 def loop_body(i): 1531 asn2 = state_ops.assign_add(var_b, asn1, name="b_add") 1532 with ops.control_dependencies([asn2]): 1533 ni = math_ops.add(i, 1, name="i_add") 1534 return ni 1535 1536 lpa = control_flow_ops.while_loop( 1537 pred, loop_body, [c], parallel_iterations=1) 1538 1539 self.assertEqual(0, var_b.eval()) 1540 lpa.eval() # Run the loop 1541 self.assertEqual(10, var_b.eval()) 1542 1543 # b/24736492 1544 def testWhileUpdateVariable_5(self): 1545 with self.test_session(): 1546 # Create some variables. 1547 var_a = variables.Variable(0, name="a") 1548 var_b = variables.Variable(0, name="b") 1549 variables.global_variables_initializer().run() 1550 1551 # Change condition to check var_b 1552 def pred(_): 1553 return math_ops.less(var_b, 10) 1554 1555 # Change body to increment var_b 1556 def loop_body(i): 1557 asn1 = state_ops.assign_add( 1558 var_a, constant_op.constant(1), name="a_add") 1559 asn2 = state_ops.assign_add( 1560 var_b, constant_op.constant(1), name="b_add") 1561 with ops.control_dependencies([asn1, asn2]): 1562 inc_b = array_ops.identity(var_b) 1563 return inc_b 1564 1565 lpa = control_flow_ops.while_loop( 1566 pred, loop_body, [var_b], parallel_iterations=1, name="loop") 1567 1568 self.assertEqual(0, var_b.eval()) 1569 lpa.eval() # Run the loop 1570 self.assertEqual(10, var_a.eval()) 1571 self.assertEqual(10, var_b.eval()) 1572 1573 # b/24814668 1574 def testWhileUpdateVariable_6(self): 1575 with self.test_session(): 1576 # Create some variables. 1577 var_a = variables.Variable(0, name="a") 1578 var_b = variables.Variable(0, name="b") 1579 c = constant_op.constant(0) 1580 variables.global_variables_initializer().run() 1581 1582 # Loop condition 1583 def pred(i): 1584 return math_ops.less(i, 10) 1585 1586 # Loop body 1587 def loop_body(i): 1588 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 1589 with ops.control_dependencies([asn1]): 1590 asn2 = state_ops.assign_add(var_b, var_a, name="b_add") 1591 with ops.control_dependencies([asn2]): 1592 ni = math_ops.add(i, 1, name="i_add") 1593 return ni 1594 1595 lpa = control_flow_ops.while_loop( 1596 pred, loop_body, [c], parallel_iterations=1, name="loop") 1597 1598 self.assertEqual(0, var_b.eval()) 1599 lpa.eval() # Run the loop 1600 self.assertEqual(55, var_b.eval()) 1601 self.assertEqual(10, var_a.eval()) 1602 1603 def testWhileQueue_1(self): 1604 with self.test_session(): 1605 q = data_flow_ops.FIFOQueue(-1, dtypes.int32) 1606 i = constant_op.constant(0) 1607 1608 def c(i): 1609 return math_ops.less(i, 10) 1610 1611 def b(i): 1612 ni = math_ops.add(i, 1) 1613 ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni) 1614 return ni 1615 1616 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 1617 self.assertEqual([10], r.eval()) 1618 for i in xrange(10): 1619 self.assertEqual([i], q.dequeue().eval()) 1620 1621 def testWhileStack_1(self): 1622 with self.test_session(): 1623 s = gen_data_flow_ops._stack_v2(-1, dtypes.int32, stack_name="foo") 1624 i = constant_op.constant(0) 1625 1626 def c(i): 1627 return math_ops.less(i, 10) 1628 1629 def b(i): 1630 ni = math_ops.add(i, 1) 1631 ni = control_flow_ops.with_dependencies( 1632 [gen_data_flow_ops._stack_push_v2(s, i)], ni) 1633 return ni 1634 1635 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 1636 1637 x = constant_op.constant(0) 1638 1639 def c1(i, _): 1640 return math_ops.greater(i, 0) 1641 1642 def b1(i, x): 1643 ni = math_ops.subtract(i, 1) 1644 nx = x + gen_data_flow_ops._stack_pop_v2(s, dtypes.int32) 1645 return [ni, nx] 1646 1647 _, rx = control_flow_ops.while_loop( 1648 c1, 1649 b1, [r, x], 1650 [r.get_shape(), tensor_shape.unknown_shape()], 1651 parallel_iterations=1) 1652 self.assertEqual(45, rx.eval()) 1653 1654 def _testWhileGrad_ColocateGradients(self, colocate): 1655 gpu_dev_name = test.gpu_device_name() if test.is_gpu_available( 1656 ) else "/device:GPU:0" 1657 1658 graph = ops.Graph() 1659 with graph.as_default(): 1660 v = constant_op.constant(2.0, name="v") 1661 c = lambda v: math_ops.less(v, 100.0) 1662 1663 def b(x): 1664 with ops.device(gpu_dev_name): 1665 return math_ops.square(x) 1666 1667 loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 1668 r = gradients_impl.gradients( 1669 loop, v, colocate_gradients_with_ops=colocate)[0] 1670 1671 r_ops = graph.get_operations() 1672 r_devices = [(op.name, op.device) for op in r_ops] 1673 1674 self.assertTrue(any("Square" in op.name for op in r_ops)) 1675 1676 for (name, dev) in r_devices: 1677 if not colocate and name.endswith("Square"): 1678 # Only forward graph contain gpu in Square device 1679 self.assertTrue(gpu_dev_name in dev) 1680 elif colocate and "Square" in name: 1681 # Forward and backward graphs contain gpu in Square/Square_grad devices 1682 self.assertTrue(gpu_dev_name in dev) 1683 else: 1684 self.assertFalse(gpu_dev_name in dev) 1685 1686 with self.test_session(graph=graph) as sess: 1687 self.assertAllClose(1024.0, sess.run(r)) 1688 1689 def testWhileGrad_ColocateGradients(self): 1690 self._testWhileGrad_ColocateGradients(colocate=False) 1691 self._testWhileGrad_ColocateGradients(colocate=True) 1692 1693 def testWhileGrad_Square(self): 1694 with self.test_session(): 1695 v = constant_op.constant(2.0, name="v") 1696 c = lambda v: math_ops.less(v, 100.0) 1697 b = math_ops.square 1698 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 1699 r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v) 1700 1701 r = gradients_impl.gradients(r, v)[0] 1702 self.assertAllClose(1024.0, r.eval()) 1703 1704 def testWhileGrad_Shape(self): 1705 with self.test_session(): 1706 x = array_ops.placeholder(dtypes.float32, shape=[None]) 1707 v = constant_op.constant([2.0], name="v") 1708 n = constant_op.constant(0, name="n") 1709 c = lambda i, v: math_ops.less(i, 5) 1710 b = lambda i, v: [i + 1, math_ops.multiply(x, v)] 1711 r = control_flow_ops.while_loop( 1712 c, 1713 b, [n, v], 1714 [n.get_shape(), tensor_shape.unknown_shape()], 1715 parallel_iterations=1) 1716 1717 r = gradients_impl.gradients(r[1], x)[0] 1718 self.assertEqual([None], r.get_shape().as_list()) 1719 self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]})) 1720 1721 def testWhileGrad_BaseShape(self): 1722 with self.test_session() as sess: 1723 x = array_ops.placeholder(dtypes.float32, [None]) 1724 v0 = constant_op.constant([2.0, 2.0], name="v") 1725 c = lambda v: constant_op.constant(False) 1726 b = lambda v: math_ops.multiply(v, x) 1727 r = control_flow_ops.while_loop(c, b, [v0]) 1728 y = math_ops.square(x) 1729 1730 r = gradients_impl.gradients([r, y], x)[0] 1731 self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]})) 1732 1733 def testWhileGrad_MultipleUses(self): 1734 with self.test_session(): 1735 v = constant_op.constant(2.0, name="v") 1736 c = lambda v: math_ops.less(v, 100.0) 1737 b = math_ops.square 1738 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 1739 r = math_ops.multiply(r, r) 1740 1741 r = gradients_impl.gradients(r, v)[0] 1742 self.assertEqual(524288.0, r.eval()) 1743 1744 def testWhileGrad_LoopAdd(self): 1745 with self.test_session(): 1746 v = constant_op.constant(2.0, name="v") 1747 c = lambda v: math_ops.less(v, 100.0) 1748 b = math_ops.square 1749 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 1750 r = math_ops.add(r, r) 1751 1752 r = gradients_impl.gradients(r, v)[0] 1753 self.assertAllClose(2048.0, r.eval()) 1754 1755 def _testWhileGrad_Mul(self, use_gpu, p_iters): 1756 with self.test_session(use_gpu=use_gpu) as sess: 1757 a = constant_op.constant(3.0, name="a") 1758 v = constant_op.constant(2.0, name="v") 1759 c = lambda v: math_ops.less(v, 100.0) 1760 b = lambda v: math_ops.multiply(v, a) 1761 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters) 1762 1763 grad_a, grad_v = gradients_impl.gradients(r, [a, v]) 1764 grad_a_val, grad_v_val = sess.run([grad_a, grad_v]) 1765 self.assertAllClose(216.0, grad_a_val) 1766 self.assertAllClose(81.0, grad_v_val) 1767 1768 def testWhileGrad_Mul(self): 1769 self._testWhileGrad_Mul(use_gpu=False, p_iters=1) 1770 self._testWhileGrad_Mul(use_gpu=False, p_iters=10) 1771 self._testWhileGrad_Mul(use_gpu=True, p_iters=1) 1772 self._testWhileGrad_Mul(use_gpu=True, p_iters=10) 1773 1774 def _testNestedWhileCondWhileGrad(self, use_gpu): 1775 with self.test_session(use_gpu=use_gpu): 1776 v = constant_op.constant(1.0) 1777 1778 def inner_loop(s): 1779 z = constant_op.constant(0) 1780 c = lambda i, x: math_ops.less(i, 4) 1781 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 1782 return control_flow_ops.while_loop(c, b, [z, s]) 1783 1784 c = lambda x: math_ops.less(x, 128.0) 1785 1786 def b(x): 1787 return control_flow_ops.cond( 1788 constant_op.constant(True), 1789 lambda: math_ops.square(inner_loop(x)[1]), 1790 lambda: math_ops.multiply(x, 2.0)) 1791 1792 r = control_flow_ops.while_loop(c, b, [v]) 1793 r = gradients_impl.gradients(r, v)[0] 1794 self.assertAllClose(512.0, r.eval()) 1795 1796 def testNestedWhileCondWhileGrad(self): 1797 self._testNestedWhileCondWhileGrad(use_gpu=False) 1798 self._testNestedWhileCondWhileGrad(use_gpu=True) 1799 1800 def testWhileGrad_Variable(self): 1801 with self.test_session(): 1802 a = variables.Variable(3.0) 1803 v = constant_op.constant(2.0, name="v") 1804 c = lambda v: math_ops.less(v, 100.0) 1805 b = lambda v: math_ops.multiply(v, a) 1806 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 1807 1808 r = gradients_impl.gradients(r, a) 1809 variables.global_variables_initializer().run() 1810 self.assertAllClose(216.0, r[0].eval()) 1811 1812 def testWhileGradInCond(self): 1813 with self.test_session(): 1814 n = ops.convert_to_tensor(1.0, name="n") 1815 x = array_ops.placeholder(dtypes.float32, shape=None) 1816 c = lambda n: math_ops.less(n, 10.0) 1817 b = lambda n: math_ops.add(n, x) 1818 1819 def fn1(): 1820 r = control_flow_ops.while_loop(c, b, [n], 1821 [tensor_shape.unknown_shape()]) 1822 return gradients_impl.gradients(r, x) 1823 1824 r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x) 1825 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 1826 1827 def testWhileGradInWhile(self): 1828 with self.test_session(): 1829 n = ops.convert_to_tensor(1.0, name="n") 1830 x = array_ops.placeholder(dtypes.float32, shape=None) 1831 c = lambda n: math_ops.less(n, 10.0) 1832 b = lambda n: math_ops.add(n, x) 1833 1834 def b1(n): 1835 r = control_flow_ops.while_loop(c, b, [n], 1836 [tensor_shape.unknown_shape()]) 1837 return gradients_impl.gradients(r, x) 1838 1839 r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n], 1840 [tensor_shape.unknown_shape()]) 1841 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 1842 1843 def testWhile_NestedInput(self): 1844 with self.test_session() as sess: 1845 named = collections.namedtuple("named", ("a", "b")) 1846 loop_vars = [ 1847 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 1848 (constant_op.constant(2.0), constant_op.constant(3.0)), 1849 constant_op.constant(4.0) 1850 ] 1851 c = lambda lv0, _1, _2: lv0.a < 100.0 1852 1853 def b(lv0, lv1, lv2): 1854 lv0 = named(a=lv0.a + 1, b=lv0.b) 1855 lv1 = (lv1[0] + 1, lv1[1]) 1856 lv2 += 2 1857 return [lv0, lv1, lv2] 1858 1859 r = control_flow_ops.while_loop(c, b, loop_vars) 1860 1861 self.assertTrue(isinstance(r, list)) 1862 self.assertTrue(isinstance(r[0], named)) 1863 self.assertTrue(isinstance(r[1], tuple)) 1864 self.assertTrue(isinstance(r[2], ops.Tensor)) 1865 1866 r_flattened = nest.flatten(r) 1867 self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0], 1868 sess.run(r_flattened)) 1869 1870 def testWhile_NestedBadArityFails(self): 1871 with self.test_session(): 1872 named = collections.namedtuple("named", ("a", "b")) 1873 loop_vars = [ 1874 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 1875 (constant_op.constant(2.0), constant_op.constant(3.0)), 1876 constant_op.constant(4.0) 1877 ] 1878 c = lambda lv0, _1, _2: lv0.a < 100.0 1879 1880 def b(lv0, lv1, _): 1881 return [lv0, lv1] 1882 1883 with self.assertRaisesRegexp(ValueError, "the same number of elements"): 1884 control_flow_ops.while_loop(c, b, loop_vars) 1885 1886 def testWhileGrad_ys_xs(self): 1887 with self.test_session(): 1888 x = constant_op.constant(3.0, name="x") 1889 y = constant_op.constant(2.0, name="y") 1890 1891 c = lambda x, y: math_ops.less(x, 100.0) 1892 1893 def b(x, y): 1894 y1 = math_ops.add(x, y) 1895 x1 = math_ops.multiply(x, y1) 1896 return x1, y1 1897 1898 rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1) 1899 1900 r = gradients_impl.gradients([rx, ry], x) 1901 self.assertAllClose(304.0, r[0].eval()) 1902 r = gradients_impl.gradients([rx, ry], y) 1903 self.assertAllClose(124.0, r[0].eval()) 1904 r = gradients_impl.gradients([rx], x) 1905 self.assertAllClose(295.0, r[0].eval()) 1906 r = gradients_impl.gradients([rx], y) 1907 self.assertAllClose(120.0, r[0].eval()) 1908 1909 def testWhileGrad_Dependency(self): 1910 with self.test_session(): 1911 i = constant_op.constant(0, name="i") 1912 x = constant_op.constant(2.0, name="x") 1913 1914 c = lambda i, x: math_ops.less(i, 10) 1915 1916 def b(i, x): 1917 x = math_ops.multiply(x, 2.0) 1918 i = math_ops.add(i, 1) 1919 return i, x 1920 1921 ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 1922 1923 r = gradients_impl.gradients([ri, rx], x) 1924 self.assertAllClose(1024.0, r[0].eval()) 1925 r = gradients_impl.gradients([rx], x) 1926 self.assertAllClose(1024.0, r[0].eval()) 1927 1928 def testWhileGrad_NoGradient(self): 1929 with self.test_session(): 1930 v = constant_op.constant(2.0, name="v") 1931 c = lambda v: math_ops.less(v, 100.0) 1932 b = math_ops.square 1933 r = control_flow_ops.while_loop(c, b, [v], back_prop=False) 1934 r = math_ops.add(r, v) 1935 r = gradients_impl.gradients(r, v) 1936 self.assertAllClose(1.0, r[0].eval()) 1937 1938 def testWhileGrad_NoDependency(self): 1939 with self.test_session() as sess: 1940 variable = variables.Variable(array_ops.ones([2, 3])) 1941 duration = array_ops.zeros([], dtype=dtypes.int32) 1942 1943 def cond(duration, tensor, _): 1944 del tensor 1945 return duration < 10 1946 1947 def body(duration, tensor, _): 1948 return (duration + 1, tensor, tensor) 1949 1950 loop_vars = [duration, variable, variable] 1951 tensors = control_flow_ops.while_loop( 1952 cond=cond, body=body, loop_vars=loop_vars) 1953 cost = math_ops.reduce_sum(tensors[2]) 1954 grad = gradients_impl.gradients(cost, [variable]) 1955 variables.global_variables_initializer().run() 1956 self.assertAllClose(np.ones([2, 3]), sess.run(grad[0])) 1957 1958 def testWhileGrad_Const(self): 1959 with self.test_session() as sess: 1960 c0 = constant_op.constant(0.0, name="c0") 1961 c1 = constant_op.constant(1.0, name="c1") 1962 duration = constant_op.constant(0, name="t") 1963 1964 def cond(duration, _): 1965 return duration < 1 1966 1967 def body(duration, _): 1968 return duration + 1, c1 1969 1970 loop_vars = [duration, c0] 1971 tensors = control_flow_ops.while_loop( 1972 cond=cond, body=body, loop_vars=loop_vars) 1973 cost = math_ops.reduce_sum(tensors[1]) 1974 grad = gradients_impl.gradients(cost, [c0]) 1975 self.assertAllClose(0.0, sess.run(grad[0])) 1976 1977 def testWhileGrad_SerialTwoLoops(self): 1978 with self.test_session(): 1979 i = constant_op.constant(0, name="i") 1980 x = constant_op.constant(2.0, name="x") 1981 1982 c = lambda i, x: math_ops.less(i, 5) 1983 1984 def b(i, x): 1985 x = math_ops.multiply(x, 2.0) 1986 i = math_ops.add(i, 1) 1987 return i, x 1988 1989 _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 1990 _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1) 1991 1992 r = gradients_impl.gradients([rx], x) 1993 self.assertAllClose(1024.0, r[0].eval()) 1994 1995 def testWhileGrad_ParallelTwoLoops(self): 1996 with self.test_session(): 1997 i = constant_op.constant(0, name="i") 1998 x = constant_op.constant(2.0, name="x") 1999 2000 c = lambda i, x: math_ops.less(i, 5) 2001 2002 def b(i, x): 2003 x = math_ops.multiply(x, 2.0) 2004 i = math_ops.add(i, 1) 2005 return i, x 2006 2007 _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 2008 _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 2009 rx = math_ops.add(r1, r2) 2010 2011 r = gradients_impl.gradients([rx], x) 2012 self.assertAllClose(64.0, r[0].eval()) 2013 2014 def testWhileGrad_OneOutputWithControlDependencyOnSecond(self): 2015 with self.test_session(): 2016 i = constant_op.constant(0, name="i") 2017 x = constant_op.constant(1.0, name="x") 2018 y = constant_op.constant(1.0, name="y") 2019 c = lambda i, *_: math_ops.less(i, 1, name="cond_less") 2020 2021 def b(i, xi, yi): 2022 # return (i + 1, xi, xi + yi) 2023 return (math_ops.add(i, 1, name="inc"), array_ops.identity( 2024 xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi")) 2025 2026 _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y]) 2027 with ops.control_dependencies([x_f]): 2028 y_f_d = array_ops.identity(y_f, name="y_f_d") 2029 2030 self.assertAllClose(2.0, y_f_d.eval()) # y_f_d = 1.0 + 1.0 2031 g = gradients_impl.gradients([y_f_d], [x])[0] 2032 self.assertTrue(g is not None) 2033 self.assertAllClose(1.0, g.eval()) # y_f_d = x + 1.0, dy_f_d/dx = 1.0 2034 2035 def _testNestedWhileGrad_Simple(self, use_gpu): 2036 with self.test_session(use_gpu=use_gpu): 2037 v = constant_op.constant(1.0) 2038 2039 def inner_loop(s): 2040 c = lambda x: math_ops.less(x, 4.0) 2041 b = lambda x: math_ops.multiply(x, 2.0) 2042 return control_flow_ops.while_loop(c, b, [s]) 2043 2044 c = lambda x: math_ops.less(x, 2.0) 2045 b = lambda x: math_ops.multiply(inner_loop(x), 2.0) 2046 r = control_flow_ops.while_loop(c, b, [v]) 2047 2048 r = gradients_impl.gradients(r, v)[0] 2049 self.assertAllClose(8.0, r.eval()) 2050 2051 def testNestedWhileGrad_Simple(self): 2052 self._testNestedWhileGrad_Simple(use_gpu=False) 2053 self._testNestedWhileGrad_Simple(use_gpu=True) 2054 2055 def testNestedWhileGrad_SerialInner(self): 2056 with self.test_session(): 2057 v = constant_op.constant(1.0) 2058 2059 def inner_loop1(s): 2060 z = constant_op.constant(0) 2061 c = lambda i, x: math_ops.less(i, 4) 2062 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 2063 return control_flow_ops.while_loop(c, b, [z, s]) 2064 2065 def inner_loop2(s): 2066 z = constant_op.constant(0) 2067 c = lambda i, x: math_ops.less(i, 4) 2068 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 2069 return control_flow_ops.while_loop(c, b, [z, s]) 2070 2071 c = lambda x: math_ops.less(x, 128.0) 2072 b = lambda x: inner_loop2(inner_loop1(x)[1])[1] 2073 r = control_flow_ops.while_loop(c, b, [v]) 2074 2075 r = gradients_impl.gradients(r, v)[0] 2076 self.assertAllClose(256.0, r.eval()) 2077 2078 def testNestedWhileGrad_ParallelInner(self): 2079 with self.test_session(): 2080 v = constant_op.constant(1.0) 2081 2082 def inner_loop1(s): 2083 z = constant_op.constant(0) 2084 c = lambda i, x: math_ops.less(i, 4) 2085 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 2086 return control_flow_ops.while_loop(c, b, [z, s]) 2087 2088 def inner_loop2(s): 2089 z = constant_op.constant(0) 2090 c = lambda i, x: math_ops.less(i, 4) 2091 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 2092 return control_flow_ops.while_loop(c, b, [z, s]) 2093 2094 c = lambda x: math_ops.less(x, 128.0) 2095 b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1]) 2096 r = control_flow_ops.while_loop(c, b, [v]) 2097 2098 r = gradients_impl.gradients(r, v)[0] 2099 self.assertAllClose(512.0, r.eval()) 2100 2101 def testNestedWhileGrad_ParallelIterations(self): 2102 # Make sure the stack pushes and pops of an inner loop are executed in 2103 # the sequential order of the iterations of its outer loop. 2104 with self.test_session() as sess: 2105 2106 def inner_loop(t): 2107 fn = lambda n: n + math_ops.square(var) 2108 return functional_ops.map_fn(fn=fn, elems=t, parallel_iterations=10) 2109 2110 def outer_loop(inp): 2111 return functional_ops.map_fn( 2112 fn=inner_loop, elems=inp, parallel_iterations=10) 2113 2114 var = variables.Variable(constant_op.constant(3.0)) 2115 inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 2116 res = outer_loop(inp) 2117 optimizer = adam.AdamOptimizer(learning_rate=0.001) 2118 train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res))) 2119 sess.run(variables.global_variables_initializer()) 2120 sess.run(train_op) 2121 self.assertAllClose(2.999, var.eval()) 2122 2123 def _testWhileCondGrad_Simple(self, use_gpu): 2124 with self.test_session(use_gpu=use_gpu): 2125 v = ops.convert_to_tensor(2.0, name="v") 2126 n = ops.convert_to_tensor(100.0, name="n") 2127 one = ops.convert_to_tensor(1.0, name="one") 2128 c = lambda x: math_ops.less(x, n) 2129 # pylint: disable=undefined-variable 2130 # for OSS build 2131 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 2132 lambda: math_ops.square(x), 2133 lambda: math_ops.subtract(x, one)) 2134 # pylint: enable=undefined-variable 2135 r = control_flow_ops.while_loop(c, b, [v]) 2136 r = gradients_impl.gradients(r, v)[0] 2137 self.assertAllClose(1024.0, r.eval()) 2138 2139 def testWhileCondGrad_Simple(self): 2140 self._testWhileCondGrad_Simple(use_gpu=False) 2141 self._testWhileCondGrad_Simple(use_gpu=True) 2142 2143 def testWhileCondGrad_UnknownShape(self): 2144 with self.test_session() as sess: 2145 v = array_ops.placeholder(dtypes.float32) 2146 n = ops.convert_to_tensor(100.0, name="n") 2147 one = ops.convert_to_tensor(1.0, name="one") 2148 c = lambda x: math_ops.less(x, n) 2149 # pylint: disable=undefined-variable 2150 # for OSS build 2151 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 2152 lambda: math_ops.square(x), 2153 lambda: math_ops.subtract(x, one)) 2154 # pylint: enable=undefined-variable 2155 r = control_flow_ops.while_loop(c, b, [v]) 2156 r = gradients_impl.gradients(r, v)[0] 2157 r = sess.run(r, feed_dict={v: 2.0}) 2158 self.assertAllClose(1024.0, r) 2159 2160 def testWhileGrad_Concat(self): 2161 with self.test_session() as sess: 2162 x = variable_scope.get_variable("x", initializer=[[1., 2.]]) 2163 i0 = constant_op.constant(0) 2164 h0 = array_ops.zeros([0, 2]) 2165 2166 def condition(i, _): 2167 return i < 2 2168 2169 def body(i, h): 2170 return i + 1, array_ops.concat([h, x], 0) 2171 2172 _, h = control_flow_ops.while_loop( 2173 condition, body, [i0, h0], 2174 [i0.get_shape(), tensor_shape.TensorShape([None, 2])]) 2175 s = math_ops.reduce_sum(h) 2176 2177 sess.run(variables.global_variables_initializer()) 2178 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 2179 op = optimizer.minimize(s) 2180 sess.run(op) 2181 self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x)) 2182 2183 def testWhileWithRefsWithGradients_1(self): 2184 with self.test_session() as sess: 2185 x = variables.Variable(0)._ref() # pylint: disable=protected-access 2186 i = constant_op.constant(0) 2187 c = lambda i, x: math_ops.less(i, 10) 2188 2189 self.assertEqual(x.dtype, dtypes.int32_ref) 2190 2191 # pylint: disable=protected-access 2192 def body(i, x): 2193 self.assertEqual(x.dtype, dtypes.int32_ref) 2194 return [i + 1, gen_array_ops._ref_identity(x)] 2195 2196 # pylint: enable=protected-access 2197 2198 r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) 2199 2200 grad_ys = [variables.Variable(73)._ref()] # pylint: disable=protected-access 2201 grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys) 2202 2203 variables.global_variables_initializer().run() 2204 2205 self.assertEqual(r[0].dtype, dtypes.int32) 2206 self.assertEqual(r[1].dtype, dtypes.int32_ref) 2207 2208 value_i, value_x, value_x_grad = sess.run(r + grad) 2209 2210 self.assertEqual(10, value_i) 2211 self.assertEqual(0, value_x) 2212 self.assertEqual(73, value_x_grad) 2213 2214 def testWhileGrad_IndexedSlices(self): 2215 with self.test_session(): 2216 values = constant_op.constant([2.0, 4.0], name="values") 2217 indices = constant_op.constant([0, 3], name="indices") 2218 shape = constant_op.constant([10], name="dense_shape") 2219 i = constant_op.constant(0) 2220 x = ops.IndexedSlices(values, indices, dense_shape=shape) 2221 2222 def c(i, _): 2223 return i < 10 2224 2225 def b(i, x): 2226 return [ 2227 i + 1, 2228 ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) 2229 ] 2230 2231 _, r = control_flow_ops.while_loop(c, b, [i, x]) 2232 r = gradients_impl.gradients(r.values, values)[0] 2233 self.assertAllClose(np.array([1024.0, 1024.0]), r.eval()) 2234 2235 def testWhileGrad_SparseTensor(self): 2236 with self.test_session(): 2237 values = constant_op.constant([2.0, 4.0], name="values") 2238 indices = constant_op.constant( 2239 [[0], [3]], dtype=dtypes.int64, name="indices") 2240 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2241 i = constant_op.constant(0) 2242 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2243 2244 def c(i, _): 2245 return i < 10 2246 2247 def b(i, x): 2248 return [ 2249 i + 1, 2250 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 2251 ] 2252 2253 _, r = control_flow_ops.while_loop(c, b, [i, x]) 2254 r = gradients_impl.gradients(r.values, values)[0] 2255 self.assertAllClose(np.array([1024.0, 1024.0]), r.eval()) 2256 2257 def testCallGradInLoop(self): 2258 with self.test_session() as sess: 2259 i0 = constant_op.constant(0) 2260 params = constant_op.constant(5.0) 2261 params_1 = math_ops.square(params) 2262 2263 def c(i, _): 2264 return i < 10 2265 2266 def b(i, x): 2267 data = constant_op.constant([1.0, 2.0, 3.0]) 2268 data = math_ops.multiply(data, params_1) 2269 x1 = x + gradients_impl.gradients(data, params)[0] 2270 return i + 1, x1 2271 2272 output_grad = control_flow_ops.while_loop( 2273 c, b, [i0, constant_op.constant(0.0)]) 2274 self.assertAllClose(600.0, sess.run(output_grad)[1]) 2275 2276 def testWhileAndTensorArray(self): 2277 with self.test_session() as sess: 2278 param = constant_op.constant(2.0) 2279 n0 = constant_op.constant(0) 2280 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 2281 2282 def c(i, _): 2283 return i < 10 2284 2285 def b(i, y): 2286 return [ 2287 i + 1, 2288 functional_ops.map_fn(lambda x: math_ops.multiply(x, param), y) 2289 ] 2290 2291 r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1) 2292 r = gradients_impl.gradients(r, param)[0] 2293 self.assertAllClose(107520.0, sess.run(r)) 2294 2295 def testWhileGrad_StopGrad(self): 2296 with self.test_session(): 2297 x = constant_op.constant(3.0, name="x") 2298 y = constant_op.constant(2.0, name="y") 2299 2300 c = lambda x, y: math_ops.less(x, 100.0) 2301 2302 def b(x, y): 2303 y1 = math_ops.square(y) 2304 x1 = math_ops.add(math_ops.square(x), y1) 2305 return x1, y1 2306 2307 rx, ry = control_flow_ops.while_loop(c, b, [x, y]) 2308 2309 r = gradients_impl.gradients(rx, y)[0] 2310 self.assertEqual(136.0, r.eval()) 2311 r = gradients_impl.gradients(ry, y)[0] 2312 self.assertEqual(32.0, r.eval()) 2313 2314 r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0] 2315 self.assertEqual(r, None) 2316 r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0] 2317 self.assertEqual(r, None) 2318 2319 r = gradients_impl.gradients( 2320 array_ops.stop_gradient(math_ops.square(rx)), y)[0] 2321 self.assertEqual(r, None) 2322 r = gradients_impl.gradients( 2323 array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0] 2324 self.assertEqual(r, None) 2325 r = gradients_impl.gradients( 2326 array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0] 2327 self.assertEqual(r, None) 2328 2329 r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0] 2330 self.assertEqual(168.0, r.eval()) 2331 r = gradients_impl.gradients( 2332 math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0] 2333 self.assertEqual(136.0, r.eval()) 2334 r = gradients_impl.gradients( 2335 math_ops.add(array_ops.stop_gradient(rx), ry), y)[0] 2336 self.assertEqual(32.0, r.eval()) 2337 2338 def testWhileGrad_StopGradInside(self): 2339 with self.test_session(): 2340 x = constant_op.constant(3.0, name="x") 2341 y = constant_op.constant(2.0, name="y") 2342 2343 c = lambda x, y: math_ops.less(x, 100.0) 2344 2345 def b(x, y): 2346 y1 = array_ops.stop_gradient(math_ops.square(y)) 2347 x1 = math_ops.add(math_ops.square(x), y1) 2348 return x1, y1 2349 2350 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 2351 2352 r = gradients_impl.gradients(rx, y)[0] 2353 self.assertAllClose(0.0, r.eval()) 2354 r = gradients_impl.gradients(rx, x)[0] 2355 self.assertAllClose(156.0, r.eval()) 2356 2357 def testWhileGrad_StopGradInsideNoShape(self): 2358 with self.test_session() as sess: 2359 x = array_ops.placeholder(dtypes.float32) 2360 y = array_ops.placeholder(dtypes.float32) 2361 2362 c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0) 2363 2364 def b(x, y): 2365 y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped")) 2366 x1 = math_ops.add(math_ops.square(x), y1) 2367 return x1, y1 2368 2369 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 2370 2371 r = gradients_impl.gradients(rx, y)[0] 2372 feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]} 2373 self.assertAllClose([0.0, 0.0], sess.run(r, feed_dict=feed_dict)) 2374 r = gradients_impl.gradients(rx, x)[0] 2375 self.assertAllClose([156.0, 400.0], sess.run(r, feed_dict=feed_dict)) 2376 name = "gradients/while/stopped_grad" 2377 all_ops = x.graph.get_operations() 2378 self.assertFalse(any([name in op.name for op in all_ops])) 2379 2380 def testWhileGradGradFail(self): 2381 theta = variables.Variable(initial_value=1.) 2382 2383 def fn(prev, x): 2384 return prev + x * theta 2385 2386 result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32)) 2387 grad_theta = gradients_impl.gradients(result, theta) 2388 with self.assertRaisesRegexp(TypeError, "Second-order gradient"): 2389 gradients_impl.gradients(grad_theta, theta) 2390 grad_theta_stopped = array_ops.stop_gradient(grad_theta) 2391 gradients_impl.gradients(grad_theta_stopped, theta) 2392 2393 def testStopGradOnWhileGrad(self): 2394 with self.test_session(): 2395 x = constant_op.constant(2.0, name="x") 2396 y = constant_op.constant(2.0, name="y") 2397 2398 c = lambda x: math_ops.less(x, 100.0) 2399 b = lambda x: math_ops.multiply(x, y) 2400 rx = control_flow_ops.while_loop(c, b, [x]) 2401 2402 rg = gradients_impl.gradients(rx, y)[0] 2403 rg = array_ops.stop_gradient(rg) 2404 r = math_ops.add(math_ops.square(y), rx) 2405 r = math_ops.add(r, rg) 2406 r = gradients_impl.gradients(r, y)[0] 2407 self.assertEqual(388.0, r.eval()) 2408 2409 def testStopGradMultiFlows(self): 2410 with self.test_session(): 2411 2412 def body(i, y, r): 2413 x = variable_scope.get_variable( 2414 "x", 2415 shape=(), 2416 dtype=dtypes.float32, 2417 initializer=init_ops.ones_initializer()) 2418 y *= x 2419 return [i + 1, y, r + math_ops.reduce_sum(y)] 2420 2421 i0 = constant_op.constant(0) 2422 y0 = array_ops.ones(5) 2423 r0 = constant_op.constant(0.0) 2424 cond = lambda i, y, r: i < 1 2425 _, _, r = control_flow_ops.while_loop( 2426 cond, body, [i0, y0, r0], back_prop=True) 2427 2428 vars_ = variables.global_variables() 2429 grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0]) 2430 z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads))) 2431 result = gradients_impl.gradients(z, vars_)[0] 2432 variables.global_variables_initializer().run() 2433 self.assertEqual(5.0, result.eval()) 2434 2435 def testOneValueCond(self): 2436 with self.test_session(): 2437 c = array_ops.placeholder(dtypes.int32, shape=[]) 2438 one = ops.convert_to_tensor(1, name="one") 2439 two = ops.convert_to_tensor(2, name="two") 2440 p = math_ops.greater_equal(c, 1) 2441 i = control_flow_ops.cond(p, lambda: one, lambda: two) 2442 self.assertTrue(isinstance(i, ops.Tensor)) 2443 2444 # True case: c = 2 is >= 1 2445 self.assertEqual([1], i.eval(feed_dict={c: 2})) 2446 2447 # False case: c = 0 is not >= 1 2448 self.assertEqual([2], i.eval(feed_dict={c: 0})) 2449 2450 def testExampleCond(self): 2451 with self.test_session(): 2452 x = ops.convert_to_tensor([-2.0, 2.0], name="x") 2453 d = array_ops.placeholder(dtypes.int32, shape=[]) 2454 2455 def l2(): 2456 return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x))) 2457 2458 def l1(): 2459 return math_ops.reduce_sum(math_ops.abs(x)) 2460 2461 i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1) 2462 self.assertAllClose(4.0, i.eval(feed_dict={d: 1})) 2463 self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) 2464 2465 def testCase(self): 2466 with self.test_session(): 2467 x = constant_op.constant(1) 2468 y = constant_op.constant(2) 2469 z = constant_op.constant(3) 2470 f1 = lambda: constant_op.constant(17) 2471 f2 = lambda: constant_op.constant(23) 2472 f3 = lambda: constant_op.constant(-1) 2473 2474 r1 = control_flow_ops.case( 2475 { 2476 x < y: f1, 2477 x > z: f2 2478 }, default=f3, exclusive=True) 2479 self.assertAllEqual(r1.eval(), 17) 2480 2481 r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3) 2482 self.assertAllEqual(r2.eval(), 23) 2483 2484 # Duplicate events can happen, first one is selected 2485 r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3) 2486 self.assertAllEqual(r3.eval(), 17) 2487 2488 # Duplicate events cause an error if exclusive = True 2489 r4 = control_flow_ops.case( 2490 [(x < y, f1), (x < y, f2)], default=f3, exclusive=True) 2491 with self.assertRaisesOpError("Input error:"): 2492 r4.eval() 2493 2494 # Check that the default is called if none of the others are 2495 r5 = control_flow_ops.case({x > y: f1}, default=f3) 2496 self.assertAllEqual(r5.eval(), -1) 2497 2498 ran_once = [False, False, False] 2499 2500 def break_run_twice(ix): 2501 2502 def _break(): 2503 ran_once[ix] = True 2504 return constant_op.constant(ix) 2505 2506 return _break 2507 2508 # Should not fail - each conditional gets called exactly once 2509 # except default. Default gets called twice: once to create an 2510 # empty output and once for the actual cond switch. 2511 r6 = control_flow_ops.case( 2512 [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))], 2513 default=lambda: constant_op.constant(2)) 2514 2515 self.assertAllEqual(r6.eval(), 0) 2516 2517 def testCaseSideEffects(self): 2518 with self.test_session() as sess: 2519 v0 = variables.Variable(-1) 2520 v1 = variables.Variable(-1) 2521 v2 = variables.Variable(-1) 2522 2523 a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0) 2524 b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1) 2525 c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2) 2526 2527 x = constant_op.constant(1) 2528 y = constant_op.constant(2) 2529 2530 r0 = control_flow_ops.case( 2531 ((x < y, a), (x > y, b)), default=c, exclusive=True) 2532 r1 = control_flow_ops.case( 2533 ((x > y, a), (x < y, b)), default=c, exclusive=True) 2534 r2 = control_flow_ops.case( 2535 ((x > y, a), (x > y, b)), default=c, exclusive=True) 2536 2537 variables.global_variables_initializer().run() 2538 self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3) 2539 self.assertEqual(2, r2.eval()) 2540 self.assertAllEqual(sess.run([v0, v1, v2]), [-1, -1, 2]) 2541 2542 variables.global_variables_initializer().run() 2543 self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3) 2544 self.assertEqual(1, r1.eval()) 2545 self.assertAllEqual(sess.run([v0, v1, v2]), [-1, 1, -1]) 2546 2547 variables.global_variables_initializer().run() 2548 self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3) 2549 self.assertEqual(0, r0.eval()) 2550 self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1]) 2551 2552 def testOneOpCond(self): 2553 with self.test_session(): 2554 v = variables.Variable(0) 2555 c = ops.convert_to_tensor(0) 2556 one = ops.convert_to_tensor(1) 2557 two = ops.convert_to_tensor(2) 2558 p = math_ops.greater_equal(c, 1) 2559 2560 def a(): 2561 return state_ops.assign(v, one) 2562 2563 def b(): 2564 return state_ops.assign(v, two) 2565 2566 i = control_flow_ops.cond(p, a, b) 2567 self.assertTrue(isinstance(i, ops.Tensor)) 2568 variables.global_variables_initializer().run() 2569 2570 self.assertEqual(0, v.eval()) 2571 2572 # True case: c = 2 is >= 1, v is set to 1. 2573 self.assertEqual(1, i.eval(feed_dict={c.name: 2})) 2574 self.assertEqual(1, v.eval()) 2575 2576 # False case: c = 0 is not >= 1, v is set to 2. 2577 self.assertEqual(2, i.eval(feed_dict={c.name: 0})) 2578 self.assertEqual(2, v.eval()) 2579 2580 def testWithOpsDependencies(self): 2581 with self.test_session() as sess: 2582 v = variables.Variable(0.0) 2583 c = constant_op.constant(10) 2584 2585 # Fetching v directly will result in an uninitialized error 2586 with self.assertRaisesOpError("Attempting to use uninitialized value"): 2587 sess.run([c, v]) 2588 2589 # Use a control dependency to ensure init_variable is run 2590 # while asking for c 2591 real_v = control_flow_ops.with_dependencies( 2592 name="real_tensor", 2593 output_tensor=v._ref(), # pylint: disable=protected-access 2594 dependencies=[v.initializer]) 2595 c_val, real_v_val = sess.run([c, real_v]) 2596 2597 # Ensure the result of 'real_c' is the same as 'c' 2598 self.assertAllEqual(10, c_val) 2599 2600 # Ensure that 'v' is initialized 2601 self.assertAllClose(0.0, real_v_val) 2602 2603 def testWithTensorDependencies(self): 2604 with self.test_session(): 2605 v = variables.Variable(0.0) 2606 c1 = constant_op.constant(10) 2607 c2 = constant_op.constant(20) 2608 2609 # c1_with_init_v depends on the init op for v 2610 c1_with_init_v = control_flow_ops.with_dependencies( 2611 name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer]) 2612 # c2_with_c1 depends on the value of c1_with_init_v 2613 c2_with_c1_dep = control_flow_ops.with_dependencies( 2614 name="c2_with_c1_dep", 2615 output_tensor=c2, 2616 dependencies=[c1_with_init_v]) 2617 2618 # Fetching v directly will result in an uninitialized error 2619 with self.assertRaisesOpError("Attempting to use uninitialized value"): 2620 v.eval() 2621 2622 # Get the value of 'c2_with_c1_dep', which should cause 'v' 2623 # to be initialized. 2624 self.assertAllEqual(20, c2_with_c1_dep.eval()) 2625 2626 # Ensure that 'v' is initialized 2627 self.assertAllClose(0.0, v.eval()) 2628 2629 def testWithIndexedSlicesDependencies(self): 2630 with self.test_session(): 2631 v = variables.Variable( 2632 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32)) 2633 v_at_1 = ops.IndexedSlices(v, constant_op.constant([1])) 2634 gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices) 2635 v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer], 2636 v_at_1) 2637 gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values, 2638 v_at_1_after_init.indices) 2639 2640 # Fetching gather_v_at_1 will result in an uninitialized error 2641 with self.assertRaisesOpError("Attempting to use uninitialized value"): 2642 gather_v_at_1.eval() 2643 2644 # Getting gather_v_at_1_after_init will work, and initialize v. 2645 self.assertAllEqual([[10.0, 11.0]], gather_v_at_1_after_init.eval()) 2646 2647 # Double check that 'v' is initialized 2648 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], v.eval()) 2649 2650 def testDependenciesDevice(self): 2651 with ops.Graph().as_default(): 2652 # device set on tensor => same device on dep. 2653 with ops.device("/job:ps"): 2654 vd = variables.Variable([0.0]) 2655 with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd) 2656 self.assertTrue("/job:ps" in with_vd_dep.device) 2657 2658 # No device set on tensor => no device on dep. 2659 vnod = variables.Variable([0.0]) 2660 with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer], 2661 vnod) 2662 self.assertDeviceEqual(None, with_vnod_dep.device) 2663 2664 # device set on tensor, default device on graph => default device on dep. 2665 vdef = variables.Variable([0.0], name="vdef") 2666 with ops.device("/job:worker/device:GPU:1"): 2667 with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer], 2668 vdef) 2669 # The device is empty, but the colocation constraint is set. 2670 self.assertDeviceEqual("", with_vdef_dep.device) 2671 self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups()) 2672 2673 def testGroup(self): 2674 with self.test_session() as sess: 2675 v1 = variables.Variable([0.0]) 2676 v2 = variables.Variable([1.0]) 2677 2678 # Group init1 and init2 and run. 2679 init = control_flow_ops.group(v1.initializer, v2.initializer) 2680 # Fetching v1 directly will result in an uninitialized error 2681 with self.assertRaisesOpError("Attempting to use uninitialized value"): 2682 v1.eval() 2683 2684 # Runs "init" before fetching v1 and v2. 2685 init.run() 2686 v1_val, v2_val = sess.run([v1, v2]) 2687 2688 # Ensure that v1 and v2 are initialized 2689 self.assertAllClose([0.0], v1_val) 2690 self.assertAllClose([1.0], v2_val) 2691 2692 def testGroupEmpty(self): 2693 op = control_flow_ops.group() 2694 self.assertEqual(op.type, "NoOp") 2695 self.assertEqual(op.control_inputs, []) 2696 2697 def testMergeShapes(self): 2698 # All inputs unknown. 2699 p1 = array_ops.placeholder(dtypes.float32) 2700 p2 = array_ops.placeholder(dtypes.float32) 2701 p3 = array_ops.placeholder(dtypes.float32) 2702 m, index = control_flow_ops.merge([p1, p2, p3]) 2703 self.assertIs(None, m.get_shape().ndims) 2704 self.assertEqual([], index.get_shape()) 2705 2706 # All inputs known with different ranks. 2707 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 2708 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3]) 2709 m, index = control_flow_ops.merge([p1, p2]) 2710 self.assertIs(None, m.get_shape().ndims) 2711 self.assertEqual([], index.get_shape()) 2712 2713 # All inputs known with some dimensions different. 2714 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 2715 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1]) 2716 m, index = control_flow_ops.merge([p1, p2]) 2717 self.assertEqual([None, None], m.get_shape().as_list()) 2718 self.assertEqual([], index.get_shape()) 2719 2720 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 2721 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 2722 m, index = control_flow_ops.merge([p1, p2]) 2723 self.assertEqual([None, 2], m.get_shape().as_list()) 2724 self.assertEqual([], index.get_shape()) 2725 2726 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 2727 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2]) 2728 m, index = control_flow_ops.merge([p1, p2]) 2729 self.assertEqual([None, 2], m.get_shape().as_list()) 2730 self.assertEqual([], index.get_shape()) 2731 2732 # All inputs known with same dimensions. 2733 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 2734 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 2735 m, index = control_flow_ops.merge([p1, p2]) 2736 self.assertEqual([1, 2], m.get_shape().as_list()) 2737 self.assertEqual([], index.get_shape()) 2738 2739 p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 2740 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 2741 m, index = control_flow_ops.merge([p1, p2]) 2742 self.assertEqual([None, 2], m.get_shape().as_list()) 2743 self.assertEqual([], index.get_shape()) 2744 2745 p1 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 2746 p2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 2747 m, index = control_flow_ops.merge([p1, p2]) 2748 self.assertEqual([None, None], m.get_shape().as_list()) 2749 self.assertEqual([], index.get_shape()) 2750 2751 def testRefSelect(self): 2752 index = array_ops.placeholder(dtypes.int32) 2753 2754 # All inputs unknown. 2755 p1 = array_ops.placeholder(dtypes.float32) 2756 p2 = array_ops.placeholder(dtypes.float32) 2757 p3 = array_ops.placeholder(dtypes.float32) 2758 v1 = variables.Variable(p1, validate_shape=False) 2759 v2 = variables.Variable(p2, validate_shape=False) 2760 v3 = variables.Variable(p3, validate_shape=False) 2761 self.assertIs(None, v1.get_shape().ndims) 2762 s = control_flow_ops.ref_select(index, [v1, v2, v3]) 2763 self.assertIs(None, s.get_shape().ndims) 2764 2765 # All inputs known but different. 2766 v1 = variables.Variable([[1, 2]]) 2767 v2 = variables.Variable([[2], [1]]) 2768 s = control_flow_ops.ref_select(index, [v1, v2]) 2769 self.assertIs(None, s.get_shape().ndims) 2770 2771 # All inputs known and same. 2772 v1 = variables.Variable([[1, 2]]) 2773 v2 = variables.Variable([[1, 2]]) 2774 s = control_flow_ops.ref_select(index, [v1, v2]) 2775 self.assertEqual([1, 2], s.get_shape()) 2776 2777 # Possibly the same but not guaranteed. 2778 v1 = variables.Variable([[1., 2.]]) 2779 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 2780 v2 = variables.Variable(p2, validate_shape=False) 2781 s = control_flow_ops.ref_select(index, [v1, v2]) 2782 self.assertEqual(None, s.get_shape()) 2783 2784 def testRunLoopTensor(self): 2785 with self.test_session() as sess: 2786 tensor_list = [] 2787 2788 def condition(t): 2789 return t < constant_op.constant(5) 2790 2791 def body(_): 2792 tensor_list.append(constant_op.constant(5)) 2793 return constant_op.constant(10) 2794 2795 result = control_flow_ops.while_loop(condition, body, 2796 [constant_op.constant(4)]) 2797 self.assertEqual(10, sess.run(result)) 2798 2799 # Ensure that we cannot run a tensor that escapes the loop body 2800 # accidentally. 2801 with self.assertRaises(ValueError): 2802 sess.run(tensor_list[0]) 2803 2804 def testWhilePyFuncBasic(self): 2805 2806 def func(x): 2807 return np.square(x) 2808 2809 with self.test_session(): 2810 r = control_flow_ops.while_loop( 2811 lambda i, v: i < 4, 2812 lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]], 2813 [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)], 2814 [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()]) 2815 self.assertEqual(r[1].eval(), 65536.0) 2816 2817 def testWhileFuncBasic(self): 2818 2819 @function.Defun(dtypes.float32) 2820 def func(x): 2821 return math_ops.square(math_ops.square(x)) 2822 2823 with self.test_session(): 2824 x = constant_op.constant(2.0, dtypes.float32) 2825 r = control_flow_ops.while_loop( 2826 lambda i, v: i < 2, lambda i, v: [i + 1, func(v)], 2827 [constant_op.constant(0), x], 2828 [tensor_shape.unknown_shape(), 2829 tensor_shape.unknown_shape()]) 2830 self.assertEqual(r[1].eval(), 65536.0) 2831 2832 r = gradients_impl.gradients(r, x)[0] 2833 self.assertEqual(r.eval(), 524288.0) 2834 self.assertEqual( 2835 len([op for op in x.graph.get_operations() if op.type == "StackV2"]), 2836 1) 2837 2838 2839@test_util.with_c_api 2840class ControlFlowContextCheckTest(test.TestCase): 2841 2842 def _getWhileTensor(self): 2843 """Creates and returns a tensor from a while context.""" 2844 tensor = [] 2845 2846 def body(i): 2847 if not tensor: 2848 tensor.append(constant_op.constant(1)) 2849 return i + tensor[0] 2850 2851 control_flow_ops.while_loop(lambda i: i < 10, body, [0]) 2852 return tensor[0] 2853 2854 def _getCondTensor(self): 2855 cond_tensor = [] 2856 2857 def true_fn(): 2858 if not cond_tensor: 2859 cond_tensor.append(constant_op.constant(1)) 2860 return cond_tensor[0] 2861 2862 control_flow_ops.cond( 2863 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 2864 return cond_tensor[0] 2865 2866 def testInvalidContext(self): 2867 # Accessing a while loop tensor outside of control flow is illegal. 2868 while_tensor = self._getWhileTensor() 2869 with self.assertRaisesRegexp( 2870 ValueError, 2871 "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' " 2872 "is in a while loop. See info log for more details."): 2873 math_ops.add(1, while_tensor) 2874 2875 def testInvalidContextInCond(self): 2876 # Accessing a while loop tensor in cond is illegal. 2877 while_tensor = self._getWhileTensor() 2878 with self.assertRaisesRegexp( 2879 ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because " 2880 "'while/Const_1' is in a while loop. See info log for more details."): 2881 # TODO(skyewm): this passes if we return while_tensor directly instead 2882 # of using it as input to another op. 2883 control_flow_ops.cond( 2884 math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor), 2885 lambda: constant_op.constant(0)) 2886 2887 def testInvalidContextInWhile(self): 2888 # Accessing a while loop tensor in a different while loop is illegal. 2889 while_tensor = self._getWhileTensor() 2890 with self.assertRaisesRegexp( 2891 ValueError, 2892 "Cannot use 'while_1/Add' as input to 'while/Const_1' because they are " 2893 "in different while loops. See info log for more details."): 2894 control_flow_ops.while_loop(lambda i: i < 10, 2895 lambda x: math_ops.add(1, while_tensor), [0]) 2896 2897 with self.assertRaisesRegexp( 2898 ValueError, 2899 "Cannot use 'while_2/NextIteration' as input to 'while/Const_1' " 2900 "because they are in different while loops. See info log for more " 2901 "details."): 2902 control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0]) 2903 2904 def testValidCondContext(self): 2905 # Accessing a tensor from a cond context is OK (although dangerous). 2906 cond_tensor = self._getCondTensor() 2907 math_ops.add(1, cond_tensor) 2908 2909 def testValidCondContextBranches(self): 2910 # Accessing a tensor from a cond context from the other branch's cond 2911 # context is OK (although dangerous). 2912 cond_tensor = [] 2913 2914 def branch_fn(): 2915 if not cond_tensor: 2916 cond_tensor.append(constant_op.constant(1)) 2917 return cond_tensor[0] 2918 2919 control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn) 2920 2921 def testValidWhileContext(self): 2922 # Accessing a tensor in a nested while is OK. 2923 def body(_): 2924 c = constant_op.constant(1) 2925 return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0]) 2926 2927 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 2928 2929 def testValidNestedContexts(self): 2930 # Accessing a tensor from a cond context in a while context, all inside an 2931 # outer while context, is OK. 2932 def body(_): 2933 cond_tensor = self._getCondTensor() 2934 # Create another cond containing the while loop for good measure 2935 return control_flow_ops.cond( 2936 math_ops.less(1, 2), 2937 lambda: control_flow_ops.while_loop(lambda i: i < 3, 2938 lambda i: i + cond_tensor, [0]), 2939 lambda: constant_op.constant(0)) 2940 2941 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 2942 2943 def testInvalidNestedContexts(self): 2944 # Accessing a tensor from a while context in a different while context, all 2945 # inside a cond context, is illegal. 2946 def true_fn(): 2947 while_tensor = self._getWhileTensor() 2948 return control_flow_ops.while_loop(lambda i: i < 3, 2949 lambda i: i + while_tensor, [0]) 2950 2951 with self.assertRaisesRegexp( 2952 ValueError, 2953 "Cannot use 'cond/while_1/add' as input to 'cond/while/Const_1' because" 2954 " they are in different while loops. See info log for more details."): 2955 control_flow_ops.cond( 2956 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 2957 2958 2959@test_util.with_c_api 2960class TupleTest(test.TestCase): 2961 2962 def testTensors(self): 2963 for v1_first in [True, False]: 2964 with self.test_session(): 2965 v1 = variables.Variable([1.0]) 2966 add1 = math_ops.add( 2967 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 2968 2.0) 2969 v2 = variables.Variable([10.0]) 2970 add2 = math_ops.add( 2971 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 2972 20.0) 2973 t1, _, t2 = control_flow_ops.tuple([add1, None, add2]) 2974 2975 # v1 is not initialized. 2976 with self.assertRaisesOpError("Attempting to use uninitialized value"): 2977 v1.eval() 2978 2979 # v2 is not initialized. 2980 with self.assertRaisesOpError("Attempting to use uninitialized value"): 2981 v2.eval() 2982 2983 if v1_first: 2984 # Getting t1 initializes v2. 2985 self.assertAllClose([3.0], t1.eval()) 2986 self.assertAllClose([10.0], v2.eval()) 2987 else: 2988 # Getting t2 initializes v1. 2989 self.assertAllClose([30.0], t2.eval()) 2990 self.assertAllClose([1.0], v1.eval()) 2991 2992 def testIndexedSlices(self): 2993 for v1_first in [True, False]: 2994 with self.test_session(): 2995 v1 = variables.Variable( 2996 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype( 2997 np.float32)) 2998 v1_at_1 = ops.IndexedSlices( 2999 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 3000 constant_op.constant([1])) 3001 3002 v2 = variables.Variable( 3003 np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype( 3004 np.float32)) 3005 v2_at_1 = ops.IndexedSlices( 3006 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 3007 constant_op.constant([1])) 3008 3009 st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1]) 3010 g1 = array_ops.gather(st1.values, st1.indices) 3011 g2 = array_ops.gather(st2.values, st2.indices) 3012 3013 # v1 is not initialized. 3014 with self.assertRaisesOpError("Attempting to use uninitialized value"): 3015 v1.eval() 3016 3017 # v2 is not initialized. 3018 with self.assertRaisesOpError("Attempting to use uninitialized value"): 3019 v2.eval() 3020 3021 if v1_first: 3022 # Getting g1 initializes v2. 3023 self.assertAllClose([[10.0, 11.0]], g1.eval()) 3024 self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]], 3025 v2.eval()) 3026 else: 3027 # Getting g2 initializes v1. 3028 self.assertAllClose([[10.1, 11.1]], g2.eval()) 3029 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 3030 v1.eval()) 3031 3032 def testAcceptTensorsAsControlInputs(self): 3033 with self.test_session(): 3034 var = variables.Variable(0) 3035 assign = state_ops.assign(var, 1) 3036 t, = control_flow_ops.tuple( 3037 [constant_op.constant(0)], control_inputs=[assign]) 3038 3039 # Should trigger the assign. 3040 t.eval() 3041 3042 self.assertEquals(1, var.eval()) 3043 3044 3045@test_util.with_c_api 3046class AssertTest(test.TestCase): 3047 3048 def testGuardedAssertDoesNotCopyWhenTrue(self): 3049 with self.test_session(use_gpu=True) as sess: 3050 with ops.device(test.gpu_device_name()): 3051 value = constant_op.constant(1.0) 3052 with ops.device("/cpu:0"): 3053 true = constant_op.constant(True) 3054 guarded_assert = control_flow_ops.Assert(true, [value], name="guarded") 3055 unguarded_assert = gen_logging_ops._assert( 3056 true, [value], name="unguarded") 3057 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 3058 guarded_metadata = config_pb2.RunMetadata() 3059 sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata) 3060 unguarded_metadata = config_pb2.RunMetadata() 3061 sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata) 3062 guarded_nodestat_names = [ 3063 n.node_name 3064 for d in guarded_metadata.step_stats.dev_stats 3065 for n in d.node_stats 3066 ] 3067 unguarded_nodestat_names = [ 3068 n.node_name 3069 for d in unguarded_metadata.step_stats.dev_stats 3070 for n in d.node_stats 3071 ] 3072 guarded_memcpy_nodestat_names = [ 3073 n for n in guarded_nodestat_names if "MEMCPYDtoH" in n 3074 ] 3075 unguarded_memcpy_nodestat_names = [ 3076 n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n 3077 ] 3078 if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: 3079 # A copy was performed for the unguarded assert 3080 self.assertLess(0, len(unguarded_memcpy_nodestat_names)) 3081 # No copy was performed for the guarded assert 3082 self.assertEqual([], guarded_memcpy_nodestat_names) 3083 3084 3085@test_util.with_c_api 3086class WhileOpBenchmark(test.Benchmark): 3087 """Evaluate the performance of while_loop op.""" 3088 3089 def _getInitVariables(self): 3090 batch_size = 10 3091 image_size = 256 3092 kernel_size = 3 3093 depth = 16 3094 3095 init_step = constant_op.constant(-1) 3096 image = variable_scope.get_variable( 3097 "image", 3098 initializer=random_ops.random_normal( 3099 [batch_size, image_size, image_size, depth], 3100 dtype=dtypes.float32, 3101 stddev=1e-1)) 3102 kernel = variable_scope.get_variable( 3103 "weights", 3104 initializer=random_ops.truncated_normal( 3105 [kernel_size, kernel_size, depth, depth], 3106 dtype=dtypes.float32, 3107 stddev=1e-1)) 3108 return init_step, image, kernel 3109 3110 def _runOneBenchmark(self, 3111 default_device, 3112 num_iters=10, 3113 static_unroll=False, 3114 steps=10): 3115 """Evaluate the while loop performance. 3116 3117 Args: 3118 default_device: The default device to run all ops except the loop_body. 3119 loop_body is always run on GPU. 3120 num_iters: Number of iterations to run. 3121 static_unroll: If true, run unrolled version; otherwise, run while_loop. 3122 steps: Total number of repeated steps to run the loop. 3123 3124 Returns: 3125 The duration of the run in seconds. 3126 """ 3127 3128 def loop_body(i, x): 3129 with ops.device("/gpu:0"): 3130 # Always put loop body on GPU. 3131 nx = nn_ops.conv2d( 3132 input=x, 3133 filter=kernel, 3134 strides=[1, 1, 1, 1], 3135 padding="SAME", 3136 data_format="NHWC", 3137 name="conv2d") 3138 ni = math_ops.add(i, 1) 3139 return ni, nx 3140 3141 ops.reset_default_graph() 3142 with session.Session() as sess, ops.device(default_device): 3143 # Get the initial id i, input x, and kernel. 3144 i, x, kernel = self._getInitVariables() 3145 sess.run(variables.global_variables_initializer()) 3146 3147 if static_unroll: 3148 for _ in xrange(steps): 3149 i, x = loop_body(i, x) 3150 else: 3151 i, x = control_flow_ops.while_loop( 3152 lambda i, _: i < steps, 3153 loop_body, [i, x], 3154 parallel_iterations=steps, 3155 swap_memory=True) 3156 3157 r = math_ops.reduce_sum(x) 3158 dx, dk = gradients_impl.gradients(r, [x, kernel]) 3159 # Use group to avoid fetching back results. 3160 r = control_flow_ops.group(dx, dk) 3161 3162 for _ in xrange(3): 3163 # exclude warm up time 3164 sess.run(r) 3165 3166 start_time = time.time() 3167 for _ in xrange(num_iters): 3168 sess.run(r) 3169 return (time.time() - start_time) / num_iters 3170 3171 def benchmarkWhileOpCrossDevicePlacement(self): 3172 iters = 10 3173 # Run loop body on GPU, but other ops on CPU. 3174 duration = self._runOneBenchmark("cpu", iters, static_unroll=False) 3175 self.report_benchmark( 3176 name="while_op_cross_device", iters=iters, wall_time=duration) 3177 3178 def benchmarkWhileOpSameDevicePlacement(self): 3179 iters = 10 3180 # Run all ops on the same GPU device. 3181 duration = self._runOneBenchmark("gpu", iters, static_unroll=False) 3182 self.report_benchmark( 3183 name="while_op_same_device", iters=iters, wall_time=duration) 3184 3185 def benchmarkWhileOpUnrollCrossDevicePlacement(self): 3186 iters = 10 3187 # Run loop body on GPU, but other ops on CPU. 3188 duration = self._runOneBenchmark("cpu", iters, static_unroll=True) 3189 self.report_benchmark( 3190 name="unroll_cross_device_cpu", iters=iters, wall_time=duration) 3191 3192 def benchmarkWhileOpUnrollSameDevicePlacement(self): 3193 iters = 10 3194 # Run all ops on GPU. 3195 duration = self._runOneBenchmark("gpu", iters, static_unroll=True) 3196 self.report_benchmark( 3197 name="unroll_same_device", iters=iters, wall_time=duration) 3198 3199 3200@test_util.with_c_api 3201class EagerTest(test.TestCase): 3202 3203 def testCond(self): 3204 with context.eager_mode(): 3205 pred = math_ops.less(1, 2) 3206 fn1 = lambda: [constant_op.constant(10)] 3207 fn2 = lambda: [constant_op.constant(20)] 3208 r = control_flow_ops.cond(pred, fn1, fn2) 3209 3210 self.assertAllEqual(r.numpy(), 10) 3211 self.assertFalse(isinstance(r, list)) 3212 3213 def testWhileLoop(self): 3214 with context.eager_mode(): 3215 tensor = constant_op.constant([1, 2, 3, 4, 5]) 3216 self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50]) 3217 3218 def testWhileLoopWithMaxIterations(self): 3219 with context.eager_mode(): 3220 tensor = constant_op.constant([1, 2, 3, 4, 5]) 3221 self.assertAllEqual( 3222 isum(tensor, maximum_iterations=3).numpy(), 3223 [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3]) 3224 3225 def testWhileWithMaximumIterationsAndSingleArgument(self): 3226 with context.eager_mode(): 3227 tensor = constant_op.constant(0) 3228 r = control_flow_ops.while_loop( 3229 lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1) 3230 self.assertEqual(1, r.numpy()) 3231 3232 def testWithDependencies(self): 3233 with context.eager_mode(): 3234 t1 = constant_op.constant(1) 3235 t2 = constant_op.constant(2) 3236 t3 = control_flow_ops.with_dependencies(t1, t2) 3237 self.assertAllEqual(t2.numpy(), t3.numpy()) 3238 3239 def testTuple(self): 3240 with context.eager_mode(): 3241 t1 = constant_op.constant(1) 3242 t2 = constant_op.constant(2) 3243 tup1, tup2 = control_flow_ops.tuple([t1, t2]) 3244 self.assertAllEqual(t1.numpy(), tup1.numpy()) 3245 self.assertAllEqual(t2.numpy(), tup2.numpy()) 3246 3247 def testCase(self): 3248 with context.eager_mode(): 3249 x = constant_op.constant(1) 3250 y = constant_op.constant(2) 3251 z = constant_op.constant(3) 3252 f1 = lambda: constant_op.constant(17) 3253 f2 = lambda: constant_op.constant(23) 3254 f3 = lambda: constant_op.constant(-1) 3255 3256 r1 = control_flow_ops.case( 3257 [(x < y, f1), (x > z, f2)], default=f3, exclusive=True) 3258 self.assertAllEqual(r1.numpy(), 17) 3259 3260 3261if __name__ == "__main__": 3262 test.main() 3263