metric_ops.py revision 4fb94d38b25c73d3bd8bd6a0ea632eb34c2787d6
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.framework import deprecated_args 26from tensorflow.contrib.framework.python.ops import variables as contrib_variables 27from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops 28from tensorflow.contrib.metrics.python.ops import metric_ops_util 29from tensorflow.contrib.metrics.python.ops import set_ops 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nn 36from tensorflow.python.ops import sparse_ops 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.ops import variables 40from tensorflow.python.util.all_util import make_all 41 42 43IGNORE_MASK_DATE = '2016-10-19' 44IGNORE_MASK_INSTRUCTIONS = ( 45 '`ignore_mask` is being deprecated. Instead use `weights` with values 0.0 ' 46 'and 1.0 to mask values. For example, `weights=tf.logical_not(mask)`.') 47 48 49def _mask_weights(mask=None, weights=None): 50 """Mask a given set of weights. 51 52 Elements are included when the corresponding `mask` element is `False`, and 53 excluded otherwise. 54 55 Args: 56 mask: An optional, `bool` `Tensor`. 57 weights: An optional `Tensor` whose shape matches `mask` if `mask` is not 58 `None`. 59 60 Returns: 61 Masked weights if `mask` and `weights` are not `None`, weights equivalent to 62 `mask` if `weights` is `None`, and otherwise `weights`. 63 64 Raises: 65 ValueError: If `weights` and `mask` are not `None` and have mismatched 66 shapes. 67 """ 68 if mask is not None: 69 check_ops.assert_type(mask, dtypes.bool) 70 if weights is None: 71 weights = array_ops.ones_like(mask, dtype=dtypes.float32) 72 weights = math_ops.cast(math_ops.logical_not(mask), weights.dtype) * weights 73 74 return weights 75 76 77def _safe_div(numerator, denominator, name): 78 """Divides two values, returning 0 if the denominator is <= 0. 79 80 Args: 81 numerator: A real `Tensor`. 82 denominator: A real `Tensor`, with dtype matching `numerator`. 83 name: Name for the returned op. 84 85 Returns: 86 0 if `denominator` <= 0, else `numerator` / `denominator` 87 """ 88 return math_ops.select( 89 math_ops.greater(denominator, 0), 90 math_ops.truediv(numerator, denominator), 91 0, 92 name=name) 93 94 95def _create_local(name, shape=None, collections=None, dtype=dtypes.float32): 96 """Creates a new local variable. 97 98 Args: 99 name: The name of the new or existing variable. 100 shape: Shape of the new or existing variable. 101 collections: A list of collection names to which the Variable will be added. 102 dtype: Data type of the variables. 103 104 Returns: 105 The created variable. 106 """ 107 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 108 collections = list(collections or []) 109 collections += [ops.GraphKeys.LOCAL_VARIABLES] 110 return variables.Variable( 111 initial_value=array_ops.zeros(shape, dtype=dtype), 112 name=name, 113 trainable=False, 114 collections=collections) 115 116 117def _count_condition(values, weights=None, metrics_collections=None, 118 updates_collections=None): 119 """Sums the weights of cases where the given values are True. 120 121 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 122 123 Args: 124 values: A `bool` `Tensor` of arbitrary size. 125 weights: An optional `Tensor` whose shape matches `values`. 126 metrics_collections: An optional list of collections that the metric 127 value variable should be added to. 128 updates_collections: An optional list of collections that the metric update 129 ops should be added to. 130 131 Returns: 132 value_tensor: A tensor representing the current value of the metric. 133 update_op: An operation that accumulates the error from a batch of data. 134 135 Raises: 136 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 137 or if either `metrics_collections` or `updates_collections` are not a list 138 or tuple. 139 """ 140 check_ops.assert_type(values, dtypes.bool) 141 count = _create_local('count', shape=[]) 142 143 values = math_ops.to_float(values) 144 if weights is not None: 145 values.get_shape().assert_is_compatible_with(weights.get_shape()) 146 weights = math_ops.to_float(weights) 147 values = math_ops.mul(values, weights) 148 149 value_tensor = array_ops.identity(count) 150 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 151 152 if metrics_collections: 153 ops.add_to_collections(metrics_collections, value_tensor) 154 155 if updates_collections: 156 ops.add_to_collections(updates_collections, update_op) 157 158 return value_tensor, update_op 159 160 161def _streaming_true_negatives(predictions, labels, weights=None, 162 metrics_collections=None, 163 updates_collections=None, 164 name=None): 165 """Sum the weights of true_negatives. 166 167 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 168 169 Args: 170 predictions: The predicted values, a `bool` `Tensor` of arbitrary 171 dimensions. 172 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 173 match `predictions`. 174 weights: An optional `Tensor` whose shape matches `predictions`. 175 metrics_collections: An optional list of collections that the metric 176 value variable should be added to. 177 updates_collections: An optional list of collections that the metric update 178 ops should be added to. 179 name: An optional variable_scope name. 180 181 Returns: 182 value_tensor: A tensor representing the current value of the metric. 183 update_op: An operation that accumulates the error from a batch of data. 184 185 Raises: 186 ValueError: If `predictions` and `labels` have mismatched shapes, or if 187 `weights` is not `None` and its shape doesn't match `predictions`, or if 188 either `metrics_collections` or `updates_collections` are not a list or 189 tuple. 190 """ 191 with variable_scope.variable_scope( 192 [predictions, labels], name, 'true_negatives'): 193 194 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 195 is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0), 196 math_ops.equal(predictions, 0)) 197 return _count_condition(is_true_negative, weights, metrics_collections, 198 updates_collections) 199 200 201def _streaming_true_positives(predictions, labels, weights=None, 202 metrics_collections=None, 203 updates_collections=None, 204 name=None): 205 """Sum the weights of true_positives. 206 207 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 208 209 Args: 210 predictions: The predicted values, a `bool` `Tensor` of arbitrary 211 dimensions. 212 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 213 match `predictions`. 214 weights: An optional `Tensor` whose shape matches `predictions`. 215 metrics_collections: An optional list of collections that the metric 216 value variable should be added to. 217 updates_collections: An optional list of collections that the metric update 218 ops should be added to. 219 name: An optional variable_scope name. 220 221 Returns: 222 value_tensor: A tensor representing the current value of the metric. 223 update_op: An operation that accumulates the error from a batch of data. 224 225 Raises: 226 ValueError: If `predictions` and `labels` have mismatched shapes, or if 227 `weights` is not `None` and its shape doesn't match `predictions`, or if 228 either `metrics_collections` or `updates_collections` are not a list or 229 tuple. 230 """ 231 with variable_scope.variable_scope( 232 name, 'true_positives', [predictions, labels]): 233 234 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 235 is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1), 236 math_ops.equal(predictions, 1)) 237 return _count_condition(is_true_positive, weights, metrics_collections, 238 updates_collections) 239 240 241def _streaming_false_positives(predictions, labels, weights=None, 242 metrics_collections=None, 243 updates_collections=None, 244 name=None): 245 """Sum the weights of false positives. 246 247 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 248 249 Args: 250 predictions: The predicted values, a `bool` `Tensor` of arbitrary 251 dimensions. 252 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 253 match `predictions`. 254 weights: An optional `Tensor` whose shape matches `predictions`. 255 metrics_collections: An optional list of collections that the metric 256 value variable should be added to. 257 updates_collections: An optional list of collections that the metric update 258 ops should be added to. 259 name: An optional variable_scope name. 260 261 Returns: 262 value_tensor: A tensor representing the current value of the metric. 263 update_op: An operation that accumulates the error from a batch of data. 264 265 Raises: 266 ValueError: If `predictions` and `labels` have mismatched shapes, or if 267 `weights` is not `None` and its shape doesn't match `predictions`, or if 268 either `metrics_collections` or `updates_collections` are not a list or 269 tuple. 270 """ 271 with variable_scope.variable_scope( 272 name, 'false_positives', [predictions, labels]): 273 274 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 275 is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0), 276 math_ops.equal(predictions, 1)) 277 return _count_condition(is_false_positive, weights, metrics_collections, 278 updates_collections) 279 280 281def _streaming_false_negatives(predictions, labels, weights=None, 282 metrics_collections=None, 283 updates_collections=None, 284 name=None): 285 """Computes the total number of false positives. 286 287 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 288 289 Args: 290 predictions: The predicted values, a `bool` `Tensor` of arbitrary 291 dimensions. 292 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 293 match `predictions`. 294 weights: An optional `Tensor` whose shape matches `predictions`. 295 metrics_collections: An optional list of collections that the metric 296 value variable should be added to. 297 updates_collections: An optional list of collections that the metric update 298 ops should be added to. 299 name: An optional variable_scope name. 300 301 Returns: 302 value_tensor: A tensor representing the current value of the metric. 303 update_op: An operation that accumulates the error from a batch of data. 304 305 Raises: 306 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 307 or if either `metrics_collections` or `updates_collections` are not a list 308 or tuple. 309 """ 310 with variable_scope.variable_scope( 311 name, 'false_negatives', [predictions, labels]): 312 313 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 314 is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1), 315 math_ops.equal(predictions, 0)) 316 return _count_condition(is_false_negative, weights, metrics_collections, 317 updates_collections) 318 319 320def streaming_mean(values, weights=None, metrics_collections=None, 321 updates_collections=None, name=None): 322 """Computes the (weighted) mean of the given values. 323 324 The `streaming_mean` function creates two local variables, `total` and `count` 325 that are used to compute the average of `values`. This average is ultimately 326 returned as `mean` which is an idempotent operation that simply divides 327 `total` by `count`. 328 329 For estimation of the metric over a stream of data, the function creates an 330 `update_op` operation that updates these variables and returns the `mean`. 331 `update_op` increments `total` with the reduced sum of the product of `values` 332 and `weights`, and it increments `count` with the reduced sum of `weights`. 333 334 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 335 336 Args: 337 values: A `Tensor` of arbitrary dimensions. 338 weights: An optional `Tensor` whose shape matches `values`. 339 metrics_collections: An optional list of collections that `mean` 340 should be added to. 341 updates_collections: An optional list of collections that `update_op` 342 should be added to. 343 name: An optional variable_scope name. 344 345 Returns: 346 mean: A tensor representing the current mean, the value of `total` divided 347 by `count`. 348 update_op: An operation that increments the `total` and `count` variables 349 appropriately and whose value matches `mean_value`. 350 351 Raises: 352 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 353 or if either `metrics_collections` or `updates_collections` are not a list 354 or tuple. 355 """ 356 with variable_scope.variable_scope(name, 'mean', [values, weights]): 357 values = math_ops.to_float(values) 358 359 total = _create_local('total', shape=[]) 360 count = _create_local('count', shape=[]) 361 362 if weights is not None: 363 values.get_shape().assert_is_compatible_with(weights.get_shape()) 364 weights = math_ops.to_float(weights) 365 values = math_ops.mul(values, weights) 366 num_values = math_ops.reduce_sum(weights) 367 else: 368 num_values = math_ops.to_float(array_ops.size(values)) 369 370 total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 371 count_compute_op = state_ops.assign_add(count, num_values) 372 373 mean = _safe_div(total, count, 'value') 374 with ops.control_dependencies([total_compute_op, count_compute_op]): 375 update_op = _safe_div(total, count, 'update_op') 376 377 if metrics_collections: 378 ops.add_to_collections(metrics_collections, mean) 379 380 if updates_collections: 381 ops.add_to_collections(updates_collections, update_op) 382 383 return mean, update_op 384 385 386def streaming_mean_tensor(values, weights=None, metrics_collections=None, 387 updates_collections=None, name=None): 388 """Computes the element-wise (weighted) mean of the given tensors. 389 390 In contrast to the `streaming_mean` function which returns a scalar with the 391 mean, this function returns an average tensor with the same shape as the 392 input tensors. 393 394 The `streaming_mean_tensor` function creates two local variables, 395 `total_tensor` and `count_tensor` that are used to compute the average of 396 `values`. This average is ultimately returned as `mean` which is an idempotent 397 operation that simply divides `total` by `count`. 398 399 For estimation of the metric over a stream of data, the function creates an 400 `update_op` operation that updates these variables and returns the `mean`. 401 `update_op` increments `total` with the reduced sum of the product of `values` 402 and `weights`, and it increments `count` with the reduced sum of `weights`. 403 404 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 405 406 Args: 407 values: A `Tensor` of arbitrary dimensions. 408 weights: An optional `Tensor` whose shape matches `values`. 409 metrics_collections: An optional list of collections that `mean` 410 should be added to. 411 updates_collections: An optional list of collections that `update_op` 412 should be added to. 413 name: An optional variable_scope name. 414 415 Returns: 416 mean: A float tensor representing the current mean, the value of `total` 417 divided by `count`. 418 update_op: An operation that increments the `total` and `count` variables 419 appropriately and whose value matches `mean_value`. 420 421 Raises: 422 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 423 or if either `metrics_collections` or `updates_collections` are not a list 424 or tuple. 425 """ 426 with variable_scope.variable_scope(name, 'mean', [values, weights]): 427 total = _create_local('total_tensor', shape=values.get_shape()) 428 count = _create_local('count_tensor', shape=values.get_shape()) 429 430 if weights is not None: 431 values.get_shape().assert_is_compatible_with(weights.get_shape()) 432 weights = math_ops.to_float(weights) 433 values = math_ops.mul(values, weights) 434 num_values = weights 435 else: 436 num_values = array_ops.ones_like(values) 437 438 total_compute_op = state_ops.assign_add(total, values) 439 count_compute_op = state_ops.assign_add(count, num_values) 440 441 def compute_mean(total, count, name): 442 non_zero_count = math_ops.maximum(count, 443 array_ops.ones_like(count), 444 name=name) 445 return math_ops.truediv(total, non_zero_count, name=name) 446 447 mean = compute_mean(total, count, 'value') 448 with ops.control_dependencies([total_compute_op, count_compute_op]): 449 update_op = compute_mean(total, count, 'update_op') 450 451 if metrics_collections: 452 ops.add_to_collections(metrics_collections, mean) 453 454 if updates_collections: 455 ops.add_to_collections(updates_collections, update_op) 456 457 return mean, update_op 458 459 460def streaming_accuracy(predictions, labels, weights=None, 461 metrics_collections=None, updates_collections=None, 462 name=None): 463 """Calculates how often `predictions` matches `labels`. 464 465 The `streaming_accuracy` function creates two local variables, `total` and 466 `count` that are used to compute the frequency with which `predictions` 467 matches `labels`. This frequency is ultimately returned as `accuracy`: an 468 idempotent operation that simply divides `total` by `count`. 469 470 For estimation of the metric over a stream of data, the function creates an 471 `update_op` operation that updates these variables and returns the `accuracy`. 472 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 473 where the corresponding elements of `predictions` and `labels` match and 0.0 474 otherwise. Then `update_op` increments `total` with the reduced sum of the 475 product of `weights` and `is_correct`, and it increments `count` with the 476 reduced sum of `weights`. 477 478 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 479 480 Args: 481 predictions: The predicted values, a `Tensor` of any shape. 482 labels: The ground truth values, a `Tensor` whose shape matches 483 `predictions`. 484 weights: An optional `Tensor` whose shape matches `predictions`. 485 metrics_collections: An optional list of collections that `accuracy` should 486 be added to. 487 updates_collections: An optional list of collections that `update_op` should 488 be added to. 489 name: An optional variable_scope name. 490 491 Returns: 492 accuracy: A tensor representing the accuracy, the value of `total` divided 493 by `count`. 494 update_op: An operation that increments the `total` and `count` variables 495 appropriately and whose value matches `accuracy`. 496 497 Raises: 498 ValueError: If `predictions` and `labels` have mismatched shapes, or if 499 `weights` is not `None` and its shape doesn't match `predictions`, or if 500 either `metrics_collections` or `updates_collections` are not a list or 501 tuple. 502 """ 503 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 504 predictions, labels) 505 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 506 if labels.dtype != predictions.dtype: 507 predictions = math_ops.cast(predictions, labels.dtype) 508 is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) 509 return streaming_mean(is_correct, weights, metrics_collections, 510 updates_collections, name or 'accuracy') 511 512 513@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 514def streaming_precision(predictions, labels, ignore_mask=None, weights=None, 515 metrics_collections=None, updates_collections=None, 516 name=None): 517 """Computes the precision of the predictions with respect to the labels. 518 519 The `streaming_precision` function creates two local variables, 520 `true_positives` and `false_positives`, that are used to compute the 521 precision. This value is ultimately returned as `precision`, an idempotent 522 operation that simply divides `true_positives` by the sum of `true_positives` 523 and `false_positives`. 524 525 For estimation of the metric over a stream of data, the function creates an 526 `update_op` operation that updates these variables and returns the 527 `precision`. `update_op` weights each prediction by the corresponding value in 528 `weights`. 529 530 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 531 Alternatively, if `ignore_mask` is not `None`, then mask values where 532 `ignore_mask` is `True`. 533 534 Args: 535 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 536 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 537 match `predictions`. 538 ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`. 539 weights: An optional `Tensor` whose shape matches `predictions`. 540 metrics_collections: An optional list of collections that `precision` should 541 be added to. 542 updates_collections: An optional list of collections that `update_op` should 543 be added to. 544 name: An optional variable_scope name. 545 546 Returns: 547 precision: Scalar float `Tensor` with the value of `true_positives` 548 divided by the sum of `true_positives` and `false_positives`. 549 update_op: `Operation` that increments `true_positives` and 550 `false_positives` variables appropriately and whose value matches 551 `precision`. 552 553 Raises: 554 ValueError: If `predictions` and `labels` have mismatched shapes, or if 555 `ignore_mask` is not `None` and its shape doesn't match `predictions`, or 556 if `weights` is not `None` and its shape doesn't match `predictions`, or 557 if either `metrics_collections` or `updates_collections` are not a list or 558 tuple. 559 """ 560 with variable_scope.variable_scope( 561 name, 'precision', [predictions, labels]): 562 563 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 564 predictions, labels) 565 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 566 567 weights = _mask_weights(ignore_mask, weights) 568 true_positives, true_positives_update_op = _streaming_true_positives( 569 predictions, labels, weights, metrics_collections=None, 570 updates_collections=None, name=None) 571 false_positives, false_positives_update_op = _streaming_false_positives( 572 predictions, labels, weights, metrics_collections=None, 573 updates_collections=None, name=None) 574 575 def compute_precision(name): 576 return math_ops.select( 577 math_ops.greater(true_positives + false_positives, 0), 578 math_ops.div(true_positives, true_positives + false_positives), 579 0, 580 name) 581 582 precision = compute_precision('value') 583 with ops.control_dependencies([true_positives_update_op, 584 false_positives_update_op]): 585 update_op = compute_precision('update_op') 586 587 if metrics_collections: 588 ops.add_to_collections(metrics_collections, precision) 589 590 if updates_collections: 591 ops.add_to_collections(updates_collections, update_op) 592 593 return precision, update_op 594 595 596@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 597def streaming_recall(predictions, labels, ignore_mask=None, weights=None, 598 metrics_collections=None, updates_collections=None, 599 name=None): 600 """Computes the recall of the predictions with respect to the labels. 601 602 The `streaming_recall` function creates two local variables, `true_positives` 603 and `false_negatives`, that are used to compute the recall. This value is 604 ultimately returned as `recall`, an idempotent operation that simply divides 605 `true_positives` by the sum of `true_positives` and `false_negatives`. 606 607 For estimation of the metric over a stream of data, the function creates an 608 `update_op` that updates these variables and returns the `recall`. `update_op` 609 weights each prediction by the corresponding value in `weights`. 610 611 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 612 Alternatively, if `ignore_mask` is not `None`, then mask values where 613 `ignore_mask` is `True`. 614 615 Args: 616 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 617 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 618 match `predictions`. 619 ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`. 620 weights: An optional `Tensor` whose shape matches `predictions`. 621 metrics_collections: An optional list of collections that `recall` should 622 be added to. 623 updates_collections: An optional list of collections that `update_op` should 624 be added to. 625 name: An optional variable_scope name. 626 627 Returns: 628 recall: Scalar float `Tensor` with the value of `true_positives` divided 629 by the sum of `true_positives` and `false_negatives`. 630 update_op: `Operation` that increments `true_positives` and 631 `false_negatives` variables appropriately and whose value matches 632 `recall`. 633 634 Raises: 635 ValueError: If `predictions` and `labels` have mismatched shapes, or if 636 `ignore_mask` is not `None` and its shape doesn't match `predictions`, or 637 if `weights` is not `None` and its shape doesn't match `predictions`, or 638 if either `metrics_collections` or `updates_collections` are not a list or 639 tuple. 640 """ 641 with variable_scope.variable_scope(name, 'recall', [predictions, labels]): 642 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 643 predictions, labels) 644 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 645 646 weights = _mask_weights(ignore_mask, weights) 647 true_positives, true_positives_update_op = _streaming_true_positives( 648 predictions, labels, weights, metrics_collections=None, 649 updates_collections=None, name=None) 650 false_negatives, false_negatives_update_op = _streaming_false_negatives( 651 predictions, labels, weights, metrics_collections=None, 652 updates_collections=None, name=None) 653 654 def compute_recall(true_positives, false_negatives, name): 655 return math_ops.select( 656 math_ops.greater(true_positives + false_negatives, 0), 657 math_ops.div(true_positives, true_positives + false_negatives), 658 0, 659 name) 660 661 recall = compute_recall(true_positives, false_negatives, 'value') 662 with ops.control_dependencies([true_positives_update_op, 663 false_negatives_update_op]): 664 update_op = compute_recall(true_positives, false_negatives, 'update_op') 665 666 if metrics_collections: 667 ops.add_to_collections(metrics_collections, recall) 668 669 if updates_collections: 670 ops.add_to_collections(updates_collections, update_op) 671 672 return recall, update_op 673 674 675def _tp_fn_tn_fp(predictions, labels, thresholds, weights=None): 676 """Computes true_positives, false_negatives, true_negatives, false_positives. 677 678 The `_tp_fn_tn_fp` function creates four local variables, `true_positives`, 679 `true_negatives`, `false_positives` and `false_negatives`. 680 `true_positive[i]` is defined as the total weight of values in `predictions` 681 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 682 `false_negatives[i]` is defined as the total weight of values in `predictions` 683 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 684 `true_negatives[i]` is defined as the total weight of values in `predictions` 685 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 686 `false_positives[i]` is defined as the total weight of values in `predictions` 687 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 688 689 For estimation of these metrics over a stream of data, for each metric the 690 function respectively creates an `update_op` operation that updates the 691 variable and returns its value. 692 693 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 694 695 Args: 696 predictions: A floating point `Tensor` of arbitrary shape and whose values 697 are in the range `[0, 1]`. 698 labels: A `bool` `Tensor` whose shape matches `predictions`. 699 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 700 weights: An optional `Tensor` whose shape matches `predictions`. 701 702 Returns: 703 true_positive: A variable of shape [len(thresholds)]. 704 false_negative: A variable of shape [len(thresholds)]. 705 true_negatives: A variable of shape [len(thresholds)]. 706 false_positives: A variable of shape [len(thresholds)]. 707 true_positives_update_op: An operation that increments the `true_positives`. 708 false_negative_update_op: An operation that increments the `false_negative`. 709 true_negatives_update_op: An operation that increments the `true_negatives`. 710 false_positives_update_op: An operation that increments the 711 `false_positives`. 712 713 Raises: 714 ValueError: If `predictions` and `labels` have mismatched shapes, or if 715 `weights` is not `None` and its shape doesn't match `predictions`. 716 """ 717 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 718 predictions, labels) 719 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 720 721 num_thresholds = len(thresholds) 722 723 # Reshape predictions and labels 724 predictions = array_ops.reshape(predictions, [-1, 1]) 725 labels = array_ops.reshape(math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 726 727 # Use static shape if known. 728 num_predictions = predictions.get_shape().as_list()[0] 729 730 # Otherwise use dynamic shape. 731 if num_predictions is None: 732 num_predictions = array_ops.shape(predictions)[0] 733 thresh_tiled = array_ops.tile( 734 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 735 array_ops.pack([1, num_predictions])) 736 737 # Tile the predictions after thresholding them across different thresholds. 738 pred_is_pos = math_ops.greater( 739 array_ops.tile(array_ops.transpose(predictions), [num_thresholds, 1]), 740 thresh_tiled) 741 pred_is_neg = math_ops.logical_not(pred_is_pos) 742 743 # Tile labels by number of thresholds 744 label_is_pos = array_ops.tile(labels, [num_thresholds, 1]) 745 label_is_neg = math_ops.logical_not(label_is_pos) 746 747 true_positives = _create_local('true_positives', shape=[num_thresholds]) 748 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 749 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 750 false_positives = _create_local('false_positives', shape=[num_thresholds]) 751 752 is_true_positive = math_ops.to_float( 753 math_ops.logical_and(label_is_pos, pred_is_pos)) 754 is_false_negative = math_ops.to_float( 755 math_ops.logical_and(label_is_pos, pred_is_neg)) 756 is_false_positive = math_ops.to_float( 757 math_ops.logical_and(label_is_neg, pred_is_pos)) 758 is_true_negative = math_ops.to_float( 759 math_ops.logical_and(label_is_neg, pred_is_neg)) 760 761 if weights is not None: 762 weights = math_ops.to_float(weights) 763 weights_tiled = array_ops.tile( 764 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 765 thresh_tiled.get_shape().assert_is_compatible_with( 766 weights_tiled.get_shape()) 767 is_true_positive *= weights_tiled 768 is_false_negative *= weights_tiled 769 is_false_positive *= weights_tiled 770 is_true_negative *= weights_tiled 771 772 true_positives_update_op = state_ops.assign_add( 773 true_positives, math_ops.reduce_sum(is_true_positive, 1)) 774 false_negatives_update_op = state_ops.assign_add( 775 false_negatives, math_ops.reduce_sum(is_false_negative, 1)) 776 true_negatives_update_op = state_ops.assign_add( 777 true_negatives, math_ops.reduce_sum(is_true_negative, 1)) 778 false_positives_update_op = state_ops.assign_add( 779 false_positives, math_ops.reduce_sum(is_false_positive, 1)) 780 781 return (true_positives, false_negatives, true_negatives, false_positives, 782 true_positives_update_op, false_negatives_update_op, 783 true_negatives_update_op, false_positives_update_op) 784 785 786def streaming_auc(predictions, labels, weights=None, num_thresholds=200, 787 metrics_collections=None, updates_collections=None, 788 curve='ROC', name=None): 789 """Computes the approximate AUC via a Riemann sum. 790 791 The `streaming_auc` function creates four local variables, `true_positives`, 792 `true_negatives`, `false_positives` and `false_negatives` that are used to 793 compute the AUC. To discretize the AUC curve, a linearly spaced set of 794 thresholds is used to compute pairs of recall and precision values. The area 795 under the ROC-curve is therefore computed using the height of the recall 796 values by the false positive rate, while the area under the PR-curve is the 797 computed using the height of the precision values by the recall. 798 799 This value is ultimately returned as `auc`, an idempotent operation that 800 computes the area under a discretized curve of precision versus recall values 801 (computed using the afformentioned variables). The `num_thresholds` variable 802 controls the degree of discretization with larger numbers of thresholds more 803 closely approximating the true AUC. 804 805 For estimation of the metric over a stream of data, the function creates an 806 `update_op` operation that updates these variables and returns the `auc`. 807 808 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 809 810 Args: 811 predictions: A floating point `Tensor` of arbitrary shape and whose values 812 are in the range `[0, 1]`. 813 labels: A `bool` `Tensor` whose shape matches `predictions`. 814 weights: An optional `Tensor` whose shape matches `predictions`. 815 num_thresholds: The number of thresholds to use when discretizing the roc 816 curve. 817 metrics_collections: An optional list of collections that `auc` should be 818 added to. 819 updates_collections: An optional list of collections that `update_op` should 820 be added to. 821 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 822 'PR' for the Precision-Recall-curve. 823 name: An optional variable_scope name. 824 825 Returns: 826 auc: A scalar tensor representing the current area-under-curve. 827 update_op: An operation that increments the `true_positives`, 828 `true_negatives`, `false_positives` and `false_negatives` variables 829 appropriately and whose value matches `auc`. 830 831 Raises: 832 ValueError: If `predictions` and `labels` have mismatched shapes, or if 833 `weights` is not `None` and its shape doesn't match `predictions`, or if 834 either `metrics_collections` or `updates_collections` are not a list or 835 tuple. 836 """ 837 with variable_scope.variable_scope(name, 'auc', [predictions, labels]): 838 if curve != 'ROC' and curve != 'PR': 839 raise ValueError('curve must be either ROC or PR, %s unknown' % 840 (curve)) 841 kepsilon = 1e-7 # to account for floating point imprecisions 842 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 843 for i in range(num_thresholds-2)] 844 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 845 846 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 847 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 848 849 # Add epsilons to avoid dividing by 0. 850 epsilon = 1.0e-6 851 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 852 853 def compute_auc(tp, fn, tn, fp, name): 854 """Computes the roc-auc or pr-auc based on confusion counts.""" 855 recall = math_ops.div(tp + epsilon, tp + fn + epsilon) 856 if curve == 'ROC': 857 fp_rate = math_ops.div(fp, fp + tn + epsilon) 858 x = fp_rate 859 y = recall 860 else: # curve == 'PR'. 861 precision = math_ops.div(tp + epsilon, tp + fp + epsilon) 862 x = recall 863 y = precision 864 return math_ops.reduce_sum(math_ops.mul( 865 x[:num_thresholds - 1] - x[1:], 866 (y[:num_thresholds - 1] + y[1:]) / 2.), name=name) 867 868 # sum up the areas of all the trapeziums 869 auc = compute_auc(tp, fn, tn, fp, 'value') 870 update_op = compute_auc( 871 tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op') 872 873 if metrics_collections: 874 ops.add_to_collections(metrics_collections, auc) 875 876 if updates_collections: 877 ops.add_to_collections(updates_collections, update_op) 878 879 return auc, update_op 880 881 882def streaming_specificity_at_sensitivity( 883 predictions, labels, sensitivity, weights=None, num_thresholds=200, 884 metrics_collections=None, updates_collections=None, name=None): 885 """Computes the the specificity at a given sensitivity. 886 887 The `streaming_specificity_at_sensitivity` function creates four local 888 variables, `true_positives`, `true_negatives`, `false_positives` and 889 `false_negatives` that are used to compute the specificity at the given 890 sensitivity value. The threshold for the given sensitivity value is computed 891 and used to evaluate the corresponding specificity. 892 893 For estimation of the metric over a stream of data, the function creates an 894 `update_op` operation that updates these variables and returns the 895 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 896 `false_positives` and `false_negatives` counts with the weight of each case 897 found in the `predictions` and `labels`. 898 899 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 900 901 For additional information about specificity and sensitivity, see the 902 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 903 904 Args: 905 predictions: A floating point `Tensor` of arbitrary shape and whose values 906 are in the range `[0, 1]`. 907 labels: A `bool` `Tensor` whose shape matches `predictions`. 908 sensitivity: A scalar value in range `[0, 1]`. 909 weights: An optional `Tensor` whose shape matches `predictions`. 910 num_thresholds: The number of thresholds to use for matching the given 911 sensitivity. 912 metrics_collections: An optional list of collections that `specificity` 913 should be added to. 914 updates_collections: An optional list of collections that `update_op` should 915 be added to. 916 name: An optional variable_scope name. 917 918 Returns: 919 specificity: A scalar tensor representing the specificity at the given 920 `specificity` value. 921 update_op: An operation that increments the `true_positives`, 922 `true_negatives`, `false_positives` and `false_negatives` variables 923 appropriately and whose value matches `specificity`. 924 925 Raises: 926 ValueError: If `predictions` and `labels` have mismatched shapes, if 927 `weights` is not `None` and its shape doesn't match `predictions`, or if 928 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 929 or `updates_collections` are not a list or tuple. 930 """ 931 if sensitivity < 0 or sensitivity > 1: 932 raise ValueError('`sensitivity` must be in the range [0, 1].') 933 934 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 935 [predictions, labels]): 936 kepsilon = 1e-7 # to account for floating point imprecisions 937 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 938 for i in range(num_thresholds-2)] 939 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 940 941 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 942 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 943 944 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 945 946 def compute_specificity_at_sensitivity(name): 947 """Computes the specificity at the given sensitivity. 948 949 Args: 950 name: The name of the operation. 951 952 Returns: 953 The specificity using the aggregated values. 954 """ 955 sensitivities = math_ops.div(tp, tp + fn + kepsilon) 956 957 # We'll need to use this trick until tf.argmax allows us to specify 958 # whether we should use the first or last index in case of ties. 959 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 960 indices_at_minval = math_ops.equal( 961 math_ops.abs(sensitivities - sensitivity), min_val) 962 indices_at_minval = math_ops.to_int64(indices_at_minval) 963 indices_at_minval = math_ops.cumsum(indices_at_minval) 964 tf_index = math_ops.argmax(indices_at_minval, 0) 965 tf_index = math_ops.cast(tf_index, dtypes.int32) 966 967 # Now, we have the implicit threshold, so compute the specificity: 968 return math_ops.div(tn[tf_index], 969 tn[tf_index] + fp[tf_index] + kepsilon, 970 name) 971 972 specificity = compute_specificity_at_sensitivity('value') 973 with ops.control_dependencies( 974 [tp_update_op, fn_update_op, tn_update_op, fp_update_op]): 975 update_op = compute_specificity_at_sensitivity('update_op') 976 977 if metrics_collections: 978 ops.add_to_collections(metrics_collections, specificity) 979 980 if updates_collections: 981 ops.add_to_collections(updates_collections, update_op) 982 983 return specificity, update_op 984 985 986def streaming_sensitivity_at_specificity( 987 predictions, labels, specificity, weights=None, num_thresholds=200, 988 metrics_collections=None, updates_collections=None, name=None): 989 """Computes the the specificity at a given sensitivity. 990 991 The `streaming_sensitivity_at_specificity` function creates four local 992 variables, `true_positives`, `true_negatives`, `false_positives` and 993 `false_negatives` that are used to compute the sensitivity at the given 994 specificity value. The threshold for the given specificity value is computed 995 and used to evaluate the corresponding sensitivity. 996 997 For estimation of the metric over a stream of data, the function creates an 998 `update_op` operation that updates these variables and returns the 999 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 1000 `false_positives` and `false_negatives` counts with the weight of each case 1001 found in the `predictions` and `labels`. 1002 1003 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1004 1005 For additional information about specificity and sensitivity, see the 1006 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1007 1008 Args: 1009 predictions: A floating point `Tensor` of arbitrary shape and whose values 1010 are in the range `[0, 1]`. 1011 labels: A `bool` `Tensor` whose shape matches `predictions`. 1012 specificity: A scalar value in range `[0, 1]`. 1013 weights: An optional `Tensor` whose shape matches `predictions`. 1014 num_thresholds: The number of thresholds to use for matching the given 1015 specificity. 1016 metrics_collections: An optional list of collections that `sensitivity` 1017 should be added to. 1018 updates_collections: An optional list of collections that `update_op` should 1019 be added to. 1020 name: An optional variable_scope name. 1021 1022 Returns: 1023 sensitivity: A scalar tensor representing the sensitivity at the given 1024 `specificity` value. 1025 update_op: An operation that increments the `true_positives`, 1026 `true_negatives`, `false_positives` and `false_negatives` variables 1027 appropriately and whose value matches `sensitivity`. 1028 1029 Raises: 1030 ValueError: If `predictions` and `labels` have mismatched shapes, if 1031 `weights` is not `None` and its shape doesn't match `predictions`, or if 1032 `specificity` is not between 0 and 1, or if either `metrics_collections` 1033 or `updates_collections` are not a list or tuple. 1034 """ 1035 if specificity < 0 or specificity > 1: 1036 raise ValueError('`specificity` must be in the range [0, 1].') 1037 1038 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 1039 [predictions, labels]): 1040 kepsilon = 1e-7 # to account for floating point imprecisions 1041 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1042 for i in range(num_thresholds-2)] 1043 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 1044 1045 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 1046 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 1047 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 1048 1049 def compute_sensitivity_at_specificity(name): 1050 specificities = math_ops.div(tn, tn + fp + kepsilon) 1051 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 1052 tf_index = math_ops.cast(tf_index, dtypes.int32) 1053 1054 # Now, we have the implicit threshold, so compute the sensitivity: 1055 return math_ops.div(tp[tf_index], 1056 tp[tf_index] + fn[tf_index] + kepsilon, 1057 name) 1058 1059 sensitivity = compute_sensitivity_at_specificity('value') 1060 with ops.control_dependencies( 1061 [tp_update_op, fn_update_op, tn_update_op, fp_update_op]): 1062 update_op = compute_sensitivity_at_specificity('update_op') 1063 1064 if metrics_collections: 1065 ops.add_to_collections(metrics_collections, sensitivity) 1066 1067 if updates_collections: 1068 ops.add_to_collections(updates_collections, update_op) 1069 1070 return sensitivity, update_op 1071 1072 1073def streaming_precision_at_thresholds(predictions, labels, thresholds, 1074 weights=None, 1075 metrics_collections=None, 1076 updates_collections=None, name=None): 1077 """Computes precision values for different `thresholds` on `predictions`. 1078 1079 The `streaming_precision_at_thresholds` function creates four local variables, 1080 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1081 for various values of thresholds. `precision[i]` is defined as the total 1082 weight of values in `predictions` above `thresholds[i]` whose corresponding 1083 entry in `labels` is `True`, divided by the total weight of values in 1084 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 1085 false_positives[i])`). 1086 1087 For estimation of the metric over a stream of data, the function creates an 1088 `update_op` operation that updates these variables and returns the 1089 `precision`. 1090 1091 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1092 1093 Args: 1094 predictions: A floating point `Tensor` of arbitrary shape and whose values 1095 are in the range `[0, 1]`. 1096 labels: A `bool` `Tensor` whose shape matches `predictions`. 1097 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1098 weights: An optional `Tensor` whose shape matches `predictions`. 1099 metrics_collections: An optional list of collections that `auc` should be 1100 added to. 1101 updates_collections: An optional list of collections that `update_op` should 1102 be added to. 1103 name: An optional variable_scope name. 1104 1105 Returns: 1106 precision: A float tensor of shape [len(thresholds)]. 1107 update_op: An operation that increments the `true_positives`, 1108 `true_negatives`, `false_positives` and `false_negatives` variables that 1109 are used in the computation of `precision`. 1110 1111 Raises: 1112 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1113 `weights` is not `None` and its shape doesn't match `predictions`, or if 1114 either `metrics_collections` or `updates_collections` are not a list or 1115 tuple. 1116 """ 1117 with variable_scope.variable_scope(name, 'precision_at_thresholds', 1118 [predictions, labels]): 1119 1120 # TODO(nsilberman): Replace with only tp and fp, this results in unnecessary 1121 # variable creation. b/30842882 1122 (true_positives, _, _, false_positives, true_positives_compute_op, _, _, 1123 false_positives_compute_op,) = _tp_fn_tn_fp( 1124 predictions, labels, thresholds, weights) 1125 1126 # avoid division by zero 1127 epsilon = 1e-7 1128 def compute_precision(name): 1129 precision = math_ops.div(true_positives, 1130 epsilon + true_positives + false_positives, 1131 name='precision_' + name) 1132 return precision 1133 1134 precision = compute_precision('value') 1135 with ops.control_dependencies([true_positives_compute_op, 1136 false_positives_compute_op]): 1137 update_op = compute_precision('update_op') 1138 1139 if metrics_collections: 1140 ops.add_to_collections(metrics_collections, precision) 1141 1142 if updates_collections: 1143 ops.add_to_collections(updates_collections, update_op) 1144 1145 return precision, update_op 1146 1147 1148def streaming_recall_at_thresholds(predictions, labels, thresholds, 1149 weights=None, metrics_collections=None, 1150 updates_collections=None, name=None): 1151 """Computes various recall values for different `thresholds` on `predictions`. 1152 1153 The `streaming_recall_at_thresholds` function creates four local variables, 1154 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1155 for various values of thresholds. `recall[i]` is defined as the total weight 1156 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1157 `labels` is `True`, divided by the total weight of `True` values in `labels` 1158 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 1159 1160 For estimation of the metric over a stream of data, the function creates an 1161 `update_op` operation that updates these variables and returns the `recall`. 1162 1163 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1164 1165 Args: 1166 predictions: A floating point `Tensor` of arbitrary shape and whose values 1167 are in the range `[0, 1]`. 1168 labels: A `bool` `Tensor` whose shape matches `predictions`. 1169 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1170 weights: An optional `Tensor` whose shape matches `predictions`. 1171 metrics_collections: An optional list of collections that `recall` should be 1172 added to. 1173 updates_collections: An optional list of collections that `update_op` should 1174 be added to. 1175 name: An optional variable_scope name. 1176 1177 Returns: 1178 recall: A float tensor of shape [len(thresholds)]. 1179 update_op: An operation that increments the `true_positives`, 1180 `true_negatives`, `false_positives` and `false_negatives` variables that 1181 are used in the computation of `recall`. 1182 1183 Raises: 1184 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1185 `weights` is not `None` and its shape doesn't match `predictions`, or if 1186 either `metrics_collections` or `updates_collections` are not a list or 1187 tuple. 1188 """ 1189 with variable_scope.variable_scope(name, 'recall_at_thresholds', 1190 [predictions, labels]): 1191 (true_positives, false_negatives, _, _, true_positives_compute_op, 1192 false_negatives_compute_op, _, _,) = _tp_fn_tn_fp( 1193 predictions, labels, thresholds, weights) 1194 1195 # avoid division by zero 1196 epsilon = 1e-7 1197 def compute_recall(name): 1198 recall = math_ops.div(true_positives, 1199 epsilon + true_positives + false_negatives, 1200 name='recall_' + name) 1201 return recall 1202 1203 recall = compute_recall('value') 1204 with ops.control_dependencies([true_positives_compute_op, 1205 false_negatives_compute_op]): 1206 update_op = compute_recall('update_op') 1207 1208 if metrics_collections: 1209 ops.add_to_collections(metrics_collections, recall) 1210 1211 if updates_collections: 1212 ops.add_to_collections(updates_collections, update_op) 1213 1214 return recall, update_op 1215 1216 1217@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 1218def streaming_recall_at_k(predictions, labels, k, ignore_mask=None, 1219 weights=None, metrics_collections=None, 1220 updates_collections=None, name=None): 1221 """Computes the recall@k of the predictions with respect to dense labels. 1222 1223 The `streaming_recall_at_k` function creates two local variables, `total` and 1224 `count`, that are used to compute the recall@k frequency. This frequency is 1225 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1226 divides `total` by `count`. 1227 1228 For estimation of the metric over a stream of data, the function creates an 1229 `update_op` operation that updates these variables and returns the 1230 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 1231 shape [batch_size] whose elements indicate whether or not the corresponding 1232 label is in the top `k` `predictions`. Then `update_op` increments `total` 1233 with the reduced sum of `weights` where `in_top_k` is `True`, and it 1234 increments `count` with the reduced sum of `weights`. 1235 1236 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1237 Alternatively, if `ignore_mask` is not `None`, then mask values where 1238 `ignore_mask` is `True`. 1239 1240 Args: 1241 predictions: A floating point tensor of dimension [batch_size, num_classes] 1242 labels: A tensor of dimension [batch_size] whose type is in `int32`, 1243 `int64`. 1244 k: The number of top elements to look at for computing recall. 1245 ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`. 1246 weights: An optional `Tensor` whose shape matches `predictions`. 1247 metrics_collections: An optional list of collections that `recall_at_k` 1248 should be added to. 1249 updates_collections: An optional list of collections `update_op` should be 1250 added to. 1251 name: An optional variable_scope name. 1252 1253 Returns: 1254 recall_at_k: A tensor representing the recall@k, the fraction of labels 1255 which fall into the top `k` predictions. 1256 update_op: An operation that increments the `total` and `count` variables 1257 appropriately and whose value matches `recall_at_k`. 1258 1259 Raises: 1260 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1261 `ignore_mask` is not `None` and its shape doesn't match `predictions`, or 1262 if `weights` is not `None` and its shape doesn't match `predictions`, or 1263 if either `metrics_collections` or `updates_collections` are not a list or 1264 tuple. 1265 """ 1266 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1267 return streaming_mean(in_top_k, 1268 _mask_weights(ignore_mask, weights), 1269 metrics_collections, 1270 updates_collections, 1271 name or ('recall_at_%d' % k)) 1272 1273 1274# TODO(ptucker): Validate range of values in labels? 1275@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 1276def streaming_sparse_recall_at_k(predictions, 1277 labels, 1278 k, 1279 class_id=None, 1280 ignore_mask=None, 1281 weights=None, 1282 metrics_collections=None, 1283 updates_collections=None, 1284 name=None): 1285 """Computes recall@k of the predictions with respect to sparse labels. 1286 1287 If `class_id` is specified, we calculate recall by considering only the 1288 entries in the batch for which `class_id` is in the label, and computing 1289 the fraction of them for which `class_id` is in the top-k `predictions`. 1290 If `class_id` is not specified, we'll calculate recall as how often on 1291 average a class among the labels of a batch entry is in the top-k 1292 `predictions`. 1293 1294 `streaming_sparse_recall_at_k` creates two local variables, 1295 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1296 the recall_at_k frequency. This frequency is ultimately returned as 1297 `recall_at_<k>`: an idempotent operation that simply divides 1298 `true_positive_at_<k>` by total (`true_positive_at_<k>` + `recall_at_<k>`). 1299 1300 For estimation of the metric over a stream of data, the function creates an 1301 `update_op` operation that updates these variables and returns the 1302 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1303 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1304 `labels` calculate the true positives and false negatives weighted by 1305 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1306 `false_negative_at_<k>` using these values. 1307 1308 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1309 Alternatively, if `ignore_mask` is not `None`, then mask values where 1310 `ignore_mask` is `True`. 1311 1312 Args: 1313 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1314 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1315 The final dimension contains the logit values for each class. [D1, ... DN] 1316 must match `labels`. 1317 labels: `int64` `Tensor` or `SparseTensor` with shape 1318 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1319 target classes for the associated prediction. Commonly, N=1 and `labels` 1320 has shape [batch_size, num_labels]. [D1, ... DN] must match `labels`. 1321 Values should be in range [0, num_classes], where num_classes is the last 1322 dimension of `predictions`. 1323 k: Integer, k for @k metric. 1324 class_id: Integer class ID for which we want binary metrics. This should be 1325 in range [0, num_classes], where num_classes is the last dimension of 1326 `predictions`. 1327 ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to 1328 the the first [D1, ... DN] dimensions of `predictions` and `labels`. 1329 weights: An optional `Tensor` whose shape is broadcastable to the the first 1330 [D1, ... DN] dimensions of `predictions` and `labels`. 1331 metrics_collections: An optional list of collections that values should 1332 be added to. 1333 updates_collections: An optional list of collections that updates should 1334 be added to. 1335 name: Name of new update operation, and namespace for other dependant ops. 1336 1337 Returns: 1338 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1339 by the sum of `true_positives` and `false_negatives`. 1340 update_op: `Operation` that increments `true_positives` and 1341 `false_negatives` variables appropriately, and whose value matches 1342 `recall`. 1343 1344 Raises: 1345 ValueError: If `ignore_mask` is not `None` and its shape doesn't match 1346 `predictions`, or if `weights` is not `None` and its shape doesn't match 1347 `predictions`, or if either `metrics_collections` or `updates_collections` 1348 are not a list or tuple. 1349 """ 1350 default_name = 'recall_at_%d' % k 1351 if class_id is not None: 1352 default_name = '%s_class%d' % (default_name, class_id) 1353 1354 with ops.name_scope(name, default_name, [predictions, labels]) as scope: 1355 _, top_k_idx = nn.top_k(predictions, k) 1356 top_k_idx = math_ops.to_int64(top_k_idx) 1357 weights = _mask_weights(ignore_mask, weights) 1358 tp, tp_update = _streaming_sparse_true_positive_at_k( 1359 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1360 weights=weights) 1361 fn, fn_update = _streaming_sparse_false_negative_at_k( 1362 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1363 weights=weights) 1364 1365 metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) 1366 update = math_ops.div( 1367 tp_update, math_ops.add(tp_update, fn_update), name='update') 1368 if metrics_collections: 1369 ops.add_to_collections(metrics_collections, metric) 1370 if updates_collections: 1371 ops.add_to_collections(updates_collections, update) 1372 return metric, update 1373 1374 1375# TODO(ptucker): Validate range of values in labels? 1376@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 1377def streaming_sparse_precision_at_k(predictions, 1378 labels, 1379 k, 1380 class_id=None, 1381 ignore_mask=None, 1382 weights=None, 1383 metrics_collections=None, 1384 updates_collections=None, 1385 name=None): 1386 """Computes precision@k of the predictions with respect to sparse labels. 1387 1388 If `class_id` is specified, we calculate precision by considering only the 1389 entries in the batch for which `class_id` is in the top-k highest 1390 `predictions`, and computing the fraction of them for which `class_id` is 1391 indeed a correct label. 1392 If `class_id` is not specified, we'll calculate precision as how often on 1393 average a class among the top-k classes with the highest predicted values 1394 of a batch entry is correct and can be found in the label for that entry. 1395 1396 `streaming_sparse_precision_at_k` creates two local variables, 1397 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1398 the precision@k frequency. This frequency is ultimately returned as 1399 `precision_at_<k>`: an idempotent operation that simply divides 1400 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1401 `false_positive_at_<k>`). 1402 1403 For estimation of the metric over a stream of data, the function creates an 1404 `update_op` operation that updates these variables and returns the 1405 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1406 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1407 `labels` calculate the true positives and false positives weighted by 1408 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1409 `false_positive_at_<k>` using these values. 1410 1411 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1412 Alternatively, if `ignore_mask` is not `None`, then mask values where 1413 `ignore_mask` is `True`. 1414 1415 Args: 1416 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1417 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1418 The final dimension contains the logit values for each class. [D1, ... DN] 1419 must match `labels`. 1420 labels: `int64` `Tensor` or `SparseTensor` with shape 1421 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1422 target classes for the associated prediction. Commonly, N=1 and `labels` 1423 has shape [batch_size, num_labels]. [D1, ... DN] must match 1424 `predictions_idx`. Values should be in range [0, num_classes], where 1425 num_classes is the last dimension of `predictions`. 1426 k: Integer, k for @k metric. 1427 class_id: Integer class ID for which we want binary metrics. This should be 1428 in range [0, num_classes], where num_classes is the last dimension of 1429 `predictions`. 1430 ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to 1431 the the first [D1, ... DN] dimensions of `predictions` and `labels`. 1432 weights: An optional `Tensor` whose shape is broadcastable to the the first 1433 [D1, ... DN] dimensions of `predictions` and `labels`. 1434 metrics_collections: An optional list of collections that values should 1435 be added to. 1436 updates_collections: An optional list of collections that updates should 1437 be added to. 1438 name: Name of new update operation, and namespace for other dependant ops. 1439 1440 Returns: 1441 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1442 divided by the sum of `true_positives` and `false_positives`. 1443 update_op: `Operation` that increments `true_positives` and 1444 `false_positives` variables appropriately, and whose value matches 1445 `precision`. 1446 1447 Raises: 1448 ValueError: If `ignore_mask` is not `None` and its shape doesn't match 1449 `predictions`, or if `weights` is not `None` and its shape doesn't match 1450 `predictions`, or if either `metrics_collections` or `updates_collections` 1451 are not a list or tuple. 1452 """ 1453 default_name = 'precision_at_%d' % k 1454 if class_id is not None: 1455 default_name = '%s_class%d' % (default_name, class_id) 1456 with ops.name_scope(name, default_name, [predictions, labels]) as scope: 1457 _, top_k_idx = nn.top_k(predictions, k) 1458 top_k_idx = math_ops.to_int64(top_k_idx) 1459 weights = _mask_weights(ignore_mask, weights) 1460 tp, tp_update = _streaming_sparse_true_positive_at_k( 1461 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1462 weights=weights) 1463 fp, fp_update = _streaming_sparse_false_positive_at_k( 1464 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1465 weights=weights) 1466 1467 metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) 1468 update = math_ops.div( 1469 tp_update, math_ops.add(tp_update, fp_update), name='update') 1470 if metrics_collections: 1471 ops.add_to_collections(metrics_collections, metric) 1472 if updates_collections: 1473 ops.add_to_collections(updates_collections, update) 1474 return metric, update 1475 1476 1477def _select_class_id(ids, selected_id): 1478 """Filter all but `selected_id` out of `ids`. 1479 1480 Args: 1481 ids: `int64` `Tensor` or `SparseTensor` of IDs. 1482 selected_id: Int id to select. 1483 1484 Returns: 1485 `SparseTensor` of same dimensions as `ids`, except for the last dimension, 1486 which might be smaller. This contains only the entries equal to 1487 `selected_id`. 1488 """ 1489 if isinstance(ids, ops.SparseTensor): 1490 return sparse_ops.sparse_retain( 1491 ids, math_ops.equal(ids.values, selected_id)) 1492 1493 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 1494 # tf.equal and tf.reduce_any? 1495 1496 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 1497 ids_shape = array_ops.shape(ids) 1498 ids_last_dim = array_ops.size(ids_shape) - 1 1499 filled_selected_id_shape = math_ops.reduced_shape( 1500 ids_shape, array_ops.reshape(ids_last_dim, [1])) 1501 1502 # Intersect `ids` with the selected ID. 1503 filled_selected_id = array_ops.fill( 1504 filled_selected_id_shape, math_ops.to_int64(selected_id)) 1505 return set_ops.set_intersection(filled_selected_id, ids) 1506 1507 1508def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 1509 """If class ID is specified, filter all other classes. 1510 1511 Args: 1512 labels: `int64` `Tensor` or `SparseTensor` with shape 1513 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1514 target classes for the associated prediction. Commonly, N=1 and `labels` 1515 has shape [batch_size, num_labels]. [D1, ... DN] must match 1516 `predictions_idx`. 1517 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 1518 where N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 1519 selected_id: Int id to select. 1520 1521 Returns: 1522 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 1523 """ 1524 if selected_id is None: 1525 return labels, predictions_idx 1526 return (_select_class_id(labels, selected_id), 1527 _select_class_id(predictions_idx, selected_id)) 1528 1529 1530def _streaming_sparse_true_positive_at_k(predictions_idx, 1531 labels, 1532 k, 1533 class_id=None, 1534 weights=None, 1535 name=None): 1536 """Calculates weighted per step true positives for recall@k and precision@k. 1537 1538 If `class_id` is specified, calculate binary true positives for `class_id` 1539 only. 1540 If `class_id` is not specified, calculate metrics for `k` predicted vs 1541 `n` label classes, where `n` is the 2nd dimension of `labels`. 1542 1543 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1544 1545 Args: 1546 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1547 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1548 match `labels`. 1549 labels: `int64` `Tensor` or `SparseTensor` with shape 1550 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1551 target classes for the associated prediction. Commonly, N=1 and `labels` 1552 has shape [batch_size, num_labels]. [D1, ... DN] must match 1553 `predictions_idx`. 1554 k: Integer, k for @k metric. This is only used for default op name. 1555 class_id: Class for which we want binary metrics. 1556 weights: `Tensor` whose shape is broadcastable to the the first [D1, ... DN] 1557 dimensions of `predictions_idx` and `labels`. 1558 name: Name of new variable, and namespace for other dependant ops. 1559 1560 Returns: 1561 A tuple of `Variable` and update `Operation`. 1562 1563 Raises: 1564 ValueError: If `weights` is not `None` and has an incomptable shape. 1565 """ 1566 default_name = 'true_positive_at_%d' % k 1567 if class_id is not None: 1568 default_name = '%s_class%d' % (default_name, class_id) 1569 with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope: 1570 labels, predictions_idx = _maybe_select_class_id(labels, 1571 predictions_idx, 1572 class_id) 1573 tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels)) 1574 tp = math_ops.to_double(tp) 1575 if weights is not None: 1576 tp.get_shape().assert_is_compatible_with(weights.get_shape()) 1577 weights = math_ops.to_double(weights) 1578 tp = math_ops.mul(tp, weights) 1579 batch_total_tp = math_ops.reduce_sum(tp) 1580 1581 var = contrib_variables.local_variable( 1582 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1583 return var, state_ops.assign_add(var, batch_total_tp, name='update') 1584 1585 1586def _streaming_sparse_false_positive_at_k(predictions_idx, 1587 labels, 1588 k, 1589 class_id=None, 1590 weights=None, 1591 name=None): 1592 """Calculates weighted per step false positives for precision@k. 1593 1594 If `class_id` is specified, calculate binary true positives for `class_id` 1595 only. 1596 If `class_id` is not specified, calculate metrics for `k` predicted vs 1597 `n` label classes, where `n` is the 2nd dimension of `labels`. 1598 1599 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1600 1601 Args: 1602 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1603 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1604 match `labels`. 1605 labels: `int64` `Tensor` or `SparseTensor` with shape 1606 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1607 target classes for the associated prediction. Commonly, N=1 and `labels` 1608 has shape [batch_size, num_labels]. [D1, ... DN] must match 1609 `predictions_idx`. 1610 k: Integer, k for @k metric. This is only used for default op name. 1611 class_id: Class for which we want binary metrics. 1612 weights: `Tensor` whose shape is broadcastable to the the first [D1, ... DN] 1613 dimensions of `predictions_idx` and `labels`. 1614 name: Name of new variable, and namespace for other dependant ops. 1615 1616 Returns: 1617 A tuple of `Variable` and update `Operation`. 1618 1619 Raises: 1620 ValueError: If `weights` is not `None` and has an incomptable shape. 1621 """ 1622 default_name = 'false_positive_at_%d' % k 1623 if class_id is not None: 1624 default_name = '%s_class%d' % (default_name, class_id) 1625 with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope: 1626 labels, predictions_idx = _maybe_select_class_id(labels, 1627 predictions_idx, 1628 class_id) 1629 fp = set_ops.set_size(set_ops.set_difference(predictions_idx, 1630 labels, 1631 aminusb=True)) 1632 fp = math_ops.to_double(fp) 1633 if weights is not None: 1634 fp.get_shape().assert_is_compatible_with(weights.get_shape()) 1635 weights = math_ops.to_double(weights) 1636 fp = math_ops.mul(fp, weights) 1637 batch_total_fp = math_ops.reduce_sum(fp) 1638 1639 var = contrib_variables.local_variable( 1640 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1641 return var, state_ops.assign_add(var, batch_total_fp, name='update') 1642 1643 1644def _streaming_sparse_false_negative_at_k(predictions_idx, 1645 labels, 1646 k, 1647 class_id=None, 1648 weights=None, 1649 name=None): 1650 """Calculates weighted per step false negatives for recall@k. 1651 1652 If `class_id` is specified, calculate binary true positives for `class_id` 1653 only. 1654 If `class_id` is not specified, calculate metrics for `k` predicted vs 1655 `n` label classes, where `n` is the 2nd dimension of `labels`. 1656 1657 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1658 1659 Args: 1660 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1661 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1662 match `labels`. 1663 labels: `int64` `Tensor` or `SparseTensor` with shape 1664 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1665 target classes for the associated prediction. Commonly, N=1 and `labels` 1666 has shape [batch_size, num_labels]. [D1, ... DN] must match 1667 `predictions_idx`. 1668 k: Integer, k for @k metric. This is only used for default op name. 1669 class_id: Class for which we want binary metrics. 1670 weights: `Tensor` whose shape is broadcastable to the the first [D1, ... DN] 1671 dimensions of `predictions_idx` and `labels`. 1672 name: Name of new variable, and namespace for other dependant ops. 1673 1674 Returns: 1675 A tuple of `Variable` and update `Operation`. 1676 1677 Raises: 1678 ValueError: If `weights` is not `None` and has an incomptable shape. 1679 """ 1680 default_name = 'false_negative_at_%d' % k 1681 if class_id is not None: 1682 default_name = '%s_class%d' % (default_name, class_id) 1683 with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope: 1684 labels, predictions_idx = _maybe_select_class_id(labels, 1685 predictions_idx, 1686 class_id) 1687 fn = set_ops.set_size(set_ops.set_difference(predictions_idx, 1688 labels, 1689 aminusb=False)) 1690 fn = math_ops.to_double(fn) 1691 if weights is not None: 1692 fn.get_shape().assert_is_compatible_with(weights.get_shape()) 1693 weights = math_ops.to_double(weights) 1694 fn = math_ops.mul(fn, weights) 1695 batch_total_fn = math_ops.reduce_sum(fn) 1696 1697 var = contrib_variables.local_variable( 1698 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1699 return var, state_ops.assign_add(var, batch_total_fn, name='update') 1700 1701 1702def streaming_mean_absolute_error(predictions, labels, weights=None, 1703 metrics_collections=None, 1704 updates_collections=None, 1705 name=None): 1706 """Computes the mean absolute error between the labels and predictions. 1707 1708 The `streaming_mean_absolute_error` function creates two local variables, 1709 `total` and `count` that are used to compute the mean absolute error. This 1710 average is weighted by `weights`, and it is ultimately returned as 1711 `mean_absolute_error`: an idempotent operation that simply divides `total` by 1712 `count`. 1713 1714 For estimation of the metric over a stream of data, the function creates an 1715 `update_op` operation that updates these variables and returns the 1716 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 1717 absolute value of the differences between `predictions` and `labels`. Then 1718 `update_op` increments `total` with the reduced sum of the product of 1719 `weights` and `absolute_errors`, and it increments `count` with the reduced 1720 sum of `weights` 1721 1722 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1723 1724 Args: 1725 predictions: A `Tensor` of arbitrary shape. 1726 labels: A `Tensor` of the same shape as `predictions`. 1727 weights: An optional `Tensor` whose shape matches `predictions`. 1728 metrics_collections: An optional list of collections that 1729 `mean_absolute_error` should be added to. 1730 updates_collections: An optional list of collections that `update_op` should 1731 be added to. 1732 name: An optional variable_scope name. 1733 1734 Returns: 1735 mean_absolute_error: A tensor representing the current mean, the value of 1736 `total` divided by `count`. 1737 update_op: An operation that increments the `total` and `count` variables 1738 appropriately and whose value matches `mean_absolute_error`. 1739 1740 Raises: 1741 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1742 `weights` is not `None` and its shape doesn't match `predictions`, or if 1743 either `metrics_collections` or `updates_collections` are not a list or 1744 tuple. 1745 """ 1746 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1747 predictions, labels) 1748 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1749 absolute_errors = math_ops.abs(predictions - labels) 1750 return streaming_mean(absolute_errors, weights, metrics_collections, 1751 updates_collections, name or 'mean_absolute_error') 1752 1753 1754def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, 1755 metrics_collections=None, 1756 updates_collections=None, 1757 name=None): 1758 """Computes the mean relative error by normalizing with the given values. 1759 1760 The `streaming_mean_relative_error` function creates two local variables, 1761 `total` and `count` that are used to compute the mean relative absolute error. 1762 This average is weighted by `weights`, and it is ultimately returned as 1763 `mean_relative_error`: an idempotent operation that simply divides `total` by 1764 `count`. 1765 1766 For estimation of the metric over a stream of data, the function creates an 1767 `update_op` operation that updates these variables and returns the 1768 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1769 absolute value of the differences between `predictions` and `labels` by the 1770 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1771 product of `weights` and `relative_errors`, and it increments `count` with the 1772 reduced sum of `weights`. 1773 1774 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1775 1776 Args: 1777 predictions: A `Tensor` of arbitrary shape. 1778 labels: A `Tensor` of the same shape as `predictions`. 1779 normalizer: A `Tensor` of the same shape as `predictions`. 1780 weights: An optional `Tensor` whose shape matches `predictions`. 1781 metrics_collections: An optional list of collections that 1782 `mean_relative_error` should be added to. 1783 updates_collections: An optional list of collections that `update_op` should 1784 be added to. 1785 name: An optional variable_scope name. 1786 1787 Returns: 1788 mean_relative_error: A tensor representing the current mean, the value of 1789 `total` divided by `count`. 1790 update_op: An operation that increments the `total` and `count` variables 1791 appropriately and whose value matches `mean_relative_error`. 1792 1793 Raises: 1794 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1795 `weights` is not `None` and its shape doesn't match `predictions`, or if 1796 either `metrics_collections` or `updates_collections` are not a list or 1797 tuple. 1798 """ 1799 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1800 predictions, labels) 1801 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1802 1803 predictions, normalizer = metric_ops_util.remove_squeezable_dimensions( 1804 predictions, normalizer) 1805 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1806 relative_errors = math_ops.select( 1807 math_ops.equal(normalizer, 0.0), 1808 array_ops.zeros_like(labels), 1809 math_ops.div(math_ops.abs(labels - predictions), normalizer)) 1810 return streaming_mean(relative_errors, weights, metrics_collections, 1811 updates_collections, name or 'mean_relative_error') 1812 1813 1814def streaming_mean_squared_error(predictions, labels, weights=None, 1815 metrics_collections=None, 1816 updates_collections=None, 1817 name=None): 1818 """Computes the mean squared error between the labels and predictions. 1819 1820 The `streaming_mean_squared_error` function creates two local variables, 1821 `total` and `count` that are used to compute the mean squared error. 1822 This average is weighted by `weights`, and it is ultimately returned as 1823 `mean_squared_error`: an idempotent operation that simply divides `total` by 1824 `count`. 1825 1826 For estimation of the metric over a stream of data, the function creates an 1827 `update_op` operation that updates these variables and returns the 1828 `mean_squared_error`. Internally, a `squared_error` operation computes the 1829 element-wise square of the difference between `predictions` and `labels`. Then 1830 `update_op` increments `total` with the reduced sum of the product of 1831 `weights` and `squared_error`, and it increments `count` with the reduced sum 1832 of `weights`. 1833 1834 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1835 1836 Args: 1837 predictions: A `Tensor` of arbitrary shape. 1838 labels: A `Tensor` of the same shape as `predictions`. 1839 weights: An optional `Tensor` whose shape matches `predictions`. 1840 metrics_collections: An optional list of collections that 1841 `mean_squared_error` should be added to. 1842 updates_collections: An optional list of collections that `update_op` should 1843 be added to. 1844 name: An optional variable_scope name. 1845 1846 Returns: 1847 mean_squared_error: A tensor representing the current mean, the value of 1848 `total` divided by `count`. 1849 update_op: An operation that increments the `total` and `count` variables 1850 appropriately and whose value matches `mean_squared_error`. 1851 1852 Raises: 1853 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1854 `weights` is not `None` and its shape doesn't match `predictions`, or if 1855 either `metrics_collections` or `updates_collections` are not a list or 1856 tuple. 1857 """ 1858 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1859 predictions, labels) 1860 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1861 squared_error = math_ops.square(labels - predictions) 1862 return streaming_mean(squared_error, weights, metrics_collections, 1863 updates_collections, name or 'mean_squared_error') 1864 1865 1866def streaming_root_mean_squared_error(predictions, labels, weights=None, 1867 metrics_collections=None, 1868 updates_collections=None, 1869 name=None): 1870 """Computes the root mean squared error between the labels and predictions. 1871 1872 The `streaming_root_mean_squared_error` function creates two local variables, 1873 `total` and `count` that are used to compute the root mean squared error. 1874 This average is weighted by `weights`, and it is ultimately returned as 1875 `root_mean_squared_error`: an idempotent operation that takes the square root 1876 of the division of `total` by `count`. 1877 1878 For estimation of the metric over a stream of data, the function creates an 1879 `update_op` operation that updates these variables and returns the 1880 `root_mean_squared_error`. Internally, a `squared_error` operation computes 1881 the element-wise square of the difference between `predictions` and `labels`. 1882 Then `update_op` increments `total` with the reduced sum of the product of 1883 `weights` and `squared_error`, and it increments `count` with the reduced sum 1884 of `weights`. 1885 1886 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1887 1888 Args: 1889 predictions: A `Tensor` of arbitrary shape. 1890 labels: A `Tensor` of the same shape as `predictions`. 1891 weights: An optional `Tensor` whose shape matches `predictions`. 1892 metrics_collections: An optional list of collections that 1893 `root_mean_squared_error` should be added to. 1894 updates_collections: An optional list of collections that `update_op` should 1895 be added to. 1896 name: An optional variable_scope name. 1897 1898 Returns: 1899 root_mean_squared_error: A tensor representing the current mean, the value 1900 of `total` divided by `count`. 1901 update_op: An operation that increments the `total` and `count` variables 1902 appropriately and whose value matches `root_mean_squared_error`. 1903 1904 Raises: 1905 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1906 `weights` is not `None` and its shape doesn't match `predictions`, or if 1907 either `metrics_collections` or `updates_collections` are not a list or 1908 tuple. 1909 """ 1910 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1911 predictions, labels) 1912 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1913 value_tensor, update_op = streaming_mean_squared_error( 1914 predictions, labels, weights, None, None, 1915 name or 'root_mean_squared_error') 1916 1917 root_mean_squared_error = math_ops.sqrt(value_tensor) 1918 with ops.control_dependencies([update_op]): 1919 update_op = math_ops.sqrt(update_op) 1920 1921 if metrics_collections: 1922 ops.add_to_collections(metrics_collections, root_mean_squared_error) 1923 1924 if updates_collections: 1925 ops.add_to_collections(updates_collections, update_op) 1926 1927 return root_mean_squared_error, update_op 1928 1929 1930def streaming_covariance(predictions, 1931 labels, 1932 weights=None, 1933 metrics_collections=None, 1934 updates_collections=None, 1935 name=None): 1936 """Computes the unbiased sample covariance between `predictions` and `labels`. 1937 1938 The `streaming_covariance` function creates four local variables, 1939 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 1940 compute the sample covariance between predictions and labels across multiple 1941 batches of data. The covariance is ultimately returned as an idempotent 1942 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 1943 in order to get an unbiased estimate. 1944 1945 The algorithm used for this online computation is described in 1946 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 1947 Specifically, the formula used to combine two sample comoments is 1948 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 1949 The comoment for a single batch of data is simply 1950 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 1951 1952 If `weights` is not None, then it is used to compute weighted comoments, 1953 means, and count. NOTE: these weights are treated as "frequency weights", as 1954 opposed to "reliability weights". See discussion of the difference on 1955 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 1956 1957 To facilitate the computation of covariance across multiple batches of data, 1958 the function creates an `update_op` operation, which updates underlying 1959 variables and returns the updated covariance. 1960 1961 Args: 1962 predictions: A `Tensor` of arbitrary size. 1963 labels: A `Tensor` of the same size as `predictions`. 1964 weights: An optional set of weights which indicates the frequency with which 1965 an example is sampled. Must be broadcastable with `labels`. 1966 metrics_collections: An optional list of collections that the metric 1967 value variable should be added to. 1968 updates_collections: An optional list of collections that the metric update 1969 ops should be added to. 1970 name: An optional variable_scope name. 1971 1972 Returns: 1973 covariance: A `Tensor` representing the current unbiased sample covariance, 1974 `comoment` / (`count` - 1). 1975 update_op: An operation that updates the local variables appropriately. 1976 1977 Raises: 1978 ValueError: If labels and predictions are of different sizes or if either 1979 `metrics_collections` or `updates_collections` are not a list or tuple. 1980 """ 1981 with variable_scope.variable_scope(name, 'covariance', [predictions, labels]): 1982 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1983 predictions, labels) 1984 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1985 count = _create_local('count', []) 1986 mean_prediction = _create_local('mean_prediction', []) 1987 mean_label = _create_local('mean_label', []) 1988 comoment = _create_local('comoment', []) # C_A in update equation 1989 1990 if weights is None: 1991 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 1992 weighted_predictions = predictions 1993 weighted_labels = labels 1994 else: 1995 batch_count = math_ops.reduce_sum(weights) # n_B in update equation 1996 weighted_predictions = predictions * weights 1997 weighted_labels = labels * weights 1998 1999 update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn 2000 prev_count = update_count - batch_count # n_A in update equation 2001 2002 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 2003 # batch_mean_prediction is E[x_B] in the update equation 2004 batch_mean_prediction = _safe_div( 2005 math_ops.reduce_sum(weighted_predictions), batch_count, 2006 'batch_mean_prediction') 2007 delta_mean_prediction = _safe_div( 2008 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 2009 'delta_mean_prediction') 2010 update_mean_prediction = state_ops.assign_add(mean_prediction, 2011 delta_mean_prediction) 2012 # prev_mean_prediction is E[x_A] in the update equation 2013 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 2014 2015 # batch_mean_label is E[y_B] in the update equation 2016 batch_mean_label = _safe_div( 2017 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 2018 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 2019 update_count, 'delta_mean_label') 2020 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 2021 # prev_mean_label is E[y_A] in the update equation 2022 prev_mean_label = update_mean_label - delta_mean_label 2023 2024 unweighted_batch_coresiduals = ( 2025 (predictions - batch_mean_prediction) * (labels - batch_mean_label)) 2026 # batch_comoment is C_B in the update equation 2027 if weights is None: 2028 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 2029 else: 2030 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * 2031 weights) 2032 2033 # View delta_comoment as = C_AB - C_A in the update equation above. 2034 # Since C_A is stored in a var, by how much do we need to increment that var 2035 # to make the var = C_AB? 2036 delta_comoment = (batch_comoment + 2037 (prev_mean_prediction - batch_mean_prediction) * 2038 (prev_mean_label - batch_mean_label) * 2039 (prev_count * batch_count / update_count)) 2040 update_comoment = state_ops.assign_add(comoment, delta_comoment) 2041 2042 covariance = _safe_div(comoment, count - 1, 'covariance') 2043 with ops.control_dependencies([update_comoment]): 2044 update_op = _safe_div(comoment, count - 1, 'update_op') 2045 2046 if metrics_collections: 2047 ops.add_to_collections(metrics_collections, covariance) 2048 2049 if updates_collections: 2050 ops.add_to_collections(updates_collections, update_op) 2051 2052 return covariance, update_op 2053 2054 2055def streaming_pearson_correlation(predictions, 2056 labels, 2057 weights=None, 2058 metrics_collections=None, 2059 updates_collections=None, 2060 name=None): 2061 """Computes pearson correlation coefficient between `predictions`, `labels`. 2062 2063 The `streaming_pearson_correlation` function delegates to 2064 `streaming_covariance` the tracking of three [co]variances: 2065 - streaming_covariance(predictions, labels), i.e. covariance 2066 - streaming_covariance(predictions, predictions), i.e. variance 2067 - streaming_covariance(labels, labels), i.e. variance 2068 2069 The product-moment correlation ultimately returned is an idempotent operation 2070 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 2071 facilitate correlation computation across multiple batches, the function 2072 groups the `update_op`s of the underlying streaming_covariance and returns an 2073 `update_op`. 2074 2075 If `weights` is not None, then it is used to compute a weighted correlation. 2076 NOTE: these weights are treated as "frequency weights", as opposed to 2077 "reliability weights". See discussion of the difference on 2078 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2079 2080 Args: 2081 predictions: A `Tensor` of arbitrary size. 2082 labels: A `Tensor` of the same size as predictions. 2083 weights: An optional set of weights which indicates the frequency with which 2084 an example is sampled. Must be broadcastable with `labels`. 2085 metrics_collections: An optional list of collections that the metric 2086 value variable should be added to. 2087 updates_collections: An optional list of collections that the metric update 2088 ops should be added to. 2089 name: An optional variable_scope name. 2090 2091 Returns: 2092 pearson_r: A tensor representing the current pearson product-moment 2093 correlation coefficient, the value of 2094 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 2095 update_op: An operation that updates the underlying variables appropriately. 2096 2097 Raises: 2098 ValueError: If labels and predictions are of different sizes or if the 2099 ignore_mask is of the wrong size or if either `metrics_collections` or 2100 `updates_collections` are not a list or tuple. 2101 """ 2102 with variable_scope.variable_scope(name, 'pearson_r', [predictions, labels]): 2103 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 2104 predictions, labels) 2105 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2106 cov, update_cov = streaming_covariance( 2107 predictions, labels, weights=weights, name='covariance') 2108 var_predictions, update_var_predictions = streaming_covariance( 2109 predictions, predictions, weights=weights, name='variance_predictions') 2110 var_labels, update_var_labels = streaming_covariance( 2111 labels, labels, weights=weights, name='variance_labels') 2112 2113 pearson_r = _safe_div( 2114 cov, 2115 math_ops.mul(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 2116 'pearson_r') 2117 with ops.control_dependencies( 2118 [update_cov, update_var_predictions, update_var_labels]): 2119 update_op = _safe_div(update_cov, math_ops.mul( 2120 math_ops.sqrt(update_var_predictions), 2121 math_ops.sqrt(update_var_labels)), 'update_op') 2122 2123 if metrics_collections: 2124 ops.add_to_collections(metrics_collections, pearson_r) 2125 2126 if updates_collections: 2127 ops.add_to_collections(updates_collections, update_op) 2128 2129 return pearson_r, update_op 2130 2131 2132# TODO(nsilberman): add a 'normalized' flag so that the user can request 2133# normalization if the inputs are not normalized. 2134def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, 2135 metrics_collections=None, 2136 updates_collections=None, 2137 name=None): 2138 """Computes the cosine distance between the labels and predictions. 2139 2140 The `streaming_mean_cosine_distance` function creates two local variables, 2141 `total` and `count` that are used to compute the average cosine distance 2142 between `predictions` and `labels`. This average is weighted by `weights`, 2143 and it is ultimately returned as `mean_distance`, which is an idempotent 2144 operation that simply divides `total` by `count`. 2145 2146 For estimation of the metric over a stream of data, the function creates an 2147 `update_op` operation that updates these variables and returns the 2148 `mean_distance`. 2149 2150 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2151 2152 Args: 2153 predictions: A `Tensor` of the same shape as `labels`. 2154 labels: A `Tensor` of arbitrary shape. 2155 dim: The dimension along which the cosine distance is computed. 2156 weights: An optional `Tensor` whose shape matches `predictions`, and whose 2157 dimension `dim` is 1. 2158 metrics_collections: An optional list of collections that the metric 2159 value variable should be added to. 2160 updates_collections: An optional list of collections that the metric update 2161 ops should be added to. 2162 name: An optional variable_scope name. 2163 2164 Returns: 2165 mean_distance: A tensor representing the current mean, the value of `total` 2166 divided by `count`. 2167 update_op: An operation that increments the `total` and `count` variables 2168 appropriately. 2169 2170 Raises: 2171 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2172 `weights` is not `None` and its shape doesn't match `predictions`, or if 2173 either `metrics_collections` or `updates_collections` are not a list or 2174 tuple. 2175 """ 2176 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 2177 predictions, labels) 2178 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2179 radial_diffs = math_ops.mul(predictions, labels) 2180 radial_diffs = math_ops.reduce_sum(radial_diffs, 2181 reduction_indices=[dim,], 2182 keep_dims=True) 2183 mean_distance, update_op = streaming_mean(radial_diffs, weights, 2184 None, 2185 None, 2186 name or 'mean_cosine_distance') 2187 mean_distance = math_ops.sub(1.0, mean_distance) 2188 update_op = math_ops.sub(1.0, update_op) 2189 2190 if metrics_collections: 2191 ops.add_to_collections(metrics_collections, mean_distance) 2192 2193 if updates_collections: 2194 ops.add_to_collections(updates_collections, update_op) 2195 2196 return mean_distance, update_op 2197 2198 2199@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 2200def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None, 2201 metrics_collections=None, 2202 updates_collections=None, 2203 name=None): 2204 """Computes the percentage of values less than the given threshold. 2205 2206 The `streaming_percentage_less` function creates two local variables, 2207 `total` and `count` that are used to compute the percentage of `values` that 2208 fall below `threshold`. This rate is weighted by `weights`, and it is 2209 ultimately returned as `percentage` which is an idempotent operation that 2210 simply divides `total` by `count`. 2211 2212 For estimation of the metric over a stream of data, the function creates an 2213 `update_op` operation that updates these variables and returns the 2214 `percentage`. 2215 2216 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2217 Alternatively, if `ignore_mask` is not `None`, then mask values where 2218 `ignore_mask` is `True`. 2219 2220 Args: 2221 values: A numeric `Tensor` of arbitrary size. 2222 threshold: A scalar threshold. 2223 ignore_mask: An optional, `bool` `Tensor` whose shape matches `values`. 2224 weights: An optional `Tensor` whose shape matches `values`. 2225 metrics_collections: An optional list of collections that the metric 2226 value variable should be added to. 2227 updates_collections: An optional list of collections that the metric update 2228 ops should be added to. 2229 name: An optional variable_scope name. 2230 2231 Returns: 2232 percentage: A tensor representing the current mean, the value of `total` 2233 divided by `count`. 2234 update_op: An operation that increments the `total` and `count` variables 2235 appropriately. 2236 2237 Raises: 2238 ValueError: If `ignore_mask` is not `None` and its shape doesn't match 2239 `values`, or if `weights` is not `None` and its shape doesn't match 2240 `values`, or if either `metrics_collections` or `updates_collections` are 2241 not a list or tuple. 2242 """ 2243 is_below_threshold = math_ops.to_float(math_ops.less(values, threshold)) 2244 return streaming_mean(is_below_threshold, _mask_weights(ignore_mask, weights), 2245 metrics_collections, 2246 updates_collections, 2247 name or 'percentage_below_threshold') 2248 2249 2250@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask') 2251def streaming_mean_iou(predictions, 2252 labels, 2253 num_classes, 2254 ignore_mask=None, 2255 weights=None, 2256 metrics_collections=None, 2257 updates_collections=None, 2258 name=None): 2259 """Calculate per-step mean Intersection-Over-Union (mIOU). 2260 2261 Mean Intersection-Over-Union is a common evaluation metric for 2262 semantic image segmentation, which first computes the IOU for each 2263 semantic class and then computes the average over classes. 2264 IOU is defined as follows: 2265 IOU = true_positive / (true_positive + false_positive + false_negative). 2266 The predictions are accumulated in a confusion matrix, weighted by `weights`, 2267 and mIOU is then calculated from it. 2268 2269 For estimation of the metric over a stream of data, the function creates an 2270 `update_op` operation that updates these variables and returns the `mean_iou`. 2271 2272 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2273 Alternatively, if `ignore_mask` is not `None`, then mask values where 2274 `ignore_mask` is `True`. 2275 2276 Args: 2277 predictions: A tensor of prediction results for semantic labels, whose 2278 shape is [batch size] and type `int32` or `int64`. The tensor will be 2279 flattened, if its rank > 1. 2280 labels: A tensor of ground truth labels with shape [batch size] and of 2281 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 2282 num_classes: The possible number of labels the prediction task can 2283 have. This value must be provided, since a confusion matrix of 2284 dimension = [num_classes, num_classes] will be allocated. 2285 ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`. 2286 weights: An optional `Tensor` whose shape matches `predictions`. 2287 metrics_collections: An optional list of collections that `mean_iou` 2288 should be added to. 2289 updates_collections: An optional list of collections `update_op` should be 2290 added to. 2291 name: An optional variable_scope name. 2292 2293 Returns: 2294 mean_iou: A tensor representing the mean intersection-over-union. 2295 update_op: An operation that increments the confusion matrix. 2296 2297 Raises: 2298 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2299 `ignore_mask` is not `None` and its shape doesn't match `predictions`, or 2300 if `weights` is not `None` and its shape doesn't match `predictions`, or 2301 if either `metrics_collections` or `updates_collections` are not a list or 2302 tuple. 2303 """ 2304 with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]): 2305 # Check if shape is compatible. 2306 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2307 2308 # Local variable to accumulate the predictions in the confusion matrix. 2309 cm_dtype = dtypes.int64 if weights is not None else dtypes.float64 2310 total_cm = _create_local('total_confusion_matrix', 2311 shape=[num_classes, num_classes], dtype=cm_dtype) 2312 2313 # Cast the type to int64 required by confusion_matrix_ops. 2314 predictions = math_ops.to_int64(predictions) 2315 labels = math_ops.to_int64(labels) 2316 num_classes = math_ops.to_int64(num_classes) 2317 2318 # Flatten the input if its rank > 1. 2319 predictions_rank = predictions.get_shape().ndims 2320 if predictions_rank > 1: 2321 predictions = array_ops.reshape(predictions, [-1]) 2322 2323 labels_rank = labels.get_shape().ndims 2324 if labels_rank > 1: 2325 labels = array_ops.reshape(labels, [-1]) 2326 2327 weights = _mask_weights(ignore_mask, weights) 2328 if weights is not None: 2329 weights_rank = weights.get_shape().ndims 2330 if weights_rank > 1: 2331 weights = array_ops.reshape(weights, [-1]) 2332 2333 # Accumulate the prediction to current confusion matrix. 2334 current_cm = confusion_matrix_ops.confusion_matrix( 2335 predictions, labels, num_classes, weights=weights, dtype=cm_dtype) 2336 update_op = state_ops.assign_add(total_cm, current_cm) 2337 2338 def compute_mean_iou(name): 2339 """Compute the mean intersection-over-union via the confusion matrix.""" 2340 sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) 2341 sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) 2342 cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) 2343 denominator = sum_over_row + sum_over_col - cm_diag 2344 2345 # If the value of the denominator is 0, set it to 1 to avoid 2346 # zero division. 2347 denominator = math_ops.select( 2348 math_ops.greater(denominator, 0), 2349 denominator, 2350 array_ops.ones_like(denominator)) 2351 iou = math_ops.div(cm_diag, denominator) 2352 return math_ops.reduce_mean(iou, name=name) 2353 2354 mean_iou = compute_mean_iou('mean_iou') 2355 2356 if metrics_collections: 2357 ops.add_to_collections(metrics_collections, mean_iou) 2358 2359 if updates_collections: 2360 ops.add_to_collections(updates_collections, update_op) 2361 2362 return mean_iou, update_op 2363 2364 2365def aggregate_metrics(*value_update_tuples): 2366 """Aggregates the metric value tensors and update ops into two lists. 2367 2368 Args: 2369 *value_update_tuples: a variable number of tuples, each of which contain the 2370 pair of (value_tensor, update_op) from a streaming metric. 2371 2372 Returns: 2373 a list of value tensors and a list of update ops. 2374 2375 Raises: 2376 ValueError: if `value_update_tuples` is empty. 2377 """ 2378 if not value_update_tuples: 2379 raise ValueError('Expected at least one value_tensor/update_op pair') 2380 value_ops, update_ops = zip(*value_update_tuples) 2381 return list(value_ops), list(update_ops) 2382 2383 2384def aggregate_metric_map(names_to_tuples): 2385 """Aggregates the metric names to tuple dictionary. 2386 2387 This function is useful for pairing metric names with their associated value 2388 and update ops when the list of metrics is long. For example: 2389 2390 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 2391 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 2392 predictions, labels, weights), 2393 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 2394 predictions, labels, labels, weights), 2395 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 2396 predictions, labels, weights), 2397 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 2398 predictions, labels, weights), 2399 }) 2400 2401 Args: 2402 names_to_tuples: a map of metric names to tuples, each of which contain the 2403 pair of (value_tensor, update_op) from a streaming metric. 2404 2405 Returns: 2406 A dictionary from metric names to value ops and a dictionary from metric 2407 names to update ops. 2408 """ 2409 metric_names = names_to_tuples.keys() 2410 value_ops, update_ops = zip(*names_to_tuples.values()) 2411 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 2412 2413 2414__all__ = make_all(__name__) 2415