metric_ops.py revision 2800e4d47be9078d2688057722f1385276aae8b7
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.framework import deprecated 26from tensorflow.contrib.framework import tensor_util 27from tensorflow.contrib.framework.python.ops import variables as contrib_variables 28from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops 29from tensorflow.contrib.metrics.python.ops import set_ops 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import check_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import nn 38from tensorflow.python.ops import sparse_ops 39from tensorflow.python.ops import state_ops 40from tensorflow.python.ops import variable_scope 41from tensorflow.python.ops import variables 42 43 44def _safe_div(numerator, denominator, name): 45 """Divides two values, returning 0 if the denominator is <= 0. 46 47 Args: 48 numerator: A real `Tensor`. 49 denominator: A real `Tensor`, with dtype matching `numerator`. 50 name: Name for the returned op. 51 52 Returns: 53 0 if `denominator` <= 0, else `numerator` / `denominator` 54 """ 55 return math_ops.select( 56 math_ops.greater(denominator, 0), 57 math_ops.truediv(numerator, denominator), 58 0, 59 name=name) 60 61 62def _safe_scalar_div(numerator, denominator, name): 63 """Divides two values, returning 0 if the denominator is 0. 64 65 Args: 66 numerator: A scalar `float64` `Tensor`. 67 denominator: A scalar `float64` `Tensor`. 68 name: Name for the returned op. 69 70 Returns: 71 0 if `denominator` == 0, else `numerator` / `denominator` 72 """ 73 numerator.get_shape().with_rank_at_most(1) 74 denominator.get_shape().with_rank_at_most(1) 75 return control_flow_ops.cond( 76 math_ops.equal( 77 array_ops.constant(0.0, dtype=dtypes.float64), denominator), 78 lambda: array_ops.constant(0.0, dtype=dtypes.float64), 79 lambda: math_ops.div(numerator, denominator), 80 name=name) 81 82 83def _create_local(name, shape, collections=None, validate_shape=True, 84 dtype=dtypes.float32): 85 """Creates a new local variable. 86 87 Args: 88 name: The name of the new or existing variable. 89 shape: Shape of the new or existing variable. 90 collections: A list of collection names to which the Variable will be added. 91 validate_shape: Whether to validate the shape of the variable. 92 dtype: Data type of the variables. 93 94 Returns: 95 The created variable. 96 """ 97 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 98 collections = list(collections or []) 99 collections += [ops.GraphKeys.LOCAL_VARIABLES] 100 return variables.Variable( 101 initial_value=array_ops.zeros(shape, dtype=dtype), 102 name=name, 103 trainable=False, 104 collections=collections, 105 validate_shape=validate_shape) 106 107 108def _count_condition(values, weights=None, metrics_collections=None, 109 updates_collections=None): 110 """Sums the weights of cases where the given values are True. 111 112 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 113 114 Args: 115 values: A `bool` `Tensor` of arbitrary size. 116 weights: An optional `Tensor` whose shape is broadcastable to `values`. 117 metrics_collections: An optional list of collections that the metric 118 value variable should be added to. 119 updates_collections: An optional list of collections that the metric update 120 ops should be added to. 121 122 Returns: 123 value_tensor: A tensor representing the current value of the metric. 124 update_op: An operation that accumulates the error from a batch of data. 125 126 Raises: 127 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 128 or if either `metrics_collections` or `updates_collections` are not a list 129 or tuple. 130 """ 131 check_ops.assert_type(values, dtypes.bool) 132 count = _create_local('count', shape=[]) 133 134 values = math_ops.to_float(values) 135 if weights is not None: 136 weights = math_ops.to_float(weights) 137 values = math_ops.mul(values, weights) 138 139 value_tensor = array_ops.identity(count) 140 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 141 142 if metrics_collections: 143 ops.add_to_collections(metrics_collections, value_tensor) 144 145 if updates_collections: 146 ops.add_to_collections(updates_collections, update_op) 147 148 return value_tensor, update_op 149 150 151def _streaming_true_positives(predictions, labels, weights=None, 152 metrics_collections=None, 153 updates_collections=None, 154 name=None): 155 """Sum the weights of true_positives. 156 157 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 158 159 Args: 160 predictions: The predicted values, a `bool` `Tensor` of arbitrary 161 dimensions. 162 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 163 match `predictions`. 164 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 165 metrics_collections: An optional list of collections that the metric 166 value variable should be added to. 167 updates_collections: An optional list of collections that the metric update 168 ops should be added to. 169 name: An optional variable_scope name. 170 171 Returns: 172 value_tensor: A tensor representing the current value of the metric. 173 update_op: An operation that accumulates the error from a batch of data. 174 175 Raises: 176 ValueError: If `predictions` and `labels` have mismatched shapes, or if 177 `weights` is not `None` and its shape doesn't match `predictions`, or if 178 either `metrics_collections` or `updates_collections` are not a list or 179 tuple. 180 """ 181 with variable_scope.variable_scope( 182 name, 'true_positives', [predictions, labels]): 183 184 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 185 is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1), 186 math_ops.equal(predictions, 1)) 187 return _count_condition(is_true_positive, weights, metrics_collections, 188 updates_collections) 189 190 191def _streaming_false_positives(predictions, labels, weights=None, 192 metrics_collections=None, 193 updates_collections=None, 194 name=None): 195 """Sum the weights of false positives. 196 197 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 198 199 Args: 200 predictions: The predicted values, a `bool` `Tensor` of arbitrary 201 dimensions. 202 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 203 match `predictions`. 204 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 205 metrics_collections: An optional list of collections that the metric 206 value variable should be added to. 207 updates_collections: An optional list of collections that the metric update 208 ops should be added to. 209 name: An optional variable_scope name. 210 211 Returns: 212 value_tensor: A tensor representing the current value of the metric. 213 update_op: An operation that accumulates the error from a batch of data. 214 215 Raises: 216 ValueError: If `predictions` and `labels` have mismatched shapes, or if 217 `weights` is not `None` and its shape doesn't match `predictions`, or if 218 either `metrics_collections` or `updates_collections` are not a list or 219 tuple. 220 """ 221 with variable_scope.variable_scope( 222 name, 'false_positives', [predictions, labels]): 223 224 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 225 is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0), 226 math_ops.equal(predictions, 1)) 227 return _count_condition(is_false_positive, weights, metrics_collections, 228 updates_collections) 229 230 231def _streaming_false_negatives(predictions, labels, weights=None, 232 metrics_collections=None, 233 updates_collections=None, 234 name=None): 235 """Computes the total number of false positives. 236 237 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 238 239 Args: 240 predictions: The predicted values, a `bool` `Tensor` of arbitrary 241 dimensions. 242 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 243 match `predictions`. 244 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 245 metrics_collections: An optional list of collections that the metric 246 value variable should be added to. 247 updates_collections: An optional list of collections that the metric update 248 ops should be added to. 249 name: An optional variable_scope name. 250 251 Returns: 252 value_tensor: A tensor representing the current value of the metric. 253 update_op: An operation that accumulates the error from a batch of data. 254 255 Raises: 256 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 257 or if either `metrics_collections` or `updates_collections` are not a list 258 or tuple. 259 """ 260 with variable_scope.variable_scope( 261 name, 'false_negatives', [predictions, labels]): 262 263 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 264 is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1), 265 math_ops.equal(predictions, 0)) 266 return _count_condition(is_false_negative, weights, metrics_collections, 267 updates_collections) 268 269 270def _broadcast_weights(weights, values): 271 """Broadcast `weights` to the same shape as `values`. 272 273 This returns a version of `weights` following the same broadcast rules as 274 `mul(weights, values)`. When computing a weighted average, use this function 275 to broadcast `weights` before summing them; e.g., 276 `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 277 278 Args: 279 weights: `Tensor` whose shape is broadcastable to `values`. 280 values: `Tensor` of any shape. 281 282 Returns: 283 `weights` broadcast to `values` shape. 284 """ 285 weights_shape = weights.get_shape() 286 values_shape = values.get_shape() 287 if (weights_shape.is_fully_defined() and 288 values_shape.is_fully_defined() and 289 weights_shape.is_compatible_with(values_shape)): 290 return weights 291 return math_ops.mul( 292 weights, array_ops.ones_like(values), name='broadcast_weights') 293 294 295def streaming_mean(values, weights=None, metrics_collections=None, 296 updates_collections=None, name=None): 297 """Computes the (weighted) mean of the given values. 298 299 The `streaming_mean` function creates two local variables, `total` and `count` 300 that are used to compute the average of `values`. This average is ultimately 301 returned as `mean` which is an idempotent operation that simply divides 302 `total` by `count`. 303 304 For estimation of the metric over a stream of data, the function creates an 305 `update_op` operation that updates these variables and returns the `mean`. 306 `update_op` increments `total` with the reduced sum of the product of `values` 307 and `weights`, and it increments `count` with the reduced sum of `weights`. 308 309 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 310 311 Args: 312 values: A `Tensor` of arbitrary dimensions. 313 weights: An optional `Tensor` whose shape is broadcastable to `values`. 314 metrics_collections: An optional list of collections that `mean` 315 should be added to. 316 updates_collections: An optional list of collections that `update_op` 317 should be added to. 318 name: An optional variable_scope name. 319 320 Returns: 321 mean: A tensor representing the current mean, the value of `total` divided 322 by `count`. 323 update_op: An operation that increments the `total` and `count` variables 324 appropriately and whose value matches `mean_value`. 325 326 Raises: 327 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 328 or if either `metrics_collections` or `updates_collections` are not a list 329 or tuple. 330 """ 331 with variable_scope.variable_scope(name, 'mean', [values, weights]): 332 values = math_ops.to_float(values) 333 334 total = _create_local('total', shape=[]) 335 count = _create_local('count', shape=[]) 336 337 if weights is not None: 338 weights = math_ops.to_float(weights) 339 values = math_ops.mul(values, weights) 340 num_values = math_ops.reduce_sum(_broadcast_weights(weights, values)) 341 else: 342 num_values = math_ops.to_float(array_ops.size(values)) 343 344 total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 345 count_compute_op = state_ops.assign_add(count, num_values) 346 347 mean = _safe_div(total, count, 'value') 348 with ops.control_dependencies([total_compute_op, count_compute_op]): 349 update_op = _safe_div(total, count, 'update_op') 350 351 if metrics_collections: 352 ops.add_to_collections(metrics_collections, mean) 353 354 if updates_collections: 355 ops.add_to_collections(updates_collections, update_op) 356 357 return mean, update_op 358 359 360def streaming_mean_tensor(values, weights=None, metrics_collections=None, 361 updates_collections=None, name=None): 362 """Computes the element-wise (weighted) mean of the given tensors. 363 364 In contrast to the `streaming_mean` function which returns a scalar with the 365 mean, this function returns an average tensor with the same shape as the 366 input tensors. 367 368 The `streaming_mean_tensor` function creates two local variables, 369 `total_tensor` and `count_tensor` that are used to compute the average of 370 `values`. This average is ultimately returned as `mean` which is an idempotent 371 operation that simply divides `total` by `count`. 372 373 For estimation of the metric over a stream of data, the function creates an 374 `update_op` operation that updates these variables and returns the `mean`. 375 `update_op` increments `total` with the reduced sum of the product of `values` 376 and `weights`, and it increments `count` with the reduced sum of `weights`. 377 378 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 379 380 Args: 381 values: A `Tensor` of arbitrary dimensions. 382 weights: An optional `Tensor` whose shape is broadcastable to `values`. 383 metrics_collections: An optional list of collections that `mean` 384 should be added to. 385 updates_collections: An optional list of collections that `update_op` 386 should be added to. 387 name: An optional variable_scope name. 388 389 Returns: 390 mean: A float tensor representing the current mean, the value of `total` 391 divided by `count`. 392 update_op: An operation that increments the `total` and `count` variables 393 appropriately and whose value matches `mean_value`. 394 395 Raises: 396 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 397 or if either `metrics_collections` or `updates_collections` are not a list 398 or tuple. 399 """ 400 with variable_scope.variable_scope(name, 'mean', [values, weights]): 401 total = _create_local('total_tensor', shape=values.get_shape()) 402 count = _create_local('count_tensor', shape=values.get_shape()) 403 404 num_values = array_ops.ones_like(values) 405 if weights is not None: 406 weights = math_ops.to_float(weights) 407 values = math_ops.mul(values, weights) 408 num_values = math_ops.mul(num_values, weights) 409 410 total_compute_op = state_ops.assign_add(total, values) 411 count_compute_op = state_ops.assign_add(count, num_values) 412 413 def compute_mean(total, count, name): 414 non_zero_count = math_ops.maximum(count, 415 array_ops.ones_like(count), 416 name=name) 417 return math_ops.truediv(total, non_zero_count, name=name) 418 419 mean = compute_mean(total, count, 'value') 420 with ops.control_dependencies([total_compute_op, count_compute_op]): 421 update_op = compute_mean(total, count, 'update_op') 422 423 if metrics_collections: 424 ops.add_to_collections(metrics_collections, mean) 425 426 if updates_collections: 427 ops.add_to_collections(updates_collections, update_op) 428 429 return mean, update_op 430 431 432def streaming_accuracy(predictions, labels, weights=None, 433 metrics_collections=None, updates_collections=None, 434 name=None): 435 """Calculates how often `predictions` matches `labels`. 436 437 The `streaming_accuracy` function creates two local variables, `total` and 438 `count` that are used to compute the frequency with which `predictions` 439 matches `labels`. This frequency is ultimately returned as `accuracy`: an 440 idempotent operation that simply divides `total` by `count`. 441 442 For estimation of the metric over a stream of data, the function creates an 443 `update_op` operation that updates these variables and returns the `accuracy`. 444 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 445 where the corresponding elements of `predictions` and `labels` match and 0.0 446 otherwise. Then `update_op` increments `total` with the reduced sum of the 447 product of `weights` and `is_correct`, and it increments `count` with the 448 reduced sum of `weights`. 449 450 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 451 452 Args: 453 predictions: The predicted values, a `Tensor` of any shape. 454 labels: The ground truth values, a `Tensor` whose shape matches 455 `predictions`. 456 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 457 metrics_collections: An optional list of collections that `accuracy` should 458 be added to. 459 updates_collections: An optional list of collections that `update_op` should 460 be added to. 461 name: An optional variable_scope name. 462 463 Returns: 464 accuracy: A tensor representing the accuracy, the value of `total` divided 465 by `count`. 466 update_op: An operation that increments the `total` and `count` variables 467 appropriately and whose value matches `accuracy`. 468 469 Raises: 470 ValueError: If `predictions` and `labels` have mismatched shapes, or if 471 `weights` is not `None` and its shape doesn't match `predictions`, or if 472 either `metrics_collections` or `updates_collections` are not a list or 473 tuple. 474 """ 475 predictions, labels, weights = _remove_squeezable_dimensions( 476 predictions, labels, weights=weights) 477 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 478 if labels.dtype != predictions.dtype: 479 predictions = math_ops.cast(predictions, labels.dtype) 480 is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) 481 return streaming_mean(is_correct, weights, metrics_collections, 482 updates_collections, name or 'accuracy') 483 484 485def streaming_precision(predictions, labels, weights=None, 486 metrics_collections=None, updates_collections=None, 487 name=None): 488 """Computes the precision of the predictions with respect to the labels. 489 490 The `streaming_precision` function creates two local variables, 491 `true_positives` and `false_positives`, that are used to compute the 492 precision. This value is ultimately returned as `precision`, an idempotent 493 operation that simply divides `true_positives` by the sum of `true_positives` 494 and `false_positives`. 495 496 For estimation of the metric over a stream of data, the function creates an 497 `update_op` operation that updates these variables and returns the 498 `precision`. `update_op` weights each prediction by the corresponding value in 499 `weights`. 500 501 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 502 503 Args: 504 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 505 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 506 match `predictions`. 507 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 508 metrics_collections: An optional list of collections that `precision` should 509 be added to. 510 updates_collections: An optional list of collections that `update_op` should 511 be added to. 512 name: An optional variable_scope name. 513 514 Returns: 515 precision: Scalar float `Tensor` with the value of `true_positives` 516 divided by the sum of `true_positives` and `false_positives`. 517 update_op: `Operation` that increments `true_positives` and 518 `false_positives` variables appropriately and whose value matches 519 `precision`. 520 521 Raises: 522 ValueError: If `predictions` and `labels` have mismatched shapes, or if 523 `weights` is not `None` and its shape doesn't match `predictions`, or if 524 either `metrics_collections` or `updates_collections` are not a list or 525 tuple. 526 """ 527 with variable_scope.variable_scope( 528 name, 'precision', [predictions, labels]): 529 530 predictions, labels, weights = _remove_squeezable_dimensions( 531 predictions, labels, weights) 532 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 533 534 true_positives, true_positives_update_op = _streaming_true_positives( 535 predictions, labels, weights, metrics_collections=None, 536 updates_collections=None, name=None) 537 false_positives, false_positives_update_op = _streaming_false_positives( 538 predictions, labels, weights, metrics_collections=None, 539 updates_collections=None, name=None) 540 541 def compute_precision(name): 542 return math_ops.select( 543 math_ops.greater(true_positives + false_positives, 0), 544 math_ops.div(true_positives, true_positives + false_positives), 545 0, 546 name) 547 548 precision = compute_precision('value') 549 with ops.control_dependencies([true_positives_update_op, 550 false_positives_update_op]): 551 update_op = compute_precision('update_op') 552 553 if metrics_collections: 554 ops.add_to_collections(metrics_collections, precision) 555 556 if updates_collections: 557 ops.add_to_collections(updates_collections, update_op) 558 559 return precision, update_op 560 561 562def streaming_recall(predictions, labels, weights=None, 563 metrics_collections=None, updates_collections=None, 564 name=None): 565 """Computes the recall of the predictions with respect to the labels. 566 567 The `streaming_recall` function creates two local variables, `true_positives` 568 and `false_negatives`, that are used to compute the recall. This value is 569 ultimately returned as `recall`, an idempotent operation that simply divides 570 `true_positives` by the sum of `true_positives` and `false_negatives`. 571 572 For estimation of the metric over a stream of data, the function creates an 573 `update_op` that updates these variables and returns the `recall`. `update_op` 574 weights each prediction by the corresponding value in `weights`. 575 576 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 577 578 Args: 579 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 580 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 581 match `predictions`. 582 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 583 metrics_collections: An optional list of collections that `recall` should 584 be added to. 585 updates_collections: An optional list of collections that `update_op` should 586 be added to. 587 name: An optional variable_scope name. 588 589 Returns: 590 recall: Scalar float `Tensor` with the value of `true_positives` divided 591 by the sum of `true_positives` and `false_negatives`. 592 update_op: `Operation` that increments `true_positives` and 593 `false_negatives` variables appropriately and whose value matches 594 `recall`. 595 596 Raises: 597 ValueError: If `predictions` and `labels` have mismatched shapes, or if 598 `weights` is not `None` and its shape doesn't match `predictions`, or if 599 either `metrics_collections` or `updates_collections` are not a list or 600 tuple. 601 """ 602 with variable_scope.variable_scope(name, 'recall', [predictions, labels]): 603 predictions, labels, weights = _remove_squeezable_dimensions( 604 predictions, labels, weights) 605 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 606 607 true_positives, true_positives_update_op = _streaming_true_positives( 608 predictions, labels, weights, metrics_collections=None, 609 updates_collections=None, name=None) 610 false_negatives, false_negatives_update_op = _streaming_false_negatives( 611 predictions, labels, weights, metrics_collections=None, 612 updates_collections=None, name=None) 613 614 def compute_recall(true_positives, false_negatives, name): 615 return math_ops.select( 616 math_ops.greater(true_positives + false_negatives, 0), 617 math_ops.div(true_positives, true_positives + false_negatives), 618 0, 619 name) 620 621 recall = compute_recall(true_positives, false_negatives, 'value') 622 with ops.control_dependencies([true_positives_update_op, 623 false_negatives_update_op]): 624 update_op = compute_recall(true_positives, false_negatives, 'update_op') 625 626 if metrics_collections: 627 ops.add_to_collections(metrics_collections, recall) 628 629 if updates_collections: 630 ops.add_to_collections(updates_collections, update_op) 631 632 return recall, update_op 633 634 635def _tp_fn_tn_fp(predictions, labels, thresholds, weights=None): 636 """Computes true_positives, false_negatives, true_negatives, false_positives. 637 638 The `_tp_fn_tn_fp` function creates four local variables, `true_positives`, 639 `true_negatives`, `false_positives` and `false_negatives`. 640 `true_positive[i]` is defined as the total weight of values in `predictions` 641 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 642 `false_negatives[i]` is defined as the total weight of values in `predictions` 643 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 644 `true_negatives[i]` is defined as the total weight of values in `predictions` 645 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 646 `false_positives[i]` is defined as the total weight of values in `predictions` 647 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 648 649 For estimation of these metrics over a stream of data, for each metric the 650 function respectively creates an `update_op` operation that updates the 651 variable and returns its value. 652 653 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 654 655 Args: 656 predictions: A floating point `Tensor` of arbitrary shape and whose values 657 are in the range `[0, 1]`. 658 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 659 to `bool`. 660 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 661 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 662 663 Returns: 664 true_positive: A variable of shape [len(thresholds)]. 665 false_negative: A variable of shape [len(thresholds)]. 666 true_negatives: A variable of shape [len(thresholds)]. 667 false_positives: A variable of shape [len(thresholds)]. 668 true_positives_update_op: An operation that increments the `true_positives`. 669 false_negative_update_op: An operation that increments the `false_negative`. 670 true_negatives_update_op: An operation that increments the `true_negatives`. 671 false_positives_update_op: An operation that increments the 672 `false_positives`. 673 674 Raises: 675 ValueError: If `predictions` and `labels` have mismatched shapes, or if 676 `weights` is not `None` and its shape doesn't match `predictions`. 677 """ 678 predictions, labels, weights = _remove_squeezable_dimensions( 679 predictions, labels, weights) 680 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 681 682 num_thresholds = len(thresholds) 683 684 # Reshape predictions and labels. 685 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 686 labels_2d = array_ops.reshape( 687 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 688 689 # Use static shape if known. 690 num_predictions = predictions_2d.get_shape().as_list()[0] 691 692 # Otherwise use dynamic shape. 693 if num_predictions is None: 694 num_predictions = array_ops.shape(predictions_2d)[0] 695 thresh_tiled = array_ops.tile( 696 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 697 array_ops.pack([1, num_predictions])) 698 699 # Tile the predictions after thresholding them across different thresholds. 700 pred_is_pos = math_ops.greater( 701 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 702 thresh_tiled) 703 pred_is_neg = math_ops.logical_not(pred_is_pos) 704 705 # Tile labels by number of thresholds 706 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 707 label_is_neg = math_ops.logical_not(label_is_pos) 708 709 true_positives = _create_local('true_positives', shape=[num_thresholds]) 710 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 711 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 712 false_positives = _create_local('false_positives', shape=[num_thresholds]) 713 714 is_true_positive = math_ops.to_float( 715 math_ops.logical_and(label_is_pos, pred_is_pos)) 716 is_false_negative = math_ops.to_float( 717 math_ops.logical_and(label_is_pos, pred_is_neg)) 718 is_false_positive = math_ops.to_float( 719 math_ops.logical_and(label_is_neg, pred_is_pos)) 720 is_true_negative = math_ops.to_float( 721 math_ops.logical_and(label_is_neg, pred_is_neg)) 722 723 if weights is not None: 724 weights = math_ops.to_float(weights) 725 weights_tiled = array_ops.tile(array_ops.reshape(_broadcast_weights( 726 weights, predictions), [1, -1]), [num_thresholds, 1]) 727 728 thresh_tiled.get_shape().assert_is_compatible_with( 729 weights_tiled.get_shape()) 730 is_true_positive *= weights_tiled 731 is_false_negative *= weights_tiled 732 is_false_positive *= weights_tiled 733 is_true_negative *= weights_tiled 734 735 true_positives_update_op = state_ops.assign_add( 736 true_positives, math_ops.reduce_sum(is_true_positive, 1)) 737 false_negatives_update_op = state_ops.assign_add( 738 false_negatives, math_ops.reduce_sum(is_false_negative, 1)) 739 true_negatives_update_op = state_ops.assign_add( 740 true_negatives, math_ops.reduce_sum(is_true_negative, 1)) 741 false_positives_update_op = state_ops.assign_add( 742 false_positives, math_ops.reduce_sum(is_false_positive, 1)) 743 744 return (true_positives, false_negatives, true_negatives, false_positives, 745 true_positives_update_op, false_negatives_update_op, 746 true_negatives_update_op, false_positives_update_op) 747 748 749def streaming_auc(predictions, labels, weights=None, num_thresholds=200, 750 metrics_collections=None, updates_collections=None, 751 curve='ROC', name=None): 752 """Computes the approximate AUC via a Riemann sum. 753 754 The `streaming_auc` function creates four local variables, `true_positives`, 755 `true_negatives`, `false_positives` and `false_negatives` that are used to 756 compute the AUC. To discretize the AUC curve, a linearly spaced set of 757 thresholds is used to compute pairs of recall and precision values. The area 758 under the ROC-curve is therefore computed using the height of the recall 759 values by the false positive rate, while the area under the PR-curve is the 760 computed using the height of the precision values by the recall. 761 762 This value is ultimately returned as `auc`, an idempotent operation that 763 computes the area under a discretized curve of precision versus recall values 764 (computed using the aforementioned variables). The `num_thresholds` variable 765 controls the degree of discretization with larger numbers of thresholds more 766 closely approximating the true AUC. The quality of the approximation may vary 767 dramatically depending on `num_thresholds`. 768 769 For best results, `predictions` should be distributed approximately uniformly 770 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 771 approximation may be poor if this is not the case. 772 773 For estimation of the metric over a stream of data, the function creates an 774 `update_op` operation that updates these variables and returns the `auc`. 775 776 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 777 778 Args: 779 predictions: A floating point `Tensor` of arbitrary shape and whose values 780 are in the range `[0, 1]`. 781 labels: A `bool` `Tensor` whose shape matches `predictions`. 782 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 783 num_thresholds: The number of thresholds to use when discretizing the roc 784 curve. 785 metrics_collections: An optional list of collections that `auc` should be 786 added to. 787 updates_collections: An optional list of collections that `update_op` should 788 be added to. 789 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 790 'PR' for the Precision-Recall-curve. 791 name: An optional variable_scope name. 792 793 Returns: 794 auc: A scalar tensor representing the current area-under-curve. 795 update_op: An operation that increments the `true_positives`, 796 `true_negatives`, `false_positives` and `false_negatives` variables 797 appropriately and whose value matches `auc`. 798 799 Raises: 800 ValueError: If `predictions` and `labels` have mismatched shapes, or if 801 `weights` is not `None` and its shape doesn't match `predictions`, or if 802 either `metrics_collections` or `updates_collections` are not a list or 803 tuple. 804 """ 805 with variable_scope.variable_scope(name, 'auc', [predictions, labels]): 806 if curve != 'ROC' and curve != 'PR': 807 raise ValueError('curve must be either ROC or PR, %s unknown' % 808 (curve)) 809 kepsilon = 1e-7 # to account for floating point imprecisions 810 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 811 for i in range(num_thresholds-2)] 812 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 813 814 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 815 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 816 817 # Add epsilons to avoid dividing by 0. 818 epsilon = 1.0e-6 819 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 820 821 def compute_auc(tp, fn, tn, fp, name): 822 """Computes the roc-auc or pr-auc based on confusion counts.""" 823 recall = math_ops.div(tp + epsilon, tp + fn + epsilon) 824 if curve == 'ROC': 825 fp_rate = math_ops.div(fp, fp + tn + epsilon) 826 x = fp_rate 827 y = recall 828 else: # curve == 'PR'. 829 precision = math_ops.div(tp + epsilon, tp + fp + epsilon) 830 x = recall 831 y = precision 832 return math_ops.reduce_sum(math_ops.mul( 833 x[:num_thresholds - 1] - x[1:], 834 (y[:num_thresholds - 1] + y[1:]) / 2.), name=name) 835 836 # sum up the areas of all the trapeziums 837 auc = compute_auc(tp, fn, tn, fp, 'value') 838 update_op = compute_auc( 839 tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op') 840 841 if metrics_collections: 842 ops.add_to_collections(metrics_collections, auc) 843 844 if updates_collections: 845 ops.add_to_collections(updates_collections, update_op) 846 847 return auc, update_op 848 849 850def streaming_specificity_at_sensitivity( 851 predictions, labels, sensitivity, weights=None, num_thresholds=200, 852 metrics_collections=None, updates_collections=None, name=None): 853 """Computes the specificity at a given sensitivity. 854 855 The `streaming_specificity_at_sensitivity` function creates four local 856 variables, `true_positives`, `true_negatives`, `false_positives` and 857 `false_negatives` that are used to compute the specificity at the given 858 sensitivity value. The threshold for the given sensitivity value is computed 859 and used to evaluate the corresponding specificity. 860 861 For estimation of the metric over a stream of data, the function creates an 862 `update_op` operation that updates these variables and returns the 863 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 864 `false_positives` and `false_negatives` counts with the weight of each case 865 found in the `predictions` and `labels`. 866 867 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 868 869 For additional information about specificity and sensitivity, see the 870 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 871 872 Args: 873 predictions: A floating point `Tensor` of arbitrary shape and whose values 874 are in the range `[0, 1]`. 875 labels: A `bool` `Tensor` whose shape matches `predictions`. 876 sensitivity: A scalar value in range `[0, 1]`. 877 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 878 num_thresholds: The number of thresholds to use for matching the given 879 sensitivity. 880 metrics_collections: An optional list of collections that `specificity` 881 should be added to. 882 updates_collections: An optional list of collections that `update_op` should 883 be added to. 884 name: An optional variable_scope name. 885 886 Returns: 887 specificity: A scalar tensor representing the specificity at the given 888 `specificity` value. 889 update_op: An operation that increments the `true_positives`, 890 `true_negatives`, `false_positives` and `false_negatives` variables 891 appropriately and whose value matches `specificity`. 892 893 Raises: 894 ValueError: If `predictions` and `labels` have mismatched shapes, if 895 `weights` is not `None` and its shape doesn't match `predictions`, or if 896 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 897 or `updates_collections` are not a list or tuple. 898 """ 899 if sensitivity < 0 or sensitivity > 1: 900 raise ValueError('`sensitivity` must be in the range [0, 1].') 901 902 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 903 [predictions, labels]): 904 kepsilon = 1e-7 # to account for floating point imprecisions 905 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 906 for i in range(num_thresholds-2)] 907 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 908 909 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 910 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 911 912 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 913 914 def compute_specificity_at_sensitivity(name): 915 """Computes the specificity at the given sensitivity. 916 917 Args: 918 name: The name of the operation. 919 920 Returns: 921 The specificity using the aggregated values. 922 """ 923 sensitivities = math_ops.div(tp, tp + fn + kepsilon) 924 925 # We'll need to use this trick until tf.argmax allows us to specify 926 # whether we should use the first or last index in case of ties. 927 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 928 indices_at_minval = math_ops.equal( 929 math_ops.abs(sensitivities - sensitivity), min_val) 930 indices_at_minval = math_ops.to_int64(indices_at_minval) 931 indices_at_minval = math_ops.cumsum(indices_at_minval) 932 tf_index = math_ops.argmax(indices_at_minval, 0) 933 tf_index = math_ops.cast(tf_index, dtypes.int32) 934 935 # Now, we have the implicit threshold, so compute the specificity: 936 return math_ops.div(tn[tf_index], 937 tn[tf_index] + fp[tf_index] + kepsilon, 938 name) 939 940 specificity = compute_specificity_at_sensitivity('value') 941 with ops.control_dependencies( 942 [tp_update_op, fn_update_op, tn_update_op, fp_update_op]): 943 update_op = compute_specificity_at_sensitivity('update_op') 944 945 if metrics_collections: 946 ops.add_to_collections(metrics_collections, specificity) 947 948 if updates_collections: 949 ops.add_to_collections(updates_collections, update_op) 950 951 return specificity, update_op 952 953 954def streaming_sensitivity_at_specificity( 955 predictions, labels, specificity, weights=None, num_thresholds=200, 956 metrics_collections=None, updates_collections=None, name=None): 957 """Computes the specificity at a given sensitivity. 958 959 The `streaming_sensitivity_at_specificity` function creates four local 960 variables, `true_positives`, `true_negatives`, `false_positives` and 961 `false_negatives` that are used to compute the sensitivity at the given 962 specificity value. The threshold for the given specificity value is computed 963 and used to evaluate the corresponding sensitivity. 964 965 For estimation of the metric over a stream of data, the function creates an 966 `update_op` operation that updates these variables and returns the 967 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 968 `false_positives` and `false_negatives` counts with the weight of each case 969 found in the `predictions` and `labels`. 970 971 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 972 973 For additional information about specificity and sensitivity, see the 974 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 975 976 Args: 977 predictions: A floating point `Tensor` of arbitrary shape and whose values 978 are in the range `[0, 1]`. 979 labels: A `bool` `Tensor` whose shape matches `predictions`. 980 specificity: A scalar value in range `[0, 1]`. 981 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 982 num_thresholds: The number of thresholds to use for matching the given 983 specificity. 984 metrics_collections: An optional list of collections that `sensitivity` 985 should be added to. 986 updates_collections: An optional list of collections that `update_op` should 987 be added to. 988 name: An optional variable_scope name. 989 990 Returns: 991 sensitivity: A scalar tensor representing the sensitivity at the given 992 `specificity` value. 993 update_op: An operation that increments the `true_positives`, 994 `true_negatives`, `false_positives` and `false_negatives` variables 995 appropriately and whose value matches `sensitivity`. 996 997 Raises: 998 ValueError: If `predictions` and `labels` have mismatched shapes, if 999 `weights` is not `None` and its shape doesn't match `predictions`, or if 1000 `specificity` is not between 0 and 1, or if either `metrics_collections` 1001 or `updates_collections` are not a list or tuple. 1002 """ 1003 if specificity < 0 or specificity > 1: 1004 raise ValueError('`specificity` must be in the range [0, 1].') 1005 1006 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 1007 [predictions, labels]): 1008 kepsilon = 1e-7 # to account for floating point imprecisions 1009 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1010 for i in range(num_thresholds-2)] 1011 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 1012 1013 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 1014 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 1015 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 1016 1017 def compute_sensitivity_at_specificity(name): 1018 specificities = math_ops.div(tn, tn + fp + kepsilon) 1019 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 1020 tf_index = math_ops.cast(tf_index, dtypes.int32) 1021 1022 # Now, we have the implicit threshold, so compute the sensitivity: 1023 return math_ops.div(tp[tf_index], 1024 tp[tf_index] + fn[tf_index] + kepsilon, 1025 name) 1026 1027 sensitivity = compute_sensitivity_at_specificity('value') 1028 with ops.control_dependencies( 1029 [tp_update_op, fn_update_op, tn_update_op, fp_update_op]): 1030 update_op = compute_sensitivity_at_specificity('update_op') 1031 1032 if metrics_collections: 1033 ops.add_to_collections(metrics_collections, sensitivity) 1034 1035 if updates_collections: 1036 ops.add_to_collections(updates_collections, update_op) 1037 1038 return sensitivity, update_op 1039 1040 1041def streaming_precision_at_thresholds(predictions, labels, thresholds, 1042 weights=None, 1043 metrics_collections=None, 1044 updates_collections=None, name=None): 1045 """Computes precision values for different `thresholds` on `predictions`. 1046 1047 The `streaming_precision_at_thresholds` function creates four local variables, 1048 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1049 for various values of thresholds. `precision[i]` is defined as the total 1050 weight of values in `predictions` above `thresholds[i]` whose corresponding 1051 entry in `labels` is `True`, divided by the total weight of values in 1052 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 1053 false_positives[i])`). 1054 1055 For estimation of the metric over a stream of data, the function creates an 1056 `update_op` operation that updates these variables and returns the 1057 `precision`. 1058 1059 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1060 1061 Args: 1062 predictions: A floating point `Tensor` of arbitrary shape and whose values 1063 are in the range `[0, 1]`. 1064 labels: A `bool` `Tensor` whose shape matches `predictions`. 1065 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1066 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 1067 metrics_collections: An optional list of collections that `auc` should be 1068 added to. 1069 updates_collections: An optional list of collections that `update_op` should 1070 be added to. 1071 name: An optional variable_scope name. 1072 1073 Returns: 1074 precision: A float tensor of shape [len(thresholds)]. 1075 update_op: An operation that increments the `true_positives`, 1076 `true_negatives`, `false_positives` and `false_negatives` variables that 1077 are used in the computation of `precision`. 1078 1079 Raises: 1080 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1081 `weights` is not `None` and its shape doesn't match `predictions`, or if 1082 either `metrics_collections` or `updates_collections` are not a list or 1083 tuple. 1084 """ 1085 with variable_scope.variable_scope(name, 'precision_at_thresholds', 1086 [predictions, labels]): 1087 1088 # TODO(nsilberman): Replace with only tp and fp, this results in unnecessary 1089 # variable creation. b/30842882 1090 (true_positives, _, _, false_positives, true_positives_compute_op, _, _, 1091 false_positives_compute_op,) = _tp_fn_tn_fp( 1092 predictions, labels, thresholds, weights) 1093 1094 # avoid division by zero 1095 epsilon = 1e-7 1096 def compute_precision(name): 1097 precision = math_ops.div(true_positives, 1098 epsilon + true_positives + false_positives, 1099 name='precision_' + name) 1100 return precision 1101 1102 precision = compute_precision('value') 1103 with ops.control_dependencies([true_positives_compute_op, 1104 false_positives_compute_op]): 1105 update_op = compute_precision('update_op') 1106 1107 if metrics_collections: 1108 ops.add_to_collections(metrics_collections, precision) 1109 1110 if updates_collections: 1111 ops.add_to_collections(updates_collections, update_op) 1112 1113 return precision, update_op 1114 1115 1116def streaming_recall_at_thresholds(predictions, labels, thresholds, 1117 weights=None, metrics_collections=None, 1118 updates_collections=None, name=None): 1119 """Computes various recall values for different `thresholds` on `predictions`. 1120 1121 The `streaming_recall_at_thresholds` function creates four local variables, 1122 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1123 for various values of thresholds. `recall[i]` is defined as the total weight 1124 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1125 `labels` is `True`, divided by the total weight of `True` values in `labels` 1126 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 1127 1128 For estimation of the metric over a stream of data, the function creates an 1129 `update_op` operation that updates these variables and returns the `recall`. 1130 1131 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1132 1133 Args: 1134 predictions: A floating point `Tensor` of arbitrary shape and whose values 1135 are in the range `[0, 1]`. 1136 labels: A `bool` `Tensor` whose shape matches `predictions`. 1137 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1138 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 1139 metrics_collections: An optional list of collections that `recall` should be 1140 added to. 1141 updates_collections: An optional list of collections that `update_op` should 1142 be added to. 1143 name: An optional variable_scope name. 1144 1145 Returns: 1146 recall: A float tensor of shape [len(thresholds)]. 1147 update_op: An operation that increments the `true_positives`, 1148 `true_negatives`, `false_positives` and `false_negatives` variables that 1149 are used in the computation of `recall`. 1150 1151 Raises: 1152 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1153 `weights` is not `None` and its shape doesn't match `predictions`, or if 1154 either `metrics_collections` or `updates_collections` are not a list or 1155 tuple. 1156 """ 1157 with variable_scope.variable_scope(name, 'recall_at_thresholds', 1158 [predictions, labels]): 1159 (true_positives, false_negatives, _, _, true_positives_compute_op, 1160 false_negatives_compute_op, _, _,) = _tp_fn_tn_fp( 1161 predictions, labels, thresholds, weights) 1162 1163 # avoid division by zero 1164 epsilon = 1e-7 1165 def compute_recall(name): 1166 recall = math_ops.div(true_positives, 1167 epsilon + true_positives + false_negatives, 1168 name='recall_' + name) 1169 return recall 1170 1171 recall = compute_recall('value') 1172 with ops.control_dependencies([true_positives_compute_op, 1173 false_negatives_compute_op]): 1174 update_op = compute_recall('update_op') 1175 1176 if metrics_collections: 1177 ops.add_to_collections(metrics_collections, recall) 1178 1179 if updates_collections: 1180 ops.add_to_collections(updates_collections, update_op) 1181 1182 return recall, update_op 1183 1184 1185def _at_k_name(name, k=None, class_id=None): 1186 if k is not None: 1187 name = '%s_at_%d' % (name, k) 1188 else: 1189 name = '%s_at_k' % (name) 1190 if class_id is not None: 1191 name = '%s_class%d' % (name, class_id) 1192 return name 1193 1194 1195@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 1196 'and reshape labels from [batch_size] to [batch_size, 1].') 1197def streaming_recall_at_k(predictions, labels, k, weights=None, 1198 metrics_collections=None, updates_collections=None, 1199 name=None): 1200 """Computes the recall@k of the predictions with respect to dense labels. 1201 1202 The `streaming_recall_at_k` function creates two local variables, `total` and 1203 `count`, that are used to compute the recall@k frequency. This frequency is 1204 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1205 divides `total` by `count`. 1206 1207 For estimation of the metric over a stream of data, the function creates an 1208 `update_op` operation that updates these variables and returns the 1209 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 1210 shape [batch_size] whose elements indicate whether or not the corresponding 1211 label is in the top `k` `predictions`. Then `update_op` increments `total` 1212 with the reduced sum of `weights` where `in_top_k` is `True`, and it 1213 increments `count` with the reduced sum of `weights`. 1214 1215 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1216 1217 Args: 1218 predictions: A floating point tensor of dimension [batch_size, num_classes] 1219 labels: A tensor of dimension [batch_size] whose type is in `int32`, 1220 `int64`. 1221 k: The number of top elements to look at for computing recall. 1222 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 1223 metrics_collections: An optional list of collections that `recall_at_k` 1224 should be added to. 1225 updates_collections: An optional list of collections `update_op` should be 1226 added to. 1227 name: An optional variable_scope name. 1228 1229 Returns: 1230 recall_at_k: A tensor representing the recall@k, the fraction of labels 1231 which fall into the top `k` predictions. 1232 update_op: An operation that increments the `total` and `count` variables 1233 appropriately and whose value matches `recall_at_k`. 1234 1235 Raises: 1236 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1237 `weights` is not `None` and its shape doesn't match `predictions`, or if 1238 either `metrics_collections` or `updates_collections` are not a list or 1239 tuple. 1240 """ 1241 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1242 return streaming_mean(in_top_k, 1243 weights, 1244 metrics_collections, 1245 updates_collections, 1246 name or _at_k_name('recall', k)) 1247 1248 1249# TODO(ptucker): Validate range of values in labels? 1250def streaming_sparse_recall_at_k(predictions, 1251 labels, 1252 k, 1253 class_id=None, 1254 weights=None, 1255 metrics_collections=None, 1256 updates_collections=None, 1257 name=None): 1258 """Computes recall@k of the predictions with respect to sparse labels. 1259 1260 If `class_id` is specified, we calculate recall by considering only the 1261 entries in the batch for which `class_id` is in the label, and computing 1262 the fraction of them for which `class_id` is in the top-k `predictions`. 1263 If `class_id` is not specified, we'll calculate recall as how often on 1264 average a class among the labels of a batch entry is in the top-k 1265 `predictions`. 1266 1267 `streaming_sparse_recall_at_k` creates two local variables, 1268 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1269 the recall_at_k frequency. This frequency is ultimately returned as 1270 `recall_at_<k>`: an idempotent operation that simply divides 1271 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1272 `false_negative_at_<k>`). 1273 1274 For estimation of the metric over a stream of data, the function creates an 1275 `update_op` operation that updates these variables and returns the 1276 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1277 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1278 `labels` calculate the true positives and false negatives weighted by 1279 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1280 `false_negative_at_<k>` using these values. 1281 1282 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1283 1284 Args: 1285 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1286 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1287 The final dimension contains the logit values for each class. [D1, ... DN] 1288 must match `labels`. 1289 labels: `int64` `Tensor` or `SparseTensor` with shape 1290 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1291 target classes for the associated prediction. Commonly, N=1 and `labels` 1292 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 1293 Values should be in range [0, num_classes), where num_classes is the last 1294 dimension of `predictions`. Values outside this range always count 1295 towards `false_negative_at_<k>`. 1296 k: Integer, k for @k metric. 1297 class_id: Integer class ID for which we want binary metrics. This should be 1298 in range [0, num_classes), where num_classes is the last dimension of 1299 `predictions`. If class_id is outside this range, the method returns NAN. 1300 weights: An optional `Tensor` whose shape is broadcastable to the first 1301 [D1, ... DN] dimensions of `predictions` and `labels`. 1302 metrics_collections: An optional list of collections that values should 1303 be added to. 1304 updates_collections: An optional list of collections that updates should 1305 be added to. 1306 name: Name of new update operation, and namespace for other dependent ops. 1307 1308 Returns: 1309 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1310 by the sum of `true_positives` and `false_negatives`. 1311 update_op: `Operation` that increments `true_positives` and 1312 `false_negatives` variables appropriately, and whose value matches 1313 `recall`. 1314 1315 Raises: 1316 ValueError: If `weights` is not `None` and its shape doesn't match 1317 `predictions`, or if either `metrics_collections` or `updates_collections` 1318 are not a list or tuple. 1319 """ 1320 default_name = _at_k_name('recall', k, class_id=class_id) 1321 with ops.name_scope(name, default_name, (predictions, labels)) as scope: 1322 _, top_k_idx = nn.top_k(predictions, k) 1323 top_k_idx = math_ops.to_int64(top_k_idx) 1324 tp, tp_update = _streaming_sparse_true_positive_at_k( 1325 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1326 weights=weights) 1327 fn, fn_update = _streaming_sparse_false_negative_at_k( 1328 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1329 weights=weights) 1330 1331 metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) 1332 update = math_ops.div( 1333 tp_update, math_ops.add(tp_update, fn_update), name='update') 1334 if metrics_collections: 1335 ops.add_to_collections(metrics_collections, metric) 1336 if updates_collections: 1337 ops.add_to_collections(updates_collections, update) 1338 return metric, update 1339 1340 1341def _streaming_sparse_precision_at_k(top_k_idx, 1342 labels, 1343 k=None, 1344 class_id=None, 1345 weights=None, 1346 metrics_collections=None, 1347 updates_collections=None, 1348 name=None): 1349 """Computes precision@k of the top-k indices with respect to sparse labels. 1350 1351 This method contains the code shared by streaming_sparse_precision_at_k and 1352 streaming_sparse_precision_at_top_k. Refer to those methods for more details. 1353 1354 Args: 1355 top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where 1356 N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k]. 1357 The final dimension contains the indices of top-k labels. [D1, ... DN] 1358 must match `labels`. 1359 labels: `int64` `Tensor` or `SparseTensor` with shape 1360 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1361 target classes for the associated prediction. Commonly, N=1 and `labels` 1362 has shape [batch_size, num_labels]. [D1, ... DN] must match 1363 `predictions_idx`. Values should be in range [0, num_classes), where 1364 num_classes is the last dimension of `predictions`. Values outside this 1365 range are ignored. 1366 k: Integer, k for @k metric or `None`. Only used for default op name. 1367 class_id: Integer class ID for which we want binary metrics. This should be 1368 in range [0, num_classes), where num_classes is the last dimension of 1369 `predictions`. If `class_id` is outside this range, the method returns 1370 NAN. 1371 weights: An optional `Tensor` whose shape is broadcastable to the first 1372 [D1, ... DN] dimensions of `predictions` and `labels`. 1373 metrics_collections: An optional list of collections that values should 1374 be added to. 1375 updates_collections: An optional list of collections that updates should 1376 be added to. 1377 name: Name of the metric and of the enclosing scope. 1378 1379 Returns: 1380 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1381 divided by the sum of `true_positives` and `false_positives`. 1382 update_op: `Operation` that increments `true_positives` and 1383 `false_positives` variables appropriately, and whose value matches 1384 `precision`. 1385 1386 Raises: 1387 ValueError: If `weights` is not `None` and its shape doesn't match 1388 `predictions`, or if either `metrics_collections` or `updates_collections` 1389 are not a list or tuple. 1390 """ 1391 top_k_idx = math_ops.to_int64(top_k_idx) 1392 tp, tp_update = _streaming_sparse_true_positive_at_k( 1393 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1394 weights=weights) 1395 fp, fp_update = _streaming_sparse_false_positive_at_k( 1396 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1397 weights=weights) 1398 1399 metric = math_ops.div(tp, math_ops.add(tp, fp), name=name) 1400 update = math_ops.div( 1401 tp_update, math_ops.add(tp_update, fp_update), name='update') 1402 if metrics_collections: 1403 ops.add_to_collections(metrics_collections, metric) 1404 if updates_collections: 1405 ops.add_to_collections(updates_collections, update) 1406 return metric, update 1407 1408 1409# TODO(ptucker): Validate range of values in labels? 1410def streaming_sparse_precision_at_k(predictions, 1411 labels, 1412 k, 1413 class_id=None, 1414 weights=None, 1415 metrics_collections=None, 1416 updates_collections=None, 1417 name=None): 1418 """Computes precision@k of the predictions with respect to sparse labels. 1419 1420 If `class_id` is specified, we calculate precision by considering only the 1421 entries in the batch for which `class_id` is in the top-k highest 1422 `predictions`, and computing the fraction of them for which `class_id` is 1423 indeed a correct label. 1424 If `class_id` is not specified, we'll calculate precision as how often on 1425 average a class among the top-k classes with the highest predicted values 1426 of a batch entry is correct and can be found in the label for that entry. 1427 1428 `streaming_sparse_precision_at_k` creates two local variables, 1429 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1430 the precision@k frequency. This frequency is ultimately returned as 1431 `precision_at_<k>`: an idempotent operation that simply divides 1432 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1433 `false_positive_at_<k>`). 1434 1435 For estimation of the metric over a stream of data, the function creates an 1436 `update_op` operation that updates these variables and returns the 1437 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1438 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1439 `labels` calculate the true positives and false positives weighted by 1440 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1441 `false_positive_at_<k>` using these values. 1442 1443 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1444 1445 Args: 1446 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1447 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1448 The final dimension contains the logit values for each class. [D1, ... DN] 1449 must match `labels`. 1450 labels: `int64` `Tensor` or `SparseTensor` with shape 1451 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1452 target classes for the associated prediction. Commonly, N=1 and `labels` 1453 has shape [batch_size, num_labels]. [D1, ... DN] must match 1454 `predictions`. Values should be in range [0, num_classes), where 1455 num_classes is the last dimension of `predictions`. Values outside this 1456 range are ignored. 1457 k: Integer, k for @k metric. 1458 class_id: Integer class ID for which we want binary metrics. This should be 1459 in range [0, num_classes], where num_classes is the last dimension of 1460 `predictions`. If `class_id` is outside this range, the method returns 1461 NAN. 1462 weights: An optional `Tensor` whose shape is broadcastable to the first 1463 [D1, ... DN] dimensions of `predictions` and `labels`. 1464 metrics_collections: An optional list of collections that values should 1465 be added to. 1466 updates_collections: An optional list of collections that updates should 1467 be added to. 1468 name: Name of new update operation, and namespace for other dependent ops. 1469 1470 Returns: 1471 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1472 divided by the sum of `true_positives` and `false_positives`. 1473 update_op: `Operation` that increments `true_positives` and 1474 `false_positives` variables appropriately, and whose value matches 1475 `precision`. 1476 1477 Raises: 1478 ValueError: If `weights` is not `None` and its shape doesn't match 1479 `predictions`, or if either `metrics_collections` or `updates_collections` 1480 are not a list or tuple. 1481 """ 1482 default_name = _at_k_name('precision', k, class_id=class_id) 1483 with ops.name_scope(name, default_name, 1484 (predictions, labels, weights)) as scope: 1485 _, top_k_idx = nn.top_k(predictions, k) 1486 return _streaming_sparse_precision_at_k( 1487 top_k_idx=top_k_idx, 1488 labels=labels, 1489 k=k, 1490 class_id=class_id, 1491 weights=weights, 1492 metrics_collections=metrics_collections, 1493 updates_collections=updates_collections, 1494 name=scope) 1495 1496 1497# TODO(ptucker): Validate range of values in labels? 1498def streaming_sparse_precision_at_top_k(top_k_predictions, 1499 labels, 1500 class_id=None, 1501 weights=None, 1502 metrics_collections=None, 1503 updates_collections=None, 1504 name=None): 1505 """Computes precision@k of top-k predictions with respect to sparse labels. 1506 1507 If `class_id` is specified, we calculate precision by considering only the 1508 entries in the batch for which `class_id` is in the top-k highest 1509 `predictions`, and computing the fraction of them for which `class_id` is 1510 indeed a correct label. 1511 If `class_id` is not specified, we'll calculate precision as how often on 1512 average a class among the top-k classes with the highest predicted values 1513 of a batch entry is correct and can be found in the label for that entry. 1514 1515 `streaming_sparse_precision_at_top_k` creates two local variables, 1516 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 1517 the precision@k frequency. This frequency is ultimately returned as 1518 `precision_at_k`: an idempotent operation that simply divides 1519 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 1520 1521 For estimation of the metric over a stream of data, the function creates an 1522 `update_op` operation that updates these variables and returns the 1523 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 1524 and `labels` calculate the true positives and false positives weighted by 1525 `weights`. Then `update_op` increments `true_positive_at_k` and 1526 `false_positive_at_k` using these values. 1527 1528 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1529 1530 Args: 1531 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 1532 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 1533 The final dimension contains the indices of top-k labels. [D1, ... DN] 1534 must match `labels`. 1535 labels: `int64` `Tensor` or `SparseTensor` with shape 1536 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1537 target classes for the associated prediction. Commonly, N=1 and `labels` 1538 has shape [batch_size, num_labels]. [D1, ... DN] must match 1539 `top_k_predictions`. Values should be in range [0, num_classes), where 1540 num_classes is the last dimension of `predictions`. Values outside this 1541 range are ignored. 1542 class_id: Integer class ID for which we want binary metrics. This should be 1543 in range [0, num_classes), where num_classes is the last dimension of 1544 `predictions`. If `class_id` is outside this range, the method returns 1545 NAN. 1546 weights: An optional `Tensor` whose shape is broadcastable to the first 1547 [D1, ... DN] dimensions of `predictions` and `labels`. 1548 metrics_collections: An optional list of collections that values should 1549 be added to. 1550 updates_collections: An optional list of collections that updates should 1551 be added to. 1552 name: Name of new update operation, and namespace for other dependent ops. 1553 1554 Returns: 1555 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1556 divided by the sum of `true_positives` and `false_positives`. 1557 update_op: `Operation` that increments `true_positives` and 1558 `false_positives` variables appropriately, and whose value matches 1559 `precision`. 1560 1561 Raises: 1562 ValueError: If `weights` is not `None` and its shape doesn't match 1563 `predictions`, or if either `metrics_collections` or `updates_collections` 1564 are not a list or tuple. 1565 ValueError: If `top_k_predictions` has rank < 2. 1566 """ 1567 default_name = _at_k_name('precision', class_id=class_id) 1568 with ops.name_scope( 1569 name, default_name, 1570 (top_k_predictions, labels, weights)) as scope: 1571 rank = array_ops.rank(top_k_predictions) 1572 check_rank_op = control_flow_ops.Assert( 1573 math_ops.greater_equal(rank, 2), 1574 ['top_k_predictions must have rank 2 or higher, e.g. [batch_size, k].']) 1575 with ops.control_dependencies([check_rank_op]): 1576 return _streaming_sparse_precision_at_k( 1577 top_k_idx=top_k_predictions, 1578 labels=labels, 1579 class_id=class_id, 1580 weights=weights, 1581 metrics_collections=metrics_collections, 1582 updates_collections=updates_collections, 1583 name=scope) 1584 1585 1586def num_relevant(labels, k): 1587 """Computes number of relevant values for each row in labels. 1588 1589 For labels with shape [D1, ... DN, num_labels], this is the minimum of 1590 `num_labels` and `k`. 1591 1592 Args: 1593 labels: `int64` `Tensor` or `SparseTensor` with shape 1594 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1595 target classes for the associated prediction. Commonly, N=1 and `labels` 1596 has shape [batch_size, num_labels]. 1597 k: Integer, k for @k metric. 1598 1599 Returns: 1600 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 1601 relevant values for that row. 1602 1603 Raises: 1604 ValueError: if inputs have invalid dtypes or values. 1605 """ 1606 if k < 1: 1607 raise ValueError('Invalid k=%s.' % k) 1608 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 1609 # For SparseTensor, calculate separate count for each row. 1610 if isinstance( 1611 labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 1612 labels_sizes = set_ops.set_size(labels) 1613 return math_ops.minimum(labels_sizes, k, name=scope) 1614 1615 # For dense Tensor, calculate scalar count based on last dimension, and 1616 # tile across labels shape. 1617 labels_shape = array_ops.shape(labels) 1618 labels_size = labels_shape[-1] 1619 num_relevant_scalar = math_ops.minimum(labels_size, k) 1620 return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope) 1621 1622 1623def expand_and_tile(tensor, multiple, dim=0, name=None): 1624 """Slice `tensor` shape in 2, then tile along the sliced dimension. 1625 1626 A new dimension is inserted in shape of `tensor` before `dim`, then values are 1627 tiled `multiple` times along the new dimension. 1628 1629 Args: 1630 tensor: Input `Tensor` or `SparseTensor`. 1631 multiple: Integer, number of times to tile. 1632 dim: Integer, dimension along which to tile. 1633 name: Name of operation. 1634 1635 Returns: 1636 `Tensor` result of expanding and tiling `tensor`. 1637 1638 Raises: 1639 ValueError: if `multiple` is less than 1, or `dim` is not in 1640 `[-rank(tensor), rank(tensor)]`. 1641 """ 1642 if multiple < 1: 1643 raise ValueError('Invalid multiple %s, must be > 0.' % multiple) 1644 with ops.name_scope( 1645 name, 'expand_and_tile', (tensor, multiple, dim)) as scope: 1646 # Sparse. 1647 if isinstance(tensor, sparse_tensor.SparseTensorValue): 1648 tensor = sparse_tensor.SparseTensor.from_value(tensor) 1649 if isinstance(tensor, sparse_tensor.SparseTensor): 1650 if dim < 0: 1651 expand_dims = array_ops.reshape( 1652 array_ops.size(tensor.shape) + dim, [1]) 1653 else: 1654 expand_dims = [dim] 1655 expanded_shape = array_ops.concat( 1656 0, (array_ops.slice(tensor.shape, [0], expand_dims), [1], 1657 array_ops.slice(tensor.shape, expand_dims, [-1])), 1658 name='expanded_shape') 1659 expanded = sparse_ops.sparse_reshape( 1660 tensor, shape=expanded_shape, name='expand') 1661 if multiple == 1: 1662 return expanded 1663 return sparse_ops.sparse_concat( 1664 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 1665 1666 # Dense. 1667 expanded = array_ops.expand_dims( 1668 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 1669 if multiple == 1: 1670 return expanded 1671 ones = array_ops.ones_like(array_ops.shape(tensor)) 1672 tile_multiples = array_ops.concat( 1673 0, (ones[:dim], (multiple,), ones[dim:]), name='multiples') 1674 return array_ops.tile(expanded, tile_multiples, name=scope) 1675 1676 1677def sparse_average_precision_at_k(predictions, labels, k): 1678 """Computes average precision@k of predictions with respect to sparse labels. 1679 1680 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 1681 for each row is: 1682 1683 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 1684 1685 A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`, 1686 and the result `Tensors`. In the common case, this is [batch_size]. Each row 1687 of the results contains the average precision for that row. 1688 1689 Internally, a `top_k` operation computes a `Tensor` indicating the top `k` 1690 `predictions`. Set operations applied to `top_k` and `labels` calculate the 1691 true positives, which are used to calculate the precision ("P_{i}" term, 1692 above). 1693 1694 Args: 1695 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1696 N >= 1. Commonly, N=1 and `predictions` has shape 1697 [batch size, num_classes]. The final dimension contains the logit values 1698 for each class. [D1, ... DN] must match `labels`. 1699 labels: `int64` `Tensor` or `SparseTensor` with shape 1700 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1701 target classes for the associated prediction. Commonly, N=1 and `labels` 1702 has shape [batch_size, num_labels]. [D1, ... DN] must match 1703 `predictions`. Values should be in range [0, num_classes), where 1704 num_classes is the last dimension of `predictions`. Values outside this 1705 range are ignored. 1706 k: Integer, k for @k metric. This will calculate an average precision for 1707 range `[1,k]`, as documented above. 1708 1709 Returns: 1710 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 1711 precision for that row. 1712 1713 Raises: 1714 ValueError: if k is invalid. 1715 """ 1716 if k < 1: 1717 raise ValueError('Invalid k=%s.' % k) 1718 with ops.name_scope( 1719 None, 'average_precision', (predictions, labels, k)) as scope: 1720 # Calculate top k indices to produce [D1, ... DN, k] tensor. 1721 _, predictions_idx = nn.top_k(predictions, k) 1722 predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx') 1723 1724 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 1725 # prediction for each k, so we can calculate separate true positive values 1726 # for each k. 1727 predictions_idx_per_k = array_ops.expand_dims( 1728 predictions_idx, -1, name='predictions_idx_per_k') 1729 1730 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 1731 labels_per_k = expand_and_tile( 1732 labels, multiple=k, dim=-1, name='labels_per_k') 1733 1734 # The following tensors are all of shape [D1, ... DN, k], containing values 1735 # per row, per k value. 1736 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 1737 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 1738 # the formula above. 1739 # `tp_per_k` (int32) - True positive counts. 1740 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 1741 # the precision denominator. 1742 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 1743 # term from the formula above. 1744 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 1745 # precisions at all k for which relevance indicator is true. 1746 relevant_per_k = _sparse_true_positive_at_k( 1747 predictions_idx_per_k, labels_per_k, name='relevant_per_k') 1748 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 1749 retrieved_per_k = math_ops.cumsum( 1750 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 1751 precision_per_k = math_ops.div( 1752 math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k), 1753 name='precision_per_k') 1754 relevant_precision_per_k = math_ops.mul( 1755 precision_per_k, math_ops.to_double(relevant_per_k), 1756 name='relevant_precision_per_k') 1757 1758 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 1759 precision_sum = math_ops.reduce_sum( 1760 relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum') 1761 1762 # Divide by number of relevant items to get average precision. These are 1763 # the "num_relevant_items" and "AveP" terms from the formula above. 1764 num_relevant_items = math_ops.to_double(num_relevant(labels, k)) 1765 return math_ops.div(precision_sum, num_relevant_items, name=scope) 1766 1767 1768def streaming_sparse_average_precision_at_k(predictions, 1769 labels, 1770 k, 1771 weights=None, 1772 metrics_collections=None, 1773 updates_collections=None, 1774 name=None): 1775 """Computes average precision@k of predictions with respect to sparse labels. 1776 1777 See `sparse_average_precision_at_k` for details on formula. `weights` are 1778 applied to the result of `sparse_average_precision_at_k` 1779 1780 `streaming_sparse_average_precision_at_k` creates two local variables, 1781 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 1782 are used to compute the frequency. This frequency is ultimately returned as 1783 `average_precision_at_<k>`: an idempotent operation that simply divides 1784 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 1785 1786 For estimation of the metric over a stream of data, the function creates an 1787 `update_op` operation that updates these variables and returns the 1788 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1789 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1790 `labels` calculate the true positives and false positives weighted by 1791 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1792 `false_positive_at_<k>` using these values. 1793 1794 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1795 1796 Args: 1797 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1798 N >= 1. Commonly, N=1 and `predictions` has shape 1799 [batch size, num_classes]. The final dimension contains the logit values 1800 for each class. [D1, ... DN] must match `labels`. 1801 labels: `int64` `Tensor` or `SparseTensor` with shape 1802 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1803 target classes for the associated prediction. Commonly, N=1 and `labels` 1804 has shape [batch_size, num_labels]. [D1, ... DN] must match 1805 `predictions_`. Values should be in range [0, num_classes), where 1806 num_classes is the last dimension of `predictions`. Values outside this 1807 range are ignored. 1808 k: Integer, k for @k metric. This will calculate an average precision for 1809 range `[1,k]`, as documented above. 1810 weights: An optional `Tensor` whose shape is broadcastable to the first 1811 [D1, ... DN] dimensions of `predictions` and `labels`. 1812 metrics_collections: An optional list of collections that values should 1813 be added to. 1814 updates_collections: An optional list of collections that updates should 1815 be added to. 1816 name: Name of new update operation, and namespace for other dependent ops. 1817 1818 Returns: 1819 mean_average_precision: Scalar `float64` `Tensor` with the mean average 1820 precision values. 1821 update: `Operation` that increments variables appropriately, and whose 1822 value matches `metric`. 1823 """ 1824 default_name = _at_k_name('average_precision', k) 1825 with ops.name_scope(name, default_name, (predictions, labels)) as scope: 1826 # Calculate per-example average precision, and apply weights. 1827 average_precision = sparse_average_precision_at_k( 1828 predictions=predictions, labels=labels, k=k) 1829 if weights is not None: 1830 weights = math_ops.to_double(weights) 1831 average_precision = math_ops.mul(average_precision, weights) 1832 1833 # Create accumulation variables and update ops for max average precision and 1834 # total average precision. 1835 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 1836 # `max` is the max possible precision. Since max for any row is 1.0: 1837 # - For the unweighted case, this is just the number of rows. 1838 # - For the weighted case, it's the sum of the weights broadcast across 1839 # `average_precision` rows. 1840 max_var = contrib_variables.local_variable( 1841 array_ops.zeros([], dtype=dtypes.float64), name=max_scope) 1842 if weights is None: 1843 batch_max = math_ops.to_double( 1844 array_ops.size(average_precision, name='batch_max')) 1845 else: 1846 # TODO(ptucker): More efficient way to broadcast? 1847 broadcast_weights = math_ops.mul( 1848 weights, array_ops.ones_like(average_precision), 1849 name='broadcast_weights') 1850 batch_max = math_ops.reduce_sum(broadcast_weights, name='batch_max') 1851 max_update = state_ops.assign_add(max_var, batch_max, name='update') 1852 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 1853 total_var = contrib_variables.local_variable( 1854 array_ops.zeros([], dtype=dtypes.float64), name=total_scope) 1855 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 1856 total_update = state_ops.assign_add(total_var, batch_total, name='update') 1857 1858 # Divide total by max to get mean, for both vars and the update ops. 1859 mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') 1860 update = _safe_scalar_div(total_update, max_update, name=scope) 1861 1862 if metrics_collections: 1863 ops.add_to_collections(metrics_collections, mean_average_precision) 1864 if updates_collections: 1865 ops.add_to_collections(updates_collections, update) 1866 1867 return mean_average_precision, update 1868 1869 1870def _select_class_id(ids, selected_id): 1871 """Filter all but `selected_id` out of `ids`. 1872 1873 Args: 1874 ids: `int64` `Tensor` or `SparseTensor` of IDs. 1875 selected_id: Int id to select. 1876 1877 Returns: 1878 `SparseTensor` of same dimensions as `ids`. This contains only the entries 1879 equal to `selected_id`. 1880 """ 1881 if isinstance( 1882 ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 1883 return sparse_ops.sparse_retain( 1884 ids, math_ops.equal(ids.values, selected_id)) 1885 1886 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 1887 # tf.equal and tf.reduce_any? 1888 1889 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 1890 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 1891 ids_last_dim = array_ops.size(ids_shape) - 1 1892 filled_selected_id_shape = math_ops.reduced_shape( 1893 ids_shape, array_ops.reshape(ids_last_dim, [1])) 1894 1895 # Intersect `ids` with the selected ID. 1896 filled_selected_id = array_ops.fill( 1897 filled_selected_id_shape, math_ops.to_int64(selected_id)) 1898 result = set_ops.set_intersection(filled_selected_id, ids) 1899 return sparse_tensor.SparseTensor( 1900 indices=result.indices, values=result.values, shape=ids_shape) 1901 1902 1903def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 1904 """If class ID is specified, filter all other classes. 1905 1906 Args: 1907 labels: `int64` `Tensor` or `SparseTensor` with shape 1908 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1909 target classes for the associated prediction. Commonly, N=1 and `labels` 1910 has shape [batch_size, num_labels]. [D1, ... DN] must match 1911 `predictions_idx`. 1912 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 1913 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 1914 [batch size, k]. 1915 selected_id: Int id to select. 1916 1917 Returns: 1918 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 1919 """ 1920 if selected_id is None: 1921 return labels, predictions_idx 1922 return (_select_class_id(labels, selected_id), 1923 _select_class_id(predictions_idx, selected_id)) 1924 1925 1926def _sparse_true_positive_at_k(predictions_idx, 1927 labels, 1928 class_id=None, 1929 weights=None, 1930 name=None): 1931 """Calculates true positives for recall@k and precision@k. 1932 1933 If `class_id` is specified, calculate binary true positives for `class_id` 1934 only. 1935 If `class_id` is not specified, calculate metrics for `k` predicted vs 1936 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1937 1938 Args: 1939 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1940 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1941 match `labels`. 1942 labels: `int64` `Tensor` or `SparseTensor` with shape 1943 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1944 target classes for the associated prediction. Commonly, N=1 and `labels` 1945 has shape [batch_size, num_labels]. [D1, ... DN] must match 1946 `predictions_idx`. 1947 class_id: Class for which we want binary metrics. 1948 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1949 dimensions of `predictions_idx` and `labels`. 1950 name: Name of operation. 1951 1952 Returns: 1953 A [D1, ... DN] `Tensor` of true positive counts. 1954 """ 1955 with ops.name_scope(name, 'true_positives', (predictions_idx, labels)): 1956 labels, predictions_idx = _maybe_select_class_id( 1957 labels, predictions_idx, class_id) 1958 tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels)) 1959 tp = math_ops.to_double(tp) 1960 if weights is not None: 1961 weights = math_ops.to_double(weights) 1962 tp = math_ops.mul(tp, weights) 1963 return tp 1964 1965 1966def _streaming_sparse_true_positive_at_k(predictions_idx, 1967 labels, 1968 k=None, 1969 class_id=None, 1970 weights=None, 1971 name=None): 1972 """Calculates weighted per step true positives for recall@k and precision@k. 1973 1974 If `class_id` is specified, calculate binary true positives for `class_id` 1975 only. 1976 If `class_id` is not specified, calculate metrics for `k` predicted vs 1977 `n` label classes, where `n` is the 2nd dimension of `labels`. 1978 1979 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1980 1981 Args: 1982 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1983 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1984 match `labels`. 1985 labels: `int64` `Tensor` or `SparseTensor` with shape 1986 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1987 target classes for the associated prediction. Commonly, N=1 and `labels` 1988 has shape [batch_size, num_labels]. [D1, ... DN] must match 1989 `predictions_idx`. 1990 k: Integer, k for @k metric. This is only used for default op name. 1991 class_id: Class for which we want binary metrics. 1992 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 1993 dimensions of `predictions_idx` and `labels`. 1994 name: Name of new variable, and namespace for other dependent ops. 1995 1996 Returns: 1997 A tuple of `Variable` and update `Operation`. 1998 1999 Raises: 2000 ValueError: If `weights` is not `None` and has an incomptable shape. 2001 """ 2002 default_name = _at_k_name('true_positive', k, class_id=class_id) 2003 with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope: 2004 tp = _sparse_true_positive_at_k( 2005 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 2006 weights=weights) 2007 batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp)) 2008 2009 var = contrib_variables.local_variable( 2010 array_ops.zeros([], dtype=dtypes.float64), name=scope) 2011 return var, state_ops.assign_add(var, batch_total_tp, name='update') 2012 2013 2014def _sparse_false_positive_at_k(predictions_idx, 2015 labels, 2016 class_id=None, 2017 weights=None): 2018 """Calculates false positives for precision@k. 2019 2020 If `class_id` is specified, calculate binary true positives for `class_id` 2021 only. 2022 If `class_id` is not specified, calculate metrics for `k` predicted vs 2023 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2024 2025 Args: 2026 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2027 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2028 match `labels`. 2029 labels: `int64` `Tensor` or `SparseTensor` with shape 2030 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2031 target classes for the associated prediction. Commonly, N=1 and `labels` 2032 has shape [batch_size, num_labels]. [D1, ... DN] must match 2033 `predictions_idx`. 2034 class_id: Class for which we want binary metrics. 2035 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 2036 dimensions of `predictions_idx` and `labels`. 2037 2038 Returns: 2039 A [D1, ... DN] `Tensor` of false positive counts. 2040 """ 2041 with ops.name_scope(None, 'false_positives', (predictions_idx, labels)): 2042 labels, predictions_idx = _maybe_select_class_id(labels, 2043 predictions_idx, 2044 class_id) 2045 fp = set_ops.set_size(set_ops.set_difference( 2046 predictions_idx, labels, aminusb=True)) 2047 fp = math_ops.to_double(fp) 2048 if weights is not None: 2049 weights = math_ops.to_double(weights) 2050 fp = math_ops.mul(fp, weights) 2051 return fp 2052 2053 2054def _streaming_sparse_false_positive_at_k(predictions_idx, 2055 labels, 2056 k=None, 2057 class_id=None, 2058 weights=None, 2059 name=None): 2060 """Calculates weighted per step false positives for precision@k. 2061 2062 If `class_id` is specified, calculate binary true positives for `class_id` 2063 only. 2064 If `class_id` is not specified, calculate metrics for `k` predicted vs 2065 `n` label classes, where `n` is the 2nd dimension of `labels`. 2066 2067 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2068 2069 Args: 2070 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2071 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2072 match `labels`. 2073 labels: `int64` `Tensor` or `SparseTensor` with shape 2074 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2075 target classes for the associated prediction. Commonly, N=1 and `labels` 2076 has shape [batch_size, num_labels]. [D1, ... DN] must match 2077 `predictions_idx`. 2078 k: Integer, k for @k metric. This is only used for default op name. 2079 class_id: Class for which we want binary metrics. 2080 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 2081 dimensions of `predictions_idx` and `labels`. 2082 name: Name of new variable, and namespace for other dependent ops. 2083 2084 Returns: 2085 A tuple of `Variable` and update `Operation`. 2086 2087 Raises: 2088 ValueError: If `weights` is not `None` and has an incomptable shape. 2089 """ 2090 default_name = _at_k_name('false_positive', k, class_id=class_id) 2091 with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope: 2092 fp = _sparse_false_positive_at_k( 2093 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 2094 weights=weights) 2095 batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp)) 2096 2097 var = contrib_variables.local_variable( 2098 array_ops.zeros([], dtype=dtypes.float64), name=scope) 2099 return var, state_ops.assign_add(var, batch_total_fp, name='update') 2100 2101 2102def _sparse_false_negative_at_k(predictions_idx, 2103 labels, 2104 class_id=None, 2105 weights=None): 2106 """Calculates false negatives for recall@k. 2107 2108 If `class_id` is specified, calculate binary true positives for `class_id` 2109 only. 2110 If `class_id` is not specified, calculate metrics for `k` predicted vs 2111 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2112 2113 Args: 2114 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2115 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2116 match `labels`. 2117 labels: `int64` `Tensor` or `SparseTensor` with shape 2118 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2119 target classes for the associated prediction. Commonly, N=1 and `labels` 2120 has shape [batch_size, num_labels]. [D1, ... DN] must match 2121 `predictions_idx`. 2122 class_id: Class for which we want binary metrics. 2123 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 2124 dimensions of `predictions_idx` and `labels`. 2125 2126 Returns: 2127 A [D1, ... DN] `Tensor` of false negative counts. 2128 """ 2129 with ops.name_scope(None, 'false_negatives', (predictions_idx, labels)): 2130 labels, predictions_idx = _maybe_select_class_id(labels, 2131 predictions_idx, 2132 class_id) 2133 fn = set_ops.set_size(set_ops.set_difference(predictions_idx, 2134 labels, 2135 aminusb=False)) 2136 fn = math_ops.to_double(fn) 2137 if weights is not None: 2138 weights = math_ops.to_double(weights) 2139 fn = math_ops.mul(fn, weights) 2140 return fn 2141 2142 2143def _streaming_sparse_false_negative_at_k(predictions_idx, 2144 labels, 2145 k, 2146 class_id=None, 2147 weights=None, 2148 name=None): 2149 """Calculates weighted per step false negatives for recall@k. 2150 2151 If `class_id` is specified, calculate binary true positives for `class_id` 2152 only. 2153 If `class_id` is not specified, calculate metrics for `k` predicted vs 2154 `n` label classes, where `n` is the 2nd dimension of `labels`. 2155 2156 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2157 2158 Args: 2159 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2160 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2161 match `labels`. 2162 labels: `int64` `Tensor` or `SparseTensor` with shape 2163 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2164 target classes for the associated prediction. Commonly, N=1 and `labels` 2165 has shape [batch_size, num_labels]. [D1, ... DN] must match 2166 `predictions_idx`. 2167 k: Integer, k for @k metric. This is only used for default op name. 2168 class_id: Class for which we want binary metrics. 2169 weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN] 2170 dimensions of `predictions_idx` and `labels`. 2171 name: Name of new variable, and namespace for other dependent ops. 2172 2173 Returns: 2174 A tuple of `Variable` and update `Operation`. 2175 2176 Raises: 2177 ValueError: If `weights` is not `None` and has an incomptable shape. 2178 """ 2179 default_name = _at_k_name('false_negative', k, class_id=class_id) 2180 with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope: 2181 fn = _sparse_false_negative_at_k( 2182 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 2183 weights=weights) 2184 batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn)) 2185 2186 var = contrib_variables.local_variable( 2187 array_ops.zeros([], dtype=dtypes.float64), name=scope) 2188 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2189 2190 2191def streaming_mean_absolute_error(predictions, labels, weights=None, 2192 metrics_collections=None, 2193 updates_collections=None, 2194 name=None): 2195 """Computes the mean absolute error between the labels and predictions. 2196 2197 The `streaming_mean_absolute_error` function creates two local variables, 2198 `total` and `count` that are used to compute the mean absolute error. This 2199 average is weighted by `weights`, and it is ultimately returned as 2200 `mean_absolute_error`: an idempotent operation that simply divides `total` by 2201 `count`. 2202 2203 For estimation of the metric over a stream of data, the function creates an 2204 `update_op` operation that updates these variables and returns the 2205 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 2206 absolute value of the differences between `predictions` and `labels`. Then 2207 `update_op` increments `total` with the reduced sum of the product of 2208 `weights` and `absolute_errors`, and it increments `count` with the reduced 2209 sum of `weights` 2210 2211 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2212 2213 Args: 2214 predictions: A `Tensor` of arbitrary shape. 2215 labels: A `Tensor` of the same shape as `predictions`. 2216 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2217 metrics_collections: An optional list of collections that 2218 `mean_absolute_error` should be added to. 2219 updates_collections: An optional list of collections that `update_op` should 2220 be added to. 2221 name: An optional variable_scope name. 2222 2223 Returns: 2224 mean_absolute_error: A tensor representing the current mean, the value of 2225 `total` divided by `count`. 2226 update_op: An operation that increments the `total` and `count` variables 2227 appropriately and whose value matches `mean_absolute_error`. 2228 2229 Raises: 2230 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2231 `weights` is not `None` and its shape doesn't match `predictions`, or if 2232 either `metrics_collections` or `updates_collections` are not a list or 2233 tuple. 2234 """ 2235 predictions, labels, weights = _remove_squeezable_dimensions( 2236 predictions, labels, weights) 2237 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2238 absolute_errors = math_ops.abs(predictions - labels) 2239 return streaming_mean(absolute_errors, weights, metrics_collections, 2240 updates_collections, name or 'mean_absolute_error') 2241 2242 2243def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, 2244 metrics_collections=None, 2245 updates_collections=None, 2246 name=None): 2247 """Computes the mean relative error by normalizing with the given values. 2248 2249 The `streaming_mean_relative_error` function creates two local variables, 2250 `total` and `count` that are used to compute the mean relative absolute error. 2251 This average is weighted by `weights`, and it is ultimately returned as 2252 `mean_relative_error`: an idempotent operation that simply divides `total` by 2253 `count`. 2254 2255 For estimation of the metric over a stream of data, the function creates an 2256 `update_op` operation that updates these variables and returns the 2257 `mean_reative_error`. Internally, a `relative_errors` operation divides the 2258 absolute value of the differences between `predictions` and `labels` by the 2259 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 2260 product of `weights` and `relative_errors`, and it increments `count` with the 2261 reduced sum of `weights`. 2262 2263 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2264 2265 Args: 2266 predictions: A `Tensor` of arbitrary shape. 2267 labels: A `Tensor` of the same shape as `predictions`. 2268 normalizer: A `Tensor` of the same shape as `predictions`. 2269 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2270 metrics_collections: An optional list of collections that 2271 `mean_relative_error` should be added to. 2272 updates_collections: An optional list of collections that `update_op` should 2273 be added to. 2274 name: An optional variable_scope name. 2275 2276 Returns: 2277 mean_relative_error: A tensor representing the current mean, the value of 2278 `total` divided by `count`. 2279 update_op: An operation that increments the `total` and `count` variables 2280 appropriately and whose value matches `mean_relative_error`. 2281 2282 Raises: 2283 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2284 `weights` is not `None` and its shape doesn't match `predictions`, or if 2285 either `metrics_collections` or `updates_collections` are not a list or 2286 tuple. 2287 """ 2288 predictions, labels, weights = _remove_squeezable_dimensions( 2289 predictions, labels, weights) 2290 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2291 2292 predictions, normalizer = tensor_util.remove_squeezable_dimensions( 2293 predictions, normalizer) 2294 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 2295 relative_errors = math_ops.select( 2296 math_ops.equal(normalizer, 0.0), 2297 array_ops.zeros_like(labels), 2298 math_ops.div(math_ops.abs(labels - predictions), normalizer)) 2299 return streaming_mean(relative_errors, weights, metrics_collections, 2300 updates_collections, name or 'mean_relative_error') 2301 2302 2303def streaming_mean_squared_error(predictions, labels, weights=None, 2304 metrics_collections=None, 2305 updates_collections=None, 2306 name=None): 2307 """Computes the mean squared error between the labels and predictions. 2308 2309 The `streaming_mean_squared_error` function creates two local variables, 2310 `total` and `count` that are used to compute the mean squared error. 2311 This average is weighted by `weights`, and it is ultimately returned as 2312 `mean_squared_error`: an idempotent operation that simply divides `total` by 2313 `count`. 2314 2315 For estimation of the metric over a stream of data, the function creates an 2316 `update_op` operation that updates these variables and returns the 2317 `mean_squared_error`. Internally, a `squared_error` operation computes the 2318 element-wise square of the difference between `predictions` and `labels`. Then 2319 `update_op` increments `total` with the reduced sum of the product of 2320 `weights` and `squared_error`, and it increments `count` with the reduced sum 2321 of `weights`. 2322 2323 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2324 2325 Args: 2326 predictions: A `Tensor` of arbitrary shape. 2327 labels: A `Tensor` of the same shape as `predictions`. 2328 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2329 metrics_collections: An optional list of collections that 2330 `mean_squared_error` should be added to. 2331 updates_collections: An optional list of collections that `update_op` should 2332 be added to. 2333 name: An optional variable_scope name. 2334 2335 Returns: 2336 mean_squared_error: A tensor representing the current mean, the value of 2337 `total` divided by `count`. 2338 update_op: An operation that increments the `total` and `count` variables 2339 appropriately and whose value matches `mean_squared_error`. 2340 2341 Raises: 2342 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2343 `weights` is not `None` and its shape doesn't match `predictions`, or if 2344 either `metrics_collections` or `updates_collections` are not a list or 2345 tuple. 2346 """ 2347 predictions, labels, weights = _remove_squeezable_dimensions( 2348 predictions, labels, weights) 2349 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2350 squared_error = math_ops.square(labels - predictions) 2351 return streaming_mean(squared_error, weights, metrics_collections, 2352 updates_collections, name or 'mean_squared_error') 2353 2354 2355def streaming_root_mean_squared_error(predictions, labels, weights=None, 2356 metrics_collections=None, 2357 updates_collections=None, 2358 name=None): 2359 """Computes the root mean squared error between the labels and predictions. 2360 2361 The `streaming_root_mean_squared_error` function creates two local variables, 2362 `total` and `count` that are used to compute the root mean squared error. 2363 This average is weighted by `weights`, and it is ultimately returned as 2364 `root_mean_squared_error`: an idempotent operation that takes the square root 2365 of the division of `total` by `count`. 2366 2367 For estimation of the metric over a stream of data, the function creates an 2368 `update_op` operation that updates these variables and returns the 2369 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2370 the element-wise square of the difference between `predictions` and `labels`. 2371 Then `update_op` increments `total` with the reduced sum of the product of 2372 `weights` and `squared_error`, and it increments `count` with the reduced sum 2373 of `weights`. 2374 2375 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2376 2377 Args: 2378 predictions: A `Tensor` of arbitrary shape. 2379 labels: A `Tensor` of the same shape as `predictions`. 2380 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2381 metrics_collections: An optional list of collections that 2382 `root_mean_squared_error` should be added to. 2383 updates_collections: An optional list of collections that `update_op` should 2384 be added to. 2385 name: An optional variable_scope name. 2386 2387 Returns: 2388 root_mean_squared_error: A tensor representing the current mean, the value 2389 of `total` divided by `count`. 2390 update_op: An operation that increments the `total` and `count` variables 2391 appropriately and whose value matches `root_mean_squared_error`. 2392 2393 Raises: 2394 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2395 `weights` is not `None` and its shape doesn't match `predictions`, or if 2396 either `metrics_collections` or `updates_collections` are not a list or 2397 tuple. 2398 """ 2399 predictions, labels, weights = _remove_squeezable_dimensions( 2400 predictions, labels, weights) 2401 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2402 value_tensor, update_op = streaming_mean_squared_error( 2403 predictions, labels, weights, None, None, 2404 name or 'root_mean_squared_error') 2405 2406 root_mean_squared_error = math_ops.sqrt(value_tensor) 2407 with ops.control_dependencies([update_op]): 2408 update_op = math_ops.sqrt(update_op) 2409 2410 if metrics_collections: 2411 ops.add_to_collections(metrics_collections, root_mean_squared_error) 2412 2413 if updates_collections: 2414 ops.add_to_collections(updates_collections, update_op) 2415 2416 return root_mean_squared_error, update_op 2417 2418 2419def streaming_covariance(predictions, 2420 labels, 2421 weights=None, 2422 metrics_collections=None, 2423 updates_collections=None, 2424 name=None): 2425 """Computes the unbiased sample covariance between `predictions` and `labels`. 2426 2427 The `streaming_covariance` function creates four local variables, 2428 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 2429 compute the sample covariance between predictions and labels across multiple 2430 batches of data. The covariance is ultimately returned as an idempotent 2431 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 2432 in order to get an unbiased estimate. 2433 2434 The algorithm used for this online computation is described in 2435 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 2436 Specifically, the formula used to combine two sample comoments is 2437 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 2438 The comoment for a single batch of data is simply 2439 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 2440 2441 If `weights` is not None, then it is used to compute weighted comoments, 2442 means, and count. NOTE: these weights are treated as "frequency weights", as 2443 opposed to "reliability weights". See discussion of the difference on 2444 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2445 2446 To facilitate the computation of covariance across multiple batches of data, 2447 the function creates an `update_op` operation, which updates underlying 2448 variables and returns the updated covariance. 2449 2450 Args: 2451 predictions: A `Tensor` of arbitrary size. 2452 labels: A `Tensor` of the same size as `predictions`. 2453 weights: An optional set of weights which indicates the frequency with which 2454 an example is sampled. Must be broadcastable with `labels`. 2455 metrics_collections: An optional list of collections that the metric 2456 value variable should be added to. 2457 updates_collections: An optional list of collections that the metric update 2458 ops should be added to. 2459 name: An optional variable_scope name. 2460 2461 Returns: 2462 covariance: A `Tensor` representing the current unbiased sample covariance, 2463 `comoment` / (`count` - 1). 2464 update_op: An operation that updates the local variables appropriately. 2465 2466 Raises: 2467 ValueError: If labels and predictions are of different sizes or if either 2468 `metrics_collections` or `updates_collections` are not a list or tuple. 2469 """ 2470 with variable_scope.variable_scope(name, 'covariance', [predictions, labels]): 2471 predictions, labels, weights = _remove_squeezable_dimensions( 2472 predictions, labels, weights) 2473 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2474 count = _create_local('count', []) 2475 mean_prediction = _create_local('mean_prediction', []) 2476 mean_label = _create_local('mean_label', []) 2477 comoment = _create_local('comoment', []) # C_A in update equation 2478 2479 if weights is None: 2480 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 2481 weighted_predictions = predictions 2482 weighted_labels = labels 2483 else: 2484 batch_count = math_ops.reduce_sum( 2485 _broadcast_weights(weights, labels)) # n_B in eqn 2486 weighted_predictions = predictions * weights 2487 weighted_labels = labels * weights 2488 2489 update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn 2490 prev_count = update_count - batch_count # n_A in update equation 2491 2492 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 2493 # batch_mean_prediction is E[x_B] in the update equation 2494 batch_mean_prediction = _safe_div( 2495 math_ops.reduce_sum(weighted_predictions), batch_count, 2496 'batch_mean_prediction') 2497 delta_mean_prediction = _safe_div( 2498 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 2499 'delta_mean_prediction') 2500 update_mean_prediction = state_ops.assign_add(mean_prediction, 2501 delta_mean_prediction) 2502 # prev_mean_prediction is E[x_A] in the update equation 2503 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 2504 2505 # batch_mean_label is E[y_B] in the update equation 2506 batch_mean_label = _safe_div( 2507 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 2508 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 2509 update_count, 'delta_mean_label') 2510 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 2511 # prev_mean_label is E[y_A] in the update equation 2512 prev_mean_label = update_mean_label - delta_mean_label 2513 2514 unweighted_batch_coresiduals = ( 2515 (predictions - batch_mean_prediction) * (labels - batch_mean_label)) 2516 # batch_comoment is C_B in the update equation 2517 if weights is None: 2518 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 2519 else: 2520 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * 2521 weights) 2522 2523 # View delta_comoment as = C_AB - C_A in the update equation above. 2524 # Since C_A is stored in a var, by how much do we need to increment that var 2525 # to make the var = C_AB? 2526 delta_comoment = (batch_comoment + 2527 (prev_mean_prediction - batch_mean_prediction) * 2528 (prev_mean_label - batch_mean_label) * 2529 (prev_count * batch_count / update_count)) 2530 update_comoment = state_ops.assign_add(comoment, delta_comoment) 2531 2532 covariance = _safe_div(comoment, count - 1, 'covariance') 2533 with ops.control_dependencies([update_comoment]): 2534 update_op = _safe_div(comoment, count - 1, 'update_op') 2535 2536 if metrics_collections: 2537 ops.add_to_collections(metrics_collections, covariance) 2538 2539 if updates_collections: 2540 ops.add_to_collections(updates_collections, update_op) 2541 2542 return covariance, update_op 2543 2544 2545def streaming_pearson_correlation(predictions, 2546 labels, 2547 weights=None, 2548 metrics_collections=None, 2549 updates_collections=None, 2550 name=None): 2551 """Computes Pearson correlation coefficient between `predictions`, `labels`. 2552 2553 The `streaming_pearson_correlation` function delegates to 2554 `streaming_covariance` the tracking of three [co]variances: 2555 2556 - `streaming_covariance(predictions, labels)`, i.e. covariance 2557 - `streaming_covariance(predictions, predictions)`, i.e. variance 2558 - `streaming_covariance(labels, labels)`, i.e. variance 2559 2560 The product-moment correlation ultimately returned is an idempotent operation 2561 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 2562 facilitate correlation computation across multiple batches, the function 2563 groups the `update_op`s of the underlying streaming_covariance and returns an 2564 `update_op`. 2565 2566 If `weights` is not None, then it is used to compute a weighted correlation. 2567 NOTE: these weights are treated as "frequency weights", as opposed to 2568 "reliability weights". See discussion of the difference on 2569 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2570 2571 Args: 2572 predictions: A `Tensor` of arbitrary size. 2573 labels: A `Tensor` of the same size as predictions. 2574 weights: An optional set of weights which indicates the frequency with which 2575 an example is sampled. Must be broadcastable with `labels`. 2576 metrics_collections: An optional list of collections that the metric 2577 value variable should be added to. 2578 updates_collections: An optional list of collections that the metric update 2579 ops should be added to. 2580 name: An optional variable_scope name. 2581 2582 Returns: 2583 pearson_r: A tensor representing the current Pearson product-moment 2584 correlation coefficient, the value of 2585 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 2586 update_op: An operation that updates the underlying variables appropriately. 2587 2588 Raises: 2589 ValueError: If `labels` and `predictions` are of different sizes, or if 2590 `weights` is the wrong size, or if either `metrics_collections` or 2591 `updates_collections` are not a `list` or `tuple`. 2592 """ 2593 with variable_scope.variable_scope(name, 'pearson_r', [predictions, labels]): 2594 predictions, labels, weights = _remove_squeezable_dimensions( 2595 predictions, labels, weights) 2596 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2597 cov, update_cov = streaming_covariance( 2598 predictions, labels, weights=weights, name='covariance') 2599 var_predictions, update_var_predictions = streaming_covariance( 2600 predictions, predictions, weights=weights, name='variance_predictions') 2601 var_labels, update_var_labels = streaming_covariance( 2602 labels, labels, weights=weights, name='variance_labels') 2603 2604 pearson_r = _safe_div( 2605 cov, 2606 math_ops.mul(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 2607 'pearson_r') 2608 with ops.control_dependencies( 2609 [update_cov, update_var_predictions, update_var_labels]): 2610 update_op = _safe_div(update_cov, math_ops.mul( 2611 math_ops.sqrt(update_var_predictions), 2612 math_ops.sqrt(update_var_labels)), 'update_op') 2613 2614 if metrics_collections: 2615 ops.add_to_collections(metrics_collections, pearson_r) 2616 2617 if updates_collections: 2618 ops.add_to_collections(updates_collections, update_op) 2619 2620 return pearson_r, update_op 2621 2622 2623# TODO(nsilberman): add a 'normalized' flag so that the user can request 2624# normalization if the inputs are not normalized. 2625def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, 2626 metrics_collections=None, 2627 updates_collections=None, 2628 name=None): 2629 """Computes the cosine distance between the labels and predictions. 2630 2631 The `streaming_mean_cosine_distance` function creates two local variables, 2632 `total` and `count` that are used to compute the average cosine distance 2633 between `predictions` and `labels`. This average is weighted by `weights`, 2634 and it is ultimately returned as `mean_distance`, which is an idempotent 2635 operation that simply divides `total` by `count`. 2636 2637 For estimation of the metric over a stream of data, the function creates an 2638 `update_op` operation that updates these variables and returns the 2639 `mean_distance`. 2640 2641 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2642 2643 Args: 2644 predictions: A `Tensor` of the same shape as `labels`. 2645 labels: A `Tensor` of arbitrary shape. 2646 dim: The dimension along which the cosine distance is computed. 2647 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 2648 and whose dimension `dim` is 1. 2649 metrics_collections: An optional list of collections that the metric 2650 value variable should be added to. 2651 updates_collections: An optional list of collections that the metric update 2652 ops should be added to. 2653 name: An optional variable_scope name. 2654 2655 Returns: 2656 mean_distance: A tensor representing the current mean, the value of `total` 2657 divided by `count`. 2658 update_op: An operation that increments the `total` and `count` variables 2659 appropriately. 2660 2661 Raises: 2662 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2663 `weights` is not `None` and its shape doesn't match `predictions`, or if 2664 either `metrics_collections` or `updates_collections` are not a list or 2665 tuple. 2666 """ 2667 predictions, labels, weights = _remove_squeezable_dimensions( 2668 predictions, labels, weights) 2669 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2670 radial_diffs = math_ops.mul(predictions, labels) 2671 radial_diffs = math_ops.reduce_sum(radial_diffs, 2672 reduction_indices=[dim,], 2673 keep_dims=True) 2674 mean_distance, update_op = streaming_mean(radial_diffs, weights, 2675 None, 2676 None, 2677 name or 'mean_cosine_distance') 2678 mean_distance = math_ops.sub(1.0, mean_distance) 2679 update_op = math_ops.sub(1.0, update_op) 2680 2681 if metrics_collections: 2682 ops.add_to_collections(metrics_collections, mean_distance) 2683 2684 if updates_collections: 2685 ops.add_to_collections(updates_collections, update_op) 2686 2687 return mean_distance, update_op 2688 2689 2690def streaming_percentage_less(values, threshold, weights=None, 2691 metrics_collections=None, 2692 updates_collections=None, 2693 name=None): 2694 """Computes the percentage of values less than the given threshold. 2695 2696 The `streaming_percentage_less` function creates two local variables, 2697 `total` and `count` that are used to compute the percentage of `values` that 2698 fall below `threshold`. This rate is weighted by `weights`, and it is 2699 ultimately returned as `percentage` which is an idempotent operation that 2700 simply divides `total` by `count`. 2701 2702 For estimation of the metric over a stream of data, the function creates an 2703 `update_op` operation that updates these variables and returns the 2704 `percentage`. 2705 2706 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2707 2708 Args: 2709 values: A numeric `Tensor` of arbitrary size. 2710 threshold: A scalar threshold. 2711 weights: An optional `Tensor` whose shape is broadcastable to `values`. 2712 metrics_collections: An optional list of collections that the metric 2713 value variable should be added to. 2714 updates_collections: An optional list of collections that the metric update 2715 ops should be added to. 2716 name: An optional variable_scope name. 2717 2718 Returns: 2719 percentage: A tensor representing the current mean, the value of `total` 2720 divided by `count`. 2721 update_op: An operation that increments the `total` and `count` variables 2722 appropriately. 2723 2724 Raises: 2725 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 2726 or if either `metrics_collections` or `updates_collections` are not a list 2727 or tuple. 2728 """ 2729 is_below_threshold = math_ops.to_float(math_ops.less(values, threshold)) 2730 return streaming_mean(is_below_threshold, 2731 weights, 2732 metrics_collections, 2733 updates_collections, 2734 name or 'percentage_below_threshold') 2735 2736 2737def streaming_mean_iou(predictions, 2738 labels, 2739 num_classes, 2740 weights=None, 2741 metrics_collections=None, 2742 updates_collections=None, 2743 name=None): 2744 """Calculate per-step mean Intersection-Over-Union (mIOU). 2745 2746 Mean Intersection-Over-Union is a common evaluation metric for 2747 semantic image segmentation, which first computes the IOU for each 2748 semantic class and then computes the average over classes. 2749 IOU is defined as follows: 2750 IOU = true_positive / (true_positive + false_positive + false_negative). 2751 The predictions are accumulated in a confusion matrix, weighted by `weights`, 2752 and mIOU is then calculated from it. 2753 2754 For estimation of the metric over a stream of data, the function creates an 2755 `update_op` operation that updates these variables and returns the `mean_iou`. 2756 2757 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2758 2759 Args: 2760 predictions: A tensor of prediction results for semantic labels, whose 2761 shape is [batch size] and type `int32` or `int64`. The tensor will be 2762 flattened, if its rank > 1. 2763 labels: A tensor of ground truth labels with shape [batch size] and of 2764 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 2765 num_classes: The possible number of labels the prediction task can 2766 have. This value must be provided, since a confusion matrix of 2767 dimension = [num_classes, num_classes] will be allocated. 2768 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2769 metrics_collections: An optional list of collections that `mean_iou` 2770 should be added to. 2771 updates_collections: An optional list of collections `update_op` should be 2772 added to. 2773 name: An optional variable_scope name. 2774 2775 Returns: 2776 mean_iou: A tensor representing the mean intersection-over-union. 2777 update_op: An operation that increments the confusion matrix. 2778 2779 Raises: 2780 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2781 `weights` is not `None` and its shape doesn't match `predictions`, or if 2782 either `metrics_collections` or `updates_collections` are not a list or 2783 tuple. 2784 """ 2785 with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]): 2786 # Check if shape is compatible. 2787 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2788 2789 # Local variable to accumulate the predictions in the confusion matrix. 2790 cm_dtype = dtypes.int64 if weights is not None else dtypes.float64 2791 total_cm = _create_local('total_confusion_matrix', 2792 shape=[num_classes, num_classes], dtype=cm_dtype) 2793 2794 # Cast the type to int64 required by confusion_matrix_ops. 2795 predictions = math_ops.to_int64(predictions) 2796 labels = math_ops.to_int64(labels) 2797 num_classes = math_ops.to_int64(num_classes) 2798 2799 # Flatten the input if its rank > 1. 2800 predictions_rank = predictions.get_shape().ndims 2801 if predictions_rank > 1: 2802 predictions = array_ops.reshape(predictions, [-1]) 2803 2804 labels_rank = labels.get_shape().ndims 2805 if labels_rank > 1: 2806 labels = array_ops.reshape(labels, [-1]) 2807 2808 if weights is not None: 2809 weights_rank = weights.get_shape().ndims 2810 if weights_rank > 1: 2811 weights = array_ops.reshape(weights, [-1]) 2812 2813 # Accumulate the prediction to current confusion matrix. 2814 current_cm = confusion_matrix_ops.confusion_matrix( 2815 predictions, labels, num_classes, weights=weights, dtype=cm_dtype) 2816 update_op = state_ops.assign_add(total_cm, current_cm) 2817 2818 def compute_mean_iou(name): 2819 """Compute the mean intersection-over-union via the confusion matrix.""" 2820 sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) 2821 sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) 2822 cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) 2823 denominator = sum_over_row + sum_over_col - cm_diag 2824 2825 # If the value of the denominator is 0, set it to 1 to avoid 2826 # zero division. 2827 denominator = math_ops.select( 2828 math_ops.greater(denominator, 0), 2829 denominator, 2830 array_ops.ones_like(denominator)) 2831 iou = math_ops.div(cm_diag, denominator) 2832 return math_ops.reduce_mean(iou, name=name) 2833 2834 mean_iou = compute_mean_iou('mean_iou') 2835 2836 if metrics_collections: 2837 ops.add_to_collections(metrics_collections, mean_iou) 2838 2839 if updates_collections: 2840 ops.add_to_collections(updates_collections, update_op) 2841 2842 return mean_iou, update_op 2843 2844 2845def _next_array_size(required_size, growth_factor=1.5): 2846 """Calculate the next size for reallocating a dynamic array. 2847 2848 Args: 2849 required_size: number or tf.Tensor specifying required array capacity. 2850 growth_factor: optional number or tf.Tensor specifying the growth factor 2851 between subsequent allocations. 2852 2853 Returns: 2854 tf.Tensor with dtype=int32 giving the next array size. 2855 """ 2856 exponent = math_ops.ceil( 2857 math_ops.log(math_ops.cast(required_size, dtypes.float32)) 2858 / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) 2859 return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) 2860 2861 2862def streaming_concat(values, 2863 axis=0, 2864 max_size=None, 2865 metrics_collections=None, 2866 updates_collections=None, 2867 name=None): 2868 """Concatenate values along an axis across batches. 2869 2870 The function `streaming_concat` creates two local variables, `array` and 2871 `size`, that are used to store concatenated values. Internally, `array` is 2872 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 2873 that updates can be run in amortized constant time. 2874 2875 For estimation of the metric over a stream of data, the function creates an 2876 `update_op` operation that appends the values of a tensor and returns the 2877 `value` of the concatenated tensors. 2878 2879 This op allows for evaluating metrics that cannot be updated incrementally 2880 using the same framework as other streaming metrics. 2881 2882 Args: 2883 values: tensor to concatenate. Rank and the shape along all axes other than 2884 the axis to concatenate along must be statically known. 2885 axis: optional integer axis to concatenate along. 2886 max_size: optional integer maximum size of `value` along the given axis. 2887 Once the maximum size is reached, further updates are no-ops. By default, 2888 there is no maximum size: the array is resized as necessary. 2889 metrics_collections: An optional list of collections that `value` 2890 should be added to. 2891 updates_collections: An optional list of collections `update_op` should be 2892 added to. 2893 name: An optional variable_scope name. 2894 2895 Returns: 2896 value: A tensor representing the concatenated values. 2897 update_op: An operation that concatenates the next values. 2898 2899 Raises: 2900 ValueError: if `values` does not have a statically known rank, `axis` is 2901 not in the valid range or the size of `values` is not statically known 2902 along any axis other than `axis`. 2903 """ 2904 with variable_scope.variable_scope(name, 'streaming_concat', [values]): 2905 # pylint: disable=invalid-slice-index 2906 values_shape = values.get_shape() 2907 if values_shape.dims is None: 2908 raise ValueError('`values` must have known statically known rank') 2909 2910 ndim = len(values_shape) 2911 if axis < 0: 2912 axis += ndim 2913 if not 0 <= axis < ndim: 2914 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 2915 2916 fixed_shape = [dim.value for n, dim in enumerate(values_shape) 2917 if n != axis] 2918 if any(value is None for value in fixed_shape): 2919 raise ValueError('all dimensions of `values` other than the dimension to ' 2920 'concatenate along must have statically known size') 2921 2922 # We move `axis` to the front of the internal array so assign ops can be 2923 # applied to contiguous slices 2924 init_size = 0 if max_size is None else max_size 2925 init_shape = [init_size] + fixed_shape 2926 array = _create_local( 2927 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) 2928 size = _create_local('size', shape=[], dtype=dtypes.int32) 2929 2930 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 2931 valid_array = array[:size] 2932 valid_array.set_shape([None] + fixed_shape) 2933 value = array_ops.transpose(valid_array, perm, name='concat') 2934 2935 values_size = array_ops.shape(values)[axis] 2936 if max_size is None: 2937 batch_size = values_size 2938 else: 2939 batch_size = math_ops.minimum(values_size, max_size - size) 2940 2941 perm = [axis] + [n for n in range(ndim) if n != axis] 2942 batch_values = array_ops.transpose(values, perm)[:batch_size] 2943 2944 def reallocate(): 2945 next_size = _next_array_size(new_size) 2946 next_shape = array_ops.pack([next_size] + fixed_shape) 2947 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 2948 old_value = array.value() 2949 assign_op = state_ops.assign(array, new_value, validate_shape=False) 2950 with ops.control_dependencies([assign_op]): 2951 copy_op = array[:size].assign(old_value[:size]) 2952 # return value needs to be the same dtype as no_op() for cond 2953 with ops.control_dependencies([copy_op]): 2954 return control_flow_ops.no_op() 2955 2956 new_size = size + batch_size 2957 array_size = array_ops.shape_internal(array, optimize=False)[0] 2958 maybe_reallocate_op = control_flow_ops.cond( 2959 new_size > array_size, reallocate, control_flow_ops.no_op) 2960 with ops.control_dependencies([maybe_reallocate_op]): 2961 append_values_op = array[size:new_size].assign(batch_values) 2962 with ops.control_dependencies([append_values_op]): 2963 update_op = size.assign(new_size) 2964 2965 if metrics_collections: 2966 ops.add_to_collections(metrics_collections, value) 2967 2968 if updates_collections: 2969 ops.add_to_collections(updates_collections, update_op) 2970 2971 return value, update_op 2972 # pylint: enable=invalid-slice-index 2973 2974 2975def aggregate_metrics(*value_update_tuples): 2976 """Aggregates the metric value tensors and update ops into two lists. 2977 2978 Args: 2979 *value_update_tuples: a variable number of tuples, each of which contain the 2980 pair of (value_tensor, update_op) from a streaming metric. 2981 2982 Returns: 2983 a list of value tensors and a list of update ops. 2984 2985 Raises: 2986 ValueError: if `value_update_tuples` is empty. 2987 """ 2988 if not value_update_tuples: 2989 raise ValueError('Expected at least one value_tensor/update_op pair') 2990 value_ops, update_ops = zip(*value_update_tuples) 2991 return list(value_ops), list(update_ops) 2992 2993 2994def aggregate_metric_map(names_to_tuples): 2995 """Aggregates the metric names to tuple dictionary. 2996 2997 This function is useful for pairing metric names with their associated value 2998 and update ops when the list of metrics is long. For example: 2999 3000 ```python 3001 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 3002 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 3003 predictions, labels, weights), 3004 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 3005 predictions, labels, labels, weights), 3006 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 3007 predictions, labels, weights), 3008 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 3009 predictions, labels, weights), 3010 }) 3011 ``` 3012 3013 Args: 3014 names_to_tuples: a map of metric names to tuples, each of which contain the 3015 pair of (value_tensor, update_op) from a streaming metric. 3016 3017 Returns: 3018 A dictionary from metric names to value ops and a dictionary from metric 3019 names to update ops. 3020 """ 3021 metric_names = names_to_tuples.keys() 3022 value_ops, update_ops = zip(*names_to_tuples.values()) 3023 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 3024 3025 3026def _remove_squeezable_dimensions(predictions, labels, weights): 3027 """Squeeze last dim if needed. 3028 3029 Squeezes `predictions` and `labels` if their rank differs by 1. 3030 Squeezes `weights` if its rank is 1 more than the new rank of `predictions` 3031 3032 This will use static shape if available. Otherwise, it will add graph 3033 operations, which could result in a performance hit. 3034 3035 Args: 3036 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 3037 labels: Label values, a `Tensor` whose dimensions match `predictions`. 3038 weights: optional `weights` tensor. It will be squeezed if its rank is 1 3039 more than the new rank of `predictions` 3040 3041 Returns: 3042 Tuple of `predictions`, `labels` and `weights`, possibly with the last 3043 dimension squeezed. 3044 """ 3045 predictions, labels = tensor_util.remove_squeezable_dimensions( 3046 predictions, labels) 3047 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3048 3049 if weights is not None: 3050 predictions_shape = predictions.get_shape() 3051 predictions_rank = predictions_shape.ndims 3052 weights_shape = weights.get_shape() 3053 weights_rank = weights_shape.ndims 3054 3055 if (predictions_rank is not None) and (weights_rank is not None): 3056 # Use static rank. 3057 if weights_rank - predictions_rank == 1: 3058 weights = array_ops.squeeze(weights, [-1]) 3059 elif (weights_rank is None) or ( 3060 weights_shape.dims[-1].is_compatible_with(1)): 3061 # Use dynamic rank 3062 weights = control_flow_ops.cond( 3063 math_ops.equal(array_ops.rank(weights), 3064 math_ops.add(array_ops.rank(predictions), 1)), 3065 lambda: array_ops.squeeze(weights, [-1]), 3066 lambda: weights) 3067 return predictions, labels, weights 3068 3069 3070__all__ = [ 3071 'aggregate_metric_map', 3072 'aggregate_metrics', 3073 'streaming_accuracy', 3074 'streaming_auc', 3075 'streaming_mean', 3076 'streaming_mean_absolute_error', 3077 'streaming_mean_cosine_distance', 3078 'streaming_mean_iou', 3079 'streaming_mean_relative_error', 3080 'streaming_mean_squared_error', 3081 'streaming_mean_tensor', 3082 'streaming_percentage_less', 3083 'streaming_precision', 3084 'streaming_precision_at_thresholds', 3085 'streaming_recall', 3086 'streaming_recall_at_k', 3087 'streaming_recall_at_thresholds', 3088 'streaming_root_mean_squared_error', 3089 'streaming_sensitivity_at_specificity', 3090 'streaming_sparse_average_precision_at_k', 3091 'streaming_sparse_precision_at_k', 3092 'streaming_sparse_recall_at_k', 3093 'streaming_specificity_at_sensitivity', 3094] 3095