11b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie#
31b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# Licensed under the Apache License, Version 2.0 (the "License");
41b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# you may not use this file except in compliance with the License.
51b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# You may obtain a copy of the License at
61b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie#
71b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie#     http://www.apache.org/licenses/LICENSE-2.0
81b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie#
91b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# Unless required by applicable law or agreed to in writing, software
101b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# distributed under the License is distributed on an "AS IS" BASIS,
111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# See the License for the specific language governing permissions and
131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# limitations under the License.
141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie# ==============================================================================
151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie"""Tests for the time series input pipeline."""
161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom __future__ import absolute_import
181b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom __future__ import division
191b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom __future__ import print_function
201b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
211b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieimport csv
221b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieimport tempfile
231b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
241b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieimport numpy
251b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
261b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.contrib.timeseries.python.timeseries import input_pipeline
271b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.contrib.timeseries.python.timeseries import test_utils
281b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
291b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.core.example import example_pb2
3100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.python.framework import dtypes
321b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.framework import errors
3300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.python.lib.io import tf_record
3400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiefrom tensorflow.python.ops import parsing_ops
351b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.ops import variables
361b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.platform import test
371b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.training import coordinator as coordinator_lib
381b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiefrom tensorflow.python.training import queue_runner_impl
391b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
401b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
411b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiedef _make_csv_temp_file(to_write, test_tmpdir):
421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  _, data_file = tempfile.mkstemp(dir=test_tmpdir)
431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  with open(data_file, "w") as f:
441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    csvwriter = csv.writer(f)
451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    for record in to_write:
461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      csvwriter.writerow(record)
471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  return data_file
481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiedef _make_csv_time_series(num_features, num_samples, test_tmpdir):
511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  filename = _make_csv_temp_file(
521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      [[i] + [float(i) * 2. + feature_number
531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie              for feature_number in range(num_features)]
541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie       for i in range(num_samples)],
551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      test_tmpdir=test_tmpdir)
561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  return filename
571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
5900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoiedef _make_tfexample_series(num_features, num_samples, test_tmpdir):
6000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie  _, data_file = tempfile.mkstemp(dir=test_tmpdir)
6100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie  with tf_record.TFRecordWriter(data_file) as writer:
6200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    for i in range(num_samples):
6300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie      example = example_pb2.Example()
6400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie      times = example.features.feature[TrainEvalFeatures.TIMES]
6500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie      times.int64_list.value.append(i)
6600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie      values = example.features.feature[TrainEvalFeatures.VALUES]
6700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie      values.float_list.value.extend(
6800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie          [float(i) * 2. + feature_number
6900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie           for feature_number in range(num_features)])
7000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie      writer.write(example.SerializeToString())
7100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie  return data_file
7200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie
7300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie
741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoiedef _make_numpy_time_series(num_features, num_samples):
751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  times = numpy.arange(num_samples)
761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  values = times[:, None] * 2. + numpy.arange(num_features)[None, :]
771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  return {TrainEvalFeatures.TIMES: times,
781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          TrainEvalFeatures.VALUES: values}
791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
801b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
811b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieclass RandomWindowInputFnTests(test.TestCase):
821b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
831b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def _random_window_input_fn_test_template(
841b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self, time_series_reader, window_size, batch_size, num_features,
851b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      discard_out_of_order=False):
861b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    input_fn = input_pipeline.RandomWindowInputFn(
871b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
881b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        window_size=window_size, batch_size=batch_size)
891b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    result, _ = input_fn()
901b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    init_op = variables.local_variables_initializer()
911b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.test_session() as session:
921b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator = coordinator_lib.Coordinator()
931b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      queue_runner_impl.start_queue_runners(session, coord=coordinator)
941b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      session.run(init_op)
951b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      features = session.run(result)
961b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator.request_stop()
971b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator.join()
981b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self.assertAllEqual([batch_size, window_size],
991b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                        features[TrainEvalFeatures.TIMES].shape)
1001b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    for window_position in range(window_size - 1):
1011b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      for batch_position in range(batch_size):
1021b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        # Checks that all times are contiguous
1031b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        self.assertEqual(
1041b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie            features[TrainEvalFeatures.TIMES][batch_position,
1051b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                              window_position + 1],
1061b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie            features[TrainEvalFeatures.TIMES][batch_position,
1071b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                              window_position] + 1)
1081b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self.assertAllEqual([batch_size, window_size, num_features],
1091b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                        features[TrainEvalFeatures.VALUES].shape)
1101b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype)
1111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    for feature_number in range(num_features):
1121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self.assertAllEqual(
1131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          features[TrainEvalFeatures.TIMES] * 2. + feature_number,
1141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          features[TrainEvalFeatures.VALUES][:, :, feature_number])
1151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    return features
1161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def _test_out_of_order(self, time_series_reader, discard_out_of_order):
1181b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._random_window_input_fn_test_template(
1191b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
1201b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_features=1, window_size=2, batch_size=5,
1211b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        discard_out_of_order=discard_out_of_order)
1221b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1231b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_sort_out_of_order(self):
1241b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=50,
1251b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
1261b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader([filename])
1271b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._test_out_of_order(time_series_reader, discard_out_of_order=False)
1281b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
12900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie  def test_tfexample_sort_out_of_order(self):
13000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    filename = _make_tfexample_series(
13100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        num_features=1, num_samples=50,
13200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        test_tmpdir=self.get_temp_dir())
13300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    time_series_reader = input_pipeline.TFExampleReader(
13400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        [filename],
13500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        features={
13600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie            TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature(
13700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                shape=[], dtype=dtypes.int64),
13800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie            TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature(
13900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                shape=[1], dtype=dtypes.float32)})
14000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    self._test_out_of_order(time_series_reader, discard_out_of_order=False)
14100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie
1421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_sort_out_of_order(self):
1431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=1, num_samples=50)
1441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
1451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._test_out_of_order(time_series_reader, discard_out_of_order=False)
1461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_discard_out_of_order(self):
1481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=50,
1491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
1501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader([filename])
1511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._test_out_of_order(time_series_reader, discard_out_of_order=True)
1521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_discard_out_of_order_window_equal(self):
1541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=3,
1551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
1561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader([filename])
1571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._random_window_input_fn_test_template(
1581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
1591b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_features=1, window_size=3, batch_size=5,
1601b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        discard_out_of_order=True)
1611b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1621b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_discard_out_of_order_window_too_large(self):
1631b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=2,
1641b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
1651b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader([filename])
1661b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.assertRaises(errors.OutOfRangeError):
1671b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self._random_window_input_fn_test_template(
1681b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          time_series_reader=time_series_reader,
1691b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          num_features=1, window_size=3, batch_size=5,
1701b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          discard_out_of_order=True)
1711b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1721b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_no_data(self):
1731b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=0,
1741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
1751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader([filename])
1761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.assertRaises(errors.OutOfRangeError):
1771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self._test_out_of_order(time_series_reader, discard_out_of_order=True)
1781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_discard_out_of_order(self):
1801b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=1, num_samples=50)
1811b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
1821b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._test_out_of_order(time_series_reader, discard_out_of_order=True)
1831b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1841b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_discard_out_of_order_window_equal(self):
1851b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=1, num_samples=3)
1861b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
1871b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._random_window_input_fn_test_template(
1881b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
1891b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_features=1, window_size=3, batch_size=5,
1901b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        discard_out_of_order=True)
1911b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
1921b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_discard_out_of_order_window_too_large(self):
1931b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=1, num_samples=2)
1941b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
1951b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.assertRaisesRegexp(ValueError, "only 2 records were available"):
1961b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self._random_window_input_fn_test_template(
1971b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          time_series_reader=time_series_reader,
1981b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          num_features=1, window_size=3, batch_size=5,
1991b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          discard_out_of_order=True)
2001b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2011b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def _test_multivariate(self, time_series_reader, num_features):
2021b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._random_window_input_fn_test_template(
2031b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
2041b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_features=num_features,
2051b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        window_size=2,
2061b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        batch_size=5)
2071b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2081b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_multivariate(self):
2091b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=2, num_samples=50,
2101b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
2111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader(
2121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        [filename],
2131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        column_names=(TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES,
2141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                      TrainEvalFeatures.VALUES))
2151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._test_multivariate(time_series_reader=time_series_reader,
2161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                            num_features=2)
2171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
21800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie  def test_tfexample_multivariate(self):
21900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    filename = _make_tfexample_series(
22000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        num_features=2, num_samples=50,
22100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        test_tmpdir=self.get_temp_dir())
22200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    time_series_reader = input_pipeline.TFExampleReader(
22300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        [filename],
22400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        features={
22500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie            TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature(
22600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                shape=[], dtype=dtypes.int64),
22700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie            TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature(
22800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                shape=[2], dtype=dtypes.float32)})
22900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    self._test_multivariate(time_series_reader=time_series_reader,
23000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                            num_features=2)
23100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie
2321b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_multivariate(self):
2331b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=3, num_samples=50)
2341b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
2351b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._test_multivariate(time_series_reader, num_features=3)
2361b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2371b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_withbatch(self):
2381b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data_nobatch = _make_numpy_time_series(num_features=4, num_samples=100)
2391b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = {feature_name: feature_value[None]
2401b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie            for feature_name, feature_value in data_nobatch.items()}
2411b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
2421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._random_window_input_fn_test_template(
2431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
2441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_features=4,
2451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        window_size=3,
2461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        batch_size=5)
2471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_nobatch_nofeatures(self):
2491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=1, num_samples=100)
2501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data[TrainEvalFeatures.VALUES] = data[TrainEvalFeatures.VALUES][:, 0]
2511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
2521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._random_window_input_fn_test_template(
2531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
2541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_features=1,
2551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        window_size=16,
2561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        batch_size=16)
2571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2591b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieclass WholeDatasetInputFnTests(test.TestCase):
2601b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2611b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def _whole_dataset_input_fn_test_template(
2621b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self, time_series_reader, num_features, num_samples):
2631b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
2641b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.test_session() as session:
2651b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      session.run(variables.local_variables_initializer())
2661b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator = coordinator_lib.Coordinator()
2671b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      queue_runner_impl.start_queue_runners(session, coord=coordinator)
2681b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      features = session.run(result)
2691b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator.request_stop()
2701b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator.join()
2711b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype)
2721b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self.assertAllEqual(numpy.arange(num_samples, dtype=numpy.int64)[None, :],
2731b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                        features[TrainEvalFeatures.TIMES])
2741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    for feature_number in range(num_features):
2751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self.assertAllEqual(
2761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          features[TrainEvalFeatures.TIMES] * 2. + feature_number,
2771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          features[TrainEvalFeatures.VALUES][:, :, feature_number])
2781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv(self):
2801b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=3, num_samples=50,
2811b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
2821b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader(
2831b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        [filename],
2841b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        column_names=(TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES,
2851b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                      TrainEvalFeatures.VALUES, TrainEvalFeatures.VALUES))
2861b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._whole_dataset_input_fn_test_template(
2871b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader, num_features=3, num_samples=50)
2881b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
2891b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv_no_data(self):
2901b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=0,
2911b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
2921b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader([filename])
2931b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.assertRaises(errors.OutOfRangeError):
2941b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self._whole_dataset_input_fn_test_template(
2951b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          time_series_reader=time_series_reader, num_features=1, num_samples=50)
2961b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
29700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie  def test_tfexample(self):
29800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    filename = _make_tfexample_series(
29900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        num_features=4, num_samples=100,
30000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        test_tmpdir=self.get_temp_dir())
30100ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    time_series_reader = input_pipeline.TFExampleReader(
30200ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        [filename],
30300ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        features={
30400ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie            TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature(
30500ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                shape=[], dtype=dtypes.int64),
30600ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie            TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature(
30700ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie                shape=[4], dtype=dtypes.float32)})
30800ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie    self._whole_dataset_input_fn_test_template(
30900ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie        time_series_reader=time_series_reader, num_features=4, num_samples=100)
31000ff4f56d54bff3cc6f078ceb64da23b45021d42Allen Lavoie
3111b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy(self):
3121b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=4, num_samples=100)
3131b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
3141b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._whole_dataset_input_fn_test_template(
3151b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader, num_features=4, num_samples=100)
3161b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3171b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_withbatch(self):
3181b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data_nobatch = _make_numpy_time_series(num_features=4, num_samples=100)
3191b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = {feature_name: feature_value[None]
3201b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie            for feature_name, feature_value in data_nobatch.items()}
3211b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
3221b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._whole_dataset_input_fn_test_template(
3231b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader, num_features=4, num_samples=100)
3241b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3251b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy_nobatch_nofeatures(self):
3261b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=1, num_samples=100)
3271b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data[TrainEvalFeatures.VALUES] = data[TrainEvalFeatures.VALUES][:, 0]
3281b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
3291b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._whole_dataset_input_fn_test_template(
3301b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader, num_features=1, num_samples=100)
3311b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3321b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3331b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieclass AllWindowInputFnTests(test.TestCase):
3341b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3351b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def _all_window_input_fn_test_template(
3361b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self, time_series_reader, num_samples, window_size,
3371b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      original_numpy_features=None):
3381b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    input_fn = test_utils.AllWindowInputFn(
3391b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader,
3401b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        window_size=window_size)
3411b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    features, _ = input_fn()
3421b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    init_op = variables.local_variables_initializer()
3431b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    with self.test_session() as session:
3441b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator = coordinator_lib.Coordinator()
3451b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      queue_runner_impl.start_queue_runners(session, coord=coordinator)
3461b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      session.run(init_op)
3471b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      chunked_times, chunked_values = session.run(
3481b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie          [features[TrainEvalFeatures.TIMES],
3491b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie           features[TrainEvalFeatures.VALUES]])
3501b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator.request_stop()
3511b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      coordinator.join()
3521b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self.assertAllEqual([num_samples - window_size + 1, window_size],
3531b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                        chunked_times.shape)
3541b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    if original_numpy_features is not None:
3551b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      original_times = original_numpy_features[TrainEvalFeatures.TIMES]
3561b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      original_values = original_numpy_features[TrainEvalFeatures.VALUES]
3571b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self.assertAllEqual(original_times, numpy.unique(chunked_times))
3581b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie      self.assertAllEqual(original_values[chunked_times],
3591b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                          chunked_values)
3601b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3611b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_csv(self):
3621b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    filename = _make_csv_time_series(num_features=1, num_samples=50,
3631b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie                                     test_tmpdir=self.get_temp_dir())
3641b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.CSVReader(
3651b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        [filename],
3661b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        column_names=(TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES))
3671b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._all_window_input_fn_test_template(
3681b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader, num_samples=50, window_size=10)
3691b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3701b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  def test_numpy(self):
3711b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    data = _make_numpy_time_series(num_features=2, num_samples=31)
3721b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    time_series_reader = input_pipeline.NumpyReader(data)
3731b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie    self._all_window_input_fn_test_template(
3741b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        time_series_reader=time_series_reader, original_numpy_features=data,
3751b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie        num_samples=31, window_size=5)
3761b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3771b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie
3781b7d13181843631f59c515a00a73c21711ec5802Allen Lavoieif __name__ == "__main__":
3791b7d13181843631f59c515a00a73c21711ec5802Allen Lavoie  test.main()
380