118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng#
318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# Licensed under the Apache License, Version 2.0 (the "License");
418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# you may not use this file except in compliance with the License.
518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# You may obtain a copy of the License at
618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng#
718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng#     http://www.apache.org/licenses/LICENSE-2.0
818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng#
918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# Unless required by applicable law or agreed to in writing, software
1018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# distributed under the License is distributed on an "AS IS" BASIS,
1118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# See the License for the specific language governing permissions and
1318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# limitations under the License.
1418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng# ==============================================================================
1518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng"""Tests for io_utils."""
1618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
1718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengfrom __future__ import absolute_import
1818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengfrom __future__ import division
1918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengfrom __future__ import print_function
2018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
2118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengimport os
2218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengimport shutil
2318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
2418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengimport numpy as np
2518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
2618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengfrom tensorflow.python.keras._impl import keras
2718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengfrom tensorflow.python.platform import test
2818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
2918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengtry:
3018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  import h5py  # pylint:disable=g-import-not-at-top
3118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengexcept ImportError:
3218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  h5py = None
3318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
3418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
3518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengdef create_dataset(h5_path='test.h5'):
3618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  x = np.random.randn(200, 10).astype('float32')
3718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  y = np.random.randint(0, 2, size=(200, 1))
3818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  f = h5py.File(h5_path, 'w')
3918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  # Creating dataset to store features
4018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  x_dset = f.create_dataset('my_data', (200, 10), dtype='f')
4118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  x_dset[:] = x
4218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  # Creating dataset to store labels
4318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  y_dset = f.create_dataset('my_labels', (200, 1), dtype='i')
4418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  y_dset[:] = y
4518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  f.close()
4618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
4718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
4818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengclass TestIOUtils(test.TestCase):
4918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
5018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  def test_HDF5Matrix(self):
5118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    if h5py is None:
5218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng      return
5318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
5418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    temp_dir = self.get_temp_dir()
5518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.addCleanup(shutil.rmtree, temp_dir)
5618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
5718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    h5_path = os.path.join(temp_dir, 'test.h5')
5818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    create_dataset(h5_path)
5918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
6018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # Instantiating HDF5Matrix for the training set,
6118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # which is a slice of the first 150 elements
6218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    x_train = keras.utils.io_utils.HDF5Matrix(
6318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng        h5_path, 'my_data', start=0, end=150)
6418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    y_train = keras.utils.io_utils.HDF5Matrix(
6518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng        h5_path, 'my_labels', start=0, end=150)
6618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
6718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # Likewise for the test set
6818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    x_test = keras.utils.io_utils.HDF5Matrix(
6918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng        h5_path, 'my_data', start=150, end=200)
7018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    y_test = keras.utils.io_utils.HDF5Matrix(
7118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng        h5_path, 'my_labels', start=150, end=200)
7218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
7318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # HDF5Matrix behave more or less like Numpy matrices
7418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # with regard to indexing
7518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertEqual(y_train.shape, (150, 1))
7618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # But they do not support negative indices, so don't try print(x_train[-1])
7718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
7818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertEqual(y_train.dtype, np.dtype('i'))
7918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertEqual(y_train.ndim, 2)
8018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertEqual(y_train.size, 150)
8118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
8218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    model = keras.models.Sequential()
8318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
8418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    model.add(keras.layers.Dense(1, activation='sigmoid'))
8518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    model.compile(loss='binary_crossentropy', optimizer='sgd')
8618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
8718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # Note: you have to use shuffle='batch' or False with HDF5Matrix
8818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
8918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # test that evalutation and prediction
9018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    # don't crash and return reasonable results
9118f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    out_pred = model.predict(x_test, batch_size=32, verbose=False)
9218f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
9318f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
9418f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertEqual(out_pred.shape, (50, 1))
9518f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertEqual(out_eval.shape, ())
9618f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng    self.assertGreater(out_eval, 0)
9718f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
9818f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng
9918f36927160d05b941c056f10dc7f9aecaa05e23Yifei Fengif __name__ == '__main__':
10018f36927160d05b941c056f10dc7f9aecaa05e23Yifei Feng  test.main()
101