11b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# 31b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# Licensed under the Apache License, Version 2.0 (the "License"); 41b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# you may not use this file except in compliance with the License. 51b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# You may obtain a copy of the License at 61b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# 71b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# http://www.apache.org/licenses/LICENSE-2.0 81b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# 91b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# Unless required by applicable law or agreed to in writing, software 101b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# distributed under the License is distributed on an "AS IS" BASIS, 111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# See the License for the specific language governing permissions and 131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# limitations under the License. 141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# ============================================================================== 151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie"""Tests for the time series input pipeline.""" 161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom __future__ import absolute_import 181b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom __future__ import division 191b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom __future__ import print_function 201b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 211b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieimport csv 221b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieimport tempfile 231b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 241b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieimport numpy 251b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 261b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.contrib.timeseries.python.timeseries import input_pipeline 271b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.contrib.timeseries.python.timeseries import test_utils 281b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures 291b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.core.example import example_pb2 3100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.python.framework import dtypes 321b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.framework import errors 3300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.python.lib.io import tf_record 3400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.python.ops import parsing_ops 351b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.ops import variables 361b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.platform import test 371b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.training import coordinator as coordinator_lib 381b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.training import queue_runner_impl 391b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 401b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 411b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiedef _make_csv_temp_file(to_write, test_tmpdir): 421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie _, data_file = tempfile.mkstemp(dir=test_tmpdir) 431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with open(data_file, "w") as f: 441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie csvwriter = csv.writer(f) 451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for record in to_write: 461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie csvwriter.writerow(record) 471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie return data_file 481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiedef _make_csv_time_series(num_features, num_samples, test_tmpdir): 511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_temp_file( 521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie [[i] + [float(i) * 2. + feature_number 531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for feature_number in range(num_features)] 541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for i in range(num_samples)], 551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=test_tmpdir) 561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie return filename 571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 5900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiedef _make_tfexample_series(num_features, num_samples, test_tmpdir): 6000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie _, data_file = tempfile.mkstemp(dir=test_tmpdir) 6100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie with tf_record.TFRecordWriter(data_file) as writer: 6200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie for i in range(num_samples): 6300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie example = example_pb2.Example() 6400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie times = example.features.feature[TrainEvalFeatures.TIMES] 6500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie times.int64_list.value.append(i) 6600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie values = example.features.feature[TrainEvalFeatures.VALUES] 6700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie values.float_list.value.extend( 6800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie [float(i) * 2. + feature_number 6900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie for feature_number in range(num_features)]) 7000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie writer.write(example.SerializeToString()) 7100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie return data_file 7200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie 7300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie 741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiedef _make_numpy_time_series(num_features, num_samples): 751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie times = numpy.arange(num_samples) 761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie values = times[:, None] * 2. + numpy.arange(num_features)[None, :] 771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie return {TrainEvalFeatures.TIMES: times, 781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie TrainEvalFeatures.VALUES: values} 791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 801b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 811b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieclass RandomWindowInputFnTests(test.TestCase): 821b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 831b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def _random_window_input_fn_test_template( 841b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self, time_series_reader, window_size, batch_size, num_features, 851b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie discard_out_of_order=False): 861b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie input_fn = input_pipeline.RandomWindowInputFn( 871b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 881b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_size=window_size, batch_size=batch_size) 891b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie result, _ = input_fn() 901b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie init_op = variables.local_variables_initializer() 911b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.test_session() as session: 921b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator = coordinator_lib.Coordinator() 931b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie queue_runner_impl.start_queue_runners(session, coord=coordinator) 941b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie session.run(init_op) 951b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features = session.run(result) 961b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator.request_stop() 971b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator.join() 981b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual([batch_size, window_size], 991b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.TIMES].shape) 1001b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for window_position in range(window_size - 1): 1011b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for batch_position in range(batch_size): 1021b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie # Checks that all times are contiguous 1031b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertEqual( 1041b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.TIMES][batch_position, 1051b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_position + 1], 1061b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.TIMES][batch_position, 1071b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_position] + 1) 1081b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual([batch_size, window_size, num_features], 1091b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.VALUES].shape) 1101b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype) 1111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for feature_number in range(num_features): 1121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual( 1131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.TIMES] * 2. + feature_number, 1141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.VALUES][:, :, feature_number]) 1151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie return features 1161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def _test_out_of_order(self, time_series_reader, discard_out_of_order): 1181b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 1191b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 1201b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=1, window_size=2, batch_size=5, 1211b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie discard_out_of_order=discard_out_of_order) 1221b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1231b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_sort_out_of_order(self): 1241b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=50, 1251b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 1261b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader([filename]) 1271b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_out_of_order(time_series_reader, discard_out_of_order=False) 1281b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 12900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie def test_tfexample_sort_out_of_order(self): 13000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie filename = _make_tfexample_series( 13100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie num_features=1, num_samples=50, 13200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie test_tmpdir=self.get_temp_dir()) 13300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie time_series_reader = input_pipeline.TFExampleReader( 13400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie [filename], 13500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie features={ 13600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( 13700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie shape=[], dtype=dtypes.int64), 13800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( 13900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie shape=[1], dtype=dtypes.float32)}) 14000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie self._test_out_of_order(time_series_reader, discard_out_of_order=False) 14100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie 1421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_sort_out_of_order(self): 1431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=1, num_samples=50) 1441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 1451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_out_of_order(time_series_reader, discard_out_of_order=False) 1461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_discard_out_of_order(self): 1481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=50, 1491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 1501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader([filename]) 1511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_out_of_order(time_series_reader, discard_out_of_order=True) 1521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_discard_out_of_order_window_equal(self): 1541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=3, 1551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 1561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader([filename]) 1571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 1581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 1591b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=1, window_size=3, batch_size=5, 1601b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie discard_out_of_order=True) 1611b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1621b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_discard_out_of_order_window_too_large(self): 1631b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=2, 1641b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 1651b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader([filename]) 1661b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.assertRaises(errors.OutOfRangeError): 1671b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 1681b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 1691b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=1, window_size=3, batch_size=5, 1701b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie discard_out_of_order=True) 1711b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1721b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_no_data(self): 1731b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=0, 1741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 1751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader([filename]) 1761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.assertRaises(errors.OutOfRangeError): 1771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_out_of_order(time_series_reader, discard_out_of_order=True) 1781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_discard_out_of_order(self): 1801b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=1, num_samples=50) 1811b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 1821b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_out_of_order(time_series_reader, discard_out_of_order=True) 1831b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1841b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_discard_out_of_order_window_equal(self): 1851b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=1, num_samples=3) 1861b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 1871b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 1881b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 1891b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=1, window_size=3, batch_size=5, 1901b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie discard_out_of_order=True) 1911b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 1921b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_discard_out_of_order_window_too_large(self): 1931b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=1, num_samples=2) 1941b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 1951b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.assertRaisesRegexp(ValueError, "only 2 records were available"): 1961b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 1971b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 1981b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=1, window_size=3, batch_size=5, 1991b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie discard_out_of_order=True) 2001b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2011b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def _test_multivariate(self, time_series_reader, num_features): 2021b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 2031b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 2041b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=num_features, 2051b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_size=2, 2061b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie batch_size=5) 2071b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2081b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_multivariate(self): 2091b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=2, num_samples=50, 2101b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 2111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader( 2121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie [filename], 2131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie column_names=(TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES, 2141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie TrainEvalFeatures.VALUES)) 2151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_multivariate(time_series_reader=time_series_reader, 2161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=2) 2171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 21800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie def test_tfexample_multivariate(self): 21900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie filename = _make_tfexample_series( 22000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie num_features=2, num_samples=50, 22100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie test_tmpdir=self.get_temp_dir()) 22200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie time_series_reader = input_pipeline.TFExampleReader( 22300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie [filename], 22400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie features={ 22500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( 22600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie shape=[], dtype=dtypes.int64), 22700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( 22800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie shape=[2], dtype=dtypes.float32)}) 22900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie self._test_multivariate(time_series_reader=time_series_reader, 23000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie num_features=2) 23100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie 2321b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_multivariate(self): 2331b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=3, num_samples=50) 2341b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 2351b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._test_multivariate(time_series_reader, num_features=3) 2361b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2371b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_withbatch(self): 2381b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data_nobatch = _make_numpy_time_series(num_features=4, num_samples=100) 2391b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = {feature_name: feature_value[None] 2401b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for feature_name, feature_value in data_nobatch.items()} 2411b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 2421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 2431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 2441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=4, 2451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_size=3, 2461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie batch_size=5) 2471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_nobatch_nofeatures(self): 2491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=1, num_samples=100) 2501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data[TrainEvalFeatures.VALUES] = data[TrainEvalFeatures.VALUES][:, 0] 2511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 2521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._random_window_input_fn_test_template( 2531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 2541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_features=1, 2551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_size=16, 2561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie batch_size=16) 2571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2591b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieclass WholeDatasetInputFnTests(test.TestCase): 2601b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2611b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def _whole_dataset_input_fn_test_template( 2621b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self, time_series_reader, num_features, num_samples): 2631b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)() 2641b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.test_session() as session: 2651b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie session.run(variables.local_variables_initializer()) 2661b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator = coordinator_lib.Coordinator() 2671b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie queue_runner_impl.start_queue_runners(session, coord=coordinator) 2681b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features = session.run(result) 2691b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator.request_stop() 2701b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator.join() 2711b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype) 2721b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual(numpy.arange(num_samples, dtype=numpy.int64)[None, :], 2731b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.TIMES]) 2741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for feature_number in range(num_features): 2751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual( 2761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.TIMES] * 2. + feature_number, 2771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.VALUES][:, :, feature_number]) 2781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv(self): 2801b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=3, num_samples=50, 2811b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 2821b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader( 2831b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie [filename], 2841b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie column_names=(TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES, 2851b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie TrainEvalFeatures.VALUES, TrainEvalFeatures.VALUES)) 2861b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._whole_dataset_input_fn_test_template( 2871b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, num_features=3, num_samples=50) 2881b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 2891b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv_no_data(self): 2901b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=0, 2911b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 2921b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader([filename]) 2931b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.assertRaises(errors.OutOfRangeError): 2941b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._whole_dataset_input_fn_test_template( 2951b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, num_features=1, num_samples=50) 2961b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 29700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie def test_tfexample(self): 29800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie filename = _make_tfexample_series( 29900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie num_features=4, num_samples=100, 30000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie test_tmpdir=self.get_temp_dir()) 30100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie time_series_reader = input_pipeline.TFExampleReader( 30200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie [filename], 30300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie features={ 30400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( 30500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie shape=[], dtype=dtypes.int64), 30600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( 30700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie shape=[4], dtype=dtypes.float32)}) 30800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie self._whole_dataset_input_fn_test_template( 30900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie time_series_reader=time_series_reader, num_features=4, num_samples=100) 31000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie 3111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy(self): 3121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=4, num_samples=100) 3131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 3141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._whole_dataset_input_fn_test_template( 3151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, num_features=4, num_samples=100) 3161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_withbatch(self): 3181b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data_nobatch = _make_numpy_time_series(num_features=4, num_samples=100) 3191b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = {feature_name: feature_value[None] 3201b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie for feature_name, feature_value in data_nobatch.items()} 3211b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 3221b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._whole_dataset_input_fn_test_template( 3231b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, num_features=4, num_samples=100) 3241b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3251b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy_nobatch_nofeatures(self): 3261b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=1, num_samples=100) 3271b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data[TrainEvalFeatures.VALUES] = data[TrainEvalFeatures.VALUES][:, 0] 3281b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 3291b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._whole_dataset_input_fn_test_template( 3301b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, num_features=1, num_samples=100) 3311b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3321b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3331b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieclass AllWindowInputFnTests(test.TestCase): 3341b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3351b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def _all_window_input_fn_test_template( 3361b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self, time_series_reader, num_samples, window_size, 3371b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie original_numpy_features=None): 3381b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie input_fn = test_utils.AllWindowInputFn( 3391b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, 3401b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie window_size=window_size) 3411b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features, _ = input_fn() 3421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie init_op = variables.local_variables_initializer() 3431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie with self.test_session() as session: 3441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator = coordinator_lib.Coordinator() 3451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie queue_runner_impl.start_queue_runners(session, coord=coordinator) 3461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie session.run(init_op) 3471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie chunked_times, chunked_values = session.run( 3481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie [features[TrainEvalFeatures.TIMES], 3491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie features[TrainEvalFeatures.VALUES]]) 3501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator.request_stop() 3511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie coordinator.join() 3521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual([num_samples - window_size + 1, window_size], 3531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie chunked_times.shape) 3541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie if original_numpy_features is not None: 3551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie original_times = original_numpy_features[TrainEvalFeatures.TIMES] 3561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie original_values = original_numpy_features[TrainEvalFeatures.VALUES] 3571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual(original_times, numpy.unique(chunked_times)) 3581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self.assertAllEqual(original_values[chunked_times], 3591b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie chunked_values) 3601b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3611b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_csv(self): 3621b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie filename = _make_csv_time_series(num_features=1, num_samples=50, 3631b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test_tmpdir=self.get_temp_dir()) 3641b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.CSVReader( 3651b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie [filename], 3661b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie column_names=(TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES)) 3671b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._all_window_input_fn_test_template( 3681b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, num_samples=50, window_size=10) 3691b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3701b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie def test_numpy(self): 3711b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie data = _make_numpy_time_series(num_features=2, num_samples=31) 3721b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader = input_pipeline.NumpyReader(data) 3731b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie self._all_window_input_fn_test_template( 3741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie time_series_reader=time_series_reader, original_numpy_features=data, 3751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie num_samples=31, window_size=5) 3761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie 3781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieif __name__ == "__main__": 3791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie test.main() 380