1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for numpy_io."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.estimator.inputs import numpy_io
24from tensorflow.python.framework import errors
25from tensorflow.python.platform import test
26from tensorflow.python.training import coordinator
27from tensorflow.python.training import monitored_session
28from tensorflow.python.training import queue_runner_impl
29
30
31class NumpyIoTest(test.TestCase):
32
33  def testNumpyInputFn(self):
34    a = np.arange(4) * 1.0
35    b = np.arange(32, 36)
36    x = {'a': a, 'b': b}
37    y = np.arange(-32, -28)
38
39    with self.test_session() as session:
40      input_fn = numpy_io.numpy_input_fn(
41          x, y, batch_size=2, shuffle=False, num_epochs=1)
42      features, target = input_fn()
43
44      coord = coordinator.Coordinator()
45      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
46
47      res = session.run([features, target])
48      self.assertAllEqual(res[0]['a'], [0, 1])
49      self.assertAllEqual(res[0]['b'], [32, 33])
50      self.assertAllEqual(res[1], [-32, -31])
51
52      session.run([features, target])
53      with self.assertRaises(errors.OutOfRangeError):
54        session.run([features, target])
55
56      coord.request_stop()
57      coord.join(threads)
58
59  def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self):
60    a = np.arange(2) * 1.0
61    b = np.arange(32, 34)
62    x = {'a': a, 'b': b}
63    y = np.arange(-32, -30)
64
65    with self.test_session() as session:
66      input_fn = numpy_io.numpy_input_fn(
67          x, y, batch_size=128, shuffle=False, num_epochs=2)
68      features, target = input_fn()
69
70      coord = coordinator.Coordinator()
71      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
72
73      res = session.run([features, target])
74      self.assertAllEqual(res[0]['a'], [0, 1, 0, 1])
75      self.assertAllEqual(res[0]['b'], [32, 33, 32, 33])
76      self.assertAllEqual(res[1], [-32, -31, -32, -31])
77
78      with self.assertRaises(errors.OutOfRangeError):
79        session.run([features, target])
80
81      coord.request_stop()
82      coord.join(threads)
83
84  def testNumpyInputFnWithZeroEpochs(self):
85    a = np.arange(4) * 1.0
86    b = np.arange(32, 36)
87    x = {'a': a, 'b': b}
88    y = np.arange(-32, -28)
89
90    with self.test_session() as session:
91      input_fn = numpy_io.numpy_input_fn(
92          x, y, batch_size=2, shuffle=False, num_epochs=0)
93      features, target = input_fn()
94
95      coord = coordinator.Coordinator()
96      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
97
98      with self.assertRaises(errors.OutOfRangeError):
99        session.run([features, target])
100
101      coord.request_stop()
102      coord.join(threads)
103
104  def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self):
105    batch_size = 2
106    a = np.arange(5) * 1.0
107    b = np.arange(32, 37)
108    x = {'a': a, 'b': b}
109    y = np.arange(-32, -27)
110
111    with self.test_session() as session:
112      input_fn = numpy_io.numpy_input_fn(
113          x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
114      features, target = input_fn()
115
116      coord = coordinator.Coordinator()
117      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
118
119      res = session.run([features, target])
120      self.assertAllEqual(res[0]['a'], [0, 1])
121      self.assertAllEqual(res[0]['b'], [32, 33])
122      self.assertAllEqual(res[1], [-32, -31])
123
124      res = session.run([features, target])
125      self.assertAllEqual(res[0]['a'], [2, 3])
126      self.assertAllEqual(res[0]['b'], [34, 35])
127      self.assertAllEqual(res[1], [-30, -29])
128
129      res = session.run([features, target])
130      self.assertAllEqual(res[0]['a'], [4])
131      self.assertAllEqual(res[0]['b'], [36])
132      self.assertAllEqual(res[1], [-28])
133
134      with self.assertRaises(errors.OutOfRangeError):
135        session.run([features, target])
136
137      coord.request_stop()
138      coord.join(threads)
139
140  def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self):
141    batch_size = 2
142    a = np.arange(3) * 1.0
143    b = np.arange(32, 35)
144    x = {'a': a, 'b': b}
145    y = np.arange(-32, -29)
146
147    with self.test_session() as session:
148      input_fn = numpy_io.numpy_input_fn(
149          x, y, batch_size=batch_size, shuffle=False, num_epochs=3)
150      features, target = input_fn()
151
152      coord = coordinator.Coordinator()
153      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
154
155      res = session.run([features, target])
156      self.assertAllEqual(res[0]['a'], [0, 1])
157      self.assertAllEqual(res[0]['b'], [32, 33])
158      self.assertAllEqual(res[1], [-32, -31])
159
160      res = session.run([features, target])
161      self.assertAllEqual(res[0]['a'], [2, 0])
162      self.assertAllEqual(res[0]['b'], [34, 32])
163      self.assertAllEqual(res[1], [-30, -32])
164
165      res = session.run([features, target])
166      self.assertAllEqual(res[0]['a'], [1, 2])
167      self.assertAllEqual(res[0]['b'], [33, 34])
168      self.assertAllEqual(res[1], [-31, -30])
169
170      res = session.run([features, target])
171      self.assertAllEqual(res[0]['a'], [0, 1])
172      self.assertAllEqual(res[0]['b'], [32, 33])
173      self.assertAllEqual(res[1], [-32, -31])
174
175      res = session.run([features, target])
176      self.assertAllEqual(res[0]['a'], [2])
177      self.assertAllEqual(res[0]['b'], [34])
178      self.assertAllEqual(res[1], [-30])
179
180      with self.assertRaises(errors.OutOfRangeError):
181        session.run([features, target])
182
183      coord.request_stop()
184      coord.join(threads)
185
186  def testNumpyInputFnWithBatchSizeLargerThanDataSize(self):
187    batch_size = 10
188    a = np.arange(4) * 1.0
189    b = np.arange(32, 36)
190    x = {'a': a, 'b': b}
191    y = np.arange(-32, -28)
192
193    with self.test_session() as session:
194      input_fn = numpy_io.numpy_input_fn(
195          x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
196      features, target = input_fn()
197
198      coord = coordinator.Coordinator()
199      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
200
201      res = session.run([features, target])
202      self.assertAllEqual(res[0]['a'], [0, 1, 2, 3])
203      self.assertAllEqual(res[0]['b'], [32, 33, 34, 35])
204      self.assertAllEqual(res[1], [-32, -31, -30, -29])
205
206      with self.assertRaises(errors.OutOfRangeError):
207        session.run([features, target])
208
209      coord.request_stop()
210      coord.join(threads)
211
212  def testNumpyInputFnWithDifferentDimensionsOfFeatures(self):
213    a = np.array([[1, 2], [3, 4]])
214    b = np.array([5, 6])
215    x = {'a': a, 'b': b}
216    y = np.arange(-32, -30)
217
218    with self.test_session() as session:
219      input_fn = numpy_io.numpy_input_fn(
220          x, y, batch_size=2, shuffle=False, num_epochs=1)
221      features, target = input_fn()
222
223      coord = coordinator.Coordinator()
224      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
225
226      res = session.run([features, target])
227      self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]])
228      self.assertAllEqual(res[0]['b'], [5, 6])
229      self.assertAllEqual(res[1], [-32, -31])
230
231      coord.request_stop()
232      coord.join(threads)
233
234  def testNumpyInputFnWithXAsNonDict(self):
235    x = list(range(32, 36))
236    y = np.arange(4)
237    with self.test_session():
238      with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'):
239        failing_input_fn = numpy_io.numpy_input_fn(
240            x, y, batch_size=2, shuffle=False, num_epochs=1)
241        failing_input_fn()
242
243  def testNumpyInputFnWithXIsEmptyDict(self):
244    x = {}
245    y = np.arange(4)
246    with self.test_session():
247      with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
248        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
249        failing_input_fn()
250
251  def testNumpyInputFnWithXIsEmptyArray(self):
252    x = np.array([[], []])
253    y = np.arange(4)
254    with self.test_session():
255      with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
256        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
257        failing_input_fn()
258
259  def testNumpyInputFnWithYIsNone(self):
260    a = np.arange(4) * 1.0
261    b = np.arange(32, 36)
262    x = {'a': a, 'b': b}
263    y = None
264
265    with self.test_session() as session:
266      input_fn = numpy_io.numpy_input_fn(
267          x, y, batch_size=2, shuffle=False, num_epochs=1)
268      features_tensor = input_fn()
269
270      coord = coordinator.Coordinator()
271      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
272
273      feature = session.run(features_tensor)
274      self.assertEqual(len(feature), 2)
275      self.assertAllEqual(feature['a'], [0, 1])
276      self.assertAllEqual(feature['b'], [32, 33])
277
278      session.run([features_tensor])
279      with self.assertRaises(errors.OutOfRangeError):
280        session.run([features_tensor])
281
282      coord.request_stop()
283      coord.join(threads)
284
285  def testNumpyInputFnWithNonBoolShuffle(self):
286    x = np.arange(32, 36)
287    y = np.arange(4)
288    with self.test_session():
289      with self.assertRaisesRegexp(TypeError,
290                                   'shuffle must be explicitly set as boolean'):
291        # Default shuffle is None.
292        numpy_io.numpy_input_fn(x, y)
293
294  def testNumpyInputFnWithTargetKeyAlreadyInX(self):
295    array = np.arange(32, 36)
296    x = {'__target_key__': array}
297    y = np.arange(4)
298
299    with self.test_session():
300      input_fn = numpy_io.numpy_input_fn(
301          x, y, batch_size=2, shuffle=False, num_epochs=1)
302      input_fn()
303      self.assertAllEqual(x['__target_key__'], array)
304      # The input x should not be mutated.
305      self.assertItemsEqual(x.keys(), ['__target_key__'])
306
307  def testNumpyInputFnWithMismatchLengthOfInputs(self):
308    a = np.arange(4) * 1.0
309    b = np.arange(32, 36)
310    x = {'a': a, 'b': b}
311    x_mismatch_length = {'a': np.arange(1), 'b': b}
312    y_longer_length = np.arange(10)
313
314    with self.test_session():
315      with self.assertRaisesRegexp(
316          ValueError, 'Length of tensors in x and y is mismatched.'):
317        failing_input_fn = numpy_io.numpy_input_fn(
318            x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1)
319        failing_input_fn()
320
321      with self.assertRaisesRegexp(
322          ValueError, 'Length of tensors in x and y is mismatched.'):
323        failing_input_fn = numpy_io.numpy_input_fn(
324            x=x_mismatch_length,
325            y=None,
326            batch_size=2,
327            shuffle=False,
328            num_epochs=1)
329        failing_input_fn()
330
331  def testNumpyInputFnWithYAsDict(self):
332    a = np.arange(4) * 1.0
333    b = np.arange(32, 36)
334    x = {'a': a, 'b': b}
335    y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}
336
337    with self.test_session() as session:
338      input_fn = numpy_io.numpy_input_fn(
339          x, y, batch_size=2, shuffle=False, num_epochs=1)
340      features_tensor, targets_tensor = input_fn()
341
342      coord = coordinator.Coordinator()
343      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
344
345      features, targets = session.run([features_tensor, targets_tensor])
346      self.assertEqual(len(features), 2)
347      self.assertAllEqual(features['a'], [0, 1])
348      self.assertAllEqual(features['b'], [32, 33])
349      self.assertEqual(len(targets), 2)
350      self.assertAllEqual(targets['y1'], [-32, -31])
351      self.assertAllEqual(targets['y2'], [32, 31])
352
353      session.run([features_tensor, targets_tensor])
354      with self.assertRaises(errors.OutOfRangeError):
355        session.run([features_tensor, targets_tensor])
356
357      coord.request_stop()
358      coord.join(threads)
359
360  def testNumpyInputFnWithYIsEmptyDict(self):
361    a = np.arange(4) * 1.0
362    b = np.arange(32, 36)
363    x = {'a': a, 'b': b}
364    y = {}
365    with self.test_session():
366      with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):
367        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
368        failing_input_fn()
369
370  def testNumpyInputFnWithDuplicateKeysInXAndY(self):
371    a = np.arange(4) * 1.0
372    b = np.arange(32, 36)
373    x = {'a': a, 'b': b}
374    y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b}
375    with self.test_session():
376      with self.assertRaisesRegexp(
377          ValueError, '2 duplicate keys are found in both x and y'):
378        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
379        failing_input_fn()
380
381  def testNumpyInputFnWithXIsArray(self):
382    x = np.arange(4) * 1.0
383    y = np.arange(-32, -28)
384
385    input_fn = numpy_io.numpy_input_fn(
386        x, y, batch_size=2, shuffle=False, num_epochs=1)
387    features, target = input_fn()
388
389    with monitored_session.MonitoredSession() as session:
390      res = session.run([features, target])
391      self.assertAllEqual(res[0], [0, 1])
392      self.assertAllEqual(res[1], [-32, -31])
393
394      session.run([features, target])
395      with self.assertRaises(errors.OutOfRangeError):
396        session.run([features, target])
397
398  def testNumpyInputFnWithXIsNDArray(self):
399    x = np.arange(16).reshape(4, 2, 2) * 1.0
400    y = np.arange(-48, -32).reshape(4, 2, 2)
401
402    input_fn = numpy_io.numpy_input_fn(
403        x, y, batch_size=2, shuffle=False, num_epochs=1)
404    features, target = input_fn()
405
406    with monitored_session.MonitoredSession() as session:
407      res = session.run([features, target])
408      self.assertAllEqual(res[0], [[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
409      self.assertAllEqual(
410          res[1], [[[-48, -47], [-46, -45]], [[-44, -43], [-42, -41]]])
411
412      session.run([features, target])
413      with self.assertRaises(errors.OutOfRangeError):
414        session.run([features, target])
415
416  def testNumpyInputFnWithXIsArrayYIsDict(self):
417    x = np.arange(4) * 1.0
418    y = {'y1': np.arange(-32, -28)}
419
420    input_fn = numpy_io.numpy_input_fn(
421        x, y, batch_size=2, shuffle=False, num_epochs=1)
422    features_tensor, targets_tensor = input_fn()
423
424    with monitored_session.MonitoredSession() as session:
425      features, targets = session.run([features_tensor, targets_tensor])
426      self.assertEqual(len(features), 2)
427      self.assertAllEqual(features, [0, 1])
428      self.assertEqual(len(targets), 1)
429      self.assertAllEqual(targets['y1'], [-32, -31])
430
431      session.run([features_tensor, targets_tensor])
432      with self.assertRaises(errors.OutOfRangeError):
433        session.run([features_tensor, targets_tensor])
434
435  def testArrayAndDictGiveSameOutput(self):
436    a = np.arange(4) * 1.0
437    b = np.arange(32, 36)
438    x_arr = np.vstack((a, b))
439    x_dict = {'feature1': x_arr}
440    y = np.arange(-48, -40).reshape(2, 4)
441
442    input_fn_arr = numpy_io.numpy_input_fn(
443        x_arr, y, batch_size=2, shuffle=False, num_epochs=1)
444    features_arr, targets_arr = input_fn_arr()
445
446    input_fn_dict = numpy_io.numpy_input_fn(
447        x_dict, y, batch_size=2, shuffle=False, num_epochs=1)
448    features_dict, targets_dict = input_fn_dict()
449
450    with monitored_session.MonitoredSession() as session:
451      res_arr, res_dict = session.run([
452          (features_arr, targets_arr), (features_dict, targets_dict)])
453
454      self.assertAllEqual(res_arr[0], res_dict[0]['feature1'])
455      self.assertAllEqual(res_arr[1], res_dict[1])
456
457
458if __name__ == '__main__':
459  test.main()
460