192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa#
392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# Licensed under the Apache License, Version 2.0 (the "License");
492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# you may not use this file except in compliance with the License.
592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# You may obtain a copy of the License at
692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa#
792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa#     http://www.apache.org/licenses/LICENSE-2.0
892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa#
992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# Unless required by applicable law or agreed to in writing, software
1092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# distributed under the License is distributed on an "AS IS" BASIS,
1192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# See the License for the specific language governing permissions and
1392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# limitations under the License.
1492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# ==============================================================================
1592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa"""Tests for the experimental input pipeline ops."""
1692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom __future__ import absolute_import
1792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom __future__ import division
1892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom __future__ import print_function
1992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
2092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsaimport itertools
2192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
2292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.data.ops import dataset_ops
2392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.framework import dtypes
2492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.framework import errors
2592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.framework import sparse_tensor
2692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.ops import array_ops
2792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.ops import sparse_ops
2892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.platform import test
2992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
3092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
3192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsaclass InterleaveDatasetTest(test.TestCase):
3292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
3392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa  def _interleave(self, lists, cycle_length, block_length):
3492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    num_open = 0
3592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
3692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # `all_iterators` acts as a queue of iterators over each element of `lists`.
3792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    all_iterators = [iter(l) for l in lists]
3892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
3992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # `open_iterators` are the iterators whose elements are currently being
4092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # interleaved.
4192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    open_iterators = []
4292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    for i in range(cycle_length):
4392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      if all_iterators:
4492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        open_iterators.append(all_iterators.pop(0))
4592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        num_open += 1
4692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      else:
4792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        open_iterators.append(None)
4892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
4992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    while num_open or all_iterators:
5092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for i in range(cycle_length):
5192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        if open_iterators[i] is None:
5292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          if all_iterators:
5392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            open_iterators[i] = all_iterators.pop(0)
5492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            num_open += 1
5592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          else:
5692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            continue
5792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        for _ in range(block_length):
5892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          try:
5992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            yield next(open_iterators[i])
6092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          except StopIteration:
6192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            open_iterators[i] = None
6292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            num_open -= 1
6392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            break
6492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
6592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa  def testPythonImplementation(self):
6692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
6792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                   [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
6892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
6992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # Cycle length 1 acts like `Dataset.flat_map()`.
7092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    expected_elements = itertools.chain(*input_lists)
7192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    for expected, produced in zip(
7292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        expected_elements, self._interleave(input_lists, 1, 1)):
7392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      self.assertEqual(expected, produced)
7492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
7592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # Cycle length > 1.
7692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    expected_elements = [4, 5, 4, 5, 4, 5, 4,
7792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                         5, 5, 6, 6,  # NOTE(mrry): When we cycle back
7892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                      # to a list and are already at
7992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                      # the end of that list, we move
8092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                      # on to the next element.
8192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                         4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5]
8292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    for expected, produced in zip(
8392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        expected_elements, self._interleave(input_lists, 2, 1)):
8492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      self.assertEqual(expected, produced)
8592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
8692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # Cycle length > 1 and block length > 1.
8792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6,
8892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                         4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
8992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    for expected, produced in zip(
9092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        expected_elements, self._interleave(input_lists, 2, 3)):
9192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      self.assertEqual(expected, produced)
9292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
9392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    # Cycle length > len(input_values).
9492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6,
9592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                         4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
9692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    for expected, produced in zip(
9792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        expected_elements, self._interleave(input_lists, 7, 2)):
9892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      self.assertEqual(expected, produced)
9992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
10092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa  def testInterleaveDataset(self):
10192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    input_values = array_ops.placeholder(dtypes.int64, shape=[None])
10292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
10392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    block_length = array_ops.placeholder(dtypes.int64, shape=[])
10492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
10592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    repeat_count = 2
10692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
10792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    dataset = (
10892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        dataset_ops.Dataset.from_tensor_slices(input_values)
10992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        .repeat(repeat_count)
11092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
11192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                    cycle_length, block_length))
11292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    iterator = dataset.make_initializable_iterator()
11392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    init_op = iterator.initializer
11492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    next_element = iterator.get_next()
11592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
11692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    with self.test_session() as sess:
11792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Cycle length 1 acts like `Dataset.flat_map()`.
11892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
11992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 1, block_length: 3})
12092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
12192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for expected_element in self._interleave(
12292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
12392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        self.assertEqual(expected_element, sess.run(next_element))
12492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
12592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Cycle length > 1.
12692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
12792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      #            6, 5, 6, 5, 6, 5, 6, 5]
12892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
12992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 2, block_length: 1})
13092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for expected_element in self._interleave(
13192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
13292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        self.assertEqual(expected_element, sess.run(next_element))
13392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
13492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(next_element)
13592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
13692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Cycle length > 1 and block length > 1.
13792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
13892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      #            5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
13992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
14092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 2, block_length: 3})
14192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for expected_element in self._interleave(
14292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
14392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        self.assertEqual(expected_element, sess.run(next_element))
14492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
14592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(next_element)
14692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
14792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Cycle length > len(input_values) * repeat_count.
14892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
14992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      #            5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
15092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
15192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 7, block_length: 2})
15292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for expected_element in self._interleave(
15392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
15492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        self.assertEqual(expected_element, sess.run(next_element))
15592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
15692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(next_element)
15792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
15892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Empty input.
15992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [],
16092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 2, block_length: 3})
16192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
16292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(next_element)
16392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
16492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Non-empty input leading to empty output.
16592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [0, 0, 0],
16692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 2, block_length: 3})
16792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
16892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(next_element)
16992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
17092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      # Mixture of non-empty and empty interleaved datasets.
17192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op, feed_dict={input_values: [4, 0, 6],
17292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa                                   cycle_length: 2, block_length: 3})
17392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for expected_element in self._interleave(
17492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
17592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        self.assertEqual(expected_element, sess.run(next_element))
17692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
17792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(next_element)
17892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
17992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa  def testSparse(self):
18082fa1e1ae5b2f8af642979fafb1cab455db1882fJiri Simsa
18192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    def _map_fn(i):
18282fa1e1ae5b2f8af642979fafb1cab455db1882fJiri Simsa      return sparse_tensor.SparseTensorValue(
18392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
18492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
18592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    def _interleave_fn(x):
18692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      return dataset_ops.Dataset.from_tensor_slices(
18792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
18892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
18992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    iterator = (
19092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        dataset_ops.Dataset.range(10).map(_map_fn).interleave(
19192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa            _interleave_fn, cycle_length=1).make_initializable_iterator())
19292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    init_op = iterator.initializer
19392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    get_next = iterator.get_next()
19492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
19592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa    with self.test_session() as sess:
19692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      sess.run(init_op)
19792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      for i in range(10):
19892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        for j in range(2):
19992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          expected = [i, 0] if j % 2 == 0 else [0, -i]
20092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa          self.assertAllEqual(expected, sess.run(get_next))
20192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa      with self.assertRaises(errors.OutOfRangeError):
20292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa        sess.run(get_next)
20392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
204a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray  def testEmptyInput(self):
205a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray    iterator = (
206a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray        dataset_ops.Dataset.from_tensor_slices([])
207a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray        .repeat(None)
208a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray        .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
209a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray        .make_initializable_iterator())
210a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray    init_op = iterator.initializer
211a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray    get_next = iterator.get_next()
212a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray
213a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray    with self.test_session() as sess:
214a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray      sess.run(init_op)
215a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray      with self.assertRaises(errors.OutOfRangeError):
216a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray        sess.run(get_next)
217a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray
21892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa
21992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsaif __name__ == "__main__":
22092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa  test.main()
221