15eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
25eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#
35eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# Licensed under the Apache License, Version 2.0 (the "License");
45eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# you may not use this file except in compliance with the License.
55eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# You may obtain a copy of the License at
65eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#
75eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#     http://www.apache.org/licenses/LICENSE-2.0
85eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen#
95eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# Unless required by applicable law or agreed to in writing, software
105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# distributed under the License is distributed on an "AS IS" BASIS,
115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# See the License for the specific language governing permissions and
135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# limitations under the License.
145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# ==============================================================================
155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen"""Tests for head."""
165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom __future__ import absolute_import
185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom __future__ import division
195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom __future__ import print_function
205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import feature_keys
225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import model
245eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import state_management
255eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
265eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.estimator import estimator_lib
275eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.framework import dtypes
285eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.framework import ops
295eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import array_ops
305eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import math_ops
315eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import metrics
325eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import variables
335eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.platform import test
345eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.training import coordinator as coordinator_lib
355eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.training import queue_runner_impl
365eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.training import training as train
375eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
385eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
395eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass HeadTest(test.TestCase):
405eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
415eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_labels_provided_error(self):
425eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
435eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
445eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                 estimator_lib.ModeKeys.PREDICT]:
455eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(ValueError, "labels"):
465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(features={}, labels={"a": "b"}, mode=mode)
475eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
485eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_unknown_mode(self):
495eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
505eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
515eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model_fn(features={}, labels={}, mode="Not a mode")
525eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
535eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
545eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass _TickerModel(object):
555eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  num_features = 1
565eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  dtype = dtypes.float32
575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
585eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def initialize_graph(self, input_statistics):
595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    pass
605eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
615eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def define_loss(self, features, mode):
625eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    del mode  # unused
635eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    return model.ModelOutputs(
645eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        loss=features["ticker"],
655eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        end_state=(features["ticker"], features["ticker"]),
665eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        prediction_times=array_ops.zeros(()),
675eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        predictions={"ticker": features["ticker"]})
685eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
695eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
705eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass EvaluationMetricsTests(test.TestCase):
715eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
725eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_metrics_consistent(self):
735eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    # Tests that the identity metrics used to report in-sample predictions match
745eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    # the behavior of standard metrics.
755eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    g = ops.Graph()
765eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    with g.as_default():
775eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      features = {
785eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          feature_keys.TrainEvalFeatures.TIMES:
795eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              array_ops.zeros((1, 1)),
805eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          feature_keys.TrainEvalFeatures.VALUES:
815eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              array_ops.zeros((1, 1, 1)),
825eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          "ticker":
835eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              array_ops.reshape(
845eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                  math_ops.cast(
855eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                      variables.Variable(
865eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                          name="ticker",
875eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                          initial_value=0,
885eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                          dtype=dtypes.int64,
895eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                          collections=[ops.GraphKeys.LOCAL_VARIABLES])
905eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                      .count_up_to(10),
915eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                      dtype=dtypes.float32), (1, 1, 1))
925eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      }
935eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model_fn = ts_head_lib.time_series_regression_head(
945eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          model=_TickerModel(),
955eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          state_manager=state_management.PassthroughStateManager(),
965eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec
975eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      outputs = model_fn(
985eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL)
995eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      metric_update_ops = [
1005eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          metric[1] for metric in outputs.eval_metric_ops.values()]
1015eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      loss_mean, loss_update = metrics.mean(outputs.loss)
1025eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      metric_update_ops.append(loss_update)
1035eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.test_session() as sess:
1045eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        coordinator = coordinator_lib.Coordinator()
1055eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        queue_runner_impl.start_queue_runners(sess, coord=coordinator)
1065eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        variables.local_variables_initializer().run()
1075eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        sess.run(metric_update_ops)
1085eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
1095eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            (loss_mean, outputs.eval_metric_ops["ticker"][0],
1105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen             outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][
1115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                 0][0]))
1125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        # The custom model_utils metrics for in-sample predictions should be in
1135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        # sync with the Estimator's mean metric for model loss.
1145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        self.assertAllClose(0., loss_evaled)
1155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        self.assertAllClose((((0.,),),), metric_evaled)
1165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        self.assertAllClose((((0.,),),), nested_metric_evaled)
1175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        coordinator.request_stop()
1185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        coordinator.join()
1195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass _StubModel(object):
1225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  num_features = 3
1235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  dtype = dtypes.float64
1245eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1255eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def initialize_graph(self, input_statistics):
1265eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    del input_statistics  # unused
1275eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1285eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1295eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chendef _stub_model_fn():
1305eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  return ts_head_lib.time_series_regression_head(
1315eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model=_StubModel(),
1325eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      state_manager=state_management.PassthroughStateManager(),
1335eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      optimizer=train.AdamOptimizer(0.001)).create_estimator_spec
1345eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1355eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1365eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass TrainEvalFeatureCheckingTests(test.TestCase):
1375eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1385eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_no_time_feature(self):
1395eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
1405eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
1415eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
1425eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          feature_keys.TrainEvalFeatures.TIMES)):
1435eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(
1445eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]},
1455eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            labels=None,
1465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            mode=mode)
1475eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1485eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_no_value_feature(self):
1495eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
1505eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
1515eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
1525eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          feature_keys.TrainEvalFeatures.VALUES)):
1535eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(
1545eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            features={feature_keys.TrainEvalFeatures.TIMES: [[1]]},
1555eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            labels=None,
1565eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            mode=mode)
1575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1585eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_bad_time_rank(self):
1595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
1605eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
1615eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(ValueError,
1625eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                                   "Expected shape.*for feature '{}'".format(
1635eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                                       feature_keys.TrainEvalFeatures.TIMES)):
1645eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(
1655eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            features={
1665eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.TIMES: [[[1]]],
1675eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
1685eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            },
1695eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            labels=None,
1705eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            mode=mode)
1715eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1725eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_bad_value_rank(self):
1735eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
1745eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
1755eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(ValueError,
1765eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                                   "Expected shape.*for feature '{}'".format(
1775eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                                       feature_keys.TrainEvalFeatures.VALUES)):
1785eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(
1795eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            features={
1805eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.TIMES: [[1]],
1815eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.VALUES: [[1.]]
1825eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            },
1835eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            labels=None,
1845eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            mode=mode)
1855eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
1865eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_bad_value_num_features(self):
1875eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
1885eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
1895eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(
1905eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          ValueError, "Expected shape.*, 3.*for feature '{}'".format(
1915eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              feature_keys.TrainEvalFeatures.VALUES)):
1925eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(
1935eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            features={
1945eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.TIMES: [[1]],
1955eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
1965eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            },
1975eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            labels=None,
1985eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            mode=mode)
1995eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2005eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_bad_exogenous_shape(self):
2015eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
2025eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
2035eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      with self.assertRaisesRegexp(
2045eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          ValueError,
2055eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          "Features must have shape.*for feature 'exogenous'"):
2065eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        model_fn(
2075eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            features={
2085eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.TIMES: [[1]],
2095eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]],
2105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                "exogenous": [[1], [2]]
2115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            },
2125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            labels=None,
2135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            mode=mode)
2145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass PredictFeatureCheckingTests(test.TestCase):
2175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_no_time_feature(self):
2195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
2205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
2215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        feature_keys.PredictionFeatures.TIMES)):
2225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model_fn(
2235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          features={
2245eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.)
2255eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          },
2265eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          labels=None,
2275eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          mode=estimator_lib.ModeKeys.PREDICT)
2285eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2295eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_no_start_state_feature(self):
2305eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
2315eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
2325eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        feature_keys.PredictionFeatures.STATE_TUPLE)):
2335eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model_fn(
2345eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          features={feature_keys.PredictionFeatures.TIMES: [[1]]},
2355eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          labels=None,
2365eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          mode=estimator_lib.ModeKeys.PREDICT)
2375eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2385eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_bad_time_rank(self):
2395eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
2405eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    with self.assertRaisesRegexp(ValueError,
2415eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                                 "Expected shape.*for feature '{}'".format(
2425eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                                     feature_keys.PredictionFeatures.TIMES)):
2435eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model_fn(
2445eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          features={
2455eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              feature_keys.PredictionFeatures.TIMES: 1,
2465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.))
2475eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          },
2485eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          labels=None,
2495eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          mode=estimator_lib.ModeKeys.PREDICT)
2505eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2515eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  def test_bad_exogenous_shape(self):
2525eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    model_fn = _stub_model_fn()
2535eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen    with self.assertRaisesRegexp(
2545eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        ValueError,
2555eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        "Features must have shape.*for feature 'exogenous'"):
2565eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      model_fn(
2575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          features={
2585eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              feature_keys.PredictionFeatures.TIMES: [[1]],
2595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)),
2605eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              "exogenous": 1.
2615eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          },
2625eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          labels=None,
2635eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          mode=estimator_lib.ModeKeys.PREDICT)
2645eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2655eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen
2665eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenif __name__ == "__main__":
2675eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen  test.main()
268