1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for MetricSpec."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23# pylint: disable=g-bad-todo,g-import-not-at-top
24from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
25from tensorflow.python.platform import test
26
27
28class MetricSpecTest(test.TestCase):
29
30  def test_named_args_with_weights(self):
31    features = {"f1": "f1_value", "f2": "f2_value"}
32    labels_ = {"l1": "l1_value", "l2": "l2_value"}
33    predictions_ = {"p1": "p1_value", "p2": "p2_value"}
34
35    def _fn0(predictions, labels, weights=None):
36      self.assertEqual("p1_value", predictions)
37      self.assertEqual("l1_value", labels)
38      self.assertEqual("f2_value", weights)
39      return "metric_fn_result"
40
41    def _fn1(predictions, targets, weights=None):
42      self.assertEqual("p1_value", predictions)
43      self.assertEqual("l1_value", targets)
44      self.assertEqual("f2_value", weights)
45      return "metric_fn_result"
46
47    def _fn2(prediction, label, weight=None):
48      self.assertEqual("p1_value", prediction)
49      self.assertEqual("l1_value", label)
50      self.assertEqual("f2_value", weight)
51      return "metric_fn_result"
52
53    def _fn3(prediction, target, weight=None):
54      self.assertEqual("p1_value", prediction)
55      self.assertEqual("l1_value", target)
56      self.assertEqual("f2_value", weight)
57      return "metric_fn_result"
58
59    for fn in (_fn0, _fn1, _fn2, _fn3):
60      spec = MetricSpec(
61          metric_fn=fn, prediction_key="p1", label_key="l1", weight_key="f2")
62      self.assertEqual(
63          "metric_fn_result",
64          spec.create_metric_ops(features, labels_, predictions_))
65
66  def test_no_args(self):
67    def _fn():
68      self.fail("Expected failure before metric_fn.")
69
70    spec = MetricSpec(metric_fn=_fn)
71    with self.assertRaises(TypeError):
72      spec.create_metric_ops(
73          {"f1": "f1_value"}, "labels_value", "predictions_value")
74
75  def test_kwargs(self):
76    features = {"f1": "f1_value"}
77    labels_ = "labels_value"
78    predictions_ = "predictions_value"
79
80    def _fn(**kwargs):
81      self.assertEqual({}, kwargs)
82      return "metric_fn_result"
83
84    spec = MetricSpec(metric_fn=_fn)
85    with self.assertRaises(TypeError):
86      spec.create_metric_ops(features, labels_, predictions_)
87
88  def test_named_labels_no_predictions(self):
89    features = {"f1": "f1_value"}
90    labels_ = "labels_value"
91    predictions_ = "predictions_value"
92
93    def _fn(labels):
94      self.assertEqual(labels_, labels)
95      return "metric_fn_result"
96
97    spec = MetricSpec(metric_fn=_fn)
98    with self.assertRaises(TypeError):
99      spec.create_metric_ops(features, labels_, predictions_)
100
101  def test_named_labels_no_predictions_with_kwargs(self):
102    features = {"f1": "f1_value"}
103    labels_ = "labels_value"
104    predictions_ = "predictions_value"
105
106    def _fn(labels, **kwargs):
107      self.assertEqual(labels_, labels)
108      self.assertEqual({}, kwargs)
109      return "metric_fn_result"
110
111    spec = MetricSpec(metric_fn=_fn)
112    with self.assertRaises(TypeError):
113      spec.create_metric_ops(features, labels_, predictions_)
114
115  def test_no_named_predictions_named_labels_first_arg(self):
116    features = {"f1": "f1_value"}
117    labels_ = "labels_value"
118    predictions_ = "predictions_value"
119
120    def _fn(labels, predictions_by_another_name):
121      self.assertEqual(predictions_, predictions_by_another_name)
122      self.assertEqual(labels_, labels)
123      return "metric_fn_result"
124
125    spec = MetricSpec(metric_fn=_fn)
126    self.assertEqual(
127        "metric_fn_result",
128        spec.create_metric_ops(features, labels_, predictions_))
129
130  def test_no_named_predictions_named_labels_second_arg(self):
131    features = {"f1": "f1_value"}
132    labels_ = "labels_value"
133    predictions_ = "predictions_value"
134
135    def _fn(predictions_by_another_name, labels):
136      self.assertEqual(predictions_, predictions_by_another_name)
137      self.assertEqual(labels_, labels)
138      return "metric_fn_result"
139
140    spec = MetricSpec(metric_fn=_fn)
141    self.assertEqual(
142        "metric_fn_result",
143        spec.create_metric_ops(features, labels_, predictions_))
144
145  def test_no_named_labels(self):
146    features = {"f1": "f1_value"}
147    labels_ = "labels_value"
148    predictions_ = "predictions_value"
149
150    def _fn(predictions):
151      self.assertEqual(predictions_, predictions)
152      return "metric_fn_result"
153
154    spec = MetricSpec(metric_fn=_fn)
155    self.assertEqual(
156        "metric_fn_result",
157        spec.create_metric_ops(features, labels_, predictions_))
158
159  def test_no_named_labels_or_predictions_1arg(self):
160    features = {"f1": "f1_value"}
161    labels_ = "labels_value"
162    predictions_ = "predictions_value"
163
164    def _fn(a):
165      self.assertEqual(predictions_, a)
166      return "metric_fn_result"
167
168    spec = MetricSpec(metric_fn=_fn)
169    self.assertEqual(
170        "metric_fn_result",
171        spec.create_metric_ops(features, labels_, predictions_))
172
173  def test_no_named_labels_or_predictions_2args(self):
174    features = {"f1": "f1_value"}
175    labels_ = "labels_value"
176    predictions_ = "predictions_value"
177
178    def _fn(a, b):
179      del a, b
180      self.fail("Expected failure before metric_fn.")
181
182    spec = MetricSpec(metric_fn=_fn)
183    with self.assertRaises(TypeError):
184      spec.create_metric_ops(features, labels_, predictions_)
185
186  def test_named_args_no_weights(self):
187    features = {"f1": "f1_value", "f2": "f2_value"}
188    labels_ = {"l1": "l1_value", "l2": "l2_value"}
189    predictions_ = {"p1": "p1_value", "p2": "p2_value"}
190
191    def _fn0(predictions, labels):
192      self.assertEqual("p1_value", predictions)
193      self.assertEqual("l1_value", labels)
194      return "metric_fn_result"
195
196    def _fn1(predictions, targets):
197      self.assertEqual("p1_value", predictions)
198      self.assertEqual("l1_value", targets)
199      return "metric_fn_result"
200
201    def _fn2(prediction, label):
202      self.assertEqual("p1_value", prediction)
203      self.assertEqual("l1_value", label)
204      return "metric_fn_result"
205
206    def _fn3(prediction, target):
207      self.assertEqual("p1_value", prediction)
208      self.assertEqual("l1_value", target)
209      return "metric_fn_result"
210
211    for fn in (_fn0, _fn1, _fn2, _fn3):
212      spec = MetricSpec(metric_fn=fn, prediction_key="p1", label_key="l1")
213      self.assertEqual(
214          "metric_fn_result",
215          spec.create_metric_ops(features, labels_, predictions_))
216
217  def test_predictions_dict_no_key(self):
218    features = {"f1": "f1_value", "f2": "f2_value"}
219    labels = {"l1": "l1_value", "l2": "l2_value"}
220    predictions = {"p1": "p1_value", "p2": "p2_value"}
221
222    def _fn(predictions, labels, weights=None):
223      del labels, predictions, weights
224      self.fail("Expected failure before metric_fn.")
225
226    spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2")
227    with self.assertRaisesRegexp(
228        ValueError,
229        "MetricSpec without specified prediction_key requires predictions"
230        " tensor or single element dict"):
231      spec.create_metric_ops(features, labels, predictions)
232
233  def test_labels_dict_no_key(self):
234    features = {"f1": "f1_value", "f2": "f2_value"}
235    labels = {"l1": "l1_value", "l2": "l2_value"}
236    predictions = {"p1": "p1_value", "p2": "p2_value"}
237
238    def _fn(labels, predictions, weights=None):
239      del labels, predictions, weights
240      self.fail("Expected failure before metric_fn.")
241
242    spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2")
243    with self.assertRaisesRegexp(
244        ValueError,
245        "MetricSpec without specified label_key requires labels tensor or"
246        " single element dict"):
247      spec.create_metric_ops(features, labels, predictions)
248
249  def test_single_prediction(self):
250    features = {"f1": "f1_value", "f2": "f2_value"}
251    labels_ = {"l1": "l1_value", "l2": "l2_value"}
252    predictions_ = "p1_value"
253
254    def _fn(predictions, labels, weights=None):
255      self.assertEqual(predictions_, predictions)
256      self.assertEqual("l1_value", labels)
257      self.assertEqual("f2_value", weights)
258      return "metric_fn_result"
259
260    spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2")
261    self.assertEqual(
262        "metric_fn_result",
263        spec.create_metric_ops(features, labels_, predictions_))
264
265  def test_single_label(self):
266    features = {"f1": "f1_value", "f2": "f2_value"}
267    labels_ = "l1_value"
268    predictions_ = {"p1": "p1_value", "p2": "p2_value"}
269
270    def _fn(predictions, labels, weights=None):
271      self.assertEqual("p1_value", predictions)
272      self.assertEqual(labels_, labels)
273      self.assertEqual("f2_value", weights)
274      return "metric_fn_result"
275
276    spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2")
277    self.assertEqual(
278        "metric_fn_result",
279        spec.create_metric_ops(features, labels_, predictions_))
280
281  def test_single_predictions_with_key(self):
282    features = {"f1": "f1_value", "f2": "f2_value"}
283    labels = {"l1": "l1_value", "l2": "l2_value"}
284    predictions = "p1_value"
285
286    def _fn(predictions, labels, weights=None):
287      del labels, predictions, weights
288      self.fail("Expected failure before metric_fn.")
289
290    spec = MetricSpec(
291        metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2")
292    with self.assertRaisesRegexp(
293        ValueError,
294        "MetricSpec with prediction_key specified requires predictions dict"):
295      spec.create_metric_ops(features, labels, predictions)
296
297  def test_single_labels_with_key(self):
298    features = {"f1": "f1_value", "f2": "f2_value"}
299    labels = "l1_value"
300    predictions = {"p1": "p1_value", "p2": "p2_value"}
301
302    def _fn(predictions, labels, weights=None):
303      del labels, predictions, weights
304      self.fail("Expected failure before metric_fn.")
305
306    spec = MetricSpec(
307        metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2")
308    with self.assertRaisesRegexp(
309        ValueError, "MetricSpec with label_key specified requires labels dict"):
310      spec.create_metric_ops(features, labels, predictions)
311
312  def test_str(self):
313    def _metric_fn(labels, predictions, weights=None):
314      return predictions, labels, weights
315
316    string = str(MetricSpec(
317        metric_fn=_metric_fn,
318        label_key="my_label",
319        prediction_key="my_prediction",
320        weight_key="my_weight"))
321    self.assertIn("_metric_fn", string)
322    self.assertIn("my_label", string)
323    self.assertIn("my_prediction", string)
324    self.assertIn("my_weight", string)
325
326  def test_partial_str(self):
327
328    def custom_metric(predictions, labels, stuff, weights=None):
329      return predictions, labels, weights, stuff
330
331    string = str(MetricSpec(
332        metric_fn=functools.partial(custom_metric, stuff=5),
333        label_key="my_label",
334        prediction_key="my_prediction",
335        weight_key="my_weight"))
336    self.assertIn("custom_metric", string)
337    self.assertIn("my_label", string)
338    self.assertIn("my_prediction", string)
339    self.assertIn("my_weight", string)
340
341  def test_partial(self):
342    features = {"f1": "f1_value", "f2": "f2_value"}
343    labels = {"l1": "l1_value"}
344    predictions = {"p1": "p1_value", "p2": "p2_value"}
345
346    def custom_metric(predictions, labels, stuff, weights=None):
347      self.assertEqual("p1_value", predictions)
348      self.assertEqual("l1_value", labels)
349      self.assertEqual("f2_value", weights)
350      if stuff:
351        return "metric_fn_result"
352      raise ValueError("No stuff.")
353
354    spec = MetricSpec(
355        metric_fn=functools.partial(custom_metric, stuff=5),
356        label_key="l1",
357        prediction_key="p1",
358        weight_key="f2")
359    self.assertEqual(
360        "metric_fn_result",
361        spec.create_metric_ops(features, labels, predictions))
362
363    spec = MetricSpec(
364        metric_fn=functools.partial(custom_metric, stuff=None),
365        prediction_key="p1", label_key="l1", weight_key="f2")
366    with self.assertRaisesRegexp(ValueError, "No stuff."):
367      spec.create_metric_ops(features, labels, predictions)
368
369  def test_label_key_without_label_arg(self):
370    def _fn0(predictions, weights=None):
371      del predictions, weights
372      self.fail("Expected failure before metric_fn.")
373
374    def _fn1(prediction, weight=None):
375      del prediction, weight
376      self.fail("Expected failure before metric_fn.")
377
378    for fn in (_fn0, _fn1):
379      with self.assertRaisesRegexp(ValueError, "label.*missing"):
380        MetricSpec(metric_fn=fn, label_key="l1")
381
382  def test_weight_key_without_weight_arg(self):
383    def _fn0(predictions, labels):
384      del predictions, labels
385      self.fail("Expected failure before metric_fn.")
386
387    def _fn1(prediction, label):
388      del prediction, label
389      self.fail("Expected failure before metric_fn.")
390
391    def _fn2(predictions, targets):
392      del predictions, targets
393      self.fail("Expected failure before metric_fn.")
394
395    def _fn3(prediction, target):
396      del prediction, target
397      self.fail("Expected failure before metric_fn.")
398
399    for fn in (_fn0, _fn1, _fn2, _fn3):
400      with self.assertRaisesRegexp(ValueError, "weight.*missing"):
401        MetricSpec(metric_fn=fn, weight_key="f2")
402
403  def test_multiple_label_args(self):
404    def _fn0(predictions, labels, targets):
405      del predictions, labels, targets
406      self.fail("Expected failure before metric_fn.")
407
408    def _fn1(prediction, label, target):
409      del prediction, label, target
410      self.fail("Expected failure before metric_fn.")
411
412    for fn in (_fn0, _fn1):
413      with self.assertRaisesRegexp(ValueError, "provide only one of.*label"):
414        MetricSpec(metric_fn=fn)
415
416  def test_multiple_prediction_args(self):
417    def _fn(predictions, prediction, labels):
418      del predictions, prediction, labels
419      self.fail("Expected failure before metric_fn.")
420
421    with self.assertRaisesRegexp(ValueError, "provide only one of.*prediction"):
422      MetricSpec(metric_fn=_fn)
423
424  def test_multiple_weight_args(self):
425    def _fn(predictions, labels, weights=None, weight=None):
426      del predictions, labels, weights, weight
427      self.fail("Expected failure before metric_fn.")
428
429    with self.assertRaisesRegexp(ValueError, "provide only one of.*weight"):
430      MetricSpec(metric_fn=_fn)
431
432if __name__ == "__main__":
433  test.main()
434