15eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 25eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# 35eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# Licensed under the Apache License, Version 2.0 (the "License"); 45eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# you may not use this file except in compliance with the License. 55eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# You may obtain a copy of the License at 65eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# 75eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# http://www.apache.org/licenses/LICENSE-2.0 85eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# 95eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# Unless required by applicable law or agreed to in writing, software 105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# distributed under the License is distributed on an "AS IS" BASIS, 115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# See the License for the specific language governing permissions and 135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# limitations under the License. 145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen# ============================================================================== 155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen"""Tests for head.""" 165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom __future__ import absolute_import 185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom __future__ import division 195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom __future__ import print_function 205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import feature_keys 225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib 235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import model 245eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.contrib.timeseries.python.timeseries import state_management 255eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 265eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.estimator import estimator_lib 275eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.framework import dtypes 285eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.framework import ops 295eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import array_ops 305eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import math_ops 315eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import metrics 325eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.ops import variables 335eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.platform import test 345eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.training import coordinator as coordinator_lib 355eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.training import queue_runner_impl 365eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenfrom tensorflow.python.training import training as train 375eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 385eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 395eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass HeadTest(test.TestCase): 405eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 415eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_labels_provided_error(self): 425eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 435eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, 445eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen estimator_lib.ModeKeys.PREDICT]: 455eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, "labels"): 465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn(features={}, labels={"a": "b"}, mode=mode) 475eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 485eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_unknown_mode(self): 495eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 505eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): 515eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn(features={}, labels={}, mode="Not a mode") 525eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 535eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 545eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass _TickerModel(object): 555eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen num_features = 1 565eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen dtype = dtypes.float32 575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 585eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def initialize_graph(self, input_statistics): 595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen pass 605eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 615eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def define_loss(self, features, mode): 625eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen del mode # unused 635eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen return model.ModelOutputs( 645eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen loss=features["ticker"], 655eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen end_state=(features["ticker"], features["ticker"]), 665eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen prediction_times=array_ops.zeros(()), 675eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen predictions={"ticker": features["ticker"]}) 685eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 695eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 705eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass EvaluationMetricsTests(test.TestCase): 715eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 725eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_metrics_consistent(self): 735eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen # Tests that the identity metrics used to report in-sample predictions match 745eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen # the behavior of standard metrics. 755eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen g = ops.Graph() 765eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with g.as_default(): 775eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features = { 785eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES: 795eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen array_ops.zeros((1, 1)), 805eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES: 815eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen array_ops.zeros((1, 1, 1)), 825eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "ticker": 835eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen array_ops.reshape( 845eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen math_ops.cast( 855eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen variables.Variable( 865eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen name="ticker", 875eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen initial_value=0, 885eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen dtype=dtypes.int64, 895eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen collections=[ops.GraphKeys.LOCAL_VARIABLES]) 905eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen .count_up_to(10), 915eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen dtype=dtypes.float32), (1, 1, 1)) 925eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen } 935eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = ts_head_lib.time_series_regression_head( 945eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model=_TickerModel(), 955eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen state_manager=state_management.PassthroughStateManager(), 965eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec 975eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen outputs = model_fn( 985eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL) 995eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen metric_update_ops = [ 1005eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen metric[1] for metric in outputs.eval_metric_ops.values()] 1015eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen loss_mean, loss_update = metrics.mean(outputs.loss) 1025eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen metric_update_ops.append(loss_update) 1035eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.test_session() as sess: 1045eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen coordinator = coordinator_lib.Coordinator() 1055eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen queue_runner_impl.start_queue_runners(sess, coord=coordinator) 1065eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen variables.local_variables_initializer().run() 1075eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen sess.run(metric_update_ops) 1085eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen loss_evaled, metric_evaled, nested_metric_evaled = sess.run( 1095eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen (loss_mean, outputs.eval_metric_ops["ticker"][0], 1105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][ 1115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 0][0])) 1125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen # The custom model_utils metrics for in-sample predictions should be in 1135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen # sync with the Estimator's mean metric for model loss. 1145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen self.assertAllClose(0., loss_evaled) 1155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen self.assertAllClose((((0.,),),), metric_evaled) 1165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen self.assertAllClose((((0.,),),), nested_metric_evaled) 1175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen coordinator.request_stop() 1185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen coordinator.join() 1195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass _StubModel(object): 1225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen num_features = 3 1235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen dtype = dtypes.float64 1245eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1255eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def initialize_graph(self, input_statistics): 1265eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen del input_statistics # unused 1275eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1285eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1295eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chendef _stub_model_fn(): 1305eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen return ts_head_lib.time_series_regression_head( 1315eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model=_StubModel(), 1325eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen state_manager=state_management.PassthroughStateManager(), 1335eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen optimizer=train.AdamOptimizer(0.001)).create_estimator_spec 1345eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1355eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1365eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass TrainEvalFeatureCheckingTests(test.TestCase): 1375eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1385eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_no_time_feature(self): 1395eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 1405eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: 1415eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( 1425eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES)): 1435eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 1445eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]}, 1455eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 1465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=mode) 1475eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1485eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_no_value_feature(self): 1495eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 1505eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: 1515eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( 1525eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES)): 1535eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 1545eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={feature_keys.TrainEvalFeatures.TIMES: [[1]]}, 1555eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 1565eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=mode) 1575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1585eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_bad_time_rank(self): 1595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 1605eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: 1615eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, 1625eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "Expected shape.*for feature '{}'".format( 1635eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES)): 1645eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 1655eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 1665eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES: [[[1]]], 1675eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES: [[[1.]]] 1685eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 1695eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 1705eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=mode) 1715eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1725eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_bad_value_rank(self): 1735eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 1745eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: 1755eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, 1765eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "Expected shape.*for feature '{}'".format( 1775eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES)): 1785eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 1795eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 1805eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES: [[1]], 1815eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES: [[1.]] 1825eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 1835eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 1845eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=mode) 1855eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 1865eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_bad_value_num_features(self): 1875eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 1885eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: 1895eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp( 1905eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen ValueError, "Expected shape.*, 3.*for feature '{}'".format( 1915eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES)): 1925eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 1935eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 1945eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES: [[1]], 1955eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES: [[[1.]]] 1965eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 1975eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 1985eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=mode) 1995eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2005eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_bad_exogenous_shape(self): 2015eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 2025eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: 2035eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp( 2045eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen ValueError, 2055eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "Features must have shape.*for feature 'exogenous'"): 2065eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 2075eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 2085eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.TIMES: [[1]], 2095eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]], 2105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "exogenous": [[1], [2]] 2115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 2125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 2135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=mode) 2145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenclass PredictFeatureCheckingTests(test.TestCase): 2175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_no_time_feature(self): 2195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 2205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( 2215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.TIMES)): 2225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 2235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 2245eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.) 2255eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 2265eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 2275eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=estimator_lib.ModeKeys.PREDICT) 2285eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2295eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_no_start_state_feature(self): 2305eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 2315eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( 2325eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.STATE_TUPLE)): 2335eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 2345eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={feature_keys.PredictionFeatures.TIMES: [[1]]}, 2355eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 2365eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=estimator_lib.ModeKeys.PREDICT) 2375eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2385eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_bad_time_rank(self): 2395eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 2405eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp(ValueError, 2415eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "Expected shape.*for feature '{}'".format( 2425eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.TIMES)): 2435eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 2445eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 2455eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.TIMES: 1, 2465eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)) 2475eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 2485eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 2495eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=estimator_lib.ModeKeys.PREDICT) 2505eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2515eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen def test_bad_exogenous_shape(self): 2525eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn = _stub_model_fn() 2535eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen with self.assertRaisesRegexp( 2545eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen ValueError, 2555eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "Features must have shape.*for feature 'exogenous'"): 2565eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen model_fn( 2575eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen features={ 2585eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.TIMES: [[1]], 2595eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)), 2605eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "exogenous": 1. 2615eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen }, 2625eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen labels=None, 2635eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen mode=estimator_lib.ModeKeys.PREDICT) 2645eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2655eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen 2665eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chenif __name__ == "__main__": 2675eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen test.main() 268