192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# 392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# Licensed under the Apache License, Version 2.0 (the "License"); 492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# you may not use this file except in compliance with the License. 592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# You may obtain a copy of the License at 692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# 792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# http://www.apache.org/licenses/LICENSE-2.0 892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# 992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# Unless required by applicable law or agreed to in writing, software 1092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# distributed under the License is distributed on an "AS IS" BASIS, 1192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# See the License for the specific language governing permissions and 1392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# limitations under the License. 1492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa# ============================================================================== 1592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa"""Tests for the experimental input pipeline ops.""" 1692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom __future__ import absolute_import 1792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom __future__ import division 1892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom __future__ import print_function 1992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 2092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsaimport itertools 2192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 2292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.data.ops import dataset_ops 2392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.framework import dtypes 2492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.framework import errors 2592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.framework import sparse_tensor 2692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.ops import array_ops 2792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.ops import sparse_ops 2892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsafrom tensorflow.python.platform import test 2992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 3092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 3192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsaclass InterleaveDatasetTest(test.TestCase): 3292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 3392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa def _interleave(self, lists, cycle_length, block_length): 3492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa num_open = 0 3592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 3692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # `all_iterators` acts as a queue of iterators over each element of `lists`. 3792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa all_iterators = [iter(l) for l in lists] 3892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 3992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # `open_iterators` are the iterators whose elements are currently being 4092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # interleaved. 4192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa open_iterators = [] 4292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for i in range(cycle_length): 4392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa if all_iterators: 4492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa open_iterators.append(all_iterators.pop(0)) 4592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa num_open += 1 4692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa else: 4792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa open_iterators.append(None) 4892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 4992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa while num_open or all_iterators: 5092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for i in range(cycle_length): 5192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa if open_iterators[i] is None: 5292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa if all_iterators: 5392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa open_iterators[i] = all_iterators.pop(0) 5492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa num_open += 1 5592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa else: 5692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa continue 5792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for _ in range(block_length): 5892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa try: 5992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa yield next(open_iterators[i]) 6092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa except StopIteration: 6192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa open_iterators[i] = None 6292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa num_open -= 1 6392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa break 6492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 6592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa def testPythonImplementation(self): 6692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], 6792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] 6892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 6992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length 1 acts like `Dataset.flat_map()`. 7092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements = itertools.chain(*input_lists) 7192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected, produced in zip( 7292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements, self._interleave(input_lists, 1, 1)): 7392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected, produced) 7492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 7592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length > 1. 7692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements = [4, 5, 4, 5, 4, 5, 4, 7792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 5, 5, 6, 6, # NOTE(mrry): When we cycle back 7892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # to a list and are already at 7992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # the end of that list, we move 8092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # on to the next element. 8192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] 8292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected, produced in zip( 8392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements, self._interleave(input_lists, 2, 1)): 8492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected, produced) 8592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 8692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length > 1 and block length > 1. 8792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 8892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] 8992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected, produced in zip( 9092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements, self._interleave(input_lists, 2, 3)): 9192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected, produced) 9292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 9392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length > len(input_values). 9492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 9592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] 9692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected, produced in zip( 9792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected_elements, self._interleave(input_lists, 7, 2)): 9892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected, produced) 9992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 10092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa def testInterleaveDataset(self): 10192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa input_values = array_ops.placeholder(dtypes.int64, shape=[None]) 10292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) 10392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa block_length = array_ops.placeholder(dtypes.int64, shape=[]) 10492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 10592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa repeat_count = 2 10692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 10792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa dataset = ( 10892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa dataset_ops.Dataset.from_tensor_slices(input_values) 10992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa .repeat(repeat_count) 11092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), 11192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length, block_length)) 11292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa iterator = dataset.make_initializable_iterator() 11392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa init_op = iterator.initializer 11492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa next_element = iterator.get_next() 11592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 11692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.test_session() as sess: 11792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length 1 acts like `Dataset.flat_map()`. 11892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [4, 5, 6], 11992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 1, block_length: 3}) 12092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 12192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected_element in self._interleave( 12292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): 12392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected_element, sess.run(next_element)) 12492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 12592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length > 1. 12692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 12792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # 6, 5, 6, 5, 6, 5, 6, 5] 12892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [4, 5, 6], 12992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 2, block_length: 1}) 13092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected_element in self._interleave( 13192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): 13292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected_element, sess.run(next_element)) 13392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 13492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(next_element) 13592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 13692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length > 1 and block length > 1. 13792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, 13892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] 13992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [4, 5, 6], 14092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 2, block_length: 3}) 14192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected_element in self._interleave( 14292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): 14392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected_element, sess.run(next_element)) 14492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 14592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(next_element) 14692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 14792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Cycle length > len(input_values) * repeat_count. 14892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 14992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] 15092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [4, 5, 6], 15192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 7, block_length: 2}) 15292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected_element in self._interleave( 15392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): 15492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected_element, sess.run(next_element)) 15592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 15692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(next_element) 15792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 15892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Empty input. 15992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [], 16092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 2, block_length: 3}) 16192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 16292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(next_element) 16392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 16492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Non-empty input leading to empty output. 16592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [0, 0, 0], 16692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 2, block_length: 3}) 16792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 16892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(next_element) 16992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 17092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa # Mixture of non-empty and empty interleaved datasets. 17192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op, feed_dict={input_values: [4, 0, 6], 17292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa cycle_length: 2, block_length: 3}) 17392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for expected_element in self._interleave( 17492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): 17592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertEqual(expected_element, sess.run(next_element)) 17692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 17792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(next_element) 17892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 17992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa def testSparse(self): 18082fa1e1ae5b2f8af642979fafb1cab455db1882fJiri Simsa 18192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa def _map_fn(i): 18282fa1e1ae5b2f8af642979fafb1cab455db1882fJiri Simsa return sparse_tensor.SparseTensorValue( 18392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) 18492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 18592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa def _interleave_fn(x): 18692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa return dataset_ops.Dataset.from_tensor_slices( 18792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) 18892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 18992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa iterator = ( 19092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa dataset_ops.Dataset.range(10).map(_map_fn).interleave( 19192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa _interleave_fn, cycle_length=1).make_initializable_iterator()) 19292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa init_op = iterator.initializer 19392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa get_next = iterator.get_next() 19492f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 19592f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.test_session() as sess: 19692f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(init_op) 19792f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for i in range(10): 19892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa for j in range(2): 19992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa expected = [i, 0] if j % 2 == 0 else [0, -i] 20092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa self.assertAllEqual(expected, sess.run(get_next)) 20192f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa with self.assertRaises(errors.OutOfRangeError): 20292f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa sess.run(get_next) 20392f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 204a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray def testEmptyInput(self): 205a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray iterator = ( 206a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray dataset_ops.Dataset.from_tensor_slices([]) 207a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray .repeat(None) 208a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) 209a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray .make_initializable_iterator()) 210a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray init_op = iterator.initializer 211a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray get_next = iterator.get_next() 212a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray 213a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray with self.test_session() as sess: 214a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray sess.run(init_op) 215a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray with self.assertRaises(errors.OutOfRangeError): 216a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray sess.run(get_next) 217a805116366eddcaa8eb6a602398f8efae076e0b5Derek Murray 21892f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa 21992f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsaif __name__ == "__main__": 22092f40bfe3beaf087efbda8412bf129a12bcd9db2Jiri Simsa test.main() 221