1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A multivariate TFTS example.
16
17Fits a multivariate model, exports it, and visualizes the learned correlations
18by iteratively predicting and sampling from the predictions.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from os import path
26import tempfile
27
28import numpy
29import tensorflow as tf
30
31try:
32  import matplotlib  # pylint: disable=g-import-not-at-top
33  matplotlib.use("TkAgg")  # Need Tk for interactive plots.
34  from matplotlib import pyplot  # pylint: disable=g-import-not-at-top
35  HAS_MATPLOTLIB = True
36except ImportError:
37  # Plotting requires matplotlib, but the unit test running this code may
38  # execute in an environment without it (i.e. matplotlib is not a build
39  # dependency). We'd still like to test the TensorFlow-dependent parts of this
40  # example, namely train_and_predict.
41  HAS_MATPLOTLIB = False
42
43_MODULE_PATH = path.dirname(__file__)
44_DATA_FILE = path.join(_MODULE_PATH, "data/multivariate_level.csv")
45
46
47def multivariate_train_and_sample(
48    csv_file_name=_DATA_FILE, export_directory=None, training_steps=500):
49  """Trains, evaluates, and exports a multivariate model."""
50  estimator = tf.contrib.timeseries.StructuralEnsembleRegressor(
51      periodicities=[], num_features=5)
52  reader = tf.contrib.timeseries.CSVReader(
53      csv_file_name,
54      column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
55                    + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5))
56  train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
57      # Larger window sizes generally produce a better covariance matrix.
58      reader, batch_size=4, window_size=64)
59  estimator.train(input_fn=train_input_fn, steps=training_steps)
60  evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
61  current_state = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
62  values = [current_state["observed"]]
63  times = [current_state[tf.contrib.timeseries.FilteringResults.TIMES]]
64  # Export the model so we can do iterative prediction and filtering without
65  # reloading model checkpoints.
66  if export_directory is None:
67    export_directory = tempfile.mkdtemp()
68  input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
69  export_location = estimator.export_savedmodel(
70      export_directory, input_receiver_fn)
71  with tf.Graph().as_default():
72    numpy.random.seed(1)  # Make the example a bit more deterministic
73    with tf.Session() as session:
74      signatures = tf.saved_model.loader.load(
75          session, [tf.saved_model.tag_constants.SERVING], export_location)
76      for _ in range(100):
77        current_prediction = (
78            tf.contrib.timeseries.saved_model_utils.predict_continuation(
79                continue_from=current_state, signatures=signatures,
80                session=session, steps=1))
81        next_sample = numpy.random.multivariate_normal(
82            # Squeeze out the batch and series length dimensions (both 1).
83            mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]),
84            cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1]))
85        # Update model state so that future predictions are conditional on the
86        # value we just sampled.
87        filtering_features = {
88            tf.contrib.timeseries.TrainEvalFeatures.TIMES: current_prediction[
89                tf.contrib.timeseries.FilteringResults.TIMES],
90            tf.contrib.timeseries.TrainEvalFeatures.VALUES: next_sample[
91                None, None, :]}
92        current_state = (
93            tf.contrib.timeseries.saved_model_utils.filter_continuation(
94                continue_from=current_state,
95                session=session,
96                signatures=signatures,
97                features=filtering_features))
98        values.append(next_sample[None, None, :])
99        times.append(current_state["times"])
100  all_observations = numpy.squeeze(numpy.concatenate(values, axis=1), axis=0)
101  all_times = numpy.squeeze(numpy.concatenate(times, axis=1), axis=0)
102  return all_times, all_observations
103
104
105def main(unused_argv):
106  if not HAS_MATPLOTLIB:
107    raise ImportError(
108        "Please install matplotlib to generate a plot from this example.")
109  all_times, all_observations = multivariate_train_and_sample()
110  # Show where sampling starts on the plot
111  pyplot.axvline(1000, linestyle="dotted")
112  pyplot.plot(all_times, all_observations)
113  pyplot.show()
114
115
116if __name__ == "__main__":
117  tf.app.run(main=main)
118