1a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# 3a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# Licensed under the Apache License, Version 2.0 (the "License"); 4a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# you may not use this file except in compliance with the License. 5a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# You may obtain a copy of the License at 6a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# 7a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# http://www.apache.org/licenses/LICENSE-2.0 8a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# 9a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# Unless required by applicable law or agreed to in writing, software 10a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# distributed under the License is distributed on an "AS IS" BASIS, 11a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# See the License for the specific language governing permissions and 13a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# limitations under the License. 14a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# ============================================================================== 15a274c662b3f090193ead4138791896ffb65d680eMartin Wicke"""Tests for MetricSpec.""" 16a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 17a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom __future__ import absolute_import 18a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom __future__ import division 19a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom __future__ import print_function 20a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 2166e69e28ca912a9f379e8dc8aa9445986001945fMartin Wickeimport functools 2266e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke 23e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower# pylint: disable=g-bad-todo,g-import-not-at-top 24a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec 25e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import test 26a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 27a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 28e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyclass MetricSpecTest(test.TestCase): 29a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 30e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_named_args_with_weights(self): 31e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 32e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = {"l1": "l1_value", "l2": "l2_value"} 33e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = {"p1": "p1_value", "p2": "p2_value"} 34e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 35e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn0(predictions, labels, weights=None): 364148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", predictions) 374148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", labels) 384148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weights) 39e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 40e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 41e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn1(predictions, targets, weights=None): 424148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", predictions) 434148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", targets) 444148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weights) 45e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 46e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 47e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn2(prediction, label, weight=None): 484148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", prediction) 494148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", label) 504148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weight) 51e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 52e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 53e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn3(prediction, target, weight=None): 544148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", prediction) 554148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", target) 564148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weight) 57e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 58e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 59e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower for fn in (_fn0, _fn1, _fn2, _fn3): 60e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec( 61e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=fn, prediction_key="p1", label_key="l1", weight_key="f2") 62e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 63e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 64e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 65e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 66e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_no_args(self): 67e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(): 68e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 69e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 70e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 714148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower with self.assertRaises(TypeError): 72e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops( 73e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower {"f1": "f1_value"}, "labels_value", "predictions_value") 74e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 75e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_kwargs(self): 76e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 77e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 78e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 79e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 80e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(**kwargs): 814148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual({}, kwargs) 82e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 83e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 84e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 854148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower with self.assertRaises(TypeError): 86e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_) 87e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 88e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_named_labels_no_predictions(self): 89e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 90e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 91e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 92e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 93e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(labels): 944148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(labels_, labels) 95e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 96e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 97e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 984148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower with self.assertRaises(TypeError): 99e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_) 100e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 101e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_named_labels_no_predictions_with_kwargs(self): 102e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 103e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 104e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 105e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 106e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(labels, **kwargs): 1074148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(labels_, labels) 1084148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual({}, kwargs) 109e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 110e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 111e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 1124148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower with self.assertRaises(TypeError): 113e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_) 114e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 115e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_no_named_predictions_named_labels_first_arg(self): 116e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 117e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 118e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 119e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 120e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(labels, predictions_by_another_name): 1214148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(predictions_, predictions_by_another_name) 1224148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(labels_, labels) 123e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 124e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 125e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 126e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 127e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 128e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 129e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 130e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_no_named_predictions_named_labels_second_arg(self): 131e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 132e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 133e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 134e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 135e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions_by_another_name, labels): 1364148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(predictions_, predictions_by_another_name) 1374148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(labels_, labels) 138e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 139e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 140e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 141e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 142e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 143e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 144e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 145e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_no_named_labels(self): 146e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 147e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 148e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 149e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 150e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions): 1514148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(predictions_, predictions) 152e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 153e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 154e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 155e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 156e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 157e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 158e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 159e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_no_named_labels_or_predictions_1arg(self): 160e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 161e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 162e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 163e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 164e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(a): 1654148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(predictions_, a) 166e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 167e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 168e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 169e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 170e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 171e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 172e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 173e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_no_named_labels_or_predictions_2args(self): 174e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value"} 175e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "labels_value" 176e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "predictions_value" 177e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 178e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(a, b): 179e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del a, b 180e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 181e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 182e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn) 1834148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower with self.assertRaises(TypeError): 184e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_) 185e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 186e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_named_args_no_weights(self): 187e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 188e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = {"l1": "l1_value", "l2": "l2_value"} 189e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = {"p1": "p1_value", "p2": "p2_value"} 190e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 191e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn0(predictions, labels): 1924148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", predictions) 1934148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", labels) 194e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 195e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 196e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn1(predictions, targets): 1974148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", predictions) 1984148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", targets) 199e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 200e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 201e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn2(prediction, label): 2024148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", prediction) 2034148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", label) 204e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 205e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 206e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn3(prediction, target): 2074148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", prediction) 2084148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", target) 209e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 210e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 211e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower for fn in (_fn0, _fn1, _fn2, _fn3): 212e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=fn, prediction_key="p1", label_key="l1") 213e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 214e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 215e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 216e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 217e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_predictions_dict_no_key(self): 218e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 219e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels = {"l1": "l1_value", "l2": "l2_value"} 220e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions = {"p1": "p1_value", "p2": "p2_value"} 221e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 222e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, labels, weights=None): 223e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del labels, predictions, weights 224e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 225e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 226e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2") 227e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp( 228e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower ValueError, 229e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "MetricSpec without specified prediction_key requires predictions" 230e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower " tensor or single element dict"): 231e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels, predictions) 232a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 233e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_labels_dict_no_key(self): 234e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 235e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels = {"l1": "l1_value", "l2": "l2_value"} 236e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions = {"p1": "p1_value", "p2": "p2_value"} 237a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 238e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(labels, predictions, weights=None): 239e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del labels, predictions, weights 240e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 241a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 242e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2") 243e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp( 244e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ValueError, 245e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "MetricSpec without specified label_key requires labels tensor or" 246e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower " single element dict"): 247e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels, predictions) 248a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 249a274c662b3f090193ead4138791896ffb65d680eMartin Wicke def test_single_prediction(self): 250e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 251e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = {"l1": "l1_value", "l2": "l2_value"} 252e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = "p1_value" 253a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 254e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, labels, weights=None): 2554148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(predictions_, predictions) 2564148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", labels) 2574148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weights) 258e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 259a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 260e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2") 261e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 262e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 263e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 264a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 265a274c662b3f090193ead4138791896ffb65d680eMartin Wicke def test_single_label(self): 266e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 267e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels_ = "l1_value" 268e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions_ = {"p1": "p1_value", "p2": "p2_value"} 269e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 270e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, labels, weights=None): 2714148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", predictions) 2724148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual(labels_, labels) 2734148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weights) 274e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 275e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 276e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2") 277e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertEqual( 278e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 279e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels_, predictions_)) 280e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 281e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_single_predictions_with_key(self): 282e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 283e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels = {"l1": "l1_value", "l2": "l2_value"} 284e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions = "p1_value" 285e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 286e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, labels, weights=None): 287e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del labels, predictions, weights 288e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 289e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 290e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec( 291e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2") 292e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp( 293e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower ValueError, 294e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "MetricSpec with prediction_key specified requires predictions dict"): 295e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels, predictions) 296a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 297e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_single_labels_with_key(self): 298e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 299e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels = "l1_value" 300e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions = {"p1": "p1_value", "p2": "p2_value"} 301a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 302e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, labels, weights=None): 303e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del labels, predictions, weights 304e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 305a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 306e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec( 307e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2") 308e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp( 309e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower ValueError, "MetricSpec with label_key specified requires labels dict"): 310e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels, predictions) 311a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 31266e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke def test_str(self): 313e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _metric_fn(labels, predictions, weights=None): 314e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return predictions, labels, weights 315e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 316e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower string = str(MetricSpec( 317e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=_metric_fn, 318e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower label_key="my_label", 319e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower prediction_key="my_prediction", 320e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower weight_key="my_weight")) 321e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("_metric_fn", string) 322e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("my_label", string) 323e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("my_prediction", string) 324e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("my_weight", string) 32566e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke 32666e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke def test_partial_str(self): 327e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney 32866e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke def custom_metric(predictions, labels, stuff, weights=None): 32966e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke return predictions, labels, weights, stuff 33066e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke 331e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower string = str(MetricSpec( 332e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=functools.partial(custom_metric, stuff=5), 333e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower label_key="my_label", 334e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower prediction_key="my_prediction", 335e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower weight_key="my_weight")) 336e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("custom_metric", string) 337e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("my_label", string) 338e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("my_prediction", string) 339e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.assertIn("my_weight", string) 34066e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke 34166e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke def test_partial(self): 342e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower features = {"f1": "f1_value", "f2": "f2_value"} 343e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower labels = {"l1": "l1_value"} 344e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower predictions = {"p1": "p1_value", "p2": "p2_value"} 34566e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke 34666e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke def custom_metric(predictions, labels, stuff, weights=None): 3474148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("p1_value", predictions) 3484148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("l1_value", labels) 3494148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual("f2_value", weights) 35066e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke if stuff: 351e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower return "metric_fn_result" 352e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower raise ValueError("No stuff.") 353e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 354e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec( 355e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=functools.partial(custom_metric, stuff=5), 356e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower label_key="l1", 357e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower prediction_key="p1", 358e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower weight_key="f2") 3594148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower self.assertEqual( 360e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower "metric_fn_result", 361e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels, predictions)) 362e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 363e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec = MetricSpec( 364e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower metric_fn=functools.partial(custom_metric, stuff=None), 365e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower prediction_key="p1", label_key="l1", weight_key="f2") 366e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "No stuff."): 367e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower spec.create_metric_ops(features, labels, predictions) 368e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 369e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_label_key_without_label_arg(self): 370e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn0(predictions, weights=None): 371e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del predictions, weights 372e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 373e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 374e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn1(prediction, weight=None): 375e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del prediction, weight 376e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 377e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 378e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower for fn in (_fn0, _fn1): 379e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "label.*missing"): 380e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower MetricSpec(metric_fn=fn, label_key="l1") 381e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 382e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_weight_key_without_weight_arg(self): 383e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn0(predictions, labels): 384e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del predictions, labels 385e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 386e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 387e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn1(prediction, label): 388e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del prediction, label 389e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 390e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 391e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn2(predictions, targets): 392e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del predictions, targets 393e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 394e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 395e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn3(prediction, target): 396e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del prediction, target 397e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 398e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 399e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower for fn in (_fn0, _fn1, _fn2, _fn3): 400e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "weight.*missing"): 401e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower MetricSpec(metric_fn=fn, weight_key="f2") 402e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 403e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_multiple_label_args(self): 404e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn0(predictions, labels, targets): 405e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del predictions, labels, targets 406e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 407e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 408e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn1(prediction, label, target): 409e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del prediction, label, target 410e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 411e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 412e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower for fn in (_fn0, _fn1): 413e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "provide only one of.*label"): 414e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower MetricSpec(metric_fn=fn) 415e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 416e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_multiple_prediction_args(self): 417e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, prediction, labels): 418e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del predictions, prediction, labels 419e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 420e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 421e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "provide only one of.*prediction"): 422e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower MetricSpec(metric_fn=_fn) 423e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 424e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def test_multiple_weight_args(self): 425e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower def _fn(predictions, labels, weights=None, weight=None): 426e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower del predictions, labels, weights, weight 427e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower self.fail("Expected failure before metric_fn.") 428e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower 429e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "provide only one of.*weight"): 430e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower MetricSpec(metric_fn=_fn) 431a274c662b3f090193ead4138791896ffb65d680eMartin Wicke 432a274c662b3f090193ead4138791896ffb65d680eMartin Wickeif __name__ == "__main__": 433e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney test.main() 434