1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for barrier ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import time
22
23import numpy as np
24
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors_impl
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import data_flow_ops
29from tensorflow.python.platform import test
30
31
32class BarrierTest(test.TestCase):
33
34  def testConstructorWithShapes(self):
35    with ops.Graph().as_default():
36      b = data_flow_ops.Barrier(
37          (dtypes.float32, dtypes.float32),
38          shapes=((1, 2, 3), (8,)),
39          shared_name="B",
40          name="B")
41    self.assertTrue(isinstance(b.barrier_ref, ops.Tensor))
42    self.assertProtoEquals("""
43      name:'B' op:'Barrier'
44      attr {
45        key: "capacity"
46        value {
47          i: -1
48        }
49      }
50      attr { key: 'component_types'
51             value { list { type: DT_FLOAT type: DT_FLOAT } } }
52      attr {
53        key: 'shapes'
54        value {
55          list {
56            shape {
57              dim { size: 1 } dim { size: 2 } dim { size: 3 }
58            }
59            shape {
60              dim { size: 8 }
61            }
62          }
63        }
64      }
65      attr { key: 'container' value { s: "" } }
66      attr { key: 'shared_name' value: { s: 'B' } }
67      """, b.barrier_ref.op.node_def)
68
69  def testInsertMany(self):
70    with self.test_session():
71      b = data_flow_ops.Barrier(
72          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
73      size_t = b.ready_size()
74      self.assertEqual([], size_t.get_shape())
75      keys = [b"a", b"b", b"c"]
76      insert_0_op = b.insert_many(0, keys, [10.0, 20.0, 30.0])
77      insert_1_op = b.insert_many(1, keys, [100.0, 200.0, 300.0])
78
79      self.assertEquals(size_t.eval(), [0])
80      insert_0_op.run()
81      self.assertEquals(size_t.eval(), [0])
82      insert_1_op.run()
83      self.assertEquals(size_t.eval(), [3])
84
85  def testInsertManyEmptyTensor(self):
86    with self.test_session():
87      error_message = ("Empty tensors are not supported, but received shape "
88                       r"\'\(0,\)\' at index 1")
89      with self.assertRaisesRegexp(ValueError, error_message):
90        data_flow_ops.Barrier(
91            (dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B")
92
93  def testInsertManyEmptyTensorUnknown(self):
94    with self.test_session():
95      b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B")
96      size_t = b.ready_size()
97      self.assertEqual([], size_t.get_shape())
98      keys = [b"a", b"b", b"c"]
99      insert_0_op = b.insert_many(0, keys, np.array([[], [], []], np.float32))
100      self.assertEquals(size_t.eval(), [0])
101      with self.assertRaisesOpError(
102          ".*Tensors with no elements are not supported.*"):
103        insert_0_op.run()
104
105  def testTakeMany(self):
106    with self.test_session() as sess:
107      b = data_flow_ops.Barrier(
108          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
109      size_t = b.ready_size()
110      keys = [b"a", b"b", b"c"]
111      values_0 = [10.0, 20.0, 30.0]
112      values_1 = [100.0, 200.0, 300.0]
113      insert_0_op = b.insert_many(0, keys, values_0)
114      insert_1_op = b.insert_many(1, keys, values_1)
115      take_t = b.take_many(3)
116
117      insert_0_op.run()
118      insert_1_op.run()
119      self.assertEquals(size_t.eval(), [3])
120
121      indices_val, keys_val, values_0_val, values_1_val = sess.run(
122          [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
123
124    self.assertAllEqual(indices_val, [-2**63] * 3)
125    for k, v0, v1 in zip(keys, values_0, values_1):
126      idx = keys_val.tolist().index(k)
127      self.assertEqual(values_0_val[idx], v0)
128      self.assertEqual(values_1_val[idx], v1)
129
130  def testTakeManySmallBatch(self):
131    with self.test_session() as sess:
132      b = data_flow_ops.Barrier(
133          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
134      size_t = b.ready_size()
135      size_i = b.incomplete_size()
136      keys = [b"a", b"b", b"c", b"d"]
137      values_0 = [10.0, 20.0, 30.0, 40.0]
138      values_1 = [100.0, 200.0, 300.0, 400.0]
139      insert_0_op = b.insert_many(0, keys, values_0)
140      # Split adding of the second component into two independent operations.
141      # After insert_1_1_op, we'll have two ready elements in the barrier,
142      # 2 will still be incomplete.
143      insert_1_1_op = b.insert_many(1, keys[0:2], values_1[0:2])  # add "a", "b"
144      insert_1_2_op = b.insert_many(1, keys[2:3], values_1[2:3])  # add "c"
145      insert_1_3_op = b.insert_many(1, keys[3:], values_1[3:])  # add "d"
146      insert_empty_op = b.insert_many(0, [], [])
147      close_op = b.close()
148      close_op_final = b.close(cancel_pending_enqueues=True)
149      index_t, key_t, value_list_t = b.take_many(3, allow_small_batch=True)
150      insert_0_op.run()
151      insert_1_1_op.run()
152      close_op.run()
153      # Now we have a closed barrier with 2 ready elements. Running take_t
154      # should return a reduced batch with 2 elements only.
155      self.assertEquals(size_i.eval(), [2])  # assert that incomplete size = 2
156      self.assertEquals(size_t.eval(), [2])  # assert that ready size = 2
157      _, keys_val, values_0_val, values_1_val = sess.run(
158          [index_t, key_t, value_list_t[0], value_list_t[1]])
159      # Check that correct values have been returned.
160      for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]):
161        idx = keys_val.tolist().index(k)
162        self.assertEqual(values_0_val[idx], v0)
163        self.assertEqual(values_1_val[idx], v1)
164
165      # The next insert completes the element with key "c". The next take_t
166      # should return a batch with just 1 element.
167      insert_1_2_op.run()
168      self.assertEquals(size_i.eval(), [1])  # assert that incomplete size = 1
169      self.assertEquals(size_t.eval(), [1])  # assert that ready size = 1
170      _, keys_val, values_0_val, values_1_val = sess.run(
171          [index_t, key_t, value_list_t[0], value_list_t[1]])
172      # Check that correct values have been returned.
173      for k, v0, v1 in zip(keys[2:3], values_0[2:3], values_1[2:3]):
174        idx = keys_val.tolist().index(k)
175        self.assertEqual(values_0_val[idx], v0)
176        self.assertEqual(values_1_val[idx], v1)
177
178      # Adding nothing ought to work, even if the barrier is closed.
179      insert_empty_op.run()
180
181      # currently keys "a" and "b" are not in the barrier, adding them
182      # again after it has been closed, ought to cause failure.
183      with self.assertRaisesOpError("is closed"):
184        insert_1_1_op.run()
185      close_op_final.run()
186
187      # These ops should fail because the barrier has now been closed with
188      # cancel_pending_enqueues = True.
189      with self.assertRaisesOpError("is closed"):
190        insert_empty_op.run()
191      with self.assertRaisesOpError("is closed"):
192        insert_1_3_op.run()
193
194  def testUseBarrierWithShape(self):
195    with self.test_session() as sess:
196      b = data_flow_ops.Barrier(
197          (dtypes.float32, dtypes.float32), shapes=((2, 2), (8,)), name="B")
198      size_t = b.ready_size()
199      keys = [b"a", b"b", b"c"]
200      values_0 = np.array(
201          [[[10.0] * 2] * 2, [[20.0] * 2] * 2, [[30.0] * 2] * 2], np.float32)
202      values_1 = np.array([[100.0] * 8, [200.0] * 8, [300.0] * 8], np.float32)
203      insert_0_op = b.insert_many(0, keys, values_0)
204      insert_1_op = b.insert_many(1, keys, values_1)
205      take_t = b.take_many(3)
206
207      insert_0_op.run()
208      insert_1_op.run()
209      self.assertEquals(size_t.eval(), [3])
210
211      indices_val, keys_val, values_0_val, values_1_val = sess.run(
212          [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
213      self.assertAllEqual(indices_val, [-2**63] * 3)
214      self.assertShapeEqual(keys_val, take_t[1])
215      self.assertShapeEqual(values_0_val, take_t[2][0])
216      self.assertShapeEqual(values_1_val, take_t[2][1])
217
218    for k, v0, v1 in zip(keys, values_0, values_1):
219      idx = keys_val.tolist().index(k)
220      self.assertAllEqual(values_0_val[idx], v0)
221      self.assertAllEqual(values_1_val[idx], v1)
222
223  def testParallelInsertMany(self):
224    with self.test_session() as sess:
225      b = data_flow_ops.Barrier(dtypes.float32, shapes=())
226      size_t = b.ready_size()
227      keys = [str(x).encode("ascii") for x in range(10)]
228      values = [float(x) for x in range(10)]
229      insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)]
230      take_t = b.take_many(10)
231
232      sess.run(insert_ops)
233      self.assertEquals(size_t.eval(), [10])
234
235      indices_val, keys_val, values_val = sess.run(
236          [take_t[0], take_t[1], take_t[2][0]])
237
238    self.assertAllEqual(indices_val, [-2**63 + x for x in range(10)])
239    for k, v in zip(keys, values):
240      idx = keys_val.tolist().index(k)
241      self.assertEqual(values_val[idx], v)
242
243  def testParallelTakeMany(self):
244    with self.test_session() as sess:
245      b = data_flow_ops.Barrier(dtypes.float32, shapes=())
246      size_t = b.ready_size()
247      keys = [str(x).encode("ascii") for x in range(10)]
248      values = [float(x) for x in range(10)]
249      insert_op = b.insert_many(0, keys, values)
250      take_t = [b.take_many(1) for _ in keys]
251
252      insert_op.run()
253      self.assertEquals(size_t.eval(), [10])
254
255      index_fetches = []
256      key_fetches = []
257      value_fetches = []
258      for ix_t, k_t, v_t in take_t:
259        index_fetches.append(ix_t)
260        key_fetches.append(k_t)
261        value_fetches.append(v_t[0])
262      vals = sess.run(index_fetches + key_fetches + value_fetches)
263
264    index_vals = vals[:len(keys)]
265    key_vals = vals[len(keys):2 * len(keys)]
266    value_vals = vals[2 * len(keys):]
267
268    taken_elems = []
269    for k, v in zip(key_vals, value_vals):
270      taken_elems.append((k[0], v[0]))
271
272    self.assertAllEqual(np.hstack(index_vals), [-2**63] * 10)
273
274    self.assertItemsEqual(
275        zip(keys, values), [(k[0], v[0]) for k, v in zip(key_vals, value_vals)])
276
277  def testBlockingTakeMany(self):
278    with self.test_session() as sess:
279      b = data_flow_ops.Barrier(dtypes.float32, shapes=())
280      keys = [str(x).encode("ascii") for x in range(10)]
281      values = [float(x) for x in range(10)]
282      insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)]
283      take_t = b.take_many(10)
284
285      def take():
286        indices_val, keys_val, values_val = sess.run(
287            [take_t[0], take_t[1], take_t[2][0]])
288        self.assertAllEqual(indices_val,
289                            [int(x.decode("ascii")) - 2**63 for x in keys_val])
290        self.assertItemsEqual(zip(keys, values), zip(keys_val, values_val))
291
292      t = self.checkedThread(target=take)
293      t.start()
294      time.sleep(0.1)
295      for insert_op in insert_ops:
296        insert_op.run()
297      t.join()
298
299  def testParallelInsertManyTakeMany(self):
300    with self.test_session() as sess:
301      b = data_flow_ops.Barrier(
302          (dtypes.float32, dtypes.int64), shapes=((), (2,)))
303      num_iterations = 100
304      keys = [str(x) for x in range(10)]
305      values_0 = np.asarray(range(10), dtype=np.float32)
306      values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64)
307      keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys]
308      insert_0_ops = [
309          b.insert_many(0, keys_i(i), values_0 + i)
310          for i in range(num_iterations)
311      ]
312      insert_1_ops = [
313          b.insert_many(1, keys_i(i), values_1 + i)
314          for i in range(num_iterations)
315      ]
316      take_ops = [b.take_many(10) for _ in range(num_iterations)]
317
318      def take(sess, i, taken):
319        indices_val, keys_val, values_0_val, values_1_val = sess.run([
320            take_ops[i][0], take_ops[i][1], take_ops[i][2][0], take_ops[i][2][1]
321        ])
322        taken.append({
323            "indices": indices_val,
324            "keys": keys_val,
325            "values_0": values_0_val,
326            "values_1": values_1_val
327        })
328
329      def insert(sess, i):
330        sess.run([insert_0_ops[i], insert_1_ops[i]])
331
332      taken = []
333
334      take_threads = [
335          self.checkedThread(
336              target=take, args=(sess, i, taken)) for i in range(num_iterations)
337      ]
338      insert_threads = [
339          self.checkedThread(
340              target=insert, args=(sess, i)) for i in range(num_iterations)
341      ]
342
343      for t in take_threads:
344        t.start()
345      time.sleep(0.1)
346      for t in insert_threads:
347        t.start()
348      for t in take_threads:
349        t.join()
350      for t in insert_threads:
351        t.join()
352
353      self.assertEquals(len(taken), num_iterations)
354      flatten = lambda l: [item for sublist in l for item in sublist]
355      all_indices = sorted(flatten([t_i["indices"] for t_i in taken]))
356      all_keys = sorted(flatten([t_i["keys"] for t_i in taken]))
357
358      expected_keys = sorted(
359          flatten([keys_i(i) for i in range(num_iterations)]))
360      expected_indices = sorted(
361          flatten([-2**63 + j] * 10 for j in range(num_iterations)))
362
363      self.assertAllEqual(all_indices, expected_indices)
364      self.assertAllEqual(all_keys, expected_keys)
365
366      for taken_i in taken:
367        outer_indices_from_keys = np.array(
368            [int(k.decode("ascii").split(":")[0]) for k in taken_i["keys"]])
369        inner_indices_from_keys = np.array(
370            [int(k.decode("ascii").split(":")[1]) for k in taken_i["keys"]])
371        self.assertAllEqual(taken_i["values_0"],
372                            outer_indices_from_keys + inner_indices_from_keys)
373        expected_values_1 = np.vstack(
374            (1 + outer_indices_from_keys + inner_indices_from_keys,
375             2 + outer_indices_from_keys + inner_indices_from_keys)).T
376        self.assertAllEqual(taken_i["values_1"], expected_values_1)
377
378  def testClose(self):
379    with self.test_session() as sess:
380      b = data_flow_ops.Barrier(
381          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
382      size_t = b.ready_size()
383      incomplete_t = b.incomplete_size()
384      keys = [b"a", b"b", b"c"]
385      values_0 = [10.0, 20.0, 30.0]
386      values_1 = [100.0, 200.0, 300.0]
387      insert_0_op = b.insert_many(0, keys, values_0)
388      insert_1_op = b.insert_many(1, keys, values_1)
389      close_op = b.close()
390      fail_insert_op = b.insert_many(0, ["f"], [60.0])
391      take_t = b.take_many(3)
392      take_too_many_t = b.take_many(4)
393
394      self.assertEquals(size_t.eval(), [0])
395      self.assertEquals(incomplete_t.eval(), [0])
396      insert_0_op.run()
397      self.assertEquals(size_t.eval(), [0])
398      self.assertEquals(incomplete_t.eval(), [3])
399      close_op.run()
400
401      # This op should fail because the barrier is closed.
402      with self.assertRaisesOpError("is closed"):
403        fail_insert_op.run()
404
405      # This op should succeed because the barrier has not canceled
406      # pending enqueues
407      insert_1_op.run()
408      self.assertEquals(size_t.eval(), [3])
409      self.assertEquals(incomplete_t.eval(), [0])
410
411      # This op should fail because the barrier is closed.
412      with self.assertRaisesOpError("is closed"):
413        fail_insert_op.run()
414
415      # This op should fail because we requested more elements than are
416      # available in incomplete + ready queue.
417      with self.assertRaisesOpError(r"is closed and has insufficient elements "
418                                    r"\(requested 4, total size 3\)"):
419        sess.run(take_too_many_t[0])  # Sufficient to request just the indices
420
421      # This op should succeed because there are still completed elements
422      # to process.
423      indices_val, keys_val, values_0_val, values_1_val = sess.run(
424          [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
425      self.assertAllEqual(indices_val, [-2**63] * 3)
426      for k, v0, v1 in zip(keys, values_0, values_1):
427        idx = keys_val.tolist().index(k)
428        self.assertEqual(values_0_val[idx], v0)
429        self.assertEqual(values_1_val[idx], v1)
430
431      # This op should fail because there are no more completed elements and
432      # the queue is closed.
433      with self.assertRaisesOpError("is closed and has insufficient elements"):
434        sess.run(take_t[0])
435
436  def testCancel(self):
437    with self.test_session() as sess:
438      b = data_flow_ops.Barrier(
439          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
440      size_t = b.ready_size()
441      incomplete_t = b.incomplete_size()
442      keys = [b"a", b"b", b"c"]
443      values_0 = [10.0, 20.0, 30.0]
444      values_1 = [100.0, 200.0, 300.0]
445      insert_0_op = b.insert_many(0, keys, values_0)
446      insert_1_op = b.insert_many(1, keys[0:2], values_1[0:2])
447      insert_2_op = b.insert_many(1, keys[2:], values_1[2:])
448      cancel_op = b.close(cancel_pending_enqueues=True)
449      fail_insert_op = b.insert_many(0, ["f"], [60.0])
450      take_t = b.take_many(2)
451      take_too_many_t = b.take_many(3)
452
453      self.assertEquals(size_t.eval(), [0])
454      insert_0_op.run()
455      insert_1_op.run()
456      self.assertEquals(size_t.eval(), [2])
457      self.assertEquals(incomplete_t.eval(), [1])
458      cancel_op.run()
459
460      # This op should fail because the queue is closed.
461      with self.assertRaisesOpError("is closed"):
462        fail_insert_op.run()
463
464      # This op should fail because the queue is canceled.
465      with self.assertRaisesOpError("is closed"):
466        insert_2_op.run()
467
468      # This op should fail because we requested more elements than are
469      # available in incomplete + ready queue.
470      with self.assertRaisesOpError(r"is closed and has insufficient elements "
471                                    r"\(requested 3, total size 2\)"):
472        sess.run(take_too_many_t[0])  # Sufficient to request just the indices
473
474      # This op should succeed because there are still completed elements
475      # to process.
476      indices_val, keys_val, values_0_val, values_1_val = sess.run(
477          [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
478      self.assertAllEqual(indices_val, [-2**63] * 2)
479      for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]):
480        idx = keys_val.tolist().index(k)
481        self.assertEqual(values_0_val[idx], v0)
482        self.assertEqual(values_1_val[idx], v1)
483
484      # This op should fail because there are no more completed elements and
485      # the queue is closed.
486      with self.assertRaisesOpError("is closed and has insufficient elements"):
487        sess.run(take_t[0])
488
489  def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel):
490    with self.test_session() as sess:
491      b = data_flow_ops.Barrier(
492          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
493      take_t = b.take_many(1, allow_small_batch=True)
494      sess.run(b.close(cancel))
495      with self.assertRaisesOpError("is closed and has insufficient elements"):
496        sess.run(take_t)
497
498  def testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self):
499    self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=False)
500    self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=True)
501
502  def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel):
503    with self.test_session() as sess:
504      b = data_flow_ops.Barrier(
505          (dtypes.float32, dtypes.int64), shapes=((), (2,)))
506      num_iterations = 50
507      keys = [str(x) for x in range(10)]
508      values_0 = np.asarray(range(10), dtype=np.float32)
509      values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64)
510      keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys]
511      insert_0_ops = [
512          b.insert_many(0, keys_i(i), values_0 + i)
513          for i in range(num_iterations)
514      ]
515      insert_1_ops = [
516          b.insert_many(1, keys_i(i), values_1 + i)
517          for i in range(num_iterations)
518      ]
519      take_ops = [b.take_many(10) for _ in range(num_iterations)]
520      close_op = b.close(cancel_pending_enqueues=cancel)
521
522      def take(sess, i, taken):
523        try:
524          indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run([
525              take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
526              take_ops[i][2][1]
527          ])
528          taken.append(len(indices_val))
529        except errors_impl.OutOfRangeError:
530          taken.append(0)
531
532      def insert(sess, i):
533        try:
534          sess.run([insert_0_ops[i], insert_1_ops[i]])
535        except errors_impl.CancelledError:
536          pass
537
538      taken = []
539
540      take_threads = [
541          self.checkedThread(
542              target=take, args=(sess, i, taken)) for i in range(num_iterations)
543      ]
544      insert_threads = [
545          self.checkedThread(
546              target=insert, args=(sess, i)) for i in range(num_iterations)
547      ]
548
549      first_half_insert_threads = insert_threads[:num_iterations // 2]
550      second_half_insert_threads = insert_threads[num_iterations // 2:]
551
552      for t in take_threads:
553        t.start()
554      for t in first_half_insert_threads:
555        t.start()
556      for t in first_half_insert_threads:
557        t.join()
558
559      close_op.run()
560
561      for t in second_half_insert_threads:
562        t.start()
563      for t in take_threads:
564        t.join()
565      for t in second_half_insert_threads:
566        t.join()
567
568      self.assertEqual(
569          sorted(taken),
570          [0] * (num_iterations // 2) + [10] * (num_iterations // 2))
571
572  def testParallelInsertManyTakeManyCloseHalfwayThrough(self):
573    self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=False)
574
575  def testParallelInsertManyTakeManyCancelHalfwayThrough(self):
576    self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=True)
577
578  def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self, cancel):
579    with self.test_session() as sess:
580      b = data_flow_ops.Barrier(
581          (dtypes.float32, dtypes.int64), shapes=((), (2,)))
582      num_iterations = 100
583      keys = [str(x) for x in range(10)]
584      values_0 = np.asarray(range(10), dtype=np.float32)
585      values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64)
586      keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys]
587      insert_0_ops = [
588          b.insert_many(
589              0, keys_i(i), values_0 + i, name="insert_0_%d" % i)
590          for i in range(num_iterations)
591      ]
592
593      close_op = b.close(cancel_pending_enqueues=cancel)
594
595      take_ops = [
596          b.take_many(
597              10, name="take_%d" % i) for i in range(num_iterations)
598      ]
599      # insert_1_ops will only run after closure
600      insert_1_ops = [
601          b.insert_many(
602              1, keys_i(i), values_1 + i, name="insert_1_%d" % i)
603          for i in range(num_iterations)
604      ]
605
606      def take(sess, i, taken):
607        if cancel:
608          try:
609            indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run(
610                [
611                    take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
612                    take_ops[i][2][1]
613                ])
614            taken.append(len(indices_val))
615          except errors_impl.OutOfRangeError:
616            taken.append(0)
617        else:
618          indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run([
619              take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
620              take_ops[i][2][1]
621          ])
622          taken.append(len(indices_val))
623
624      def insert_0(sess, i):
625        insert_0_ops[i].run(session=sess)
626
627      def insert_1(sess, i):
628        if cancel:
629          try:
630            insert_1_ops[i].run(session=sess)
631          except errors_impl.CancelledError:
632            pass
633        else:
634          insert_1_ops[i].run(session=sess)
635
636      taken = []
637
638      take_threads = [
639          self.checkedThread(
640              target=take, args=(sess, i, taken)) for i in range(num_iterations)
641      ]
642      insert_0_threads = [
643          self.checkedThread(
644              target=insert_0, args=(sess, i)) for i in range(num_iterations)
645      ]
646      insert_1_threads = [
647          self.checkedThread(
648              target=insert_1, args=(sess, i)) for i in range(num_iterations)
649      ]
650
651      for t in insert_0_threads:
652        t.start()
653      for t in insert_0_threads:
654        t.join()
655      for t in take_threads:
656        t.start()
657
658      close_op.run()
659
660      for t in insert_1_threads:
661        t.start()
662      for t in take_threads:
663        t.join()
664      for t in insert_1_threads:
665        t.join()
666
667      if cancel:
668        self.assertEqual(taken, [0] * num_iterations)
669      else:
670        self.assertEqual(taken, [10] * num_iterations)
671
672  def testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self):
673    self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=False)
674
675  def testParallelPartialInsertManyTakeManyCancelHalfwayThrough(self):
676    self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=True)
677
678  def testIncompatibleSharedBarrierErrors(self):
679    with self.test_session():
680      # Do component types and shapes.
681      b_a_1 = data_flow_ops.Barrier(
682          (dtypes.float32,), shapes=(()), shared_name="b_a")
683      b_a_2 = data_flow_ops.Barrier(
684          (dtypes.int32,), shapes=(()), shared_name="b_a")
685      b_a_1.barrier_ref.eval()
686      with self.assertRaisesOpError("component types"):
687        b_a_2.barrier_ref.eval()
688
689      b_b_1 = data_flow_ops.Barrier(
690          (dtypes.float32,), shapes=(()), shared_name="b_b")
691      b_b_2 = data_flow_ops.Barrier(
692          (dtypes.float32, dtypes.int32), shapes=((), ()), shared_name="b_b")
693      b_b_1.barrier_ref.eval()
694      with self.assertRaisesOpError("component types"):
695        b_b_2.barrier_ref.eval()
696
697      b_c_1 = data_flow_ops.Barrier(
698          (dtypes.float32, dtypes.float32),
699          shapes=((2, 2), (8,)),
700          shared_name="b_c")
701      b_c_2 = data_flow_ops.Barrier(
702          (dtypes.float32, dtypes.float32), shared_name="b_c")
703      b_c_1.barrier_ref.eval()
704      with self.assertRaisesOpError("component shapes"):
705        b_c_2.barrier_ref.eval()
706
707      b_d_1 = data_flow_ops.Barrier(
708          (dtypes.float32, dtypes.float32), shapes=((), ()), shared_name="b_d")
709      b_d_2 = data_flow_ops.Barrier(
710          (dtypes.float32, dtypes.float32),
711          shapes=((2, 2), (8,)),
712          shared_name="b_d")
713      b_d_1.barrier_ref.eval()
714      with self.assertRaisesOpError("component shapes"):
715        b_d_2.barrier_ref.eval()
716
717      b_e_1 = data_flow_ops.Barrier(
718          (dtypes.float32, dtypes.float32),
719          shapes=((2, 2), (8,)),
720          shared_name="b_e")
721      b_e_2 = data_flow_ops.Barrier(
722          (dtypes.float32, dtypes.float32),
723          shapes=((2, 5), (8,)),
724          shared_name="b_e")
725      b_e_1.barrier_ref.eval()
726      with self.assertRaisesOpError("component shapes"):
727        b_e_2.barrier_ref.eval()
728
729
730if __name__ == "__main__":
731  test.main()
732