1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for numpy_io.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.estimator.inputs import numpy_io 24from tensorflow.python.framework import errors 25from tensorflow.python.platform import test 26from tensorflow.python.training import coordinator 27from tensorflow.python.training import monitored_session 28from tensorflow.python.training import queue_runner_impl 29 30 31class NumpyIoTest(test.TestCase): 32 33 def testNumpyInputFn(self): 34 a = np.arange(4) * 1.0 35 b = np.arange(32, 36) 36 x = {'a': a, 'b': b} 37 y = np.arange(-32, -28) 38 39 with self.test_session() as session: 40 input_fn = numpy_io.numpy_input_fn( 41 x, y, batch_size=2, shuffle=False, num_epochs=1) 42 features, target = input_fn() 43 44 coord = coordinator.Coordinator() 45 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 46 47 res = session.run([features, target]) 48 self.assertAllEqual(res[0]['a'], [0, 1]) 49 self.assertAllEqual(res[0]['b'], [32, 33]) 50 self.assertAllEqual(res[1], [-32, -31]) 51 52 session.run([features, target]) 53 with self.assertRaises(errors.OutOfRangeError): 54 session.run([features, target]) 55 56 coord.request_stop() 57 coord.join(threads) 58 59 def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self): 60 a = np.arange(2) * 1.0 61 b = np.arange(32, 34) 62 x = {'a': a, 'b': b} 63 y = np.arange(-32, -30) 64 65 with self.test_session() as session: 66 input_fn = numpy_io.numpy_input_fn( 67 x, y, batch_size=128, shuffle=False, num_epochs=2) 68 features, target = input_fn() 69 70 coord = coordinator.Coordinator() 71 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 72 73 res = session.run([features, target]) 74 self.assertAllEqual(res[0]['a'], [0, 1, 0, 1]) 75 self.assertAllEqual(res[0]['b'], [32, 33, 32, 33]) 76 self.assertAllEqual(res[1], [-32, -31, -32, -31]) 77 78 with self.assertRaises(errors.OutOfRangeError): 79 session.run([features, target]) 80 81 coord.request_stop() 82 coord.join(threads) 83 84 def testNumpyInputFnWithZeroEpochs(self): 85 a = np.arange(4) * 1.0 86 b = np.arange(32, 36) 87 x = {'a': a, 'b': b} 88 y = np.arange(-32, -28) 89 90 with self.test_session() as session: 91 input_fn = numpy_io.numpy_input_fn( 92 x, y, batch_size=2, shuffle=False, num_epochs=0) 93 features, target = input_fn() 94 95 coord = coordinator.Coordinator() 96 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 97 98 with self.assertRaises(errors.OutOfRangeError): 99 session.run([features, target]) 100 101 coord.request_stop() 102 coord.join(threads) 103 104 def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self): 105 batch_size = 2 106 a = np.arange(5) * 1.0 107 b = np.arange(32, 37) 108 x = {'a': a, 'b': b} 109 y = np.arange(-32, -27) 110 111 with self.test_session() as session: 112 input_fn = numpy_io.numpy_input_fn( 113 x, y, batch_size=batch_size, shuffle=False, num_epochs=1) 114 features, target = input_fn() 115 116 coord = coordinator.Coordinator() 117 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 118 119 res = session.run([features, target]) 120 self.assertAllEqual(res[0]['a'], [0, 1]) 121 self.assertAllEqual(res[0]['b'], [32, 33]) 122 self.assertAllEqual(res[1], [-32, -31]) 123 124 res = session.run([features, target]) 125 self.assertAllEqual(res[0]['a'], [2, 3]) 126 self.assertAllEqual(res[0]['b'], [34, 35]) 127 self.assertAllEqual(res[1], [-30, -29]) 128 129 res = session.run([features, target]) 130 self.assertAllEqual(res[0]['a'], [4]) 131 self.assertAllEqual(res[0]['b'], [36]) 132 self.assertAllEqual(res[1], [-28]) 133 134 with self.assertRaises(errors.OutOfRangeError): 135 session.run([features, target]) 136 137 coord.request_stop() 138 coord.join(threads) 139 140 def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self): 141 batch_size = 2 142 a = np.arange(3) * 1.0 143 b = np.arange(32, 35) 144 x = {'a': a, 'b': b} 145 y = np.arange(-32, -29) 146 147 with self.test_session() as session: 148 input_fn = numpy_io.numpy_input_fn( 149 x, y, batch_size=batch_size, shuffle=False, num_epochs=3) 150 features, target = input_fn() 151 152 coord = coordinator.Coordinator() 153 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 154 155 res = session.run([features, target]) 156 self.assertAllEqual(res[0]['a'], [0, 1]) 157 self.assertAllEqual(res[0]['b'], [32, 33]) 158 self.assertAllEqual(res[1], [-32, -31]) 159 160 res = session.run([features, target]) 161 self.assertAllEqual(res[0]['a'], [2, 0]) 162 self.assertAllEqual(res[0]['b'], [34, 32]) 163 self.assertAllEqual(res[1], [-30, -32]) 164 165 res = session.run([features, target]) 166 self.assertAllEqual(res[0]['a'], [1, 2]) 167 self.assertAllEqual(res[0]['b'], [33, 34]) 168 self.assertAllEqual(res[1], [-31, -30]) 169 170 res = session.run([features, target]) 171 self.assertAllEqual(res[0]['a'], [0, 1]) 172 self.assertAllEqual(res[0]['b'], [32, 33]) 173 self.assertAllEqual(res[1], [-32, -31]) 174 175 res = session.run([features, target]) 176 self.assertAllEqual(res[0]['a'], [2]) 177 self.assertAllEqual(res[0]['b'], [34]) 178 self.assertAllEqual(res[1], [-30]) 179 180 with self.assertRaises(errors.OutOfRangeError): 181 session.run([features, target]) 182 183 coord.request_stop() 184 coord.join(threads) 185 186 def testNumpyInputFnWithBatchSizeLargerThanDataSize(self): 187 batch_size = 10 188 a = np.arange(4) * 1.0 189 b = np.arange(32, 36) 190 x = {'a': a, 'b': b} 191 y = np.arange(-32, -28) 192 193 with self.test_session() as session: 194 input_fn = numpy_io.numpy_input_fn( 195 x, y, batch_size=batch_size, shuffle=False, num_epochs=1) 196 features, target = input_fn() 197 198 coord = coordinator.Coordinator() 199 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 200 201 res = session.run([features, target]) 202 self.assertAllEqual(res[0]['a'], [0, 1, 2, 3]) 203 self.assertAllEqual(res[0]['b'], [32, 33, 34, 35]) 204 self.assertAllEqual(res[1], [-32, -31, -30, -29]) 205 206 with self.assertRaises(errors.OutOfRangeError): 207 session.run([features, target]) 208 209 coord.request_stop() 210 coord.join(threads) 211 212 def testNumpyInputFnWithDifferentDimensionsOfFeatures(self): 213 a = np.array([[1, 2], [3, 4]]) 214 b = np.array([5, 6]) 215 x = {'a': a, 'b': b} 216 y = np.arange(-32, -30) 217 218 with self.test_session() as session: 219 input_fn = numpy_io.numpy_input_fn( 220 x, y, batch_size=2, shuffle=False, num_epochs=1) 221 features, target = input_fn() 222 223 coord = coordinator.Coordinator() 224 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 225 226 res = session.run([features, target]) 227 self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]]) 228 self.assertAllEqual(res[0]['b'], [5, 6]) 229 self.assertAllEqual(res[1], [-32, -31]) 230 231 coord.request_stop() 232 coord.join(threads) 233 234 def testNumpyInputFnWithXAsNonDict(self): 235 x = list(range(32, 36)) 236 y = np.arange(4) 237 with self.test_session(): 238 with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'): 239 failing_input_fn = numpy_io.numpy_input_fn( 240 x, y, batch_size=2, shuffle=False, num_epochs=1) 241 failing_input_fn() 242 243 def testNumpyInputFnWithXIsEmptyDict(self): 244 x = {} 245 y = np.arange(4) 246 with self.test_session(): 247 with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'): 248 failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) 249 failing_input_fn() 250 251 def testNumpyInputFnWithXIsEmptyArray(self): 252 x = np.array([[], []]) 253 y = np.arange(4) 254 with self.test_session(): 255 with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'): 256 failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) 257 failing_input_fn() 258 259 def testNumpyInputFnWithYIsNone(self): 260 a = np.arange(4) * 1.0 261 b = np.arange(32, 36) 262 x = {'a': a, 'b': b} 263 y = None 264 265 with self.test_session() as session: 266 input_fn = numpy_io.numpy_input_fn( 267 x, y, batch_size=2, shuffle=False, num_epochs=1) 268 features_tensor = input_fn() 269 270 coord = coordinator.Coordinator() 271 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 272 273 feature = session.run(features_tensor) 274 self.assertEqual(len(feature), 2) 275 self.assertAllEqual(feature['a'], [0, 1]) 276 self.assertAllEqual(feature['b'], [32, 33]) 277 278 session.run([features_tensor]) 279 with self.assertRaises(errors.OutOfRangeError): 280 session.run([features_tensor]) 281 282 coord.request_stop() 283 coord.join(threads) 284 285 def testNumpyInputFnWithNonBoolShuffle(self): 286 x = np.arange(32, 36) 287 y = np.arange(4) 288 with self.test_session(): 289 with self.assertRaisesRegexp(TypeError, 290 'shuffle must be explicitly set as boolean'): 291 # Default shuffle is None. 292 numpy_io.numpy_input_fn(x, y) 293 294 def testNumpyInputFnWithTargetKeyAlreadyInX(self): 295 array = np.arange(32, 36) 296 x = {'__target_key__': array} 297 y = np.arange(4) 298 299 with self.test_session(): 300 input_fn = numpy_io.numpy_input_fn( 301 x, y, batch_size=2, shuffle=False, num_epochs=1) 302 input_fn() 303 self.assertAllEqual(x['__target_key__'], array) 304 # The input x should not be mutated. 305 self.assertItemsEqual(x.keys(), ['__target_key__']) 306 307 def testNumpyInputFnWithMismatchLengthOfInputs(self): 308 a = np.arange(4) * 1.0 309 b = np.arange(32, 36) 310 x = {'a': a, 'b': b} 311 x_mismatch_length = {'a': np.arange(1), 'b': b} 312 y_longer_length = np.arange(10) 313 314 with self.test_session(): 315 with self.assertRaisesRegexp( 316 ValueError, 'Length of tensors in x and y is mismatched.'): 317 failing_input_fn = numpy_io.numpy_input_fn( 318 x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1) 319 failing_input_fn() 320 321 with self.assertRaisesRegexp( 322 ValueError, 'Length of tensors in x and y is mismatched.'): 323 failing_input_fn = numpy_io.numpy_input_fn( 324 x=x_mismatch_length, 325 y=None, 326 batch_size=2, 327 shuffle=False, 328 num_epochs=1) 329 failing_input_fn() 330 331 def testNumpyInputFnWithYAsDict(self): 332 a = np.arange(4) * 1.0 333 b = np.arange(32, 36) 334 x = {'a': a, 'b': b} 335 y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)} 336 337 with self.test_session() as session: 338 input_fn = numpy_io.numpy_input_fn( 339 x, y, batch_size=2, shuffle=False, num_epochs=1) 340 features_tensor, targets_tensor = input_fn() 341 342 coord = coordinator.Coordinator() 343 threads = queue_runner_impl.start_queue_runners(session, coord=coord) 344 345 features, targets = session.run([features_tensor, targets_tensor]) 346 self.assertEqual(len(features), 2) 347 self.assertAllEqual(features['a'], [0, 1]) 348 self.assertAllEqual(features['b'], [32, 33]) 349 self.assertEqual(len(targets), 2) 350 self.assertAllEqual(targets['y1'], [-32, -31]) 351 self.assertAllEqual(targets['y2'], [32, 31]) 352 353 session.run([features_tensor, targets_tensor]) 354 with self.assertRaises(errors.OutOfRangeError): 355 session.run([features_tensor, targets_tensor]) 356 357 coord.request_stop() 358 coord.join(threads) 359 360 def testNumpyInputFnWithYIsEmptyDict(self): 361 a = np.arange(4) * 1.0 362 b = np.arange(32, 36) 363 x = {'a': a, 'b': b} 364 y = {} 365 with self.test_session(): 366 with self.assertRaisesRegexp(ValueError, 'y cannot be empty'): 367 failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) 368 failing_input_fn() 369 370 def testNumpyInputFnWithDuplicateKeysInXAndY(self): 371 a = np.arange(4) * 1.0 372 b = np.arange(32, 36) 373 x = {'a': a, 'b': b} 374 y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b} 375 with self.test_session(): 376 with self.assertRaisesRegexp( 377 ValueError, '2 duplicate keys are found in both x and y'): 378 failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) 379 failing_input_fn() 380 381 def testNumpyInputFnWithXIsArray(self): 382 x = np.arange(4) * 1.0 383 y = np.arange(-32, -28) 384 385 input_fn = numpy_io.numpy_input_fn( 386 x, y, batch_size=2, shuffle=False, num_epochs=1) 387 features, target = input_fn() 388 389 with monitored_session.MonitoredSession() as session: 390 res = session.run([features, target]) 391 self.assertAllEqual(res[0], [0, 1]) 392 self.assertAllEqual(res[1], [-32, -31]) 393 394 session.run([features, target]) 395 with self.assertRaises(errors.OutOfRangeError): 396 session.run([features, target]) 397 398 def testNumpyInputFnWithXIsNDArray(self): 399 x = np.arange(16).reshape(4, 2, 2) * 1.0 400 y = np.arange(-48, -32).reshape(4, 2, 2) 401 402 input_fn = numpy_io.numpy_input_fn( 403 x, y, batch_size=2, shuffle=False, num_epochs=1) 404 features, target = input_fn() 405 406 with monitored_session.MonitoredSession() as session: 407 res = session.run([features, target]) 408 self.assertAllEqual(res[0], [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) 409 self.assertAllEqual( 410 res[1], [[[-48, -47], [-46, -45]], [[-44, -43], [-42, -41]]]) 411 412 session.run([features, target]) 413 with self.assertRaises(errors.OutOfRangeError): 414 session.run([features, target]) 415 416 def testNumpyInputFnWithXIsArrayYIsDict(self): 417 x = np.arange(4) * 1.0 418 y = {'y1': np.arange(-32, -28)} 419 420 input_fn = numpy_io.numpy_input_fn( 421 x, y, batch_size=2, shuffle=False, num_epochs=1) 422 features_tensor, targets_tensor = input_fn() 423 424 with monitored_session.MonitoredSession() as session: 425 features, targets = session.run([features_tensor, targets_tensor]) 426 self.assertEqual(len(features), 2) 427 self.assertAllEqual(features, [0, 1]) 428 self.assertEqual(len(targets), 1) 429 self.assertAllEqual(targets['y1'], [-32, -31]) 430 431 session.run([features_tensor, targets_tensor]) 432 with self.assertRaises(errors.OutOfRangeError): 433 session.run([features_tensor, targets_tensor]) 434 435 def testArrayAndDictGiveSameOutput(self): 436 a = np.arange(4) * 1.0 437 b = np.arange(32, 36) 438 x_arr = np.vstack((a, b)) 439 x_dict = {'feature1': x_arr} 440 y = np.arange(-48, -40).reshape(2, 4) 441 442 input_fn_arr = numpy_io.numpy_input_fn( 443 x_arr, y, batch_size=2, shuffle=False, num_epochs=1) 444 features_arr, targets_arr = input_fn_arr() 445 446 input_fn_dict = numpy_io.numpy_input_fn( 447 x_dict, y, batch_size=2, shuffle=False, num_epochs=1) 448 features_dict, targets_dict = input_fn_dict() 449 450 with monitored_session.MonitoredSession() as session: 451 res_arr, res_dict = session.run([ 452 (features_arr, targets_arr), (features_dict, targets_dict)]) 453 454 self.assertAllEqual(res_arr[0], res_dict[0]['feature1']) 455 self.assertAllEqual(res_arr[1], res_dict[1]) 456 457 458if __name__ == '__main__': 459 test.main() 460