1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for barrier ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import time 22 23import numpy as np 24 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors_impl 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import data_flow_ops 29from tensorflow.python.platform import test 30 31 32class BarrierTest(test.TestCase): 33 34 def testConstructorWithShapes(self): 35 with ops.Graph().as_default(): 36 b = data_flow_ops.Barrier( 37 (dtypes.float32, dtypes.float32), 38 shapes=((1, 2, 3), (8,)), 39 shared_name="B", 40 name="B") 41 self.assertTrue(isinstance(b.barrier_ref, ops.Tensor)) 42 self.assertProtoEquals(""" 43 name:'B' op:'Barrier' 44 attr { 45 key: "capacity" 46 value { 47 i: -1 48 } 49 } 50 attr { key: 'component_types' 51 value { list { type: DT_FLOAT type: DT_FLOAT } } } 52 attr { 53 key: 'shapes' 54 value { 55 list { 56 shape { 57 dim { size: 1 } dim { size: 2 } dim { size: 3 } 58 } 59 shape { 60 dim { size: 8 } 61 } 62 } 63 } 64 } 65 attr { key: 'container' value { s: "" } } 66 attr { key: 'shared_name' value: { s: 'B' } } 67 """, b.barrier_ref.op.node_def) 68 69 def testInsertMany(self): 70 with self.test_session(): 71 b = data_flow_ops.Barrier( 72 (dtypes.float32, dtypes.float32), shapes=((), ()), name="B") 73 size_t = b.ready_size() 74 self.assertEqual([], size_t.get_shape()) 75 keys = [b"a", b"b", b"c"] 76 insert_0_op = b.insert_many(0, keys, [10.0, 20.0, 30.0]) 77 insert_1_op = b.insert_many(1, keys, [100.0, 200.0, 300.0]) 78 79 self.assertEquals(size_t.eval(), [0]) 80 insert_0_op.run() 81 self.assertEquals(size_t.eval(), [0]) 82 insert_1_op.run() 83 self.assertEquals(size_t.eval(), [3]) 84 85 def testInsertManyEmptyTensor(self): 86 with self.test_session(): 87 error_message = ("Empty tensors are not supported, but received shape " 88 r"\'\(0,\)\' at index 1") 89 with self.assertRaisesRegexp(ValueError, error_message): 90 data_flow_ops.Barrier( 91 (dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B") 92 93 def testInsertManyEmptyTensorUnknown(self): 94 with self.test_session(): 95 b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B") 96 size_t = b.ready_size() 97 self.assertEqual([], size_t.get_shape()) 98 keys = [b"a", b"b", b"c"] 99 insert_0_op = b.insert_many(0, keys, np.array([[], [], []], np.float32)) 100 self.assertEquals(size_t.eval(), [0]) 101 with self.assertRaisesOpError( 102 ".*Tensors with no elements are not supported.*"): 103 insert_0_op.run() 104 105 def testTakeMany(self): 106 with self.test_session() as sess: 107 b = data_flow_ops.Barrier( 108 (dtypes.float32, dtypes.float32), shapes=((), ()), name="B") 109 size_t = b.ready_size() 110 keys = [b"a", b"b", b"c"] 111 values_0 = [10.0, 20.0, 30.0] 112 values_1 = [100.0, 200.0, 300.0] 113 insert_0_op = b.insert_many(0, keys, values_0) 114 insert_1_op = b.insert_many(1, keys, values_1) 115 take_t = b.take_many(3) 116 117 insert_0_op.run() 118 insert_1_op.run() 119 self.assertEquals(size_t.eval(), [3]) 120 121 indices_val, keys_val, values_0_val, values_1_val = sess.run( 122 [take_t[0], take_t[1], take_t[2][0], take_t[2][1]]) 123 124 self.assertAllEqual(indices_val, [-2**63] * 3) 125 for k, v0, v1 in zip(keys, values_0, values_1): 126 idx = keys_val.tolist().index(k) 127 self.assertEqual(values_0_val[idx], v0) 128 self.assertEqual(values_1_val[idx], v1) 129 130 def testTakeManySmallBatch(self): 131 with self.test_session() as sess: 132 b = data_flow_ops.Barrier( 133 (dtypes.float32, dtypes.float32), shapes=((), ()), name="B") 134 size_t = b.ready_size() 135 size_i = b.incomplete_size() 136 keys = [b"a", b"b", b"c", b"d"] 137 values_0 = [10.0, 20.0, 30.0, 40.0] 138 values_1 = [100.0, 200.0, 300.0, 400.0] 139 insert_0_op = b.insert_many(0, keys, values_0) 140 # Split adding of the second component into two independent operations. 141 # After insert_1_1_op, we'll have two ready elements in the barrier, 142 # 2 will still be incomplete. 143 insert_1_1_op = b.insert_many(1, keys[0:2], values_1[0:2]) # add "a", "b" 144 insert_1_2_op = b.insert_many(1, keys[2:3], values_1[2:3]) # add "c" 145 insert_1_3_op = b.insert_many(1, keys[3:], values_1[3:]) # add "d" 146 insert_empty_op = b.insert_many(0, [], []) 147 close_op = b.close() 148 close_op_final = b.close(cancel_pending_enqueues=True) 149 index_t, key_t, value_list_t = b.take_many(3, allow_small_batch=True) 150 insert_0_op.run() 151 insert_1_1_op.run() 152 close_op.run() 153 # Now we have a closed barrier with 2 ready elements. Running take_t 154 # should return a reduced batch with 2 elements only. 155 self.assertEquals(size_i.eval(), [2]) # assert that incomplete size = 2 156 self.assertEquals(size_t.eval(), [2]) # assert that ready size = 2 157 _, keys_val, values_0_val, values_1_val = sess.run( 158 [index_t, key_t, value_list_t[0], value_list_t[1]]) 159 # Check that correct values have been returned. 160 for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]): 161 idx = keys_val.tolist().index(k) 162 self.assertEqual(values_0_val[idx], v0) 163 self.assertEqual(values_1_val[idx], v1) 164 165 # The next insert completes the element with key "c". The next take_t 166 # should return a batch with just 1 element. 167 insert_1_2_op.run() 168 self.assertEquals(size_i.eval(), [1]) # assert that incomplete size = 1 169 self.assertEquals(size_t.eval(), [1]) # assert that ready size = 1 170 _, keys_val, values_0_val, values_1_val = sess.run( 171 [index_t, key_t, value_list_t[0], value_list_t[1]]) 172 # Check that correct values have been returned. 173 for k, v0, v1 in zip(keys[2:3], values_0[2:3], values_1[2:3]): 174 idx = keys_val.tolist().index(k) 175 self.assertEqual(values_0_val[idx], v0) 176 self.assertEqual(values_1_val[idx], v1) 177 178 # Adding nothing ought to work, even if the barrier is closed. 179 insert_empty_op.run() 180 181 # currently keys "a" and "b" are not in the barrier, adding them 182 # again after it has been closed, ought to cause failure. 183 with self.assertRaisesOpError("is closed"): 184 insert_1_1_op.run() 185 close_op_final.run() 186 187 # These ops should fail because the barrier has now been closed with 188 # cancel_pending_enqueues = True. 189 with self.assertRaisesOpError("is closed"): 190 insert_empty_op.run() 191 with self.assertRaisesOpError("is closed"): 192 insert_1_3_op.run() 193 194 def testUseBarrierWithShape(self): 195 with self.test_session() as sess: 196 b = data_flow_ops.Barrier( 197 (dtypes.float32, dtypes.float32), shapes=((2, 2), (8,)), name="B") 198 size_t = b.ready_size() 199 keys = [b"a", b"b", b"c"] 200 values_0 = np.array( 201 [[[10.0] * 2] * 2, [[20.0] * 2] * 2, [[30.0] * 2] * 2], np.float32) 202 values_1 = np.array([[100.0] * 8, [200.0] * 8, [300.0] * 8], np.float32) 203 insert_0_op = b.insert_many(0, keys, values_0) 204 insert_1_op = b.insert_many(1, keys, values_1) 205 take_t = b.take_many(3) 206 207 insert_0_op.run() 208 insert_1_op.run() 209 self.assertEquals(size_t.eval(), [3]) 210 211 indices_val, keys_val, values_0_val, values_1_val = sess.run( 212 [take_t[0], take_t[1], take_t[2][0], take_t[2][1]]) 213 self.assertAllEqual(indices_val, [-2**63] * 3) 214 self.assertShapeEqual(keys_val, take_t[1]) 215 self.assertShapeEqual(values_0_val, take_t[2][0]) 216 self.assertShapeEqual(values_1_val, take_t[2][1]) 217 218 for k, v0, v1 in zip(keys, values_0, values_1): 219 idx = keys_val.tolist().index(k) 220 self.assertAllEqual(values_0_val[idx], v0) 221 self.assertAllEqual(values_1_val[idx], v1) 222 223 def testParallelInsertMany(self): 224 with self.test_session() as sess: 225 b = data_flow_ops.Barrier(dtypes.float32, shapes=()) 226 size_t = b.ready_size() 227 keys = [str(x).encode("ascii") for x in range(10)] 228 values = [float(x) for x in range(10)] 229 insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)] 230 take_t = b.take_many(10) 231 232 sess.run(insert_ops) 233 self.assertEquals(size_t.eval(), [10]) 234 235 indices_val, keys_val, values_val = sess.run( 236 [take_t[0], take_t[1], take_t[2][0]]) 237 238 self.assertAllEqual(indices_val, [-2**63 + x for x in range(10)]) 239 for k, v in zip(keys, values): 240 idx = keys_val.tolist().index(k) 241 self.assertEqual(values_val[idx], v) 242 243 def testParallelTakeMany(self): 244 with self.test_session() as sess: 245 b = data_flow_ops.Barrier(dtypes.float32, shapes=()) 246 size_t = b.ready_size() 247 keys = [str(x).encode("ascii") for x in range(10)] 248 values = [float(x) for x in range(10)] 249 insert_op = b.insert_many(0, keys, values) 250 take_t = [b.take_many(1) for _ in keys] 251 252 insert_op.run() 253 self.assertEquals(size_t.eval(), [10]) 254 255 index_fetches = [] 256 key_fetches = [] 257 value_fetches = [] 258 for ix_t, k_t, v_t in take_t: 259 index_fetches.append(ix_t) 260 key_fetches.append(k_t) 261 value_fetches.append(v_t[0]) 262 vals = sess.run(index_fetches + key_fetches + value_fetches) 263 264 index_vals = vals[:len(keys)] 265 key_vals = vals[len(keys):2 * len(keys)] 266 value_vals = vals[2 * len(keys):] 267 268 taken_elems = [] 269 for k, v in zip(key_vals, value_vals): 270 taken_elems.append((k[0], v[0])) 271 272 self.assertAllEqual(np.hstack(index_vals), [-2**63] * 10) 273 274 self.assertItemsEqual( 275 zip(keys, values), [(k[0], v[0]) for k, v in zip(key_vals, value_vals)]) 276 277 def testBlockingTakeMany(self): 278 with self.test_session() as sess: 279 b = data_flow_ops.Barrier(dtypes.float32, shapes=()) 280 keys = [str(x).encode("ascii") for x in range(10)] 281 values = [float(x) for x in range(10)] 282 insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)] 283 take_t = b.take_many(10) 284 285 def take(): 286 indices_val, keys_val, values_val = sess.run( 287 [take_t[0], take_t[1], take_t[2][0]]) 288 self.assertAllEqual(indices_val, 289 [int(x.decode("ascii")) - 2**63 for x in keys_val]) 290 self.assertItemsEqual(zip(keys, values), zip(keys_val, values_val)) 291 292 t = self.checkedThread(target=take) 293 t.start() 294 time.sleep(0.1) 295 for insert_op in insert_ops: 296 insert_op.run() 297 t.join() 298 299 def testParallelInsertManyTakeMany(self): 300 with self.test_session() as sess: 301 b = data_flow_ops.Barrier( 302 (dtypes.float32, dtypes.int64), shapes=((), (2,))) 303 num_iterations = 100 304 keys = [str(x) for x in range(10)] 305 values_0 = np.asarray(range(10), dtype=np.float32) 306 values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64) 307 keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys] 308 insert_0_ops = [ 309 b.insert_many(0, keys_i(i), values_0 + i) 310 for i in range(num_iterations) 311 ] 312 insert_1_ops = [ 313 b.insert_many(1, keys_i(i), values_1 + i) 314 for i in range(num_iterations) 315 ] 316 take_ops = [b.take_many(10) for _ in range(num_iterations)] 317 318 def take(sess, i, taken): 319 indices_val, keys_val, values_0_val, values_1_val = sess.run([ 320 take_ops[i][0], take_ops[i][1], take_ops[i][2][0], take_ops[i][2][1] 321 ]) 322 taken.append({ 323 "indices": indices_val, 324 "keys": keys_val, 325 "values_0": values_0_val, 326 "values_1": values_1_val 327 }) 328 329 def insert(sess, i): 330 sess.run([insert_0_ops[i], insert_1_ops[i]]) 331 332 taken = [] 333 334 take_threads = [ 335 self.checkedThread( 336 target=take, args=(sess, i, taken)) for i in range(num_iterations) 337 ] 338 insert_threads = [ 339 self.checkedThread( 340 target=insert, args=(sess, i)) for i in range(num_iterations) 341 ] 342 343 for t in take_threads: 344 t.start() 345 time.sleep(0.1) 346 for t in insert_threads: 347 t.start() 348 for t in take_threads: 349 t.join() 350 for t in insert_threads: 351 t.join() 352 353 self.assertEquals(len(taken), num_iterations) 354 flatten = lambda l: [item for sublist in l for item in sublist] 355 all_indices = sorted(flatten([t_i["indices"] for t_i in taken])) 356 all_keys = sorted(flatten([t_i["keys"] for t_i in taken])) 357 358 expected_keys = sorted( 359 flatten([keys_i(i) for i in range(num_iterations)])) 360 expected_indices = sorted( 361 flatten([-2**63 + j] * 10 for j in range(num_iterations))) 362 363 self.assertAllEqual(all_indices, expected_indices) 364 self.assertAllEqual(all_keys, expected_keys) 365 366 for taken_i in taken: 367 outer_indices_from_keys = np.array( 368 [int(k.decode("ascii").split(":")[0]) for k in taken_i["keys"]]) 369 inner_indices_from_keys = np.array( 370 [int(k.decode("ascii").split(":")[1]) for k in taken_i["keys"]]) 371 self.assertAllEqual(taken_i["values_0"], 372 outer_indices_from_keys + inner_indices_from_keys) 373 expected_values_1 = np.vstack( 374 (1 + outer_indices_from_keys + inner_indices_from_keys, 375 2 + outer_indices_from_keys + inner_indices_from_keys)).T 376 self.assertAllEqual(taken_i["values_1"], expected_values_1) 377 378 def testClose(self): 379 with self.test_session() as sess: 380 b = data_flow_ops.Barrier( 381 (dtypes.float32, dtypes.float32), shapes=((), ()), name="B") 382 size_t = b.ready_size() 383 incomplete_t = b.incomplete_size() 384 keys = [b"a", b"b", b"c"] 385 values_0 = [10.0, 20.0, 30.0] 386 values_1 = [100.0, 200.0, 300.0] 387 insert_0_op = b.insert_many(0, keys, values_0) 388 insert_1_op = b.insert_many(1, keys, values_1) 389 close_op = b.close() 390 fail_insert_op = b.insert_many(0, ["f"], [60.0]) 391 take_t = b.take_many(3) 392 take_too_many_t = b.take_many(4) 393 394 self.assertEquals(size_t.eval(), [0]) 395 self.assertEquals(incomplete_t.eval(), [0]) 396 insert_0_op.run() 397 self.assertEquals(size_t.eval(), [0]) 398 self.assertEquals(incomplete_t.eval(), [3]) 399 close_op.run() 400 401 # This op should fail because the barrier is closed. 402 with self.assertRaisesOpError("is closed"): 403 fail_insert_op.run() 404 405 # This op should succeed because the barrier has not canceled 406 # pending enqueues 407 insert_1_op.run() 408 self.assertEquals(size_t.eval(), [3]) 409 self.assertEquals(incomplete_t.eval(), [0]) 410 411 # This op should fail because the barrier is closed. 412 with self.assertRaisesOpError("is closed"): 413 fail_insert_op.run() 414 415 # This op should fail because we requested more elements than are 416 # available in incomplete + ready queue. 417 with self.assertRaisesOpError(r"is closed and has insufficient elements " 418 r"\(requested 4, total size 3\)"): 419 sess.run(take_too_many_t[0]) # Sufficient to request just the indices 420 421 # This op should succeed because there are still completed elements 422 # to process. 423 indices_val, keys_val, values_0_val, values_1_val = sess.run( 424 [take_t[0], take_t[1], take_t[2][0], take_t[2][1]]) 425 self.assertAllEqual(indices_val, [-2**63] * 3) 426 for k, v0, v1 in zip(keys, values_0, values_1): 427 idx = keys_val.tolist().index(k) 428 self.assertEqual(values_0_val[idx], v0) 429 self.assertEqual(values_1_val[idx], v1) 430 431 # This op should fail because there are no more completed elements and 432 # the queue is closed. 433 with self.assertRaisesOpError("is closed and has insufficient elements"): 434 sess.run(take_t[0]) 435 436 def testCancel(self): 437 with self.test_session() as sess: 438 b = data_flow_ops.Barrier( 439 (dtypes.float32, dtypes.float32), shapes=((), ()), name="B") 440 size_t = b.ready_size() 441 incomplete_t = b.incomplete_size() 442 keys = [b"a", b"b", b"c"] 443 values_0 = [10.0, 20.0, 30.0] 444 values_1 = [100.0, 200.0, 300.0] 445 insert_0_op = b.insert_many(0, keys, values_0) 446 insert_1_op = b.insert_many(1, keys[0:2], values_1[0:2]) 447 insert_2_op = b.insert_many(1, keys[2:], values_1[2:]) 448 cancel_op = b.close(cancel_pending_enqueues=True) 449 fail_insert_op = b.insert_many(0, ["f"], [60.0]) 450 take_t = b.take_many(2) 451 take_too_many_t = b.take_many(3) 452 453 self.assertEquals(size_t.eval(), [0]) 454 insert_0_op.run() 455 insert_1_op.run() 456 self.assertEquals(size_t.eval(), [2]) 457 self.assertEquals(incomplete_t.eval(), [1]) 458 cancel_op.run() 459 460 # This op should fail because the queue is closed. 461 with self.assertRaisesOpError("is closed"): 462 fail_insert_op.run() 463 464 # This op should fail because the queue is canceled. 465 with self.assertRaisesOpError("is closed"): 466 insert_2_op.run() 467 468 # This op should fail because we requested more elements than are 469 # available in incomplete + ready queue. 470 with self.assertRaisesOpError(r"is closed and has insufficient elements " 471 r"\(requested 3, total size 2\)"): 472 sess.run(take_too_many_t[0]) # Sufficient to request just the indices 473 474 # This op should succeed because there are still completed elements 475 # to process. 476 indices_val, keys_val, values_0_val, values_1_val = sess.run( 477 [take_t[0], take_t[1], take_t[2][0], take_t[2][1]]) 478 self.assertAllEqual(indices_val, [-2**63] * 2) 479 for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]): 480 idx = keys_val.tolist().index(k) 481 self.assertEqual(values_0_val[idx], v0) 482 self.assertEqual(values_1_val[idx], v1) 483 484 # This op should fail because there are no more completed elements and 485 # the queue is closed. 486 with self.assertRaisesOpError("is closed and has insufficient elements"): 487 sess.run(take_t[0]) 488 489 def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel): 490 with self.test_session() as sess: 491 b = data_flow_ops.Barrier( 492 (dtypes.float32, dtypes.float32), shapes=((), ()), name="B") 493 take_t = b.take_many(1, allow_small_batch=True) 494 sess.run(b.close(cancel)) 495 with self.assertRaisesOpError("is closed and has insufficient elements"): 496 sess.run(take_t) 497 498 def testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self): 499 self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=False) 500 self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=True) 501 502 def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel): 503 with self.test_session() as sess: 504 b = data_flow_ops.Barrier( 505 (dtypes.float32, dtypes.int64), shapes=((), (2,))) 506 num_iterations = 50 507 keys = [str(x) for x in range(10)] 508 values_0 = np.asarray(range(10), dtype=np.float32) 509 values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64) 510 keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys] 511 insert_0_ops = [ 512 b.insert_many(0, keys_i(i), values_0 + i) 513 for i in range(num_iterations) 514 ] 515 insert_1_ops = [ 516 b.insert_many(1, keys_i(i), values_1 + i) 517 for i in range(num_iterations) 518 ] 519 take_ops = [b.take_many(10) for _ in range(num_iterations)] 520 close_op = b.close(cancel_pending_enqueues=cancel) 521 522 def take(sess, i, taken): 523 try: 524 indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run([ 525 take_ops[i][0], take_ops[i][1], take_ops[i][2][0], 526 take_ops[i][2][1] 527 ]) 528 taken.append(len(indices_val)) 529 except errors_impl.OutOfRangeError: 530 taken.append(0) 531 532 def insert(sess, i): 533 try: 534 sess.run([insert_0_ops[i], insert_1_ops[i]]) 535 except errors_impl.CancelledError: 536 pass 537 538 taken = [] 539 540 take_threads = [ 541 self.checkedThread( 542 target=take, args=(sess, i, taken)) for i in range(num_iterations) 543 ] 544 insert_threads = [ 545 self.checkedThread( 546 target=insert, args=(sess, i)) for i in range(num_iterations) 547 ] 548 549 first_half_insert_threads = insert_threads[:num_iterations // 2] 550 second_half_insert_threads = insert_threads[num_iterations // 2:] 551 552 for t in take_threads: 553 t.start() 554 for t in first_half_insert_threads: 555 t.start() 556 for t in first_half_insert_threads: 557 t.join() 558 559 close_op.run() 560 561 for t in second_half_insert_threads: 562 t.start() 563 for t in take_threads: 564 t.join() 565 for t in second_half_insert_threads: 566 t.join() 567 568 self.assertEqual( 569 sorted(taken), 570 [0] * (num_iterations // 2) + [10] * (num_iterations // 2)) 571 572 def testParallelInsertManyTakeManyCloseHalfwayThrough(self): 573 self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=False) 574 575 def testParallelInsertManyTakeManyCancelHalfwayThrough(self): 576 self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=True) 577 578 def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self, cancel): 579 with self.test_session() as sess: 580 b = data_flow_ops.Barrier( 581 (dtypes.float32, dtypes.int64), shapes=((), (2,))) 582 num_iterations = 100 583 keys = [str(x) for x in range(10)] 584 values_0 = np.asarray(range(10), dtype=np.float32) 585 values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64) 586 keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys] 587 insert_0_ops = [ 588 b.insert_many( 589 0, keys_i(i), values_0 + i, name="insert_0_%d" % i) 590 for i in range(num_iterations) 591 ] 592 593 close_op = b.close(cancel_pending_enqueues=cancel) 594 595 take_ops = [ 596 b.take_many( 597 10, name="take_%d" % i) for i in range(num_iterations) 598 ] 599 # insert_1_ops will only run after closure 600 insert_1_ops = [ 601 b.insert_many( 602 1, keys_i(i), values_1 + i, name="insert_1_%d" % i) 603 for i in range(num_iterations) 604 ] 605 606 def take(sess, i, taken): 607 if cancel: 608 try: 609 indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run( 610 [ 611 take_ops[i][0], take_ops[i][1], take_ops[i][2][0], 612 take_ops[i][2][1] 613 ]) 614 taken.append(len(indices_val)) 615 except errors_impl.OutOfRangeError: 616 taken.append(0) 617 else: 618 indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run([ 619 take_ops[i][0], take_ops[i][1], take_ops[i][2][0], 620 take_ops[i][2][1] 621 ]) 622 taken.append(len(indices_val)) 623 624 def insert_0(sess, i): 625 insert_0_ops[i].run(session=sess) 626 627 def insert_1(sess, i): 628 if cancel: 629 try: 630 insert_1_ops[i].run(session=sess) 631 except errors_impl.CancelledError: 632 pass 633 else: 634 insert_1_ops[i].run(session=sess) 635 636 taken = [] 637 638 take_threads = [ 639 self.checkedThread( 640 target=take, args=(sess, i, taken)) for i in range(num_iterations) 641 ] 642 insert_0_threads = [ 643 self.checkedThread( 644 target=insert_0, args=(sess, i)) for i in range(num_iterations) 645 ] 646 insert_1_threads = [ 647 self.checkedThread( 648 target=insert_1, args=(sess, i)) for i in range(num_iterations) 649 ] 650 651 for t in insert_0_threads: 652 t.start() 653 for t in insert_0_threads: 654 t.join() 655 for t in take_threads: 656 t.start() 657 658 close_op.run() 659 660 for t in insert_1_threads: 661 t.start() 662 for t in take_threads: 663 t.join() 664 for t in insert_1_threads: 665 t.join() 666 667 if cancel: 668 self.assertEqual(taken, [0] * num_iterations) 669 else: 670 self.assertEqual(taken, [10] * num_iterations) 671 672 def testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self): 673 self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=False) 674 675 def testParallelPartialInsertManyTakeManyCancelHalfwayThrough(self): 676 self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=True) 677 678 def testIncompatibleSharedBarrierErrors(self): 679 with self.test_session(): 680 # Do component types and shapes. 681 b_a_1 = data_flow_ops.Barrier( 682 (dtypes.float32,), shapes=(()), shared_name="b_a") 683 b_a_2 = data_flow_ops.Barrier( 684 (dtypes.int32,), shapes=(()), shared_name="b_a") 685 b_a_1.barrier_ref.eval() 686 with self.assertRaisesOpError("component types"): 687 b_a_2.barrier_ref.eval() 688 689 b_b_1 = data_flow_ops.Barrier( 690 (dtypes.float32,), shapes=(()), shared_name="b_b") 691 b_b_2 = data_flow_ops.Barrier( 692 (dtypes.float32, dtypes.int32), shapes=((), ()), shared_name="b_b") 693 b_b_1.barrier_ref.eval() 694 with self.assertRaisesOpError("component types"): 695 b_b_2.barrier_ref.eval() 696 697 b_c_1 = data_flow_ops.Barrier( 698 (dtypes.float32, dtypes.float32), 699 shapes=((2, 2), (8,)), 700 shared_name="b_c") 701 b_c_2 = data_flow_ops.Barrier( 702 (dtypes.float32, dtypes.float32), shared_name="b_c") 703 b_c_1.barrier_ref.eval() 704 with self.assertRaisesOpError("component shapes"): 705 b_c_2.barrier_ref.eval() 706 707 b_d_1 = data_flow_ops.Barrier( 708 (dtypes.float32, dtypes.float32), shapes=((), ()), shared_name="b_d") 709 b_d_2 = data_flow_ops.Barrier( 710 (dtypes.float32, dtypes.float32), 711 shapes=((2, 2), (8,)), 712 shared_name="b_d") 713 b_d_1.barrier_ref.eval() 714 with self.assertRaisesOpError("component shapes"): 715 b_d_2.barrier_ref.eval() 716 717 b_e_1 = data_flow_ops.Barrier( 718 (dtypes.float32, dtypes.float32), 719 shapes=((2, 2), (8,)), 720 shared_name="b_e") 721 b_e_2 = data_flow_ops.Barrier( 722 (dtypes.float32, dtypes.float32), 723 shapes=((2, 5), (8,)), 724 shared_name="b_e") 725 b_e_1.barrier_ref.eval() 726 with self.assertRaisesOpError("component shapes"): 727 b_e_2.barrier_ref.eval() 728 729 730if __name__ == "__main__": 731 test.main() 732