1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test PrefetchDataset.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import errors 23from tensorflow.python.ops import array_ops 24from tensorflow.python.platform import test 25 26 27class PrefetchDatasetTest(test.TestCase): 28 29 def testBufferSize(self): 30 buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) 31 iterator = dataset_ops.Dataset.range(10).prefetch( 32 buffer_size=buffer_size).make_initializable_iterator() 33 init_op = iterator.initializer 34 get_next = iterator.get_next() 35 36 with self.test_session() as sess: 37 sess.run(init_op, feed_dict={buffer_size: 5}) 38 for m in range(10): 39 self.assertEqual(m, sess.run(get_next)) 40 with self.assertRaises(errors.OutOfRangeError): 41 sess.run(get_next) 42 43 def testInvalidBufferSize(self): 44 buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) 45 iterator = dataset_ops.Dataset.range(10).prefetch( 46 buffer_size=buffer_size).make_initializable_iterator() 47 init_op = iterator.initializer 48 49 with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): 50 with self.test_session() as sess: 51 sess.run(init_op, feed_dict={buffer_size: 0}) 52 53 with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): 54 with self.test_session() as sess: 55 sess.run(init_op, feed_dict={buffer_size: -5}) 56 57 58if __name__ == "__main__": 59 test.main() 60