metric_ops.py revision d4eb834824d79c6a64a3c4a1c4a88b434b73e63e
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.framework import deprecated 26from tensorflow.contrib.framework import tensor_util 27from tensorflow.contrib.framework.python.ops import variables as contrib_variables 28from tensorflow.contrib.metrics.python.ops import set_ops 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import metrics 37from tensorflow.python.ops import nn 38from tensorflow.python.ops import sparse_ops 39from tensorflow.python.ops import state_ops 40from tensorflow.python.ops import variable_scope 41from tensorflow.python.ops import variables 42 43 44def _safe_div(numerator, denominator, name): 45 """Divides two values, returning 0 if the denominator is <= 0. 46 47 Args: 48 numerator: A real `Tensor`. 49 denominator: A real `Tensor`, with dtype matching `numerator`. 50 name: Name for the returned op. 51 52 Returns: 53 0 if `denominator` <= 0, else `numerator` / `denominator` 54 """ 55 return array_ops.where( 56 math_ops.greater(denominator, 0), 57 math_ops.truediv(numerator, denominator), 58 0, 59 name=name) 60 61 62def _safe_scalar_div(numerator, denominator, name): 63 """Divides two values, returning 0 if the denominator is 0. 64 65 Args: 66 numerator: A scalar `float64` `Tensor`. 67 denominator: A scalar `float64` `Tensor`. 68 name: Name for the returned op. 69 70 Returns: 71 0 if `denominator` == 0, else `numerator` / `denominator` 72 """ 73 numerator.get_shape().with_rank_at_most(1) 74 denominator.get_shape().with_rank_at_most(1) 75 return control_flow_ops.cond( 76 math_ops.equal( 77 array_ops.constant(0.0, dtype=dtypes.float64), denominator), 78 lambda: array_ops.constant(0.0, dtype=dtypes.float64), 79 lambda: math_ops.div(numerator, denominator), 80 name=name) 81 82 83def _create_local(name, shape, collections=None, validate_shape=True, 84 dtype=dtypes.float32): 85 """Creates a new local variable. 86 87 Args: 88 name: The name of the new or existing variable. 89 shape: Shape of the new or existing variable. 90 collections: A list of collection names to which the Variable will be added. 91 validate_shape: Whether to validate the shape of the variable. 92 dtype: Data type of the variables. 93 94 Returns: 95 The created variable. 96 """ 97 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 98 collections = list(collections or []) 99 collections += [ops.GraphKeys.LOCAL_VARIABLES] 100 return variables.Variable( 101 initial_value=array_ops.zeros(shape, dtype=dtype), 102 name=name, 103 trainable=False, 104 collections=collections, 105 validate_shape=validate_shape) 106 107 108def _count_condition(values, weights=None, metrics_collections=None, 109 updates_collections=None): 110 """Sums the weights of cases where the given values are True. 111 112 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 113 114 Args: 115 values: A `bool` `Tensor` of arbitrary size. 116 weights: An optional `Tensor` whose shape is broadcastable to `values`. 117 metrics_collections: An optional list of collections that the metric 118 value variable should be added to. 119 updates_collections: An optional list of collections that the metric update 120 ops should be added to. 121 122 Returns: 123 value_tensor: A `Tensor` representing the current value of the metric. 124 update_op: An operation that accumulates the error from a batch of data. 125 126 Raises: 127 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 128 or if either `metrics_collections` or `updates_collections` are not a list 129 or tuple. 130 """ 131 check_ops.assert_type(values, dtypes.bool) 132 count = _create_local('count', shape=[]) 133 134 values = math_ops.to_float(values) 135 if weights is not None: 136 weights = math_ops.to_float(weights) 137 values = math_ops.mul(values, weights) 138 139 value_tensor = array_ops.identity(count) 140 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 141 142 if metrics_collections: 143 ops.add_to_collections(metrics_collections, value_tensor) 144 145 if updates_collections: 146 ops.add_to_collections(updates_collections, update_op) 147 148 return value_tensor, update_op 149 150 151def streaming_true_positives(predictions, labels, weights=None, 152 metrics_collections=None, 153 updates_collections=None, 154 name=None): 155 """Sum the weights of true_positives. 156 157 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 158 159 Args: 160 predictions: The predicted values, a `bool` `Tensor` of arbitrary 161 dimensions. 162 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 163 match `predictions`. 164 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 165 metrics_collections: An optional list of collections that the metric 166 value variable should be added to. 167 updates_collections: An optional list of collections that the metric update 168 ops should be added to. 169 name: An optional variable_scope name. 170 171 Returns: 172 value_tensor: A `Tensor` representing the current value of the metric. 173 update_op: An operation that accumulates the error from a batch of data. 174 175 Raises: 176 ValueError: If `predictions` and `labels` have mismatched shapes, or if 177 `weights` is not `None` and its shape doesn't match `predictions`, or if 178 either `metrics_collections` or `updates_collections` are not a list or 179 tuple. 180 """ 181 return metrics.true_positives( 182 predictions=predictions, labels=labels, weights=weights, 183 metrics_collections=metrics_collections, 184 updates_collections=updates_collections, name=name) 185 186 187def streaming_true_negatives(predictions, labels, weights=None, 188 metrics_collections=None, 189 updates_collections=None, 190 name=None): 191 """Sum the weights of true_negatives. 192 193 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 194 195 Args: 196 predictions: The predicted values, a `bool` `Tensor` of arbitrary 197 dimensions. 198 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 199 match `predictions`. 200 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 201 metrics_collections: An optional list of collections that the metric 202 value variable should be added to. 203 updates_collections: An optional list of collections that the metric update 204 ops should be added to. 205 name: An optional variable_scope name. 206 207 Returns: 208 value_tensor: A `Tensor` representing the current value of the metric. 209 update_op: An operation that accumulates the error from a batch of data. 210 211 Raises: 212 ValueError: If `predictions` and `labels` have mismatched shapes, or if 213 `weights` is not `None` and its shape doesn't match `predictions`, or if 214 either `metrics_collections` or `updates_collections` are not a list or 215 tuple. 216 """ 217 with variable_scope.variable_scope( 218 name, 'true_negatives', (predictions, labels, weights)): 219 220 predictions = ops.convert_to_tensor(predictions) 221 labels = ops.convert_to_tensor(labels) 222 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 223 is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0), 224 math_ops.equal(predictions, 0)) 225 return _count_condition(is_true_negative, weights, metrics_collections, 226 updates_collections) 227 228 229def streaming_false_positives(predictions, labels, weights=None, 230 metrics_collections=None, 231 updates_collections=None, 232 name=None): 233 """Sum the weights of false positives. 234 235 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 236 237 Args: 238 predictions: The predicted values, a `bool` `Tensor` of arbitrary 239 dimensions. 240 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 241 match `predictions`. 242 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 243 metrics_collections: An optional list of collections that the metric 244 value variable should be added to. 245 updates_collections: An optional list of collections that the metric update 246 ops should be added to. 247 name: An optional variable_scope name. 248 249 Returns: 250 value_tensor: A `Tensor` representing the current value of the metric. 251 update_op: An operation that accumulates the error from a batch of data. 252 253 Raises: 254 ValueError: If `predictions` and `labels` have mismatched shapes, or if 255 `weights` is not `None` and its shape doesn't match `predictions`, or if 256 either `metrics_collections` or `updates_collections` are not a list or 257 tuple. 258 """ 259 return metrics.false_positives( 260 predictions=predictions, labels=labels, weights=weights, 261 metrics_collections=metrics_collections, 262 updates_collections=updates_collections, name=name) 263 264 265def streaming_false_negatives(predictions, labels, weights=None, 266 metrics_collections=None, 267 updates_collections=None, 268 name=None): 269 """Computes the total number of false positives. 270 271 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 272 273 Args: 274 predictions: The predicted values, a `bool` `Tensor` of arbitrary 275 dimensions. 276 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 277 match `predictions`. 278 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 279 metrics_collections: An optional list of collections that the metric 280 value variable should be added to. 281 updates_collections: An optional list of collections that the metric update 282 ops should be added to. 283 name: An optional variable_scope name. 284 285 Returns: 286 value_tensor: A `Tensor` representing the current value of the metric. 287 update_op: An operation that accumulates the error from a batch of data. 288 289 Raises: 290 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 291 or if either `metrics_collections` or `updates_collections` are not a list 292 or tuple. 293 """ 294 return metrics.false_negatives( 295 predictions=predictions, labels=labels, weights=weights, 296 metrics_collections=metrics_collections, 297 updates_collections=updates_collections, name=name) 298 299 300def _broadcast_weights(weights, values): 301 """Broadcast `weights` to the same shape as `values`. 302 303 This returns a version of `weights` following the same broadcast rules as 304 `mul(weights, values)`. When computing a weighted average, use this function 305 to broadcast `weights` before summing them; e.g., 306 `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 307 308 Args: 309 weights: `Tensor` whose shape is broadcastable to `values`. 310 values: `Tensor` of any shape. 311 312 Returns: 313 `weights` broadcast to `values` shape. 314 """ 315 weights_shape = weights.get_shape() 316 values_shape = values.get_shape() 317 if (weights_shape.is_fully_defined() and 318 values_shape.is_fully_defined() and 319 weights_shape.is_compatible_with(values_shape)): 320 return weights 321 return math_ops.mul( 322 weights, array_ops.ones_like(values), name='broadcast_weights') 323 324 325def streaming_mean(values, weights=None, metrics_collections=None, 326 updates_collections=None, name=None): 327 """Computes the (weighted) mean of the given values. 328 329 The `streaming_mean` function creates two local variables, `total` and `count` 330 that are used to compute the average of `values`. This average is ultimately 331 returned as `mean` which is an idempotent operation that simply divides 332 `total` by `count`. 333 334 For estimation of the metric over a stream of data, the function creates an 335 `update_op` operation that updates these variables and returns the `mean`. 336 `update_op` increments `total` with the reduced sum of the product of `values` 337 and `weights`, and it increments `count` with the reduced sum of `weights`. 338 339 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 340 341 Args: 342 values: A `Tensor` of arbitrary dimensions. 343 weights: An optional `Tensor` whose shape is broadcastable to `values`. 344 metrics_collections: An optional list of collections that `mean` 345 should be added to. 346 updates_collections: An optional list of collections that `update_op` 347 should be added to. 348 name: An optional variable_scope name. 349 350 Returns: 351 mean: A `Tensor` representing the current mean, the value of `total` divided 352 by `count`. 353 update_op: An operation that increments the `total` and `count` variables 354 appropriately and whose value matches `mean_value`. 355 356 Raises: 357 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 358 or if either `metrics_collections` or `updates_collections` are not a list 359 or tuple. 360 """ 361 return metrics.mean( 362 values=values, weights=weights, metrics_collections=metrics_collections, 363 updates_collections=updates_collections, name=name) 364 365 366def streaming_mean_tensor(values, weights=None, metrics_collections=None, 367 updates_collections=None, name=None): 368 """Computes the element-wise (weighted) mean of the given tensors. 369 370 In contrast to the `streaming_mean` function which returns a scalar with the 371 mean, this function returns an average tensor with the same shape as the 372 input tensors. 373 374 The `streaming_mean_tensor` function creates two local variables, 375 `total_tensor` and `count_tensor` that are used to compute the average of 376 `values`. This average is ultimately returned as `mean` which is an idempotent 377 operation that simply divides `total` by `count`. 378 379 For estimation of the metric over a stream of data, the function creates an 380 `update_op` operation that updates these variables and returns the `mean`. 381 `update_op` increments `total` with the reduced sum of the product of `values` 382 and `weights`, and it increments `count` with the reduced sum of `weights`. 383 384 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 385 386 Args: 387 values: A `Tensor` of arbitrary dimensions. 388 weights: An optional `Tensor` whose shape is broadcastable to `values`. 389 metrics_collections: An optional list of collections that `mean` 390 should be added to. 391 updates_collections: An optional list of collections that `update_op` 392 should be added to. 393 name: An optional variable_scope name. 394 395 Returns: 396 mean: A float `Tensor` representing the current mean, the value of `total` 397 divided by `count`. 398 update_op: An operation that increments the `total` and `count` variables 399 appropriately and whose value matches `mean_value`. 400 401 Raises: 402 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 403 or if either `metrics_collections` or `updates_collections` are not a list 404 or tuple. 405 """ 406 return metrics.mean_tensor( 407 values=values, weights=weights, metrics_collections=metrics_collections, 408 updates_collections=updates_collections, name=name) 409 410 411def streaming_accuracy(predictions, labels, weights=None, 412 metrics_collections=None, updates_collections=None, 413 name=None): 414 """Calculates how often `predictions` matches `labels`. 415 416 The `streaming_accuracy` function creates two local variables, `total` and 417 `count` that are used to compute the frequency with which `predictions` 418 matches `labels`. This frequency is ultimately returned as `accuracy`: an 419 idempotent operation that simply divides `total` by `count`. 420 421 For estimation of the metric over a stream of data, the function creates an 422 `update_op` operation that updates these variables and returns the `accuracy`. 423 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 424 where the corresponding elements of `predictions` and `labels` match and 0.0 425 otherwise. Then `update_op` increments `total` with the reduced sum of the 426 product of `weights` and `is_correct`, and it increments `count` with the 427 reduced sum of `weights`. 428 429 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 430 431 Args: 432 predictions: The predicted values, a `Tensor` of any shape. 433 labels: The ground truth values, a `Tensor` whose shape matches 434 `predictions`. 435 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 436 metrics_collections: An optional list of collections that `accuracy` should 437 be added to. 438 updates_collections: An optional list of collections that `update_op` should 439 be added to. 440 name: An optional variable_scope name. 441 442 Returns: 443 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 444 by `count`. 445 update_op: An operation that increments the `total` and `count` variables 446 appropriately and whose value matches `accuracy`. 447 448 Raises: 449 ValueError: If `predictions` and `labels` have mismatched shapes, or if 450 `weights` is not `None` and its shape doesn't match `predictions`, or if 451 either `metrics_collections` or `updates_collections` are not a list or 452 tuple. 453 """ 454 return metrics.accuracy( 455 predictions=predictions, labels=labels, weights=weights, 456 metrics_collections=metrics_collections, 457 updates_collections=updates_collections, name=name) 458 459 460def streaming_precision(predictions, labels, weights=None, 461 metrics_collections=None, updates_collections=None, 462 name=None): 463 """Computes the precision of the predictions with respect to the labels. 464 465 The `streaming_precision` function creates two local variables, 466 `true_positives` and `false_positives`, that are used to compute the 467 precision. This value is ultimately returned as `precision`, an idempotent 468 operation that simply divides `true_positives` by the sum of `true_positives` 469 and `false_positives`. 470 471 For estimation of the metric over a stream of data, the function creates an 472 `update_op` operation that updates these variables and returns the 473 `precision`. `update_op` weights each prediction by the corresponding value in 474 `weights`. 475 476 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 477 478 Args: 479 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 480 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 481 match `predictions`. 482 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 483 metrics_collections: An optional list of collections that `precision` should 484 be added to. 485 updates_collections: An optional list of collections that `update_op` should 486 be added to. 487 name: An optional variable_scope name. 488 489 Returns: 490 precision: Scalar float `Tensor` with the value of `true_positives` 491 divided by the sum of `true_positives` and `false_positives`. 492 update_op: `Operation` that increments `true_positives` and 493 `false_positives` variables appropriately and whose value matches 494 `precision`. 495 496 Raises: 497 ValueError: If `predictions` and `labels` have mismatched shapes, or if 498 `weights` is not `None` and its shape doesn't match `predictions`, or if 499 either `metrics_collections` or `updates_collections` are not a list or 500 tuple. 501 """ 502 return metrics.precision( 503 predictions=predictions, labels=labels, weights=weights, 504 metrics_collections=metrics_collections, 505 updates_collections=updates_collections, name=name) 506 507 508def streaming_recall(predictions, labels, weights=None, 509 metrics_collections=None, updates_collections=None, 510 name=None): 511 """Computes the recall of the predictions with respect to the labels. 512 513 The `streaming_recall` function creates two local variables, `true_positives` 514 and `false_negatives`, that are used to compute the recall. This value is 515 ultimately returned as `recall`, an idempotent operation that simply divides 516 `true_positives` by the sum of `true_positives` and `false_negatives`. 517 518 For estimation of the metric over a stream of data, the function creates an 519 `update_op` that updates these variables and returns the `recall`. `update_op` 520 weights each prediction by the corresponding value in `weights`. 521 522 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 523 524 Args: 525 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 526 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 527 match `predictions`. 528 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 529 metrics_collections: An optional list of collections that `recall` should 530 be added to. 531 updates_collections: An optional list of collections that `update_op` should 532 be added to. 533 name: An optional variable_scope name. 534 535 Returns: 536 recall: Scalar float `Tensor` with the value of `true_positives` divided 537 by the sum of `true_positives` and `false_negatives`. 538 update_op: `Operation` that increments `true_positives` and 539 `false_negatives` variables appropriately and whose value matches 540 `recall`. 541 542 Raises: 543 ValueError: If `predictions` and `labels` have mismatched shapes, or if 544 `weights` is not `None` and its shape doesn't match `predictions`, or if 545 either `metrics_collections` or `updates_collections` are not a list or 546 tuple. 547 """ 548 return metrics.recall( 549 predictions=predictions, labels=labels, weights=weights, 550 metrics_collections=metrics_collections, 551 updates_collections=updates_collections, name=name) 552 553 554def _streaming_confusion_matrix_at_thresholds( 555 predictions, labels, thresholds, weights=None, includes=None): 556 """Computes true_positives, false_negatives, true_negatives, false_positives. 557 558 This function creates up to four local variables, `true_positives`, 559 `true_negatives`, `false_positives` and `false_negatives`. 560 `true_positive[i]` is defined as the total weight of values in `predictions` 561 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 562 `false_negatives[i]` is defined as the total weight of values in `predictions` 563 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 564 `true_negatives[i]` is defined as the total weight of values in `predictions` 565 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 566 `false_positives[i]` is defined as the total weight of values in `predictions` 567 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 568 569 For estimation of these metrics over a stream of data, for each metric the 570 function respectively creates an `update_op` operation that updates the 571 variable and returns its value. 572 573 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 574 575 Args: 576 predictions: A floating point `Tensor` of arbitrary shape and whose values 577 are in the range `[0, 1]`. 578 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 579 to `bool`. 580 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 581 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 582 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 583 default to all four. 584 585 Returns: 586 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 587 `includes`. 588 update_ops: Dict of operations that increments the `values`. Keys are from 589 `includes`. 590 591 Raises: 592 ValueError: If `predictions` and `labels` have mismatched shapes, or if 593 `weights` is not `None` and its shape doesn't match `predictions`, or if 594 `includes` contains invalid keys. 595 """ 596 all_includes = ('tp', 'fn', 'tn', 'fp') 597 if includes is None: 598 includes = all_includes 599 else: 600 for include in includes: 601 if include not in all_includes: 602 raise ValueError('Invaild key: %s.' % include) 603 604 predictions, labels, weights = _remove_squeezable_dimensions( 605 predictions, labels, weights) 606 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 607 608 num_thresholds = len(thresholds) 609 610 # Reshape predictions and labels. 611 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 612 labels_2d = array_ops.reshape( 613 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 614 615 # Use static shape if known. 616 num_predictions = predictions_2d.get_shape().as_list()[0] 617 618 # Otherwise use dynamic shape. 619 if num_predictions is None: 620 num_predictions = array_ops.shape(predictions_2d)[0] 621 thresh_tiled = array_ops.tile( 622 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 623 array_ops.pack([1, num_predictions])) 624 625 # Tile the predictions after thresholding them across different thresholds. 626 pred_is_pos = math_ops.greater( 627 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 628 thresh_tiled) 629 if ('fn' in includes) or ('tn' in includes): 630 pred_is_neg = math_ops.logical_not(pred_is_pos) 631 632 # Tile labels by number of thresholds 633 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 634 if ('fp' in includes) or ('tn' in includes): 635 label_is_neg = math_ops.logical_not(label_is_pos) 636 637 if weights is not None: 638 weights = math_ops.to_float(weights) 639 weights_tiled = array_ops.tile(array_ops.reshape(_broadcast_weights( 640 weights, predictions), [1, -1]), [num_thresholds, 1]) 641 thresh_tiled.get_shape().assert_is_compatible_with( 642 weights_tiled.get_shape()) 643 else: 644 weights_tiled = None 645 646 values = {} 647 update_ops = {} 648 649 if 'tp' in includes: 650 true_positives = _create_local('true_positives', shape=[num_thresholds]) 651 is_true_positive = math_ops.to_float( 652 math_ops.logical_and(label_is_pos, pred_is_pos)) 653 if weights_tiled is not None: 654 is_true_positive *= weights_tiled 655 update_ops['tp'] = state_ops.assign_add( 656 true_positives, math_ops.reduce_sum(is_true_positive, 1)) 657 values['tp'] = true_positives 658 659 if 'fn' in includes: 660 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 661 is_false_negative = math_ops.to_float( 662 math_ops.logical_and(label_is_pos, pred_is_neg)) 663 if weights_tiled is not None: 664 is_false_negative *= weights_tiled 665 update_ops['fn'] = state_ops.assign_add( 666 false_negatives, math_ops.reduce_sum(is_false_negative, 1)) 667 values['fn'] = false_negatives 668 669 if 'tn' in includes: 670 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 671 is_true_negative = math_ops.to_float( 672 math_ops.logical_and(label_is_neg, pred_is_neg)) 673 if weights_tiled is not None: 674 is_true_negative *= weights_tiled 675 update_ops['tn'] = state_ops.assign_add( 676 true_negatives, math_ops.reduce_sum(is_true_negative, 1)) 677 values['tn'] = true_negatives 678 679 if 'fp' in includes: 680 false_positives = _create_local('false_positives', shape=[num_thresholds]) 681 is_false_positive = math_ops.to_float( 682 math_ops.logical_and(label_is_neg, pred_is_pos)) 683 if weights_tiled is not None: 684 is_false_positive *= weights_tiled 685 update_ops['fp'] = state_ops.assign_add( 686 false_positives, math_ops.reduce_sum(is_false_positive, 1)) 687 values['fp'] = false_positives 688 689 return values, update_ops 690 691 692def streaming_true_positives_at_thresholds( 693 predictions, labels, thresholds, weights=None): 694 values, update_ops = _streaming_confusion_matrix_at_thresholds( 695 predictions, labels, thresholds, weights=weights, includes=('tp',)) 696 return values['tp'], update_ops['tp'] 697 698 699def streaming_false_negatives_at_thresholds( 700 predictions, labels, thresholds, weights=None): 701 values, update_ops = _streaming_confusion_matrix_at_thresholds( 702 predictions, labels, thresholds, weights=weights, includes=('fn',)) 703 return values['fn'], update_ops['fn'] 704 705 706def streaming_false_positives_at_thresholds( 707 predictions, labels, thresholds, weights=None): 708 values, update_ops = _streaming_confusion_matrix_at_thresholds( 709 predictions, labels, thresholds, weights=weights, includes=('fp',)) 710 return values['fp'], update_ops['fp'] 711 712 713def streaming_true_negatives_at_thresholds( 714 predictions, labels, thresholds, weights=None): 715 values, update_ops = _streaming_confusion_matrix_at_thresholds( 716 predictions, labels, thresholds, weights=weights, includes=('tn',)) 717 return values['tn'], update_ops['tn'] 718 719 720def streaming_auc(predictions, labels, weights=None, num_thresholds=200, 721 metrics_collections=None, updates_collections=None, 722 curve='ROC', name=None): 723 """Computes the approximate AUC via a Riemann sum. 724 725 The `streaming_auc` function creates four local variables, `true_positives`, 726 `true_negatives`, `false_positives` and `false_negatives` that are used to 727 compute the AUC. To discretize the AUC curve, a linearly spaced set of 728 thresholds is used to compute pairs of recall and precision values. The area 729 under the ROC-curve is therefore computed using the height of the recall 730 values by the false positive rate, while the area under the PR-curve is the 731 computed using the height of the precision values by the recall. 732 733 This value is ultimately returned as `auc`, an idempotent operation that 734 computes the area under a discretized curve of precision versus recall values 735 (computed using the aforementioned variables). The `num_thresholds` variable 736 controls the degree of discretization with larger numbers of thresholds more 737 closely approximating the true AUC. The quality of the approximation may vary 738 dramatically depending on `num_thresholds`. 739 740 For best results, `predictions` should be distributed approximately uniformly 741 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 742 approximation may be poor if this is not the case. 743 744 For estimation of the metric over a stream of data, the function creates an 745 `update_op` operation that updates these variables and returns the `auc`. 746 747 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 748 749 Args: 750 predictions: A floating point `Tensor` of arbitrary shape and whose values 751 are in the range `[0, 1]`. 752 labels: A `bool` `Tensor` whose shape matches `predictions`. 753 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 754 num_thresholds: The number of thresholds to use when discretizing the roc 755 curve. 756 metrics_collections: An optional list of collections that `auc` should be 757 added to. 758 updates_collections: An optional list of collections that `update_op` should 759 be added to. 760 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 761 'PR' for the Precision-Recall-curve. 762 name: An optional variable_scope name. 763 764 Returns: 765 auc: A scalar `Tensor` representing the current area-under-curve. 766 update_op: An operation that increments the `true_positives`, 767 `true_negatives`, `false_positives` and `false_negatives` variables 768 appropriately and whose value matches `auc`. 769 770 Raises: 771 ValueError: If `predictions` and `labels` have mismatched shapes, or if 772 `weights` is not `None` and its shape doesn't match `predictions`, or if 773 either `metrics_collections` or `updates_collections` are not a list or 774 tuple. 775 """ 776 return metrics.auc( 777 predictions=predictions, labels=labels, weights=weights, 778 metrics_collections=metrics_collections, num_thresholds=num_thresholds, 779 curve=curve, updates_collections=updates_collections, name=name) 780 781 782def streaming_specificity_at_sensitivity( 783 predictions, labels, sensitivity, weights=None, num_thresholds=200, 784 metrics_collections=None, updates_collections=None, name=None): 785 """Computes the specificity at a given sensitivity. 786 787 The `streaming_specificity_at_sensitivity` function creates four local 788 variables, `true_positives`, `true_negatives`, `false_positives` and 789 `false_negatives` that are used to compute the specificity at the given 790 sensitivity value. The threshold for the given sensitivity value is computed 791 and used to evaluate the corresponding specificity. 792 793 For estimation of the metric over a stream of data, the function creates an 794 `update_op` operation that updates these variables and returns the 795 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 796 `false_positives` and `false_negatives` counts with the weight of each case 797 found in the `predictions` and `labels`. 798 799 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 800 801 For additional information about specificity and sensitivity, see the 802 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 803 804 Args: 805 predictions: A floating point `Tensor` of arbitrary shape and whose values 806 are in the range `[0, 1]`. 807 labels: A `bool` `Tensor` whose shape matches `predictions`. 808 sensitivity: A scalar value in range `[0, 1]`. 809 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 810 num_thresholds: The number of thresholds to use for matching the given 811 sensitivity. 812 metrics_collections: An optional list of collections that `specificity` 813 should be added to. 814 updates_collections: An optional list of collections that `update_op` should 815 be added to. 816 name: An optional variable_scope name. 817 818 Returns: 819 specificity: A scalar `Tensor` representing the specificity at the given 820 `specificity` value. 821 update_op: An operation that increments the `true_positives`, 822 `true_negatives`, `false_positives` and `false_negatives` variables 823 appropriately and whose value matches `specificity`. 824 825 Raises: 826 ValueError: If `predictions` and `labels` have mismatched shapes, if 827 `weights` is not `None` and its shape doesn't match `predictions`, or if 828 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 829 or `updates_collections` are not a list or tuple. 830 """ 831 return metrics.specificity_at_sensitivity( 832 sensitivity=sensitivity, num_thresholds=num_thresholds, 833 predictions=predictions, labels=labels, weights=weights, 834 metrics_collections=metrics_collections, 835 updates_collections=updates_collections, name=name) 836 837 838def streaming_sensitivity_at_specificity( 839 predictions, labels, specificity, weights=None, num_thresholds=200, 840 metrics_collections=None, updates_collections=None, name=None): 841 """Computes the specificity at a given sensitivity. 842 843 The `streaming_sensitivity_at_specificity` function creates four local 844 variables, `true_positives`, `true_negatives`, `false_positives` and 845 `false_negatives` that are used to compute the sensitivity at the given 846 specificity value. The threshold for the given specificity value is computed 847 and used to evaluate the corresponding sensitivity. 848 849 For estimation of the metric over a stream of data, the function creates an 850 `update_op` operation that updates these variables and returns the 851 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 852 `false_positives` and `false_negatives` counts with the weight of each case 853 found in the `predictions` and `labels`. 854 855 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 856 857 For additional information about specificity and sensitivity, see the 858 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 859 860 Args: 861 predictions: A floating point `Tensor` of arbitrary shape and whose values 862 are in the range `[0, 1]`. 863 labels: A `bool` `Tensor` whose shape matches `predictions`. 864 specificity: A scalar value in range `[0, 1]`. 865 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 866 num_thresholds: The number of thresholds to use for matching the given 867 specificity. 868 metrics_collections: An optional list of collections that `sensitivity` 869 should be added to. 870 updates_collections: An optional list of collections that `update_op` should 871 be added to. 872 name: An optional variable_scope name. 873 874 Returns: 875 sensitivity: A scalar `Tensor` representing the sensitivity at the given 876 `specificity` value. 877 update_op: An operation that increments the `true_positives`, 878 `true_negatives`, `false_positives` and `false_negatives` variables 879 appropriately and whose value matches `sensitivity`. 880 881 Raises: 882 ValueError: If `predictions` and `labels` have mismatched shapes, if 883 `weights` is not `None` and its shape doesn't match `predictions`, or if 884 `specificity` is not between 0 and 1, or if either `metrics_collections` 885 or `updates_collections` are not a list or tuple. 886 """ 887 return metrics.sensitivity_at_specificity( 888 specificity=specificity, num_thresholds=num_thresholds, 889 predictions=predictions, labels=labels, weights=weights, 890 metrics_collections=metrics_collections, 891 updates_collections=updates_collections, name=name) 892 893 894def streaming_precision_at_thresholds(predictions, labels, thresholds, 895 weights=None, 896 metrics_collections=None, 897 updates_collections=None, name=None): 898 """Computes precision values for different `thresholds` on `predictions`. 899 900 The `streaming_precision_at_thresholds` function creates four local variables, 901 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 902 for various values of thresholds. `precision[i]` is defined as the total 903 weight of values in `predictions` above `thresholds[i]` whose corresponding 904 entry in `labels` is `True`, divided by the total weight of values in 905 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 906 false_positives[i])`). 907 908 For estimation of the metric over a stream of data, the function creates an 909 `update_op` operation that updates these variables and returns the 910 `precision`. 911 912 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 913 914 Args: 915 predictions: A floating point `Tensor` of arbitrary shape and whose values 916 are in the range `[0, 1]`. 917 labels: A `bool` `Tensor` whose shape matches `predictions`. 918 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 919 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 920 metrics_collections: An optional list of collections that `auc` should be 921 added to. 922 updates_collections: An optional list of collections that `update_op` should 923 be added to. 924 name: An optional variable_scope name. 925 926 Returns: 927 precision: A float `Tensor` of shape `[len(thresholds)]`. 928 update_op: An operation that increments the `true_positives`, 929 `true_negatives`, `false_positives` and `false_negatives` variables that 930 are used in the computation of `precision`. 931 932 Raises: 933 ValueError: If `predictions` and `labels` have mismatched shapes, or if 934 `weights` is not `None` and its shape doesn't match `predictions`, or if 935 either `metrics_collections` or `updates_collections` are not a list or 936 tuple. 937 """ 938 return metrics.precision_at_thresholds( 939 thresholds=thresholds, 940 predictions=predictions, labels=labels, weights=weights, 941 metrics_collections=metrics_collections, 942 updates_collections=updates_collections, name=name) 943 944 945def streaming_recall_at_thresholds(predictions, labels, thresholds, 946 weights=None, metrics_collections=None, 947 updates_collections=None, name=None): 948 """Computes various recall values for different `thresholds` on `predictions`. 949 950 The `streaming_recall_at_thresholds` function creates four local variables, 951 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 952 for various values of thresholds. `recall[i]` is defined as the total weight 953 of values in `predictions` above `thresholds[i]` whose corresponding entry in 954 `labels` is `True`, divided by the total weight of `True` values in `labels` 955 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 956 957 For estimation of the metric over a stream of data, the function creates an 958 `update_op` operation that updates these variables and returns the `recall`. 959 960 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 961 962 Args: 963 predictions: A floating point `Tensor` of arbitrary shape and whose values 964 are in the range `[0, 1]`. 965 labels: A `bool` `Tensor` whose shape matches `predictions`. 966 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 967 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 968 metrics_collections: An optional list of collections that `recall` should be 969 added to. 970 updates_collections: An optional list of collections that `update_op` should 971 be added to. 972 name: An optional variable_scope name. 973 974 Returns: 975 recall: A float `Tensor` of shape `[len(thresholds)]`. 976 update_op: An operation that increments the `true_positives`, 977 `true_negatives`, `false_positives` and `false_negatives` variables that 978 are used in the computation of `recall`. 979 980 Raises: 981 ValueError: If `predictions` and `labels` have mismatched shapes, or if 982 `weights` is not `None` and its shape doesn't match `predictions`, or if 983 either `metrics_collections` or `updates_collections` are not a list or 984 tuple. 985 """ 986 return metrics.recall_at_thresholds( 987 thresholds=thresholds, 988 predictions=predictions, labels=labels, weights=weights, 989 metrics_collections=metrics_collections, 990 updates_collections=updates_collections, name=name) 991 992 993def _at_k_name(name, k=None, class_id=None): 994 if k is not None: 995 name = '%s_at_%d' % (name, k) 996 else: 997 name = '%s_at_k' % (name) 998 if class_id is not None: 999 name = '%s_class%d' % (name, class_id) 1000 return name 1001 1002 1003@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 1004 'and reshape labels from [batch_size] to [batch_size, 1].') 1005def streaming_recall_at_k(predictions, labels, k, weights=None, 1006 metrics_collections=None, updates_collections=None, 1007 name=None): 1008 """Computes the recall@k of the predictions with respect to dense labels. 1009 1010 The `streaming_recall_at_k` function creates two local variables, `total` and 1011 `count`, that are used to compute the recall@k frequency. This frequency is 1012 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1013 divides `total` by `count`. 1014 1015 For estimation of the metric over a stream of data, the function creates an 1016 `update_op` operation that updates these variables and returns the 1017 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 1018 shape [batch_size] whose elements indicate whether or not the corresponding 1019 label is in the top `k` `predictions`. Then `update_op` increments `total` 1020 with the reduced sum of `weights` where `in_top_k` is `True`, and it 1021 increments `count` with the reduced sum of `weights`. 1022 1023 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1024 1025 Args: 1026 predictions: A float `Tensor` of dimension [batch_size, num_classes]. 1027 labels: A `Tensor` of dimension [batch_size] whose type is in `int32`, 1028 `int64`. 1029 k: The number of top elements to look at for computing recall. 1030 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 1031 metrics_collections: An optional list of collections that `recall_at_k` 1032 should be added to. 1033 updates_collections: An optional list of collections `update_op` should be 1034 added to. 1035 name: An optional variable_scope name. 1036 1037 Returns: 1038 recall_at_k: A `Tensor` representing the recall@k, the fraction of labels 1039 which fall into the top `k` predictions. 1040 update_op: An operation that increments the `total` and `count` variables 1041 appropriately and whose value matches `recall_at_k`. 1042 1043 Raises: 1044 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1045 `weights` is not `None` and its shape doesn't match `predictions`, or if 1046 either `metrics_collections` or `updates_collections` are not a list or 1047 tuple. 1048 """ 1049 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1050 return streaming_mean(in_top_k, 1051 weights, 1052 metrics_collections, 1053 updates_collections, 1054 name or _at_k_name('recall', k)) 1055 1056 1057# TODO(ptucker): Validate range of values in labels? 1058def streaming_sparse_recall_at_k(predictions, 1059 labels, 1060 k, 1061 class_id=None, 1062 weights=None, 1063 metrics_collections=None, 1064 updates_collections=None, 1065 name=None): 1066 """Computes recall@k of the predictions with respect to sparse labels. 1067 1068 If `class_id` is specified, we calculate recall by considering only the 1069 entries in the batch for which `class_id` is in the label, and computing 1070 the fraction of them for which `class_id` is in the top-k `predictions`. 1071 If `class_id` is not specified, we'll calculate recall as how often on 1072 average a class among the labels of a batch entry is in the top-k 1073 `predictions`. 1074 1075 `streaming_sparse_recall_at_k` creates two local variables, 1076 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1077 the recall_at_k frequency. This frequency is ultimately returned as 1078 `recall_at_<k>`: an idempotent operation that simply divides 1079 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1080 `false_negative_at_<k>`). 1081 1082 For estimation of the metric over a stream of data, the function creates an 1083 `update_op` operation that updates these variables and returns the 1084 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1085 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1086 `labels` calculate the true positives and false negatives weighted by 1087 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1088 `false_negative_at_<k>` using these values. 1089 1090 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1091 1092 Args: 1093 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1094 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1095 The final dimension contains the logit values for each class. [D1, ... DN] 1096 must match `labels`. 1097 labels: `int64` `Tensor` or `SparseTensor` with shape 1098 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1099 target classes for the associated prediction. Commonly, N=1 and `labels` 1100 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 1101 Values should be in range [0, num_classes), where num_classes is the last 1102 dimension of `predictions`. Values outside this range always count 1103 towards `false_negative_at_<k>`. 1104 k: Integer, k for @k metric. 1105 class_id: Integer class ID for which we want binary metrics. This should be 1106 in range [0, num_classes), where num_classes is the last dimension of 1107 `predictions`. If class_id is outside this range, the method returns NAN. 1108 weights: An optional `Tensor` whose shape is broadcastable to the first 1109 [D1, ... DN] dimensions of `predictions` and `labels`. 1110 metrics_collections: An optional list of collections that values should 1111 be added to. 1112 updates_collections: An optional list of collections that updates should 1113 be added to. 1114 name: Name of new update operation, and namespace for other dependent ops. 1115 1116 Returns: 1117 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1118 by the sum of `true_positives` and `false_negatives`. 1119 update_op: `Operation` that increments `true_positives` and 1120 `false_negatives` variables appropriately, and whose value matches 1121 `recall`. 1122 1123 Raises: 1124 ValueError: If `weights` is not `None` and its shape doesn't match 1125 `predictions`, or if either `metrics_collections` or `updates_collections` 1126 are not a list or tuple. 1127 """ 1128 return metrics.recall_at_k( 1129 k=k, class_id=class_id, 1130 predictions=predictions, labels=labels, weights=weights, 1131 metrics_collections=metrics_collections, 1132 updates_collections=updates_collections, name=name) 1133 1134 1135def _streaming_sparse_precision_at_k(top_k_idx, 1136 labels, 1137 k=None, 1138 class_id=None, 1139 weights=None, 1140 metrics_collections=None, 1141 updates_collections=None, 1142 name=None): 1143 """Computes precision@k of the top-k indices with respect to sparse labels. 1144 1145 This method contains the code shared by streaming_sparse_precision_at_k and 1146 streaming_sparse_precision_at_top_k. Refer to those methods for more details. 1147 1148 Args: 1149 top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where 1150 N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k]. 1151 The final dimension contains the indices of top-k labels. [D1, ... DN] 1152 must match `labels`. 1153 labels: `int64` `Tensor` or `SparseTensor` with shape 1154 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1155 target classes for the associated prediction. Commonly, N=1 and `labels` 1156 has shape [batch_size, num_labels]. [D1, ... DN] must match 1157 `predictions_idx`. Values should be in range [0, num_classes), where 1158 num_classes is the last dimension of `predictions`. Values outside this 1159 range are ignored. 1160 k: Integer, k for @k metric or `None`. Only used for default op name. 1161 class_id: Integer class ID for which we want binary metrics. This should be 1162 in range [0, num_classes), where num_classes is the last dimension of 1163 `predictions`. If `class_id` is outside this range, the method returns 1164 NAN. 1165 weights: An optional `Tensor` whose shape is broadcastable to the first 1166 [D1, ... DN] dimensions of `predictions` and `labels`. 1167 metrics_collections: An optional list of collections that values should 1168 be added to. 1169 updates_collections: An optional list of collections that updates should 1170 be added to. 1171 name: Name of the metric and of the enclosing scope. 1172 1173 Returns: 1174 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1175 divided by the sum of `true_positives` and `false_positives`. 1176 update_op: `Operation` that increments `true_positives` and 1177 `false_positives` variables appropriately, and whose value matches 1178 `precision`. 1179 1180 Raises: 1181 ValueError: If `weights` is not `None` and its shape doesn't match 1182 `predictions`, or if either `metrics_collections` or `updates_collections` 1183 are not a list or tuple. 1184 """ 1185 top_k_idx = math_ops.to_int64(top_k_idx) 1186 tp, tp_update = _streaming_sparse_true_positive_at_k( 1187 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1188 weights=weights) 1189 fp, fp_update = _streaming_sparse_false_positive_at_k( 1190 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1191 weights=weights) 1192 1193 metric = math_ops.div(tp, math_ops.add(tp, fp), name=name) 1194 update = math_ops.div( 1195 tp_update, math_ops.add(tp_update, fp_update), name='update') 1196 if metrics_collections: 1197 ops.add_to_collections(metrics_collections, metric) 1198 if updates_collections: 1199 ops.add_to_collections(updates_collections, update) 1200 return metric, update 1201 1202 1203# TODO(ptucker): Validate range of values in labels? 1204def streaming_sparse_precision_at_k(predictions, 1205 labels, 1206 k, 1207 class_id=None, 1208 weights=None, 1209 metrics_collections=None, 1210 updates_collections=None, 1211 name=None): 1212 """Computes precision@k of the predictions with respect to sparse labels. 1213 1214 If `class_id` is specified, we calculate precision by considering only the 1215 entries in the batch for which `class_id` is in the top-k highest 1216 `predictions`, and computing the fraction of them for which `class_id` is 1217 indeed a correct label. 1218 If `class_id` is not specified, we'll calculate precision as how often on 1219 average a class among the top-k classes with the highest predicted values 1220 of a batch entry is correct and can be found in the label for that entry. 1221 1222 `streaming_sparse_precision_at_k` creates two local variables, 1223 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1224 the precision@k frequency. This frequency is ultimately returned as 1225 `precision_at_<k>`: an idempotent operation that simply divides 1226 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1227 `false_positive_at_<k>`). 1228 1229 For estimation of the metric over a stream of data, the function creates an 1230 `update_op` operation that updates these variables and returns the 1231 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1232 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1233 `labels` calculate the true positives and false positives weighted by 1234 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1235 `false_positive_at_<k>` using these values. 1236 1237 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1238 1239 Args: 1240 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1241 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1242 The final dimension contains the logit values for each class. [D1, ... DN] 1243 must match `labels`. 1244 labels: `int64` `Tensor` or `SparseTensor` with shape 1245 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1246 target classes for the associated prediction. Commonly, N=1 and `labels` 1247 has shape [batch_size, num_labels]. [D1, ... DN] must match 1248 `predictions`. Values should be in range [0, num_classes), where 1249 num_classes is the last dimension of `predictions`. Values outside this 1250 range are ignored. 1251 k: Integer, k for @k metric. 1252 class_id: Integer class ID for which we want binary metrics. This should be 1253 in range [0, num_classes], where num_classes is the last dimension of 1254 `predictions`. If `class_id` is outside this range, the method returns 1255 NAN. 1256 weights: An optional `Tensor` whose shape is broadcastable to the first 1257 [D1, ... DN] dimensions of `predictions` and `labels`. 1258 metrics_collections: An optional list of collections that values should 1259 be added to. 1260 updates_collections: An optional list of collections that updates should 1261 be added to. 1262 name: Name of new update operation, and namespace for other dependent ops. 1263 1264 Returns: 1265 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1266 divided by the sum of `true_positives` and `false_positives`. 1267 update_op: `Operation` that increments `true_positives` and 1268 `false_positives` variables appropriately, and whose value matches 1269 `precision`. 1270 1271 Raises: 1272 ValueError: If `weights` is not `None` and its shape doesn't match 1273 `predictions`, or if either `metrics_collections` or `updates_collections` 1274 are not a list or tuple. 1275 """ 1276 return metrics.sparse_precision_at_k( 1277 k=k, class_id=class_id, 1278 predictions=predictions, labels=labels, weights=weights, 1279 metrics_collections=metrics_collections, 1280 updates_collections=updates_collections, name=name) 1281 1282 1283# TODO(ptucker): Validate range of values in labels? 1284def streaming_sparse_precision_at_top_k(top_k_predictions, 1285 labels, 1286 class_id=None, 1287 weights=None, 1288 metrics_collections=None, 1289 updates_collections=None, 1290 name=None): 1291 """Computes precision@k of top-k predictions with respect to sparse labels. 1292 1293 If `class_id` is specified, we calculate precision by considering only the 1294 entries in the batch for which `class_id` is in the top-k highest 1295 `predictions`, and computing the fraction of them for which `class_id` is 1296 indeed a correct label. 1297 If `class_id` is not specified, we'll calculate precision as how often on 1298 average a class among the top-k classes with the highest predicted values 1299 of a batch entry is correct and can be found in the label for that entry. 1300 1301 `streaming_sparse_precision_at_top_k` creates two local variables, 1302 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 1303 the precision@k frequency. This frequency is ultimately returned as 1304 `precision_at_k`: an idempotent operation that simply divides 1305 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 1306 1307 For estimation of the metric over a stream of data, the function creates an 1308 `update_op` operation that updates these variables and returns the 1309 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 1310 and `labels` calculate the true positives and false positives weighted by 1311 `weights`. Then `update_op` increments `true_positive_at_k` and 1312 `false_positive_at_k` using these values. 1313 1314 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1315 1316 Args: 1317 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 1318 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 1319 The final dimension contains the indices of top-k labels. [D1, ... DN] 1320 must match `labels`. 1321 labels: `int64` `Tensor` or `SparseTensor` with shape 1322 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1323 target classes for the associated prediction. Commonly, N=1 and `labels` 1324 has shape [batch_size, num_labels]. [D1, ... DN] must match 1325 `top_k_predictions`. Values should be in range [0, num_classes), where 1326 num_classes is the last dimension of `predictions`. Values outside this 1327 range are ignored. 1328 class_id: Integer class ID for which we want binary metrics. This should be 1329 in range [0, num_classes), where num_classes is the last dimension of 1330 `predictions`. If `class_id` is outside this range, the method returns 1331 NAN. 1332 weights: An optional `Tensor` whose shape is broadcastable to the first 1333 [D1, ... DN] dimensions of `predictions` and `labels`. 1334 metrics_collections: An optional list of collections that values should 1335 be added to. 1336 updates_collections: An optional list of collections that updates should 1337 be added to. 1338 name: Name of new update operation, and namespace for other dependent ops. 1339 1340 Returns: 1341 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1342 divided by the sum of `true_positives` and `false_positives`. 1343 update_op: `Operation` that increments `true_positives` and 1344 `false_positives` variables appropriately, and whose value matches 1345 `precision`. 1346 1347 Raises: 1348 ValueError: If `weights` is not `None` and its shape doesn't match 1349 `predictions`, or if either `metrics_collections` or `updates_collections` 1350 are not a list or tuple. 1351 ValueError: If `top_k_predictions` has rank < 2. 1352 """ 1353 default_name = _at_k_name('precision', class_id=class_id) 1354 with ops.name_scope( 1355 name, default_name, 1356 (top_k_predictions, labels, weights)) as scope: 1357 rank = array_ops.rank(top_k_predictions) 1358 check_rank_op = control_flow_ops.Assert( 1359 math_ops.greater_equal(rank, 2), 1360 ['top_k_predictions must have rank 2 or higher, e.g. [batch_size, k].']) 1361 with ops.control_dependencies([check_rank_op]): 1362 return _streaming_sparse_precision_at_k( 1363 top_k_idx=top_k_predictions, 1364 labels=labels, 1365 class_id=class_id, 1366 weights=weights, 1367 metrics_collections=metrics_collections, 1368 updates_collections=updates_collections, 1369 name=scope) 1370 1371 1372def num_relevant(labels, k): 1373 """Computes number of relevant values for each row in labels. 1374 1375 For labels with shape [D1, ... DN, num_labels], this is the minimum of 1376 `num_labels` and `k`. 1377 1378 Args: 1379 labels: `int64` `Tensor` or `SparseTensor` with shape 1380 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1381 target classes for the associated prediction. Commonly, N=1 and `labels` 1382 has shape [batch_size, num_labels]. 1383 k: Integer, k for @k metric. 1384 1385 Returns: 1386 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 1387 relevant values for that row. 1388 1389 Raises: 1390 ValueError: if inputs have invalid dtypes or values. 1391 """ 1392 if k < 1: 1393 raise ValueError('Invalid k=%s.' % k) 1394 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 1395 # For SparseTensor, calculate separate count for each row. 1396 if isinstance( 1397 labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 1398 labels_sizes = set_ops.set_size(labels) 1399 return math_ops.minimum(labels_sizes, k, name=scope) 1400 1401 # For dense Tensor, calculate scalar count based on last dimension, and 1402 # tile across labels shape. 1403 labels_shape = array_ops.shape(labels) 1404 labels_size = labels_shape[-1] 1405 num_relevant_scalar = math_ops.minimum(labels_size, k) 1406 return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope) 1407 1408 1409def expand_and_tile(tensor, multiple, dim=0, name=None): 1410 """Slice `tensor` shape in 2, then tile along the sliced dimension. 1411 1412 A new dimension is inserted in shape of `tensor` before `dim`, then values are 1413 tiled `multiple` times along the new dimension. 1414 1415 Args: 1416 tensor: Input `Tensor` or `SparseTensor`. 1417 multiple: Integer, number of times to tile. 1418 dim: Integer, dimension along which to tile. 1419 name: Name of operation. 1420 1421 Returns: 1422 `Tensor` result of expanding and tiling `tensor`. 1423 1424 Raises: 1425 ValueError: if `multiple` is less than 1, or `dim` is not in 1426 `[-rank(tensor), rank(tensor)]`. 1427 """ 1428 if multiple < 1: 1429 raise ValueError('Invalid multiple %s, must be > 0.' % multiple) 1430 with ops.name_scope( 1431 name, 'expand_and_tile', (tensor, multiple, dim)) as scope: 1432 # Sparse. 1433 if isinstance(tensor, sparse_tensor.SparseTensorValue): 1434 tensor = sparse_tensor.SparseTensor.from_value(tensor) 1435 if isinstance(tensor, sparse_tensor.SparseTensor): 1436 if dim < 0: 1437 expand_dims = array_ops.reshape( 1438 array_ops.size(tensor.dense_shape) + dim, [1]) 1439 else: 1440 expand_dims = [dim] 1441 expanded_shape = array_ops.concat_v2( 1442 (array_ops.strided_slice(tensor.dense_shape, [0], expand_dims), 1443 [1], 1444 array_ops.strided_slice( 1445 tensor.dense_shape, expand_dims, [-1], end_mask=1 << 0)), 1446 0, name='expanded_shape') 1447 expanded = sparse_ops.sparse_reshape( 1448 tensor, shape=expanded_shape, name='expand') 1449 if multiple == 1: 1450 return expanded 1451 return sparse_ops.sparse_concat( 1452 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 1453 1454 # Dense. 1455 expanded = array_ops.expand_dims( 1456 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 1457 if multiple == 1: 1458 return expanded 1459 ones = array_ops.ones_like(array_ops.shape(tensor)) 1460 tile_multiples = array_ops.concat_v2( 1461 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 1462 return array_ops.tile(expanded, tile_multiples, name=scope) 1463 1464 1465def sparse_average_precision_at_k(predictions, labels, k): 1466 """Computes average precision@k of predictions with respect to sparse labels. 1467 1468 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 1469 for each row is: 1470 1471 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 1472 1473 A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`, 1474 and the result `Tensors`. In the common case, this is [batch_size]. Each row 1475 of the results contains the average precision for that row. 1476 1477 Internally, a `top_k` operation computes a `Tensor` indicating the top `k` 1478 `predictions`. Set operations applied to `top_k` and `labels` calculate the 1479 true positives, which are used to calculate the precision ("P_{i}" term, 1480 above). 1481 1482 Args: 1483 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1484 N >= 1. Commonly, N=1 and `predictions` has shape 1485 [batch size, num_classes]. The final dimension contains the logit values 1486 for each class. [D1, ... DN] must match `labels`. 1487 labels: `int64` `Tensor` or `SparseTensor` with shape 1488 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1489 target classes for the associated prediction. Commonly, N=1 and `labels` 1490 has shape [batch_size, num_labels]. [D1, ... DN] must match 1491 `predictions`. Values should be in range [0, num_classes), where 1492 num_classes is the last dimension of `predictions`. Values outside this 1493 range are ignored. 1494 k: Integer, k for @k metric. This will calculate an average precision for 1495 range `[1,k]`, as documented above. 1496 1497 Returns: 1498 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 1499 precision for that row. 1500 1501 Raises: 1502 ValueError: if k is invalid. 1503 """ 1504 if k < 1: 1505 raise ValueError('Invalid k=%s.' % k) 1506 with ops.name_scope( 1507 None, 'average_precision', (predictions, labels, k)) as scope: 1508 # Calculate top k indices to produce [D1, ... DN, k] tensor. 1509 _, predictions_idx = nn.top_k(predictions, k) 1510 predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx') 1511 1512 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 1513 # prediction for each k, so we can calculate separate true positive values 1514 # for each k. 1515 predictions_idx_per_k = array_ops.expand_dims( 1516 predictions_idx, -1, name='predictions_idx_per_k') 1517 1518 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 1519 labels_per_k = expand_and_tile( 1520 labels, multiple=k, dim=-1, name='labels_per_k') 1521 1522 # The following tensors are all of shape [D1, ... DN, k], containing values 1523 # per row, per k value. 1524 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 1525 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 1526 # the formula above. 1527 # `tp_per_k` (int32) - True positive counts. 1528 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 1529 # the precision denominator. 1530 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 1531 # term from the formula above. 1532 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 1533 # precisions at all k for which relevance indicator is true. 1534 relevant_per_k = _sparse_true_positive_at_k( 1535 predictions_idx_per_k, labels_per_k, name='relevant_per_k') 1536 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 1537 retrieved_per_k = math_ops.cumsum( 1538 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 1539 precision_per_k = math_ops.div( 1540 math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k), 1541 name='precision_per_k') 1542 relevant_precision_per_k = math_ops.mul( 1543 precision_per_k, math_ops.to_double(relevant_per_k), 1544 name='relevant_precision_per_k') 1545 1546 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 1547 precision_sum = math_ops.reduce_sum( 1548 relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum') 1549 1550 # Divide by number of relevant items to get average precision. These are 1551 # the "num_relevant_items" and "AveP" terms from the formula above. 1552 num_relevant_items = math_ops.to_double(num_relevant(labels, k)) 1553 return math_ops.div(precision_sum, num_relevant_items, name=scope) 1554 1555 1556def streaming_sparse_average_precision_at_k(predictions, 1557 labels, 1558 k, 1559 weights=None, 1560 metrics_collections=None, 1561 updates_collections=None, 1562 name=None): 1563 """Computes average precision@k of predictions with respect to sparse labels. 1564 1565 See `sparse_average_precision_at_k` for details on formula. `weights` are 1566 applied to the result of `sparse_average_precision_at_k` 1567 1568 `streaming_sparse_average_precision_at_k` creates two local variables, 1569 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 1570 are used to compute the frequency. This frequency is ultimately returned as 1571 `average_precision_at_<k>`: an idempotent operation that simply divides 1572 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 1573 1574 For estimation of the metric over a stream of data, the function creates an 1575 `update_op` operation that updates these variables and returns the 1576 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1577 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1578 `labels` calculate the true positives and false positives weighted by 1579 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1580 `false_positive_at_<k>` using these values. 1581 1582 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1583 1584 Args: 1585 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1586 N >= 1. Commonly, N=1 and `predictions` has shape 1587 [batch size, num_classes]. The final dimension contains the logit values 1588 for each class. [D1, ... DN] must match `labels`. 1589 labels: `int64` `Tensor` or `SparseTensor` with shape 1590 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1591 target classes for the associated prediction. Commonly, N=1 and `labels` 1592 has shape [batch_size, num_labels]. [D1, ... DN] must match 1593 `predictions_`. Values should be in range [0, num_classes), where 1594 num_classes is the last dimension of `predictions`. Values outside this 1595 range are ignored. 1596 k: Integer, k for @k metric. This will calculate an average precision for 1597 range `[1,k]`, as documented above. 1598 weights: An optional `Tensor` whose shape is broadcastable to the first 1599 [D1, ... DN] dimensions of `predictions` and `labels`. 1600 metrics_collections: An optional list of collections that values should 1601 be added to. 1602 updates_collections: An optional list of collections that updates should 1603 be added to. 1604 name: Name of new update operation, and namespace for other dependent ops. 1605 1606 Returns: 1607 mean_average_precision: Scalar `float64` `Tensor` with the mean average 1608 precision values. 1609 update: `Operation` that increments variables appropriately, and whose 1610 value matches `metric`. 1611 """ 1612 return metrics.sparse_average_precision_at_k( 1613 k=k, predictions=predictions, labels=labels, weights=weights, 1614 metrics_collections=metrics_collections, 1615 updates_collections=updates_collections, name=name) 1616 1617 1618def _select_class_id(ids, selected_id): 1619 """Filter all but `selected_id` out of `ids`. 1620 1621 Args: 1622 ids: `int64` `Tensor` or `SparseTensor` of IDs. 1623 selected_id: Int id to select. 1624 1625 Returns: 1626 `SparseTensor` of same dimensions as `ids`. This contains only the entries 1627 equal to `selected_id`. 1628 """ 1629 if isinstance( 1630 ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 1631 return sparse_ops.sparse_retain( 1632 ids, math_ops.equal(ids.values, selected_id)) 1633 1634 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 1635 # tf.equal and tf.reduce_any? 1636 1637 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 1638 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 1639 ids_last_dim = array_ops.size(ids_shape) - 1 1640 filled_selected_id_shape = math_ops.reduced_shape( 1641 ids_shape, array_ops.reshape(ids_last_dim, [1])) 1642 1643 # Intersect `ids` with the selected ID. 1644 filled_selected_id = array_ops.fill( 1645 filled_selected_id_shape, math_ops.to_int64(selected_id)) 1646 result = set_ops.set_intersection(filled_selected_id, ids) 1647 return sparse_tensor.SparseTensor( 1648 indices=result.indices, values=result.values, dense_shape=ids_shape) 1649 1650 1651def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 1652 """If class ID is specified, filter all other classes. 1653 1654 Args: 1655 labels: `int64` `Tensor` or `SparseTensor` with shape 1656 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1657 target classes for the associated prediction. Commonly, N=1 and `labels` 1658 has shape [batch_size, num_labels]. [D1, ... DN] must match 1659 `predictions_idx`. 1660 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 1661 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 1662 [batch size, k]. 1663 selected_id: Int id to select. 1664 1665 Returns: 1666 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 1667 """ 1668 if selected_id is None: 1669 return labels, predictions_idx 1670 return (_select_class_id(labels, selected_id), 1671 _select_class_id(predictions_idx, selected_id)) 1672 1673 1674def _sparse_true_positive_at_k(predictions_idx, 1675 labels, 1676 class_id=None, 1677 weights=None, 1678 name=None): 1679 """Calculates true positives for recall@k and precision@k. 1680 1681 If `class_id` is specified, calculate binary true positives for `class_id` 1682 only. 1683 If `class_id` is not specified, calculate metrics for `k` predicted vs 1684 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1685 1686 Args: 1687 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1688 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1689 match `labels`. 1690 labels: `int64` `Tensor` or `SparseTensor` with shape 1691 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1692 target classes for the associated prediction. Commonly, N=1 and `labels` 1693 has shape [batch_size, num_labels]. [D1, ... DN] must match 1694 `predictions_idx`. 1695 class_id: Class for which we want binary metrics. 1696 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1697 dimensions of `predictions_idx` and `labels`. 1698 name: Name of operation. 1699 1700 Returns: 1701 A [D1, ... DN] `Tensor` of true positive counts. 1702 """ 1703 with ops.name_scope(name, 'true_positives', (predictions_idx, labels)): 1704 labels, predictions_idx = _maybe_select_class_id( 1705 labels, predictions_idx, class_id) 1706 tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels)) 1707 tp = math_ops.to_double(tp) 1708 if weights is not None: 1709 weights = math_ops.to_double(weights) 1710 tp = math_ops.mul(tp, weights) 1711 return tp 1712 1713 1714def _streaming_sparse_true_positive_at_k(predictions_idx, 1715 labels, 1716 k=None, 1717 class_id=None, 1718 weights=None, 1719 name=None): 1720 """Calculates weighted per step true positives for recall@k and precision@k. 1721 1722 If `class_id` is specified, calculate binary true positives for `class_id` 1723 only. 1724 If `class_id` is not specified, calculate metrics for `k` predicted vs 1725 `n` label classes, where `n` is the 2nd dimension of `labels`. 1726 1727 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1728 1729 Args: 1730 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1731 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1732 match `labels`. 1733 labels: `int64` `Tensor` or `SparseTensor` with shape 1734 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1735 target classes for the associated prediction. Commonly, N=1 and `labels` 1736 has shape [batch_size, num_labels]. [D1, ... DN] must match 1737 `predictions_idx`. 1738 k: Integer, k for @k metric. This is only used for default op name. 1739 class_id: Class for which we want binary metrics. 1740 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1741 dimensions of `predictions_idx` and `labels`. 1742 name: Name of new variable, and namespace for other dependent ops. 1743 1744 Returns: 1745 A tuple of `Variable` and update `Operation`. 1746 1747 Raises: 1748 ValueError: If `weights` is not `None` and has an incomptable shape. 1749 """ 1750 default_name = _at_k_name('true_positive', k, class_id=class_id) 1751 with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope: 1752 tp = _sparse_true_positive_at_k( 1753 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 1754 weights=weights) 1755 batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp)) 1756 1757 var = contrib_variables.local_variable( 1758 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1759 return var, state_ops.assign_add(var, batch_total_tp, name='update') 1760 1761 1762def _sparse_false_positive_at_k(predictions_idx, 1763 labels, 1764 class_id=None, 1765 weights=None): 1766 """Calculates false positives for precision@k. 1767 1768 If `class_id` is specified, calculate binary true positives for `class_id` 1769 only. 1770 If `class_id` is not specified, calculate metrics for `k` predicted vs 1771 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1772 1773 Args: 1774 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1775 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1776 match `labels`. 1777 labels: `int64` `Tensor` or `SparseTensor` with shape 1778 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1779 target classes for the associated prediction. Commonly, N=1 and `labels` 1780 has shape [batch_size, num_labels]. [D1, ... DN] must match 1781 `predictions_idx`. 1782 class_id: Class for which we want binary metrics. 1783 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1784 dimensions of `predictions_idx` and `labels`. 1785 1786 Returns: 1787 A [D1, ... DN] `Tensor` of false positive counts. 1788 """ 1789 with ops.name_scope(None, 'false_positives', (predictions_idx, labels)): 1790 labels, predictions_idx = _maybe_select_class_id(labels, 1791 predictions_idx, 1792 class_id) 1793 fp = set_ops.set_size(set_ops.set_difference( 1794 predictions_idx, labels, aminusb=True)) 1795 fp = math_ops.to_double(fp) 1796 if weights is not None: 1797 weights = math_ops.to_double(weights) 1798 fp = math_ops.mul(fp, weights) 1799 return fp 1800 1801 1802def _streaming_sparse_false_positive_at_k(predictions_idx, 1803 labels, 1804 k=None, 1805 class_id=None, 1806 weights=None, 1807 name=None): 1808 """Calculates weighted per step false positives for precision@k. 1809 1810 If `class_id` is specified, calculate binary true positives for `class_id` 1811 only. 1812 If `class_id` is not specified, calculate metrics for `k` predicted vs 1813 `n` label classes, where `n` is the 2nd dimension of `labels`. 1814 1815 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1816 1817 Args: 1818 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1819 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1820 match `labels`. 1821 labels: `int64` `Tensor` or `SparseTensor` with shape 1822 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1823 target classes for the associated prediction. Commonly, N=1 and `labels` 1824 has shape [batch_size, num_labels]. [D1, ... DN] must match 1825 `predictions_idx`. 1826 k: Integer, k for @k metric. This is only used for default op name. 1827 class_id: Class for which we want binary metrics. 1828 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1829 dimensions of `predictions_idx` and `labels`. 1830 name: Name of new variable, and namespace for other dependent ops. 1831 1832 Returns: 1833 A tuple of `Variable` and update `Operation`. 1834 1835 Raises: 1836 ValueError: If `weights` is not `None` and has an incomptable shape. 1837 """ 1838 default_name = _at_k_name('false_positive', k, class_id=class_id) 1839 with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope: 1840 fp = _sparse_false_positive_at_k( 1841 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 1842 weights=weights) 1843 batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp)) 1844 1845 var = contrib_variables.local_variable( 1846 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1847 return var, state_ops.assign_add(var, batch_total_fp, name='update') 1848 1849 1850def _sparse_false_negative_at_k(predictions_idx, 1851 labels, 1852 class_id=None, 1853 weights=None): 1854 """Calculates false negatives for recall@k. 1855 1856 If `class_id` is specified, calculate binary true positives for `class_id` 1857 only. 1858 If `class_id` is not specified, calculate metrics for `k` predicted vs 1859 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1860 1861 Args: 1862 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1863 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1864 match `labels`. 1865 labels: `int64` `Tensor` or `SparseTensor` with shape 1866 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1867 target classes for the associated prediction. Commonly, N=1 and `labels` 1868 has shape [batch_size, num_labels]. [D1, ... DN] must match 1869 `predictions_idx`. 1870 class_id: Class for which we want binary metrics. 1871 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1872 dimensions of `predictions_idx` and `labels`. 1873 1874 Returns: 1875 A [D1, ... DN] `Tensor` of false negative counts. 1876 """ 1877 with ops.name_scope(None, 'false_negatives', (predictions_idx, labels)): 1878 labels, predictions_idx = _maybe_select_class_id(labels, 1879 predictions_idx, 1880 class_id) 1881 fn = set_ops.set_size(set_ops.set_difference(predictions_idx, 1882 labels, 1883 aminusb=False)) 1884 fn = math_ops.to_double(fn) 1885 if weights is not None: 1886 weights = math_ops.to_double(weights) 1887 fn = math_ops.mul(fn, weights) 1888 return fn 1889 1890 1891def _streaming_sparse_false_negative_at_k(predictions_idx, 1892 labels, 1893 k, 1894 class_id=None, 1895 weights=None, 1896 name=None): 1897 """Calculates weighted per step false negatives for recall@k. 1898 1899 If `class_id` is specified, calculate binary true positives for `class_id` 1900 only. 1901 If `class_id` is not specified, calculate metrics for `k` predicted vs 1902 `n` label classes, where `n` is the 2nd dimension of `labels`. 1903 1904 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1905 1906 Args: 1907 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1908 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1909 match `labels`. 1910 labels: `int64` `Tensor` or `SparseTensor` with shape 1911 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1912 target classes for the associated prediction. Commonly, N=1 and `labels` 1913 has shape [batch_size, num_labels]. [D1, ... DN] must match 1914 `predictions_idx`. 1915 k: Integer, k for @k metric. This is only used for default op name. 1916 class_id: Class for which we want binary metrics. 1917 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1918 dimensions of `predictions_idx` and `labels`. 1919 name: Name of new variable, and namespace for other dependent ops. 1920 1921 Returns: 1922 A tuple of `Variable` and update `Operation`. 1923 1924 Raises: 1925 ValueError: If `weights` is not `None` and has an incomptable shape. 1926 """ 1927 default_name = _at_k_name('false_negative', k, class_id=class_id) 1928 with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope: 1929 fn = _sparse_false_negative_at_k( 1930 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 1931 weights=weights) 1932 batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn)) 1933 1934 var = contrib_variables.local_variable( 1935 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1936 return var, state_ops.assign_add(var, batch_total_fn, name='update') 1937 1938 1939def streaming_mean_absolute_error(predictions, labels, weights=None, 1940 metrics_collections=None, 1941 updates_collections=None, 1942 name=None): 1943 """Computes the mean absolute error between the labels and predictions. 1944 1945 The `streaming_mean_absolute_error` function creates two local variables, 1946 `total` and `count` that are used to compute the mean absolute error. This 1947 average is weighted by `weights`, and it is ultimately returned as 1948 `mean_absolute_error`: an idempotent operation that simply divides `total` by 1949 `count`. 1950 1951 For estimation of the metric over a stream of data, the function creates an 1952 `update_op` operation that updates these variables and returns the 1953 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 1954 absolute value of the differences between `predictions` and `labels`. Then 1955 `update_op` increments `total` with the reduced sum of the product of 1956 `weights` and `absolute_errors`, and it increments `count` with the reduced 1957 sum of `weights` 1958 1959 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1960 1961 Args: 1962 predictions: A `Tensor` of arbitrary shape. 1963 labels: A `Tensor` of the same shape as `predictions`. 1964 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 1965 metrics_collections: An optional list of collections that 1966 `mean_absolute_error` should be added to. 1967 updates_collections: An optional list of collections that `update_op` should 1968 be added to. 1969 name: An optional variable_scope name. 1970 1971 Returns: 1972 mean_absolute_error: A `Tensor` representing the current mean, the value of 1973 `total` divided by `count`. 1974 update_op: An operation that increments the `total` and `count` variables 1975 appropriately and whose value matches `mean_absolute_error`. 1976 1977 Raises: 1978 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1979 `weights` is not `None` and its shape doesn't match `predictions`, or if 1980 either `metrics_collections` or `updates_collections` are not a list or 1981 tuple. 1982 """ 1983 return metrics.mean_absolute_error( 1984 predictions=predictions, labels=labels, weights=weights, 1985 metrics_collections=metrics_collections, 1986 updates_collections=updates_collections, name=name) 1987 1988 1989def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, 1990 metrics_collections=None, 1991 updates_collections=None, 1992 name=None): 1993 """Computes the mean relative error by normalizing with the given values. 1994 1995 The `streaming_mean_relative_error` function creates two local variables, 1996 `total` and `count` that are used to compute the mean relative absolute error. 1997 This average is weighted by `weights`, and it is ultimately returned as 1998 `mean_relative_error`: an idempotent operation that simply divides `total` by 1999 `count`. 2000 2001 For estimation of the metric over a stream of data, the function creates an 2002 `update_op` operation that updates these variables and returns the 2003 `mean_reative_error`. Internally, a `relative_errors` operation divides the 2004 absolute value of the differences between `predictions` and `labels` by the 2005 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 2006 product of `weights` and `relative_errors`, and it increments `count` with the 2007 reduced sum of `weights`. 2008 2009 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2010 2011 Args: 2012 predictions: A `Tensor` of arbitrary shape. 2013 labels: A `Tensor` of the same shape as `predictions`. 2014 normalizer: A `Tensor` of the same shape as `predictions`. 2015 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2016 metrics_collections: An optional list of collections that 2017 `mean_relative_error` should be added to. 2018 updates_collections: An optional list of collections that `update_op` should 2019 be added to. 2020 name: An optional variable_scope name. 2021 2022 Returns: 2023 mean_relative_error: A `Tensor` representing the current mean, the value of 2024 `total` divided by `count`. 2025 update_op: An operation that increments the `total` and `count` variables 2026 appropriately and whose value matches `mean_relative_error`. 2027 2028 Raises: 2029 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2030 `weights` is not `None` and its shape doesn't match `predictions`, or if 2031 either `metrics_collections` or `updates_collections` are not a list or 2032 tuple. 2033 """ 2034 return metrics.mean_relative_error( 2035 normalizer=normalizer, predictions=predictions, labels=labels, 2036 weights=weights, metrics_collections=metrics_collections, 2037 updates_collections=updates_collections, name=name) 2038 2039 2040def streaming_mean_squared_error(predictions, labels, weights=None, 2041 metrics_collections=None, 2042 updates_collections=None, 2043 name=None): 2044 """Computes the mean squared error between the labels and predictions. 2045 2046 The `streaming_mean_squared_error` function creates two local variables, 2047 `total` and `count` that are used to compute the mean squared error. 2048 This average is weighted by `weights`, and it is ultimately returned as 2049 `mean_squared_error`: an idempotent operation that simply divides `total` by 2050 `count`. 2051 2052 For estimation of the metric over a stream of data, the function creates an 2053 `update_op` operation that updates these variables and returns the 2054 `mean_squared_error`. Internally, a `squared_error` operation computes the 2055 element-wise square of the difference between `predictions` and `labels`. Then 2056 `update_op` increments `total` with the reduced sum of the product of 2057 `weights` and `squared_error`, and it increments `count` with the reduced sum 2058 of `weights`. 2059 2060 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2061 2062 Args: 2063 predictions: A `Tensor` of arbitrary shape. 2064 labels: A `Tensor` of the same shape as `predictions`. 2065 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2066 metrics_collections: An optional list of collections that 2067 `mean_squared_error` should be added to. 2068 updates_collections: An optional list of collections that `update_op` should 2069 be added to. 2070 name: An optional variable_scope name. 2071 2072 Returns: 2073 mean_squared_error: A `Tensor` representing the current mean, the value of 2074 `total` divided by `count`. 2075 update_op: An operation that increments the `total` and `count` variables 2076 appropriately and whose value matches `mean_squared_error`. 2077 2078 Raises: 2079 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2080 `weights` is not `None` and its shape doesn't match `predictions`, or if 2081 either `metrics_collections` or `updates_collections` are not a list or 2082 tuple. 2083 """ 2084 return metrics.mean_squared_error( 2085 predictions=predictions, labels=labels, weights=weights, 2086 metrics_collections=metrics_collections, 2087 updates_collections=updates_collections, name=name) 2088 2089 2090def streaming_root_mean_squared_error(predictions, labels, weights=None, 2091 metrics_collections=None, 2092 updates_collections=None, 2093 name=None): 2094 """Computes the root mean squared error between the labels and predictions. 2095 2096 The `streaming_root_mean_squared_error` function creates two local variables, 2097 `total` and `count` that are used to compute the root mean squared error. 2098 This average is weighted by `weights`, and it is ultimately returned as 2099 `root_mean_squared_error`: an idempotent operation that takes the square root 2100 of the division of `total` by `count`. 2101 2102 For estimation of the metric over a stream of data, the function creates an 2103 `update_op` operation that updates these variables and returns the 2104 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2105 the element-wise square of the difference between `predictions` and `labels`. 2106 Then `update_op` increments `total` with the reduced sum of the product of 2107 `weights` and `squared_error`, and it increments `count` with the reduced sum 2108 of `weights`. 2109 2110 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2111 2112 Args: 2113 predictions: A `Tensor` of arbitrary shape. 2114 labels: A `Tensor` of the same shape as `predictions`. 2115 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2116 metrics_collections: An optional list of collections that 2117 `root_mean_squared_error` should be added to. 2118 updates_collections: An optional list of collections that `update_op` should 2119 be added to. 2120 name: An optional variable_scope name. 2121 2122 Returns: 2123 root_mean_squared_error: A `Tensor` representing the current mean, the value 2124 of `total` divided by `count`. 2125 update_op: An operation that increments the `total` and `count` variables 2126 appropriately and whose value matches `root_mean_squared_error`. 2127 2128 Raises: 2129 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2130 `weights` is not `None` and its shape doesn't match `predictions`, or if 2131 either `metrics_collections` or `updates_collections` are not a list or 2132 tuple. 2133 """ 2134 return metrics.root_mean_squared_error( 2135 predictions=predictions, labels=labels, weights=weights, 2136 metrics_collections=metrics_collections, 2137 updates_collections=updates_collections, name=name) 2138 2139 2140def streaming_covariance(predictions, 2141 labels, 2142 weights=None, 2143 metrics_collections=None, 2144 updates_collections=None, 2145 name=None): 2146 """Computes the unbiased sample covariance between `predictions` and `labels`. 2147 2148 The `streaming_covariance` function creates four local variables, 2149 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 2150 compute the sample covariance between predictions and labels across multiple 2151 batches of data. The covariance is ultimately returned as an idempotent 2152 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 2153 in order to get an unbiased estimate. 2154 2155 The algorithm used for this online computation is described in 2156 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 2157 Specifically, the formula used to combine two sample comoments is 2158 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 2159 The comoment for a single batch of data is simply 2160 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 2161 2162 If `weights` is not None, then it is used to compute weighted comoments, 2163 means, and count. NOTE: these weights are treated as "frequency weights", as 2164 opposed to "reliability weights". See discussion of the difference on 2165 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2166 2167 To facilitate the computation of covariance across multiple batches of data, 2168 the function creates an `update_op` operation, which updates underlying 2169 variables and returns the updated covariance. 2170 2171 Args: 2172 predictions: A `Tensor` of arbitrary size. 2173 labels: A `Tensor` of the same size as `predictions`. 2174 weights: An optional set of weights which indicates the frequency with which 2175 an example is sampled. Must be broadcastable with `labels`. 2176 metrics_collections: An optional list of collections that the metric 2177 value variable should be added to. 2178 updates_collections: An optional list of collections that the metric update 2179 ops should be added to. 2180 name: An optional variable_scope name. 2181 2182 Returns: 2183 covariance: A `Tensor` representing the current unbiased sample covariance, 2184 `comoment` / (`count` - 1). 2185 update_op: An operation that updates the local variables appropriately. 2186 2187 Raises: 2188 ValueError: If labels and predictions are of different sizes or if either 2189 `metrics_collections` or `updates_collections` are not a list or tuple. 2190 """ 2191 with variable_scope.variable_scope( 2192 name, 'covariance', (predictions, labels, weights)): 2193 predictions, labels, weights = _remove_squeezable_dimensions( 2194 predictions, labels, weights) 2195 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2196 count = _create_local('count', []) 2197 mean_prediction = _create_local('mean_prediction', []) 2198 mean_label = _create_local('mean_label', []) 2199 comoment = _create_local('comoment', []) # C_A in update equation 2200 2201 if weights is None: 2202 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 2203 weighted_predictions = predictions 2204 weighted_labels = labels 2205 else: 2206 batch_count = math_ops.reduce_sum( 2207 _broadcast_weights(weights, labels)) # n_B in eqn 2208 weighted_predictions = predictions * weights 2209 weighted_labels = labels * weights 2210 2211 update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn 2212 prev_count = update_count - batch_count # n_A in update equation 2213 2214 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 2215 # batch_mean_prediction is E[x_B] in the update equation 2216 batch_mean_prediction = _safe_div( 2217 math_ops.reduce_sum(weighted_predictions), batch_count, 2218 'batch_mean_prediction') 2219 delta_mean_prediction = _safe_div( 2220 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 2221 'delta_mean_prediction') 2222 update_mean_prediction = state_ops.assign_add(mean_prediction, 2223 delta_mean_prediction) 2224 # prev_mean_prediction is E[x_A] in the update equation 2225 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 2226 2227 # batch_mean_label is E[y_B] in the update equation 2228 batch_mean_label = _safe_div( 2229 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 2230 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 2231 update_count, 'delta_mean_label') 2232 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 2233 # prev_mean_label is E[y_A] in the update equation 2234 prev_mean_label = update_mean_label - delta_mean_label 2235 2236 unweighted_batch_coresiduals = ( 2237 (predictions - batch_mean_prediction) * (labels - batch_mean_label)) 2238 # batch_comoment is C_B in the update equation 2239 if weights is None: 2240 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 2241 else: 2242 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * 2243 weights) 2244 2245 # View delta_comoment as = C_AB - C_A in the update equation above. 2246 # Since C_A is stored in a var, by how much do we need to increment that var 2247 # to make the var = C_AB? 2248 delta_comoment = (batch_comoment + 2249 (prev_mean_prediction - batch_mean_prediction) * 2250 (prev_mean_label - batch_mean_label) * 2251 (prev_count * batch_count / update_count)) 2252 update_comoment = state_ops.assign_add(comoment, delta_comoment) 2253 2254 covariance = _safe_div(comoment, count - 1, 'covariance') 2255 with ops.control_dependencies([update_comoment]): 2256 update_op = _safe_div(comoment, count - 1, 'update_op') 2257 2258 if metrics_collections: 2259 ops.add_to_collections(metrics_collections, covariance) 2260 2261 if updates_collections: 2262 ops.add_to_collections(updates_collections, update_op) 2263 2264 return covariance, update_op 2265 2266 2267def streaming_pearson_correlation(predictions, 2268 labels, 2269 weights=None, 2270 metrics_collections=None, 2271 updates_collections=None, 2272 name=None): 2273 """Computes Pearson correlation coefficient between `predictions`, `labels`. 2274 2275 The `streaming_pearson_correlation` function delegates to 2276 `streaming_covariance` the tracking of three [co]variances: 2277 2278 - `streaming_covariance(predictions, labels)`, i.e. covariance 2279 - `streaming_covariance(predictions, predictions)`, i.e. variance 2280 - `streaming_covariance(labels, labels)`, i.e. variance 2281 2282 The product-moment correlation ultimately returned is an idempotent operation 2283 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 2284 facilitate correlation computation across multiple batches, the function 2285 groups the `update_op`s of the underlying streaming_covariance and returns an 2286 `update_op`. 2287 2288 If `weights` is not None, then it is used to compute a weighted correlation. 2289 NOTE: these weights are treated as "frequency weights", as opposed to 2290 "reliability weights". See discussion of the difference on 2291 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2292 2293 Args: 2294 predictions: A `Tensor` of arbitrary size. 2295 labels: A `Tensor` of the same size as predictions. 2296 weights: An optional set of weights which indicates the frequency with which 2297 an example is sampled. Must be broadcastable with `labels`. 2298 metrics_collections: An optional list of collections that the metric 2299 value variable should be added to. 2300 updates_collections: An optional list of collections that the metric update 2301 ops should be added to. 2302 name: An optional variable_scope name. 2303 2304 Returns: 2305 pearson_r: A `Tensor` representing the current Pearson product-moment 2306 correlation coefficient, the value of 2307 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 2308 update_op: An operation that updates the underlying variables appropriately. 2309 2310 Raises: 2311 ValueError: If `labels` and `predictions` are of different sizes, or if 2312 `weights` is the wrong size, or if either `metrics_collections` or 2313 `updates_collections` are not a `list` or `tuple`. 2314 """ 2315 with variable_scope.variable_scope( 2316 name, 'pearson_r', (predictions, labels, weights)): 2317 predictions, labels, weights = _remove_squeezable_dimensions( 2318 predictions, labels, weights) 2319 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2320 cov, update_cov = streaming_covariance( 2321 predictions, labels, weights=weights, name='covariance') 2322 var_predictions, update_var_predictions = streaming_covariance( 2323 predictions, predictions, weights=weights, name='variance_predictions') 2324 var_labels, update_var_labels = streaming_covariance( 2325 labels, labels, weights=weights, name='variance_labels') 2326 2327 pearson_r = _safe_div( 2328 cov, 2329 math_ops.mul(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 2330 'pearson_r') 2331 with ops.control_dependencies( 2332 [update_cov, update_var_predictions, update_var_labels]): 2333 update_op = _safe_div(update_cov, math_ops.mul( 2334 math_ops.sqrt(update_var_predictions), 2335 math_ops.sqrt(update_var_labels)), 'update_op') 2336 2337 if metrics_collections: 2338 ops.add_to_collections(metrics_collections, pearson_r) 2339 2340 if updates_collections: 2341 ops.add_to_collections(updates_collections, update_op) 2342 2343 return pearson_r, update_op 2344 2345 2346# TODO(nsilberman): add a 'normalized' flag so that the user can request 2347# normalization if the inputs are not normalized. 2348def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, 2349 metrics_collections=None, 2350 updates_collections=None, 2351 name=None): 2352 """Computes the cosine distance between the labels and predictions. 2353 2354 The `streaming_mean_cosine_distance` function creates two local variables, 2355 `total` and `count` that are used to compute the average cosine distance 2356 between `predictions` and `labels`. This average is weighted by `weights`, 2357 and it is ultimately returned as `mean_distance`, which is an idempotent 2358 operation that simply divides `total` by `count`. 2359 2360 For estimation of the metric over a stream of data, the function creates an 2361 `update_op` operation that updates these variables and returns the 2362 `mean_distance`. 2363 2364 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2365 2366 Args: 2367 predictions: A `Tensor` of the same shape as `labels`. 2368 labels: A `Tensor` of arbitrary shape. 2369 dim: The dimension along which the cosine distance is computed. 2370 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 2371 and whose dimension `dim` is 1. 2372 metrics_collections: An optional list of collections that the metric 2373 value variable should be added to. 2374 updates_collections: An optional list of collections that the metric update 2375 ops should be added to. 2376 name: An optional variable_scope name. 2377 2378 Returns: 2379 mean_distance: A `Tensor` representing the current mean, the value of `total` 2380 divided by `count`. 2381 update_op: An operation that increments the `total` and `count` variables 2382 appropriately. 2383 2384 Raises: 2385 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2386 `weights` is not `None` and its shape doesn't match `predictions`, or if 2387 either `metrics_collections` or `updates_collections` are not a list or 2388 tuple. 2389 """ 2390 predictions, labels, weights = _remove_squeezable_dimensions( 2391 predictions, labels, weights) 2392 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2393 radial_diffs = math_ops.mul(predictions, labels) 2394 radial_diffs = math_ops.reduce_sum(radial_diffs, 2395 reduction_indices=[dim,], 2396 keep_dims=True) 2397 mean_distance, update_op = streaming_mean(radial_diffs, weights, 2398 None, 2399 None, 2400 name or 'mean_cosine_distance') 2401 mean_distance = math_ops.sub(1.0, mean_distance) 2402 update_op = math_ops.sub(1.0, update_op) 2403 2404 if metrics_collections: 2405 ops.add_to_collections(metrics_collections, mean_distance) 2406 2407 if updates_collections: 2408 ops.add_to_collections(updates_collections, update_op) 2409 2410 return mean_distance, update_op 2411 2412 2413def streaming_percentage_less(values, threshold, weights=None, 2414 metrics_collections=None, 2415 updates_collections=None, 2416 name=None): 2417 """Computes the percentage of values less than the given threshold. 2418 2419 The `streaming_percentage_less` function creates two local variables, 2420 `total` and `count` that are used to compute the percentage of `values` that 2421 fall below `threshold`. This rate is weighted by `weights`, and it is 2422 ultimately returned as `percentage` which is an idempotent operation that 2423 simply divides `total` by `count`. 2424 2425 For estimation of the metric over a stream of data, the function creates an 2426 `update_op` operation that updates these variables and returns the 2427 `percentage`. 2428 2429 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2430 2431 Args: 2432 values: A numeric `Tensor` of arbitrary size. 2433 threshold: A scalar threshold. 2434 weights: An optional `Tensor` whose shape is broadcastable to `values`. 2435 metrics_collections: An optional list of collections that the metric 2436 value variable should be added to. 2437 updates_collections: An optional list of collections that the metric update 2438 ops should be added to. 2439 name: An optional variable_scope name. 2440 2441 Returns: 2442 percentage: A `Tensor` representing the current mean, the value of `total` 2443 divided by `count`. 2444 update_op: An operation that increments the `total` and `count` variables 2445 appropriately. 2446 2447 Raises: 2448 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 2449 or if either `metrics_collections` or `updates_collections` are not a list 2450 or tuple. 2451 """ 2452 return metrics.percentage_below( 2453 values=values, threshold=threshold, weights=weights, 2454 metrics_collections=metrics_collections, 2455 updates_collections=updates_collections, name=name) 2456 2457 2458def streaming_mean_iou(predictions, 2459 labels, 2460 num_classes, 2461 weights=None, 2462 metrics_collections=None, 2463 updates_collections=None, 2464 name=None): 2465 """Calculate per-step mean Intersection-Over-Union (mIOU). 2466 2467 Mean Intersection-Over-Union is a common evaluation metric for 2468 semantic image segmentation, which first computes the IOU for each 2469 semantic class and then computes the average over classes. 2470 IOU is defined as follows: 2471 IOU = true_positive / (true_positive + false_positive + false_negative). 2472 The predictions are accumulated in a confusion matrix, weighted by `weights`, 2473 and mIOU is then calculated from it. 2474 2475 For estimation of the metric over a stream of data, the function creates an 2476 `update_op` operation that updates these variables and returns the `mean_iou`. 2477 2478 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2479 2480 Args: 2481 predictions: A `Tensor` of prediction results for semantic labels, whose 2482 shape is [batch size] and type `int32` or `int64`. The tensor will be 2483 flattened, if its rank > 1. 2484 labels: A `Tensor` of ground truth labels with shape [batch size] and of 2485 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 2486 num_classes: The possible number of labels the prediction task can 2487 have. This value must be provided, since a confusion matrix of 2488 dimension = [num_classes, num_classes] will be allocated. 2489 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2490 metrics_collections: An optional list of collections that `mean_iou` 2491 should be added to. 2492 updates_collections: An optional list of collections `update_op` should be 2493 added to. 2494 name: An optional variable_scope name. 2495 2496 Returns: 2497 mean_iou: A `Tensor` representing the mean intersection-over-union. 2498 update_op: An operation that increments the confusion matrix. 2499 2500 Raises: 2501 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2502 `weights` is not `None` and its shape doesn't match `predictions`, or if 2503 either `metrics_collections` or `updates_collections` are not a list or 2504 tuple. 2505 """ 2506 return metrics.mean_iou( 2507 num_classes=num_classes, predictions=predictions, labels=labels, 2508 weights=weights, metrics_collections=metrics_collections, 2509 updates_collections=updates_collections, name=name) 2510 2511 2512def _next_array_size(required_size, growth_factor=1.5): 2513 """Calculate the next size for reallocating a dynamic array. 2514 2515 Args: 2516 required_size: number or tf.Tensor specifying required array capacity. 2517 growth_factor: optional number or tf.Tensor specifying the growth factor 2518 between subsequent allocations. 2519 2520 Returns: 2521 tf.Tensor with dtype=int32 giving the next array size. 2522 """ 2523 exponent = math_ops.ceil( 2524 math_ops.log(math_ops.cast(required_size, dtypes.float32)) 2525 / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) 2526 return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) 2527 2528 2529def streaming_concat(values, 2530 axis=0, 2531 max_size=None, 2532 metrics_collections=None, 2533 updates_collections=None, 2534 name=None): 2535 """Concatenate values along an axis across batches. 2536 2537 The function `streaming_concat` creates two local variables, `array` and 2538 `size`, that are used to store concatenated values. Internally, `array` is 2539 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 2540 that updates can be run in amortized constant time. 2541 2542 For estimation of the metric over a stream of data, the function creates an 2543 `update_op` operation that appends the values of a tensor and returns the 2544 `value` of the concatenated tensors. 2545 2546 This op allows for evaluating metrics that cannot be updated incrementally 2547 using the same framework as other streaming metrics. 2548 2549 Args: 2550 values: `Tensor` to concatenate. Rank and the shape along all axes other 2551 than the axis to concatenate along must be statically known. 2552 axis: optional integer axis to concatenate along. 2553 max_size: optional integer maximum size of `value` along the given axis. 2554 Once the maximum size is reached, further updates are no-ops. By default, 2555 there is no maximum size: the array is resized as necessary. 2556 metrics_collections: An optional list of collections that `value` 2557 should be added to. 2558 updates_collections: An optional list of collections `update_op` should be 2559 added to. 2560 name: An optional variable_scope name. 2561 2562 Returns: 2563 value: A `Tensor` representing the concatenated values. 2564 update_op: An operation that concatenates the next values. 2565 2566 Raises: 2567 ValueError: if `values` does not have a statically known rank, `axis` is 2568 not in the valid range or the size of `values` is not statically known 2569 along any axis other than `axis`. 2570 """ 2571 with variable_scope.variable_scope(name, 'streaming_concat', (values,)): 2572 # pylint: disable=invalid-slice-index 2573 values_shape = values.get_shape() 2574 if values_shape.dims is None: 2575 raise ValueError('`values` must have known statically known rank') 2576 2577 ndim = len(values_shape) 2578 if axis < 0: 2579 axis += ndim 2580 if not 0 <= axis < ndim: 2581 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 2582 2583 fixed_shape = [dim.value for n, dim in enumerate(values_shape) 2584 if n != axis] 2585 if any(value is None for value in fixed_shape): 2586 raise ValueError('all dimensions of `values` other than the dimension to ' 2587 'concatenate along must have statically known size') 2588 2589 # We move `axis` to the front of the internal array so assign ops can be 2590 # applied to contiguous slices 2591 init_size = 0 if max_size is None else max_size 2592 init_shape = [init_size] + fixed_shape 2593 array = _create_local( 2594 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) 2595 size = _create_local('size', shape=[], dtype=dtypes.int32) 2596 2597 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 2598 valid_array = array[:size] 2599 valid_array.set_shape([None] + fixed_shape) 2600 value = array_ops.transpose(valid_array, perm, name='concat') 2601 2602 values_size = array_ops.shape(values)[axis] 2603 if max_size is None: 2604 batch_size = values_size 2605 else: 2606 batch_size = math_ops.minimum(values_size, max_size - size) 2607 2608 perm = [axis] + [n for n in range(ndim) if n != axis] 2609 batch_values = array_ops.transpose(values, perm)[:batch_size] 2610 2611 def reallocate(): 2612 next_size = _next_array_size(new_size) 2613 next_shape = array_ops.pack([next_size] + fixed_shape) 2614 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 2615 old_value = array.value() 2616 assign_op = state_ops.assign(array, new_value, validate_shape=False) 2617 with ops.control_dependencies([assign_op]): 2618 copy_op = array[:size].assign(old_value[:size]) 2619 # return value needs to be the same dtype as no_op() for cond 2620 with ops.control_dependencies([copy_op]): 2621 return control_flow_ops.no_op() 2622 2623 new_size = size + batch_size 2624 array_size = array_ops.shape_internal(array, optimize=False)[0] 2625 maybe_reallocate_op = control_flow_ops.cond( 2626 new_size > array_size, reallocate, control_flow_ops.no_op) 2627 with ops.control_dependencies([maybe_reallocate_op]): 2628 append_values_op = array[size:new_size].assign(batch_values) 2629 with ops.control_dependencies([append_values_op]): 2630 update_op = size.assign(new_size) 2631 2632 if metrics_collections: 2633 ops.add_to_collections(metrics_collections, value) 2634 2635 if updates_collections: 2636 ops.add_to_collections(updates_collections, update_op) 2637 2638 return value, update_op 2639 # pylint: enable=invalid-slice-index 2640 2641 2642def aggregate_metrics(*value_update_tuples): 2643 """Aggregates the metric value tensors and update ops into two lists. 2644 2645 Args: 2646 *value_update_tuples: a variable number of tuples, each of which contain the 2647 pair of (value_tensor, update_op) from a streaming metric. 2648 2649 Returns: 2650 A list of value `Tensor` objects and a list of update ops. 2651 2652 Raises: 2653 ValueError: if `value_update_tuples` is empty. 2654 """ 2655 if not value_update_tuples: 2656 raise ValueError('Expected at least one value_tensor/update_op pair') 2657 value_ops, update_ops = zip(*value_update_tuples) 2658 return list(value_ops), list(update_ops) 2659 2660 2661def aggregate_metric_map(names_to_tuples): 2662 """Aggregates the metric names to tuple dictionary. 2663 2664 This function is useful for pairing metric names with their associated value 2665 and update ops when the list of metrics is long. For example: 2666 2667 ```python 2668 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 2669 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 2670 predictions, labels, weights), 2671 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 2672 predictions, labels, labels, weights), 2673 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 2674 predictions, labels, weights), 2675 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 2676 predictions, labels, weights), 2677 }) 2678 ``` 2679 2680 Args: 2681 names_to_tuples: a map of metric names to tuples, each of which contain the 2682 pair of (value_tensor, update_op) from a streaming metric. 2683 2684 Returns: 2685 A dictionary from metric names to value ops and a dictionary from metric 2686 names to update ops. 2687 """ 2688 metric_names = names_to_tuples.keys() 2689 value_ops, update_ops = zip(*names_to_tuples.values()) 2690 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 2691 2692 2693def _remove_squeezable_dimensions(predictions, labels, weights): 2694 """Squeeze last dim if needed. 2695 2696 Squeezes `predictions` and `labels` if their rank differs by 1. 2697 Squeezes `weights` if its rank is 1 more than the new rank of `predictions` 2698 2699 This will use static shape if available. Otherwise, it will add graph 2700 operations, which could result in a performance hit. 2701 2702 Args: 2703 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 2704 labels: Label values, a `Tensor` whose dimensions match `predictions`. 2705 weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 2706 more than the new rank of `predictions` 2707 2708 Returns: 2709 Tuple of `predictions`, `labels` and `weights`, possibly with the last 2710 dimension squeezed. 2711 """ 2712 predictions, labels = tensor_util.remove_squeezable_dimensions( 2713 predictions, labels) 2714 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2715 2716 if weights is not None: 2717 weights = ops.convert_to_tensor(weights) 2718 predictions_shape = predictions.get_shape() 2719 predictions_rank = predictions_shape.ndims 2720 weights_shape = weights.get_shape() 2721 weights_rank = weights_shape.ndims 2722 2723 if (predictions_rank is not None) and (weights_rank is not None): 2724 # Use static rank. 2725 if weights_rank - predictions_rank == 1: 2726 weights = array_ops.squeeze(weights, [-1]) 2727 elif (weights_rank is None) or ( 2728 weights_shape.dims[-1].is_compatible_with(1)): 2729 # Use dynamic rank 2730 weights = control_flow_ops.cond( 2731 math_ops.equal(array_ops.rank(weights), 2732 math_ops.add(array_ops.rank(predictions), 1)), 2733 lambda: array_ops.squeeze(weights, [-1]), 2734 lambda: weights) 2735 return predictions, labels, weights 2736 2737 2738__all__ = [ 2739 'aggregate_metric_map', 2740 'aggregate_metrics', 2741 'streaming_accuracy', 2742 'streaming_auc', 2743 'streaming_false_negatives', 2744 'streaming_false_negatives_at_thresholds', 2745 'streaming_false_positives', 2746 'streaming_false_positives_at_thresholds', 2747 'streaming_mean', 2748 'streaming_mean_absolute_error', 2749 'streaming_mean_cosine_distance', 2750 'streaming_mean_iou', 2751 'streaming_mean_relative_error', 2752 'streaming_mean_squared_error', 2753 'streaming_mean_tensor', 2754 'streaming_percentage_less', 2755 'streaming_precision', 2756 'streaming_precision_at_thresholds', 2757 'streaming_recall', 2758 'streaming_recall_at_k', 2759 'streaming_recall_at_thresholds', 2760 'streaming_root_mean_squared_error', 2761 'streaming_sensitivity_at_specificity', 2762 'streaming_sparse_average_precision_at_k', 2763 'streaming_sparse_precision_at_k', 2764 'streaming_sparse_recall_at_k', 2765 'streaming_specificity_at_sensitivity', 2766 'streaming_true_negatives', 2767 'streaming_true_negatives_at_thresholds', 2768 'streaming_true_positives', 2769 'streaming_true_positives_at_thresholds', 2770] 2771