1a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2a274c662b3f090193ead4138791896ffb65d680eMartin Wicke#
3a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# Licensed under the Apache License, Version 2.0 (the "License");
4a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# you may not use this file except in compliance with the License.
5a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# You may obtain a copy of the License at
6a274c662b3f090193ead4138791896ffb65d680eMartin Wicke#
7a274c662b3f090193ead4138791896ffb65d680eMartin Wicke#     http://www.apache.org/licenses/LICENSE-2.0
8a274c662b3f090193ead4138791896ffb65d680eMartin Wicke#
9a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# Unless required by applicable law or agreed to in writing, software
10a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# distributed under the License is distributed on an "AS IS" BASIS,
11a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# See the License for the specific language governing permissions and
13a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# limitations under the License.
14a274c662b3f090193ead4138791896ffb65d680eMartin Wicke# ==============================================================================
15a274c662b3f090193ead4138791896ffb65d680eMartin Wicke"""Tests for MetricSpec."""
16a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
17a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom __future__ import absolute_import
18a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom __future__ import division
19a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom __future__ import print_function
20a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
2166e69e28ca912a9f379e8dc8aa9445986001945fMartin Wickeimport functools
2266e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke
23e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower# pylint: disable=g-bad-todo,g-import-not-at-top
24a274c662b3f090193ead4138791896ffb65d680eMartin Wickefrom tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
25e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import test
26a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
27a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
28e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyclass MetricSpecTest(test.TestCase):
29a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
30e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_named_args_with_weights(self):
31e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
32e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = {"l1": "l1_value", "l2": "l2_value"}
33e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = {"p1": "p1_value", "p2": "p2_value"}
34e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
35e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn0(predictions, labels, weights=None):
364148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", predictions)
374148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", labels)
384148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weights)
39e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
40e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
41e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn1(predictions, targets, weights=None):
424148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", predictions)
434148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", targets)
444148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weights)
45e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
46e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
47e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn2(prediction, label, weight=None):
484148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", prediction)
494148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", label)
504148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weight)
51e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
52e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
53e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn3(prediction, target, weight=None):
544148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", prediction)
554148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", target)
564148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weight)
57e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
58e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
59e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    for fn in (_fn0, _fn1, _fn2, _fn3):
60e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec = MetricSpec(
61e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower          metric_fn=fn, prediction_key="p1", label_key="l1", weight_key="f2")
62e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.assertEqual(
63e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower          "metric_fn_result",
64e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower          spec.create_metric_ops(features, labels_, predictions_))
65e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
66e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_no_args(self):
67e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn():
68e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
69e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
70e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
714148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower    with self.assertRaises(TypeError):
72e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(
73e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower          {"f1": "f1_value"}, "labels_value", "predictions_value")
74e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
75e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_kwargs(self):
76e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
77e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
78e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
79e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
80e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(**kwargs):
814148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual({}, kwargs)
82e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
83e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
84e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
854148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower    with self.assertRaises(TypeError):
86e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels_, predictions_)
87e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
88e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_named_labels_no_predictions(self):
89e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
90e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
91e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
92e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
93e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(labels):
944148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(labels_, labels)
95e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
96e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
97e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
984148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower    with self.assertRaises(TypeError):
99e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels_, predictions_)
100e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
101e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_named_labels_no_predictions_with_kwargs(self):
102e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
103e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
104e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
105e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
106e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(labels, **kwargs):
1074148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(labels_, labels)
1084148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual({}, kwargs)
109e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
110e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
111e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
1124148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower    with self.assertRaises(TypeError):
113e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels_, predictions_)
114e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
115e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_no_named_predictions_named_labels_first_arg(self):
116e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
117e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
118e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
119e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
120e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(labels, predictions_by_another_name):
1214148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(predictions_, predictions_by_another_name)
1224148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(labels_, labels)
123e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
124e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
125e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
126e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertEqual(
127e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
128e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels_, predictions_))
129e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
130e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_no_named_predictions_named_labels_second_arg(self):
131e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
132e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
133e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
134e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
135e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions_by_another_name, labels):
1364148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(predictions_, predictions_by_another_name)
1374148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(labels_, labels)
138e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
139e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
140e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
141e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertEqual(
142e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
143e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels_, predictions_))
144e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
145e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_no_named_labels(self):
146e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
147e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
148e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
149e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
150e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions):
1514148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(predictions_, predictions)
152e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
153e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
154e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
155e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertEqual(
156e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
157e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels_, predictions_))
158e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
159e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_no_named_labels_or_predictions_1arg(self):
160e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
161e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
162e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
163e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
164e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(a):
1654148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(predictions_, a)
166e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
167e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
168e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
169e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertEqual(
170e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
171e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels_, predictions_))
172e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
173e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_no_named_labels_or_predictions_2args(self):
174e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value"}
175e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "labels_value"
176e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "predictions_value"
177e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
178e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(a, b):
179e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del a, b
180e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
181e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
182e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn)
1834148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower    with self.assertRaises(TypeError):
184e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels_, predictions_)
185e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
186e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_named_args_no_weights(self):
187e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
188e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = {"l1": "l1_value", "l2": "l2_value"}
189e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = {"p1": "p1_value", "p2": "p2_value"}
190e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
191e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn0(predictions, labels):
1924148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", predictions)
1934148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", labels)
194e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
195e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
196e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn1(predictions, targets):
1974148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", predictions)
1984148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", targets)
199e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
200e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
201e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn2(prediction, label):
2024148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", prediction)
2034148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", label)
204e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
205e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
206e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn3(prediction, target):
2074148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", prediction)
2084148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", target)
209e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
210e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
211e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    for fn in (_fn0, _fn1, _fn2, _fn3):
212e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec = MetricSpec(metric_fn=fn, prediction_key="p1", label_key="l1")
213e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.assertEqual(
214e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower          "metric_fn_result",
215e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower          spec.create_metric_ops(features, labels_, predictions_))
216e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
217e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_predictions_dict_no_key(self):
218e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
219e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels = {"l1": "l1_value", "l2": "l2_value"}
220e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions = {"p1": "p1_value", "p2": "p2_value"}
221e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
222e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, labels, weights=None):
223e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del labels, predictions, weights
224e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
225e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
226e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2")
227e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(
228e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        ValueError,
229e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "MetricSpec without specified prediction_key requires predictions"
230e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        " tensor or single element dict"):
231e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels, predictions)
232a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
233e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_labels_dict_no_key(self):
234e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
235e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels = {"l1": "l1_value", "l2": "l2_value"}
236e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions = {"p1": "p1_value", "p2": "p2_value"}
237a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
238e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(labels, predictions, weights=None):
239e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del labels, predictions, weights
240e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
241a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
242e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2")
243e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(
244e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        ValueError,
245e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "MetricSpec without specified label_key requires labels tensor or"
246e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        " single element dict"):
247e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels, predictions)
248a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
249a274c662b3f090193ead4138791896ffb65d680eMartin Wicke  def test_single_prediction(self):
250e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
251e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = {"l1": "l1_value", "l2": "l2_value"}
252e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = "p1_value"
253a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
254e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, labels, weights=None):
2554148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(predictions_, predictions)
2564148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", labels)
2574148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weights)
258e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
259a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
260e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2")
261e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertEqual(
262e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
263e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels_, predictions_))
264a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
265a274c662b3f090193ead4138791896ffb65d680eMartin Wicke  def test_single_label(self):
266e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
267e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels_ = "l1_value"
268e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions_ = {"p1": "p1_value", "p2": "p2_value"}
269e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
270e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, labels, weights=None):
2714148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", predictions)
2724148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual(labels_, labels)
2734148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weights)
274e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return "metric_fn_result"
275e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
276e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2")
277e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertEqual(
278e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
279e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels_, predictions_))
280e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
281e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_single_predictions_with_key(self):
282e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
283e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels = {"l1": "l1_value", "l2": "l2_value"}
284e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions = "p1_value"
285e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
286e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, labels, weights=None):
287e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del labels, predictions, weights
288e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
289e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
290e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(
291e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2")
292e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(
293e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        ValueError,
294e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "MetricSpec with prediction_key specified requires predictions dict"):
295e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels, predictions)
296a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
297e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_single_labels_with_key(self):
298e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
299e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels = "l1_value"
300e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions = {"p1": "p1_value", "p2": "p2_value"}
301a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
302e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, labels, weights=None):
303e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del labels, predictions, weights
304e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
305a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
306e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(
307e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2")
308e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(
309e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        ValueError, "MetricSpec with label_key specified requires labels dict"):
310e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels, predictions)
311a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
31266e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke  def test_str(self):
313e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _metric_fn(labels, predictions, weights=None):
314e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      return predictions, labels, weights
315e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
316e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    string = str(MetricSpec(
317e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        metric_fn=_metric_fn,
318e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        label_key="my_label",
319e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        prediction_key="my_prediction",
320e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        weight_key="my_weight"))
321e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("_metric_fn", string)
322e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("my_label", string)
323e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("my_prediction", string)
324e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("my_weight", string)
32566e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke
32666e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke  def test_partial_str(self):
327e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney
32866e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke    def custom_metric(predictions, labels, stuff, weights=None):
32966e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke      return predictions, labels, weights, stuff
33066e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke
331e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    string = str(MetricSpec(
332e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        metric_fn=functools.partial(custom_metric, stuff=5),
333e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        label_key="my_label",
334e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        prediction_key="my_prediction",
335e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        weight_key="my_weight"))
336e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("custom_metric", string)
337e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("my_label", string)
338e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("my_prediction", string)
339e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    self.assertIn("my_weight", string)
34066e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke
34166e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke  def test_partial(self):
342e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    features = {"f1": "f1_value", "f2": "f2_value"}
343e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    labels = {"l1": "l1_value"}
344e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    predictions = {"p1": "p1_value", "p2": "p2_value"}
34566e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke
34666e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke    def custom_metric(predictions, labels, stuff, weights=None):
3474148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("p1_value", predictions)
3484148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("l1_value", labels)
3494148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower      self.assertEqual("f2_value", weights)
35066e69e28ca912a9f379e8dc8aa9445986001945fMartin Wicke      if stuff:
351e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        return "metric_fn_result"
352e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      raise ValueError("No stuff.")
353e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
354e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(
355e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        metric_fn=functools.partial(custom_metric, stuff=5),
356e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        label_key="l1",
357e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        prediction_key="p1",
358e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        weight_key="f2")
3594148fd588de61020d81cf2018c2c7c334e05b568A. Unique TensorFlower    self.assertEqual(
360e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        "metric_fn_result",
361e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        spec.create_metric_ops(features, labels, predictions))
362e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
363e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    spec = MetricSpec(
364e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        metric_fn=functools.partial(custom_metric, stuff=None),
365e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        prediction_key="p1", label_key="l1", weight_key="f2")
366e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(ValueError, "No stuff."):
367e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      spec.create_metric_ops(features, labels, predictions)
368e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
369e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_label_key_without_label_arg(self):
370e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn0(predictions, weights=None):
371e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del predictions, weights
372e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
373e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
374e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn1(prediction, weight=None):
375e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del prediction, weight
376e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
377e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
378e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    for fn in (_fn0, _fn1):
379e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "label.*missing"):
380e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        MetricSpec(metric_fn=fn, label_key="l1")
381e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
382e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_weight_key_without_weight_arg(self):
383e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn0(predictions, labels):
384e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del predictions, labels
385e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
386e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
387e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn1(prediction, label):
388e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del prediction, label
389e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
390e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
391e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn2(predictions, targets):
392e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del predictions, targets
393e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
394e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
395e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn3(prediction, target):
396e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del prediction, target
397e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
398e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
399e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    for fn in (_fn0, _fn1, _fn2, _fn3):
400e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "weight.*missing"):
401e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        MetricSpec(metric_fn=fn, weight_key="f2")
402e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
403e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_multiple_label_args(self):
404e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn0(predictions, labels, targets):
405e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del predictions, labels, targets
406e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
407e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
408e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn1(prediction, label, target):
409e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del prediction, label, target
410e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
411e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
412e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    for fn in (_fn0, _fn1):
413e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "provide only one of.*label"):
414e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower        MetricSpec(metric_fn=fn)
415e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
416e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_multiple_prediction_args(self):
417e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, prediction, labels):
418e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del predictions, prediction, labels
419e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
420e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
421e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(ValueError, "provide only one of.*prediction"):
422e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      MetricSpec(metric_fn=_fn)
423e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
424e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower  def test_multiple_weight_args(self):
425e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    def _fn(predictions, labels, weights=None, weight=None):
426e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      del predictions, labels, weights, weight
427e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      self.fail("Expected failure before metric_fn.")
428e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower
429e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower    with self.assertRaisesRegexp(ValueError, "provide only one of.*weight"):
430e47885b01d36f2c6afcffdd2764b18e8a21b56b7A. Unique TensorFlower      MetricSpec(metric_fn=_fn)
431a274c662b3f090193ead4138791896ffb65d680eMartin Wicke
432a274c662b3f090193ead4138791896ffb65d680eMartin Wickeif __name__ == "__main__":
433e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  test.main()
434