1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An example of training and predicting with a TFTS estimator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import sys 23 24import numpy as np 25import tensorflow as tf 26 27 28try: 29 import matplotlib # pylint: disable=g-import-not-at-top 30 matplotlib.use("TkAgg") # Need Tk for interactive plots. 31 from matplotlib import pyplot # pylint: disable=g-import-not-at-top 32 HAS_MATPLOTLIB = True 33except ImportError: 34 # Plotting requires matplotlib, but the unit test running this code may 35 # execute in an environment without it (i.e. matplotlib is not a build 36 # dependency). We'd still like to test the TensorFlow-dependent parts of this 37 # example, namely train_and_predict. 38 HAS_MATPLOTLIB = False 39 40FLAGS = None 41 42 43def structural_ensemble_train_and_predict(csv_file_name): 44 # Cycle between 5 latent values over a period of 100. This leads to a very 45 # smooth periodic component (and a small model), which is a good fit for our 46 # example data. Modeling high-frequency periodic variations will require a 47 # higher cycle_num_latent_values. 48 structural = tf.contrib.timeseries.StructuralEnsembleRegressor( 49 periodicities=100, num_features=1, cycle_num_latent_values=5) 50 return train_and_predict(structural, csv_file_name, training_steps=150) 51 52 53def ar_train_and_predict(csv_file_name): 54 # An autoregressive model, with periodicity handled as a time-based 55 # regression. Note that this requires windows of size 16 (input_window_size + 56 # output_window_size) for training. 57 ar = tf.contrib.timeseries.ARRegressor( 58 periodicities=100, input_window_size=10, output_window_size=6, 59 num_features=1, 60 # Use the (default) normal likelihood loss to adaptively fit the 61 # variance. SQUARED_LOSS overestimates variance when there are trends in 62 # the series. 63 loss=tf.contrib.timeseries.ARModel.NORMAL_LIKELIHOOD_LOSS) 64 return train_and_predict(ar, csv_file_name, training_steps=600) 65 66 67def train_and_predict(estimator, csv_file_name, training_steps): 68 """A simple example of training and predicting.""" 69 # Read data in the default "time,value" CSV format with no header 70 reader = tf.contrib.timeseries.CSVReader(csv_file_name) 71 # Set up windowing and batching for training 72 train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( 73 reader, batch_size=16, window_size=16) 74 # Fit model parameters to data 75 estimator.train(input_fn=train_input_fn, steps=training_steps) 76 # Evaluate on the full dataset sequentially, collecting in-sample predictions 77 # for a qualitative evaluation. Note that this loads the whole dataset into 78 # memory. For quantitative evaluation, use RandomWindowChunker. 79 evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader) 80 evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1) 81 # Predict starting after the evaluation 82 (predictions,) = tuple(estimator.predict( 83 input_fn=tf.contrib.timeseries.predict_continuation_input_fn( 84 evaluation, steps=200))) 85 times = evaluation["times"][0] 86 observed = evaluation["observed"][0, :, 0] 87 mean = np.squeeze(np.concatenate( 88 [evaluation["mean"][0], predictions["mean"]], axis=0)) 89 variance = np.squeeze(np.concatenate( 90 [evaluation["covariance"][0], predictions["covariance"]], axis=0)) 91 all_times = np.concatenate([times, predictions["times"]], axis=0) 92 upper_limit = mean + np.sqrt(variance) 93 lower_limit = mean - np.sqrt(variance) 94 return times, observed, all_times, mean, upper_limit, lower_limit 95 96 97def make_plot(name, training_times, observed, all_times, mean, 98 upper_limit, lower_limit): 99 """Plot a time series in a new figure.""" 100 pyplot.figure() 101 pyplot.plot(training_times, observed, "b", label="training series") 102 pyplot.plot(all_times, mean, "r", label="forecast") 103 pyplot.plot(all_times, upper_limit, "g", label="forecast upper bound") 104 pyplot.plot(all_times, lower_limit, "g", label="forecast lower bound") 105 pyplot.fill_between(all_times, lower_limit, upper_limit, color="grey", 106 alpha="0.2") 107 pyplot.axvline(training_times[-1], color="k", linestyle="--") 108 pyplot.xlabel("time") 109 pyplot.ylabel("observations") 110 pyplot.legend(loc=0) 111 pyplot.title(name) 112 113 114def main(unused_argv): 115 if not HAS_MATPLOTLIB: 116 raise ImportError( 117 "Please install matplotlib to generate a plot from this example.") 118 make_plot("Structural ensemble", 119 *structural_ensemble_train_and_predict(FLAGS.input_filename)) 120 make_plot("AR", *ar_train_and_predict(FLAGS.input_filename)) 121 pyplot.show() 122 123 124if __name__ == "__main__": 125 parser = argparse.ArgumentParser() 126 parser.add_argument( 127 "--input_filename", 128 type=str, 129 required=True, 130 help="Input csv file.") 131 FLAGS, unparsed = parser.parse_known_args() 132 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 133