1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An example of training and predicting with a TFTS estimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import sys
23
24import numpy as np
25import tensorflow as tf
26
27
28try:
29  import matplotlib  # pylint: disable=g-import-not-at-top
30  matplotlib.use("TkAgg")  # Need Tk for interactive plots.
31  from matplotlib import pyplot  # pylint: disable=g-import-not-at-top
32  HAS_MATPLOTLIB = True
33except ImportError:
34  # Plotting requires matplotlib, but the unit test running this code may
35  # execute in an environment without it (i.e. matplotlib is not a build
36  # dependency). We'd still like to test the TensorFlow-dependent parts of this
37  # example, namely train_and_predict.
38  HAS_MATPLOTLIB = False
39
40FLAGS = None
41
42
43def structural_ensemble_train_and_predict(csv_file_name):
44  # Cycle between 5 latent values over a period of 100. This leads to a very
45  # smooth periodic component (and a small model), which is a good fit for our
46  # example data. Modeling high-frequency periodic variations will require a
47  # higher cycle_num_latent_values.
48  structural = tf.contrib.timeseries.StructuralEnsembleRegressor(
49      periodicities=100, num_features=1, cycle_num_latent_values=5)
50  return train_and_predict(structural, csv_file_name, training_steps=150)
51
52
53def ar_train_and_predict(csv_file_name):
54  # An autoregressive model, with periodicity handled as a time-based
55  # regression. Note that this requires windows of size 16 (input_window_size +
56  # output_window_size) for training.
57  ar = tf.contrib.timeseries.ARRegressor(
58      periodicities=100, input_window_size=10, output_window_size=6,
59      num_features=1,
60      # Use the (default) normal likelihood loss to adaptively fit the
61      # variance. SQUARED_LOSS overestimates variance when there are trends in
62      # the series.
63      loss=tf.contrib.timeseries.ARModel.NORMAL_LIKELIHOOD_LOSS)
64  return train_and_predict(ar, csv_file_name, training_steps=600)
65
66
67def train_and_predict(estimator, csv_file_name, training_steps):
68  """A simple example of training and predicting."""
69  # Read data in the default "time,value" CSV format with no header
70  reader = tf.contrib.timeseries.CSVReader(csv_file_name)
71  # Set up windowing and batching for training
72  train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
73      reader, batch_size=16, window_size=16)
74  # Fit model parameters to data
75  estimator.train(input_fn=train_input_fn, steps=training_steps)
76  # Evaluate on the full dataset sequentially, collecting in-sample predictions
77  # for a qualitative evaluation. Note that this loads the whole dataset into
78  # memory. For quantitative evaluation, use RandomWindowChunker.
79  evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
80  evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
81  # Predict starting after the evaluation
82  (predictions,) = tuple(estimator.predict(
83      input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
84          evaluation, steps=200)))
85  times = evaluation["times"][0]
86  observed = evaluation["observed"][0, :, 0]
87  mean = np.squeeze(np.concatenate(
88      [evaluation["mean"][0], predictions["mean"]], axis=0))
89  variance = np.squeeze(np.concatenate(
90      [evaluation["covariance"][0], predictions["covariance"]], axis=0))
91  all_times = np.concatenate([times, predictions["times"]], axis=0)
92  upper_limit = mean + np.sqrt(variance)
93  lower_limit = mean - np.sqrt(variance)
94  return times, observed, all_times, mean, upper_limit, lower_limit
95
96
97def make_plot(name, training_times, observed, all_times, mean,
98              upper_limit, lower_limit):
99  """Plot a time series in a new figure."""
100  pyplot.figure()
101  pyplot.plot(training_times, observed, "b", label="training series")
102  pyplot.plot(all_times, mean, "r", label="forecast")
103  pyplot.plot(all_times, upper_limit, "g", label="forecast upper bound")
104  pyplot.plot(all_times, lower_limit, "g", label="forecast lower bound")
105  pyplot.fill_between(all_times, lower_limit, upper_limit, color="grey",
106                      alpha="0.2")
107  pyplot.axvline(training_times[-1], color="k", linestyle="--")
108  pyplot.xlabel("time")
109  pyplot.ylabel("observations")
110  pyplot.legend(loc=0)
111  pyplot.title(name)
112
113
114def main(unused_argv):
115  if not HAS_MATPLOTLIB:
116    raise ImportError(
117        "Please install matplotlib to generate a plot from this example.")
118  make_plot("Structural ensemble",
119            *structural_ensemble_train_and_predict(FLAGS.input_filename))
120  make_plot("AR", *ar_train_and_predict(FLAGS.input_filename))
121  pyplot.show()
122
123
124if __name__ == "__main__":
125  parser = argparse.ArgumentParser()
126  parser.add_argument(
127      "--input_filename",
128      type=str,
129      required=True,
130      help="Input csv file.")
131  FLAGS, unparsed = parser.parse_known_args()
132  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
133