1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A multivariate TFTS example. 16 17Fits a multivariate model, exports it, and visualizes the learned correlations 18by iteratively predicting and sampling from the predictions. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from os import path 26import tempfile 27 28import numpy 29import tensorflow as tf 30 31try: 32 import matplotlib # pylint: disable=g-import-not-at-top 33 matplotlib.use("TkAgg") # Need Tk for interactive plots. 34 from matplotlib import pyplot # pylint: disable=g-import-not-at-top 35 HAS_MATPLOTLIB = True 36except ImportError: 37 # Plotting requires matplotlib, but the unit test running this code may 38 # execute in an environment without it (i.e. matplotlib is not a build 39 # dependency). We'd still like to test the TensorFlow-dependent parts of this 40 # example, namely train_and_predict. 41 HAS_MATPLOTLIB = False 42 43_MODULE_PATH = path.dirname(__file__) 44_DATA_FILE = path.join(_MODULE_PATH, "data/multivariate_level.csv") 45 46 47def multivariate_train_and_sample( 48 csv_file_name=_DATA_FILE, export_directory=None, training_steps=500): 49 """Trains, evaluates, and exports a multivariate model.""" 50 estimator = tf.contrib.timeseries.StructuralEnsembleRegressor( 51 periodicities=[], num_features=5) 52 reader = tf.contrib.timeseries.CSVReader( 53 csv_file_name, 54 column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,) 55 + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5)) 56 train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( 57 # Larger window sizes generally produce a better covariance matrix. 58 reader, batch_size=4, window_size=64) 59 estimator.train(input_fn=train_input_fn, steps=training_steps) 60 evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader) 61 current_state = estimator.evaluate(input_fn=evaluation_input_fn, steps=1) 62 values = [current_state["observed"]] 63 times = [current_state[tf.contrib.timeseries.FilteringResults.TIMES]] 64 # Export the model so we can do iterative prediction and filtering without 65 # reloading model checkpoints. 66 if export_directory is None: 67 export_directory = tempfile.mkdtemp() 68 input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() 69 export_location = estimator.export_savedmodel( 70 export_directory, input_receiver_fn) 71 with tf.Graph().as_default(): 72 numpy.random.seed(1) # Make the example a bit more deterministic 73 with tf.Session() as session: 74 signatures = tf.saved_model.loader.load( 75 session, [tf.saved_model.tag_constants.SERVING], export_location) 76 for _ in range(100): 77 current_prediction = ( 78 tf.contrib.timeseries.saved_model_utils.predict_continuation( 79 continue_from=current_state, signatures=signatures, 80 session=session, steps=1)) 81 next_sample = numpy.random.multivariate_normal( 82 # Squeeze out the batch and series length dimensions (both 1). 83 mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]), 84 cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1])) 85 # Update model state so that future predictions are conditional on the 86 # value we just sampled. 87 filtering_features = { 88 tf.contrib.timeseries.TrainEvalFeatures.TIMES: current_prediction[ 89 tf.contrib.timeseries.FilteringResults.TIMES], 90 tf.contrib.timeseries.TrainEvalFeatures.VALUES: next_sample[ 91 None, None, :]} 92 current_state = ( 93 tf.contrib.timeseries.saved_model_utils.filter_continuation( 94 continue_from=current_state, 95 session=session, 96 signatures=signatures, 97 features=filtering_features)) 98 values.append(next_sample[None, None, :]) 99 times.append(current_state["times"]) 100 all_observations = numpy.squeeze(numpy.concatenate(values, axis=1), axis=0) 101 all_times = numpy.squeeze(numpy.concatenate(times, axis=1), axis=0) 102 return all_times, all_observations 103 104 105def main(unused_argv): 106 if not HAS_MATPLOTLIB: 107 raise ImportError( 108 "Please install matplotlib to generate a plot from this example.") 109 all_times, all_observations = multivariate_train_and_sample() 110 # Show where sampling starts on the plot 111 pyplot.axvline(1000, linestyle="dotted") 112 pyplot.plot(all_times, all_observations) 113 pyplot.show() 114 115 116if __name__ == "__main__": 117 tf.app.run(main=main) 118