1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for MetricSpec.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23# pylint: disable=g-bad-todo,g-import-not-at-top 24from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec 25from tensorflow.python.platform import test 26 27 28class MetricSpecTest(test.TestCase): 29 30 def test_named_args_with_weights(self): 31 features = {"f1": "f1_value", "f2": "f2_value"} 32 labels_ = {"l1": "l1_value", "l2": "l2_value"} 33 predictions_ = {"p1": "p1_value", "p2": "p2_value"} 34 35 def _fn0(predictions, labels, weights=None): 36 self.assertEqual("p1_value", predictions) 37 self.assertEqual("l1_value", labels) 38 self.assertEqual("f2_value", weights) 39 return "metric_fn_result" 40 41 def _fn1(predictions, targets, weights=None): 42 self.assertEqual("p1_value", predictions) 43 self.assertEqual("l1_value", targets) 44 self.assertEqual("f2_value", weights) 45 return "metric_fn_result" 46 47 def _fn2(prediction, label, weight=None): 48 self.assertEqual("p1_value", prediction) 49 self.assertEqual("l1_value", label) 50 self.assertEqual("f2_value", weight) 51 return "metric_fn_result" 52 53 def _fn3(prediction, target, weight=None): 54 self.assertEqual("p1_value", prediction) 55 self.assertEqual("l1_value", target) 56 self.assertEqual("f2_value", weight) 57 return "metric_fn_result" 58 59 for fn in (_fn0, _fn1, _fn2, _fn3): 60 spec = MetricSpec( 61 metric_fn=fn, prediction_key="p1", label_key="l1", weight_key="f2") 62 self.assertEqual( 63 "metric_fn_result", 64 spec.create_metric_ops(features, labels_, predictions_)) 65 66 def test_no_args(self): 67 def _fn(): 68 self.fail("Expected failure before metric_fn.") 69 70 spec = MetricSpec(metric_fn=_fn) 71 with self.assertRaises(TypeError): 72 spec.create_metric_ops( 73 {"f1": "f1_value"}, "labels_value", "predictions_value") 74 75 def test_kwargs(self): 76 features = {"f1": "f1_value"} 77 labels_ = "labels_value" 78 predictions_ = "predictions_value" 79 80 def _fn(**kwargs): 81 self.assertEqual({}, kwargs) 82 return "metric_fn_result" 83 84 spec = MetricSpec(metric_fn=_fn) 85 with self.assertRaises(TypeError): 86 spec.create_metric_ops(features, labels_, predictions_) 87 88 def test_named_labels_no_predictions(self): 89 features = {"f1": "f1_value"} 90 labels_ = "labels_value" 91 predictions_ = "predictions_value" 92 93 def _fn(labels): 94 self.assertEqual(labels_, labels) 95 return "metric_fn_result" 96 97 spec = MetricSpec(metric_fn=_fn) 98 with self.assertRaises(TypeError): 99 spec.create_metric_ops(features, labels_, predictions_) 100 101 def test_named_labels_no_predictions_with_kwargs(self): 102 features = {"f1": "f1_value"} 103 labels_ = "labels_value" 104 predictions_ = "predictions_value" 105 106 def _fn(labels, **kwargs): 107 self.assertEqual(labels_, labels) 108 self.assertEqual({}, kwargs) 109 return "metric_fn_result" 110 111 spec = MetricSpec(metric_fn=_fn) 112 with self.assertRaises(TypeError): 113 spec.create_metric_ops(features, labels_, predictions_) 114 115 def test_no_named_predictions_named_labels_first_arg(self): 116 features = {"f1": "f1_value"} 117 labels_ = "labels_value" 118 predictions_ = "predictions_value" 119 120 def _fn(labels, predictions_by_another_name): 121 self.assertEqual(predictions_, predictions_by_another_name) 122 self.assertEqual(labels_, labels) 123 return "metric_fn_result" 124 125 spec = MetricSpec(metric_fn=_fn) 126 self.assertEqual( 127 "metric_fn_result", 128 spec.create_metric_ops(features, labels_, predictions_)) 129 130 def test_no_named_predictions_named_labels_second_arg(self): 131 features = {"f1": "f1_value"} 132 labels_ = "labels_value" 133 predictions_ = "predictions_value" 134 135 def _fn(predictions_by_another_name, labels): 136 self.assertEqual(predictions_, predictions_by_another_name) 137 self.assertEqual(labels_, labels) 138 return "metric_fn_result" 139 140 spec = MetricSpec(metric_fn=_fn) 141 self.assertEqual( 142 "metric_fn_result", 143 spec.create_metric_ops(features, labels_, predictions_)) 144 145 def test_no_named_labels(self): 146 features = {"f1": "f1_value"} 147 labels_ = "labels_value" 148 predictions_ = "predictions_value" 149 150 def _fn(predictions): 151 self.assertEqual(predictions_, predictions) 152 return "metric_fn_result" 153 154 spec = MetricSpec(metric_fn=_fn) 155 self.assertEqual( 156 "metric_fn_result", 157 spec.create_metric_ops(features, labels_, predictions_)) 158 159 def test_no_named_labels_or_predictions_1arg(self): 160 features = {"f1": "f1_value"} 161 labels_ = "labels_value" 162 predictions_ = "predictions_value" 163 164 def _fn(a): 165 self.assertEqual(predictions_, a) 166 return "metric_fn_result" 167 168 spec = MetricSpec(metric_fn=_fn) 169 self.assertEqual( 170 "metric_fn_result", 171 spec.create_metric_ops(features, labels_, predictions_)) 172 173 def test_no_named_labels_or_predictions_2args(self): 174 features = {"f1": "f1_value"} 175 labels_ = "labels_value" 176 predictions_ = "predictions_value" 177 178 def _fn(a, b): 179 del a, b 180 self.fail("Expected failure before metric_fn.") 181 182 spec = MetricSpec(metric_fn=_fn) 183 with self.assertRaises(TypeError): 184 spec.create_metric_ops(features, labels_, predictions_) 185 186 def test_named_args_no_weights(self): 187 features = {"f1": "f1_value", "f2": "f2_value"} 188 labels_ = {"l1": "l1_value", "l2": "l2_value"} 189 predictions_ = {"p1": "p1_value", "p2": "p2_value"} 190 191 def _fn0(predictions, labels): 192 self.assertEqual("p1_value", predictions) 193 self.assertEqual("l1_value", labels) 194 return "metric_fn_result" 195 196 def _fn1(predictions, targets): 197 self.assertEqual("p1_value", predictions) 198 self.assertEqual("l1_value", targets) 199 return "metric_fn_result" 200 201 def _fn2(prediction, label): 202 self.assertEqual("p1_value", prediction) 203 self.assertEqual("l1_value", label) 204 return "metric_fn_result" 205 206 def _fn3(prediction, target): 207 self.assertEqual("p1_value", prediction) 208 self.assertEqual("l1_value", target) 209 return "metric_fn_result" 210 211 for fn in (_fn0, _fn1, _fn2, _fn3): 212 spec = MetricSpec(metric_fn=fn, prediction_key="p1", label_key="l1") 213 self.assertEqual( 214 "metric_fn_result", 215 spec.create_metric_ops(features, labels_, predictions_)) 216 217 def test_predictions_dict_no_key(self): 218 features = {"f1": "f1_value", "f2": "f2_value"} 219 labels = {"l1": "l1_value", "l2": "l2_value"} 220 predictions = {"p1": "p1_value", "p2": "p2_value"} 221 222 def _fn(predictions, labels, weights=None): 223 del labels, predictions, weights 224 self.fail("Expected failure before metric_fn.") 225 226 spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2") 227 with self.assertRaisesRegexp( 228 ValueError, 229 "MetricSpec without specified prediction_key requires predictions" 230 " tensor or single element dict"): 231 spec.create_metric_ops(features, labels, predictions) 232 233 def test_labels_dict_no_key(self): 234 features = {"f1": "f1_value", "f2": "f2_value"} 235 labels = {"l1": "l1_value", "l2": "l2_value"} 236 predictions = {"p1": "p1_value", "p2": "p2_value"} 237 238 def _fn(labels, predictions, weights=None): 239 del labels, predictions, weights 240 self.fail("Expected failure before metric_fn.") 241 242 spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2") 243 with self.assertRaisesRegexp( 244 ValueError, 245 "MetricSpec without specified label_key requires labels tensor or" 246 " single element dict"): 247 spec.create_metric_ops(features, labels, predictions) 248 249 def test_single_prediction(self): 250 features = {"f1": "f1_value", "f2": "f2_value"} 251 labels_ = {"l1": "l1_value", "l2": "l2_value"} 252 predictions_ = "p1_value" 253 254 def _fn(predictions, labels, weights=None): 255 self.assertEqual(predictions_, predictions) 256 self.assertEqual("l1_value", labels) 257 self.assertEqual("f2_value", weights) 258 return "metric_fn_result" 259 260 spec = MetricSpec(metric_fn=_fn, label_key="l1", weight_key="f2") 261 self.assertEqual( 262 "metric_fn_result", 263 spec.create_metric_ops(features, labels_, predictions_)) 264 265 def test_single_label(self): 266 features = {"f1": "f1_value", "f2": "f2_value"} 267 labels_ = "l1_value" 268 predictions_ = {"p1": "p1_value", "p2": "p2_value"} 269 270 def _fn(predictions, labels, weights=None): 271 self.assertEqual("p1_value", predictions) 272 self.assertEqual(labels_, labels) 273 self.assertEqual("f2_value", weights) 274 return "metric_fn_result" 275 276 spec = MetricSpec(metric_fn=_fn, prediction_key="p1", weight_key="f2") 277 self.assertEqual( 278 "metric_fn_result", 279 spec.create_metric_ops(features, labels_, predictions_)) 280 281 def test_single_predictions_with_key(self): 282 features = {"f1": "f1_value", "f2": "f2_value"} 283 labels = {"l1": "l1_value", "l2": "l2_value"} 284 predictions = "p1_value" 285 286 def _fn(predictions, labels, weights=None): 287 del labels, predictions, weights 288 self.fail("Expected failure before metric_fn.") 289 290 spec = MetricSpec( 291 metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2") 292 with self.assertRaisesRegexp( 293 ValueError, 294 "MetricSpec with prediction_key specified requires predictions dict"): 295 spec.create_metric_ops(features, labels, predictions) 296 297 def test_single_labels_with_key(self): 298 features = {"f1": "f1_value", "f2": "f2_value"} 299 labels = "l1_value" 300 predictions = {"p1": "p1_value", "p2": "p2_value"} 301 302 def _fn(predictions, labels, weights=None): 303 del labels, predictions, weights 304 self.fail("Expected failure before metric_fn.") 305 306 spec = MetricSpec( 307 metric_fn=_fn, prediction_key="p1", label_key="l1", weight_key="f2") 308 with self.assertRaisesRegexp( 309 ValueError, "MetricSpec with label_key specified requires labels dict"): 310 spec.create_metric_ops(features, labels, predictions) 311 312 def test_str(self): 313 def _metric_fn(labels, predictions, weights=None): 314 return predictions, labels, weights 315 316 string = str(MetricSpec( 317 metric_fn=_metric_fn, 318 label_key="my_label", 319 prediction_key="my_prediction", 320 weight_key="my_weight")) 321 self.assertIn("_metric_fn", string) 322 self.assertIn("my_label", string) 323 self.assertIn("my_prediction", string) 324 self.assertIn("my_weight", string) 325 326 def test_partial_str(self): 327 328 def custom_metric(predictions, labels, stuff, weights=None): 329 return predictions, labels, weights, stuff 330 331 string = str(MetricSpec( 332 metric_fn=functools.partial(custom_metric, stuff=5), 333 label_key="my_label", 334 prediction_key="my_prediction", 335 weight_key="my_weight")) 336 self.assertIn("custom_metric", string) 337 self.assertIn("my_label", string) 338 self.assertIn("my_prediction", string) 339 self.assertIn("my_weight", string) 340 341 def test_partial(self): 342 features = {"f1": "f1_value", "f2": "f2_value"} 343 labels = {"l1": "l1_value"} 344 predictions = {"p1": "p1_value", "p2": "p2_value"} 345 346 def custom_metric(predictions, labels, stuff, weights=None): 347 self.assertEqual("p1_value", predictions) 348 self.assertEqual("l1_value", labels) 349 self.assertEqual("f2_value", weights) 350 if stuff: 351 return "metric_fn_result" 352 raise ValueError("No stuff.") 353 354 spec = MetricSpec( 355 metric_fn=functools.partial(custom_metric, stuff=5), 356 label_key="l1", 357 prediction_key="p1", 358 weight_key="f2") 359 self.assertEqual( 360 "metric_fn_result", 361 spec.create_metric_ops(features, labels, predictions)) 362 363 spec = MetricSpec( 364 metric_fn=functools.partial(custom_metric, stuff=None), 365 prediction_key="p1", label_key="l1", weight_key="f2") 366 with self.assertRaisesRegexp(ValueError, "No stuff."): 367 spec.create_metric_ops(features, labels, predictions) 368 369 def test_label_key_without_label_arg(self): 370 def _fn0(predictions, weights=None): 371 del predictions, weights 372 self.fail("Expected failure before metric_fn.") 373 374 def _fn1(prediction, weight=None): 375 del prediction, weight 376 self.fail("Expected failure before metric_fn.") 377 378 for fn in (_fn0, _fn1): 379 with self.assertRaisesRegexp(ValueError, "label.*missing"): 380 MetricSpec(metric_fn=fn, label_key="l1") 381 382 def test_weight_key_without_weight_arg(self): 383 def _fn0(predictions, labels): 384 del predictions, labels 385 self.fail("Expected failure before metric_fn.") 386 387 def _fn1(prediction, label): 388 del prediction, label 389 self.fail("Expected failure before metric_fn.") 390 391 def _fn2(predictions, targets): 392 del predictions, targets 393 self.fail("Expected failure before metric_fn.") 394 395 def _fn3(prediction, target): 396 del prediction, target 397 self.fail("Expected failure before metric_fn.") 398 399 for fn in (_fn0, _fn1, _fn2, _fn3): 400 with self.assertRaisesRegexp(ValueError, "weight.*missing"): 401 MetricSpec(metric_fn=fn, weight_key="f2") 402 403 def test_multiple_label_args(self): 404 def _fn0(predictions, labels, targets): 405 del predictions, labels, targets 406 self.fail("Expected failure before metric_fn.") 407 408 def _fn1(prediction, label, target): 409 del prediction, label, target 410 self.fail("Expected failure before metric_fn.") 411 412 for fn in (_fn0, _fn1): 413 with self.assertRaisesRegexp(ValueError, "provide only one of.*label"): 414 MetricSpec(metric_fn=fn) 415 416 def test_multiple_prediction_args(self): 417 def _fn(predictions, prediction, labels): 418 del predictions, prediction, labels 419 self.fail("Expected failure before metric_fn.") 420 421 with self.assertRaisesRegexp(ValueError, "provide only one of.*prediction"): 422 MetricSpec(metric_fn=_fn) 423 424 def test_multiple_weight_args(self): 425 def _fn(predictions, labels, weights=None, weight=None): 426 del predictions, labels, weights, weight 427 self.fail("Expected failure before metric_fn.") 428 429 with self.assertRaisesRegexp(ValueError, "provide only one of.*weight"): 430 MetricSpec(metric_fn=_fn) 431 432if __name__ == "__main__": 433 test.main() 434