metric_ops.py revision 6e7c8b76bf2931754d28928623196c8090e87dc0
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.framework.python.ops import variables as contrib_variables 26 27from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops 28from tensorflow.contrib.metrics.python.ops import metric_ops_util 29from tensorflow.contrib.metrics.python.ops import set_ops 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nn 36from tensorflow.python.ops import sparse_ops 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.ops import variables 40from tensorflow.python.util.all_util import make_all 41 42 43def _mask_to_weights(mask=None): 44 """Converts a binary mask to a set of weights. 45 46 Args: 47 mask: A binary `Tensor`. 48 49 Returns: 50 The corresponding set of weights if `mask` is not `None`, otherwise `None`. 51 """ 52 if mask is not None: 53 check_ops.assert_type(mask, dtypes.bool) 54 weights = math_ops.logical_not(mask) 55 else: 56 weights = None 57 return weights 58 59 60def _create_local(name, shape=None, collections=None, dtype=dtypes.float32): 61 """Creates a new local variable. 62 63 Args: 64 name: The name of the new or existing variable. 65 shape: Shape of the new or existing variable. 66 collections: A list of collection names to which the Variable will be added. 67 dtype: Data type of the variables. 68 69 Returns: 70 The created variable. 71 """ 72 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 73 collections = list(collections or []) 74 collections += [ops.GraphKeys.LOCAL_VARIABLES] 75 return variables.Variable( 76 initial_value=array_ops.zeros(shape, dtype=dtype), 77 name=name, 78 trainable=False, 79 collections=collections) 80 81 82def _count_condition(values, ignore_mask=None, metrics_collections=None, 83 updates_collections=None): 84 """Computes the total number of cases where the given values are True. 85 86 Args: 87 values: A binary `Tensor` of arbitrary size. 88 ignore_mask: An optional, binary tensor whose size matches 'values'. 89 metrics_collections: An optional list of collections that the metric 90 value variable should be added to. 91 updates_collections: An optional list of collections that the metric update 92 ops should be added to. 93 94 Returns: 95 value_tensor: A tensor representing the current value of the metric. 96 update_op: An operation that accumulates the error from a batch of data. 97 98 Raises: 99 ValueError: If either `metrics_collections` or `updates_collections` are not 100 a list or tuple. 101 """ 102 check_ops.assert_type(values, dtypes.bool) 103 count = _create_local('count', shape=[]) 104 105 if ignore_mask is not None: 106 values.get_shape().assert_is_compatible_with(ignore_mask.get_shape()) 107 check_ops.assert_type(ignore_mask, dtypes.bool) 108 values = math_ops.select( 109 ignore_mask, 110 array_ops.zeros_like(values), 111 values) 112 values = math_ops.to_float(values) 113 114 value_tensor = array_ops.identity(count) 115 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 116 117 if metrics_collections: 118 ops.add_to_collections(metrics_collections, value_tensor) 119 120 if updates_collections: 121 ops.add_to_collections(updates_collections, update_op) 122 123 return value_tensor, update_op 124 125 126def _streaming_true_negatives(predictions, labels, ignore_mask=None, 127 metrics_collections=None, 128 updates_collections=None, 129 name=None): 130 """Computes the total number of true_negatives. 131 132 Args: 133 predictions: The predicted values, a binary `Tensor` of arbitrary 134 dimensions. 135 labels: The ground truth values, a binary `Tensor` whose dimensions must 136 match `predictions`. 137 ignore_mask: An optional, binary tensor whose size matches 'predictions'. 138 metrics_collections: An optional list of collections that the metric 139 value variable should be added to. 140 updates_collections: An optional list of collections that the metric update 141 ops should be added to. 142 name: An optional variable_scope name. 143 144 Returns: 145 value_tensor: A tensor representing the current value of the metric. 146 update_op: An operation that accumulates the error from a batch of data. 147 148 Raises: 149 ValueError: If either `metrics_collections` or `updates_collections` are not 150 a list or tuple. 151 """ 152 with variable_scope.variable_scope( 153 [predictions, labels], name, 'true_negatives'): 154 155 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 156 is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0), 157 math_ops.equal(predictions, 0)) 158 return _count_condition(is_true_negative, ignore_mask, metrics_collections, 159 updates_collections) 160 161 162def _streaming_true_positives(predictions, labels, ignore_mask=None, 163 metrics_collections=None, 164 updates_collections=None, 165 name=None): 166 """Computes the total number of true_positives. 167 168 Args: 169 predictions: The predicted values, a binary `Tensor` of arbitrary 170 dimensions. 171 labels: The ground truth values, a binary `Tensor` whose dimensions must 172 match `predictions`. 173 ignore_mask: An optional, binary tensor whose size matches 'predictions'. 174 metrics_collections: An optional list of collections that the metric 175 value variable should be added to. 176 updates_collections: An optional list of collections that the metric update 177 ops should be added to. 178 name: An optional variable_scope name. 179 180 Returns: 181 value_tensor: A tensor representing the current value of the metric. 182 update_op: An operation that accumulates the error from a batch of data. 183 184 Raises: 185 ValueError: If either `metrics_collections` or `updates_collections` are not 186 a list or tuple. 187 """ 188 with variable_scope.variable_scope( 189 name, 'true_positives', [predictions, labels]): 190 191 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 192 is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1), 193 math_ops.equal(predictions, 1)) 194 return _count_condition(is_true_positive, ignore_mask, metrics_collections, 195 updates_collections) 196 197 198def _streaming_false_positives(predictions, labels, ignore_mask=None, 199 metrics_collections=None, 200 updates_collections=None, 201 name=None): 202 """Computes the total number of false positives. 203 204 Args: 205 predictions: The predicted values, a binary `Tensor` of arbitrary 206 dimensions. 207 labels: The ground truth values, a binary `Tensor` whose dimensions must 208 match `predictions`. 209 ignore_mask: An optional, binary tensor whose size matches 'predictions'. 210 metrics_collections: An optional list of collections that the metric 211 value variable should be added to. 212 updates_collections: An optional list of collections that the metric update 213 ops should be added to. 214 name: An optional variable_scope name. 215 216 Returns: 217 value_tensor: A tensor representing the current value of the metric. 218 update_op: An operation that accumulates the error from a batch of data. 219 220 Raises: 221 ValueError: If either `metrics_collections` or `updates_collections` are not 222 a list or tuple. 223 """ 224 with variable_scope.variable_scope( 225 name, 'false_positives', [predictions, labels]): 226 227 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 228 is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0), 229 math_ops.equal(predictions, 1)) 230 return _count_condition(is_false_positive, ignore_mask, 231 metrics_collections, updates_collections) 232 233 234def _streaming_false_negatives(predictions, labels, ignore_mask=None, 235 metrics_collections=None, 236 updates_collections=None, 237 name=None): 238 """Computes the total number of false positives. 239 240 Args: 241 predictions: The predicted values, a binary `Tensor` of arbitrary 242 dimensions. 243 labels: The ground truth values, a binary `Tensor` whose dimensions must 244 match `predictions`. 245 ignore_mask: An optional, binary tensor whose size matches 'predictions'. 246 metrics_collections: An optional list of collections that the metric 247 value variable should be added to. 248 updates_collections: An optional list of collections that the metric update 249 ops should be added to. 250 name: An optional variable_scope name. 251 252 Returns: 253 value_tensor: A tensor representing the current value of the metric. 254 update_op: An operation that accumulates the error from a batch of data. 255 256 Raises: 257 ValueError: If either `metrics_collections` or `updates_collections` are not 258 a list or tuple. 259 """ 260 with variable_scope.variable_scope( 261 name, 'false_negatives', [predictions, labels]): 262 263 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 264 is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1), 265 math_ops.equal(predictions, 0)) 266 return _count_condition(is_false_negative, ignore_mask, 267 metrics_collections, updates_collections) 268 269 270def streaming_mean(values, weights=None, metrics_collections=None, 271 updates_collections=None, name=None): 272 """Computes the (weighted) mean of the given values. 273 274 The `streaming_mean` function creates two local variables, `total` and `count` 275 that are used to compute the average of `values`. This average is ultimately 276 returned as `mean` which is an idempotent operation that simply divides 277 `total` by `count`. To facilitate the estimation of a mean over a stream 278 of data, the function creates an `update_op` operation whose behavior is 279 dependent on the value of `weights`. If `weights` is None, then `update_op` 280 increments `total` with the reduced sum of `values` and increments `count` 281 with the number of elements in `values`. If `weights` is not `None`, then 282 `update_op` increments `total` with the reduced sum of the product of `values` 283 and `weights` and increments `count` with the reduced sum of weights. 284 In addition to performing the updates, `update_op` also returns the 285 `mean`. 286 287 Args: 288 values: A `Tensor` of arbitrary dimensions. 289 weights: An optional set of weights of the same shape as `values`. If 290 `weights` is not None, the function computes a weighted mean. 291 metrics_collections: An optional list of collections that `mean` 292 should be added to. 293 updates_collections: An optional list of collections that `update_op` 294 should be added to. 295 name: An optional variable_scope name. 296 297 Returns: 298 mean: A tensor representing the current mean, the value of `total` divided 299 by `count`. 300 update_op: An operation that increments the `total` and `count` variables 301 appropriately and whose value matches `mean_value`. 302 303 Raises: 304 ValueError: If `weights` is not `None` and its shape doesn't match `values` 305 or if either `metrics_collections` or `updates_collections` are not a list 306 or tuple. 307 """ 308 with variable_scope.variable_scope(name, 'mean', [values, weights]): 309 values = math_ops.to_float(values) 310 311 total = _create_local('total', shape=[]) 312 count = _create_local('count', shape=[]) 313 314 if weights is not None: 315 values.get_shape().assert_is_compatible_with(weights.get_shape()) 316 weights = math_ops.to_float(weights) 317 values = math_ops.mul(values, weights) 318 num_values = math_ops.reduce_sum(weights) 319 else: 320 num_values = math_ops.to_float(array_ops.size(values)) 321 322 total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 323 count_compute_op = state_ops.assign_add(count, num_values) 324 325 def compute_mean(total, count, name): 326 return math_ops.select(math_ops.greater(count, 0), 327 math_ops.div(total, count), 328 0, name) 329 330 mean = compute_mean(total, count, 'value') 331 with ops.control_dependencies([total_compute_op, count_compute_op]): 332 update_op = compute_mean(total, count, 'update_op') 333 334 if metrics_collections: 335 ops.add_to_collections(metrics_collections, mean) 336 337 if updates_collections: 338 ops.add_to_collections(updates_collections, update_op) 339 340 return mean, update_op 341 342 343def streaming_mean_tensor(values, weights=None, metrics_collections=None, 344 updates_collections=None, name=None): 345 """Computes the element-wise (weighted) mean of the given tensors. 346 347 In contrast to the `streaming_mean` function which returns a scalar with the 348 mean, this function returns an average tensor with the same shape as the 349 input tensors. 350 351 The `streaming_mean_tensor` function creates two local variables, 352 `total_tensor` and `count_tensor` that are used to compute the average of 353 `values`. This average is ultimately returned as `mean` which is an idempotent 354 operation that simply divides `total` by `count`. To facilitate the estimation 355 of a mean over a stream of data, the function creates an `update_op` operation 356 whose behavior is dependent on the value of `weights`. If `weights` is None, 357 then `update_op` increments `total` with the reduced sum of `values` and 358 increments `count` with the number of elements in `values`. If `weights` is 359 not `None`, then `update_op` increments `total` with the reduced sum of the 360 product of `values` and `weights` and increments `count` with the reduced sum 361 of weights. In addition to performing the updates, `update_op` also returns 362 the `mean`. 363 364 Args: 365 values: A `Tensor` of arbitrary dimensions. 366 weights: An optional set of weights of the same shape as `values`. If 367 `weights` is not None, the function computes a weighted mean. 368 metrics_collections: An optional list of collections that `mean` 369 should be added to. 370 updates_collections: An optional list of collections that `update_op` 371 should be added to. 372 name: An optional variable_scope name. 373 374 Returns: 375 mean: A float tensor representing the current mean, the value of `total` 376 divided by `count`. 377 update_op: An operation that increments the `total` and `count` variables 378 appropriately and whose value matches `mean_value`. 379 380 Raises: 381 ValueError: If `weights` is not `None` and its shape doesn't match `values` 382 or if either `metrics_collections` or `updates_collections` are not a list 383 or tuple. 384 """ 385 with variable_scope.variable_scope(name, 'mean', [values, weights]): 386 total = _create_local('total_tensor', shape=values.get_shape()) 387 count = _create_local('count_tensor', shape=values.get_shape()) 388 389 if weights is not None: 390 values.get_shape().assert_is_compatible_with(weights.get_shape()) 391 weights = math_ops.to_float(weights) 392 values = math_ops.mul(values, weights) 393 num_values = weights 394 else: 395 num_values = array_ops.ones_like(values) 396 397 total_compute_op = state_ops.assign_add(total, values) 398 count_compute_op = state_ops.assign_add(count, num_values) 399 400 def compute_mean(total, count, name): 401 non_zero_count = math_ops.maximum(count, 402 array_ops.ones_like(count), 403 name=name) 404 return math_ops.truediv(total, non_zero_count, name=name) 405 406 mean = compute_mean(total, count, 'value') 407 with ops.control_dependencies([total_compute_op, count_compute_op]): 408 update_op = compute_mean(total, count, 'update_op') 409 410 if metrics_collections: 411 ops.add_to_collections(metrics_collections, mean) 412 413 if updates_collections: 414 ops.add_to_collections(updates_collections, update_op) 415 416 return mean, update_op 417 418 419def streaming_accuracy(predictions, labels, weights=None, 420 metrics_collections=None, updates_collections=None, 421 name=None): 422 """Calculates how often `predictions` matches `labels`. 423 424 The `streaming_accuracy` function creates two local variables, `total` and 425 `count` that are used to compute the frequency with which `predictions` 426 matches `labels`. This frequency is ultimately returned as `accuracy`: an 427 idempotent operation that simply divides `total` by `count`. 428 To facilitate the estimation of the accuracy over a stream of data, the 429 function utilizes two operations. First, an `is_correct` operation that 430 computes a tensor whose shape matches `predictions` and whose elements are 431 set to 1.0 when the corresponding values of `predictions` and `labels match 432 and 0.0 otherwise. Second, an `update_op` operation whose behavior is 433 dependent on the value of `weights`. If `weights` is None, then `update_op` 434 increments `total` with the number of elements of `predictions` that match 435 `labels` and increments `count` with the number of elements in `values`. If 436 `weights` is not `None`, then `update_op` increments `total` with the reduced 437 sum of the product of `weights` and `is_correct` and increments `count` with 438 the reduced sum of `weights`. In addition to performing the updates, 439 `update_op` also returns the `accuracy` value. 440 441 Args: 442 predictions: The predicted values, a `Tensor` of any shape. 443 labels: The ground truth values, a `Tensor` whose shape matches 444 `predictions`. 445 weights: An optional set of weights whose shape matches `predictions` 446 which, when not `None`, produces a weighted mean accuracy. 447 metrics_collections: An optional list of collections that `accuracy` should 448 be added to. 449 updates_collections: An optional list of collections that `update_op` should 450 be added to. 451 name: An optional variable_scope name. 452 453 Returns: 454 accuracy: A tensor representing the accuracy, the value of `total` divided 455 by `count`. 456 update_op: An operation that increments the `total` and `count` variables 457 appropriately and whose value matches `accuracy`. 458 459 Raises: 460 ValueError: If the dimensions of `predictions` and `labels` don't match or 461 if `weight` is not `None` and its shape doesn't match `predictions` or 462 if either `metrics_collections` or `updates_collections` are not 463 a list or tuple. 464 """ 465 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 466 predictions, labels) 467 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 468 is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) 469 return streaming_mean(is_correct, weights, metrics_collections, 470 updates_collections, name or 'accuracy') 471 472 473def streaming_precision(predictions, labels, ignore_mask=None, 474 metrics_collections=None, updates_collections=None, 475 name=None): 476 """Computes the precision of the predictions with respect to the labels. 477 478 The `streaming_precision` function creates two local variables, 479 `true_positives` and `false_positives`, that are used to compute the 480 precision. This value is ultimately returned as `precision`, an idempotent 481 operation that simply divides `true_positives` by the sum of `true_positives` 482 and `false_positives`. To facilitate the calculation of the precision over a 483 stream of data, the function creates an `update_op` operation whose behavior 484 is dependent on the value of `ignore_mask`. If `ignore_mask` is None, then 485 `update_op` increments `true_positives` with the number of elements of 486 `predictions` and `labels` that are both `True` and increments 487 `false_positives` with the number of elements of `predictions` that are `True` 488 whose corresponding `labels` element is `False`. If `ignore_mask` is not 489 `None`, then the increments for `true_positives` and `false_positives` are 490 only computed using elements of `predictions` and `labels` whose corresponding 491 values in `ignore_mask` are `False`. In addition to performing the updates, 492 `update_op` also returns the value of `precision`. 493 494 Args: 495 predictions: The predicted values, a binary `Tensor` of arbitrary shape. 496 labels: The ground truth values, a binary `Tensor` whose dimensions must 497 match `predictions`. 498 ignore_mask: An optional, binary tensor whose size matches `predictions`. 499 metrics_collections: An optional list of collections that `precision` should 500 be added to. 501 updates_collections: An optional list of collections that `update_op` should 502 be added to. 503 name: An optional variable_scope name. 504 505 Returns: 506 precision: Scalar float `Tensor` with the value of `true_positives` 507 divided by the sum of `true_positives` and `false_positives`. 508 update_op: `Operation` that increments `true_positives` and 509 `false_positives` variables appropriately and whose value matches 510 `precision`. 511 512 Raises: 513 ValueError: If the dimensions of `predictions` and `labels` don't match or 514 if `ignore_mask` is not `None` and its shape doesn't match `predictions` 515 or if either `metrics_collections` or `updates_collections` are not a list 516 or tuple. 517 """ 518 with variable_scope.variable_scope( 519 name, 'precision', [predictions, labels]): 520 521 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 522 predictions, labels) 523 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 524 525 true_positives, true_positives_update_op = _streaming_true_positives( 526 predictions, labels, ignore_mask, metrics_collections=None, 527 updates_collections=None, name=None) 528 false_positives, false_positives_update_op = _streaming_false_positives( 529 predictions, labels, ignore_mask, metrics_collections=None, 530 updates_collections=None, name=None) 531 532 def compute_precision(name): 533 return math_ops.select( 534 math_ops.greater(true_positives + false_positives, 0), 535 math_ops.div(true_positives, true_positives + false_positives), 536 0, 537 name) 538 539 precision = compute_precision('value') 540 with ops.control_dependencies([true_positives_update_op, 541 false_positives_update_op]): 542 update_op = compute_precision('update_op') 543 544 if metrics_collections: 545 ops.add_to_collections(metrics_collections, precision) 546 547 if updates_collections: 548 ops.add_to_collections(updates_collections, update_op) 549 550 return precision, update_op 551 552 553def streaming_recall(predictions, labels, ignore_mask=None, 554 metrics_collections=None, updates_collections=None, 555 name=None): 556 """Computes the recall of the predictions with respect to the labels. 557 558 The `streaming_recall` function creates two local variables, 559 `true_positives` and `false_negatives`, that are used to compute the 560 recall. This value is ultimately returned as `recall`, an idempotent 561 operation that simply divides `true_positives` by the sum of `true_positives` 562 and `false_negatives`. To facilitate the calculation of the recall over a 563 stream of data, the function creates an `update_op` operation whose behavior 564 is dependent on the value of `ignore_mask`. If `ignore_mask` is None, then 565 `update_op` increments `true_positives` with the number of elements of 566 `predictions` and `labels` that are both `True` and increments 567 `false_negatives` with the number of elements of `predictions` that are 568 `False` whose corresponding `labels` element is `False`. If `ignore_mask` is 569 not `None`, then the increments for `true_positives` and `false_negatives` are 570 only computed using elements of `predictions` and `labels` whose corresponding 571 values in `ignore_mask` are `False`. In addition to performing the updates, 572 `update_op` also returns the value of `recall`. 573 574 Args: 575 predictions: The predicted values, a binary `Tensor` of arbitrary shape. 576 labels: The ground truth values, a binary `Tensor` whose dimensions must 577 match `predictions`. 578 ignore_mask: An optional, binary tensor whose size matches `predictions`. 579 metrics_collections: An optional list of collections that `recall` should 580 be added to. 581 updates_collections: An optional list of collections that `update_op` should 582 be added to. 583 name: An optional variable_scope name. 584 585 Returns: 586 recall: Scalar float `Tensor` with the value of `true_positives` divided 587 by the sum of `true_positives` and `false_negatives`. 588 update_op: `Operation` that increments `true_positives` and 589 `false_negatives` variables appropriately and whose value matches 590 `recall`. 591 592 Raises: 593 ValueError: If the dimensions of `predictions` and `labels` don't match or 594 if `ignore_mask` is not `None` and its shape doesn't match `predictions` 595 or if either `metrics_collections` or `updates_collections` are not a list 596 or tuple. 597 """ 598 with variable_scope.variable_scope(name, 'recall', [predictions, labels]): 599 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 600 predictions, labels) 601 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 602 603 true_positives, true_positives_update_op = _streaming_true_positives( 604 predictions, labels, ignore_mask, metrics_collections=None, 605 updates_collections=None, name=None) 606 false_negatives, false_negatives_update_op = _streaming_false_negatives( 607 predictions, labels, ignore_mask, metrics_collections=None, 608 updates_collections=None, name=None) 609 610 def compute_recall(true_positives, false_negatives, name): 611 return math_ops.select( 612 math_ops.greater(true_positives + false_negatives, 0), 613 math_ops.div(true_positives, true_positives + false_negatives), 614 0, 615 name) 616 617 recall = compute_recall(true_positives, false_negatives, 'value') 618 with ops.control_dependencies([true_positives_update_op, 619 false_negatives_update_op]): 620 update_op = compute_recall(true_positives, false_negatives, 'update_op') 621 622 if metrics_collections: 623 ops.add_to_collections(metrics_collections, recall) 624 625 if updates_collections: 626 ops.add_to_collections(updates_collections, update_op) 627 628 return recall, update_op 629 630 631def _tp_fn_tn_fp(predictions, labels, thresholds, weights): 632 """Computes true_positives, false_negatives, true_negatives, false_positives. 633 634 The `_tp_fn_tn_fp` function creates four local variables, `true_positives`, 635 `true_negatives`, `false_positives` and `false_negatives`. 636 `true_positive[i]` is defined as the total weight of values in `predictions` 637 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 638 `false_negatives[i]` is defined as the total weight of values in `predictions` 639 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 640 `true_negatives[i]` is defined as the total weight of values in `predictions` 641 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 642 `false_positives[i]` is defined as the total weight of values in `predictions` 643 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 644 645 These four variables are updated through the `update_op`. 646 The streaming behavior is that the values of the variables after a few 647 `update_op`s is the same as if the inputs had been concatenated and a single 648 `update_op` had been performed. 649 650 If `weights` is `None`, all entries are assumed to have weight 1. Note that 651 a weight of 0 effectively discards an entry from consideration. 652 653 Args: 654 predictions: A floating point `Tensor` of arbitrary shape and whose values 655 are in the range `[0, 1]`. 656 labels: A binary `Tensor` whose shape matches `predictions`. 657 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 658 weights: An optional, floating point `Tensor` with the same shape as 659 `predictions`. 660 661 Returns: 662 true_positive: A variable of shape [len(thresholds)]. 663 false_negative: A variable of shape [len(thresholds)]. 664 true_negatives: A variable of shape [len(thresholds)]. 665 false_positives: A variable of shape [len(thresholds)]. 666 true_positives_update_op: An operation that increments the `true_positives`. 667 false_negative_update_op: An operation that increments the `false_negative`. 668 true_negatives_update_op: An operation that increments the `true_negatives`. 669 false_positives_update_op: An operation that increments the 670 `false_positives`. 671 672 Raises: 673 ValueError: If the shape of `predictions` and `labels` do not match or if 674 `weights` is not `None` and its shape doesn't match `predictions` 675 or if either `metrics_collections` or `updates_collections` are not a list 676 or tuple. 677 """ 678 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 679 predictions, labels) 680 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 681 682 num_thresholds = len(thresholds) 683 684 # Reshape predictions and labels 685 predictions = array_ops.reshape(predictions, [-1, 1]) 686 labels = array_ops.reshape(math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 687 688 # Use static shape if known. 689 num_predictions = predictions.get_shape().as_list()[0] 690 691 # Otherwise use dynamic shape. 692 if num_predictions is None: 693 num_predictions = array_ops.shape(predictions)[0] 694 thresh_tiled = array_ops.tile( 695 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 696 array_ops.pack([1, num_predictions])) 697 698 # Tile the predictions after thresholding them across different thresholds. 699 pred_is_pos = math_ops.greater( 700 array_ops.tile(array_ops.transpose(predictions), [num_thresholds, 1]), 701 thresh_tiled) 702 pred_is_neg = math_ops.logical_not(pred_is_pos) 703 704 # Tile labels by number of thresholds 705 label_is_pos = array_ops.tile(labels, [num_thresholds, 1]) 706 label_is_neg = math_ops.logical_not(label_is_pos) 707 708 true_positives = _create_local('true_positives', shape=[num_thresholds]) 709 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 710 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 711 false_positives = _create_local('false_positives', shape=[num_thresholds]) 712 713 is_true_positive = math_ops.to_float( 714 math_ops.logical_and(label_is_pos, pred_is_pos)) 715 is_false_negative = math_ops.to_float( 716 math_ops.logical_and(label_is_pos, pred_is_neg)) 717 is_false_positive = math_ops.to_float( 718 math_ops.logical_and(label_is_neg, pred_is_pos)) 719 is_true_negative = math_ops.to_float( 720 math_ops.logical_and(label_is_neg, pred_is_neg)) 721 722 if weights is not None: 723 weights_tiled = array_ops.tile( 724 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 725 thresh_tiled.get_shape().assert_is_compatible_with( 726 weights_tiled.get_shape()) 727 check_ops.assert_type(weights_tiled, dtypes.float32) 728 is_true_positive *= weights_tiled 729 is_false_negative *= weights_tiled 730 is_false_positive *= weights_tiled 731 is_true_negative *= weights_tiled 732 733 true_positives_update_op = state_ops.assign_add( 734 true_positives, math_ops.reduce_sum(is_true_positive, 1)) 735 false_negatives_update_op = state_ops.assign_add( 736 false_negatives, math_ops.reduce_sum(is_false_negative, 1)) 737 true_negatives_update_op = state_ops.assign_add( 738 true_negatives, math_ops.reduce_sum(is_true_negative, 1)) 739 false_positives_update_op = state_ops.assign_add( 740 false_positives, math_ops.reduce_sum(is_false_positive, 1)) 741 742 return (true_positives, false_negatives, true_negatives, false_positives, 743 true_positives_update_op, false_negatives_update_op, 744 true_negatives_update_op, false_positives_update_op) 745 746 747def streaming_auc(predictions, labels, weights=None, num_thresholds=200, 748 metrics_collections=None, updates_collections=None, 749 curve='ROC', name=None): 750 """Computes the approximate AUC via a Riemann sum. 751 752 The `streaming_auc` function creates four local variables, `true_positives`, 753 `true_negatives`, `false_positives` and `false_negatives` that are used to 754 compute the AUC. To discretize the AUC curve, a linearly spaced set of 755 thresholds is used to compute pairs of recall and precision values. The area 756 under the ROC-curve is therefore computed using the height of the recall 757 values by the false positive rate, while the area under the PR-curve is the 758 computed using the height of the precision values by the recall. 759 760 This value is ultimately returned as `auc`, an idempotent 761 operation the computes the area under a discretized curve of precision versus 762 recall values (computed using the afformentioned variables). The 763 `num_thresholds` variable controls the degree of discretization with larger 764 numbers of thresholds more closely approximating the true AUC. 765 766 To faciliate the estimation of the AUC over a stream of data, the function 767 creates an `update_op` operation. `update_op` increments the 768 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 769 counts with the weighted number of each found in the current `predictions` 770 and `labels` `Tensors`. If `weights` is `None`, it is assumed that all 771 entries have weight 1. Note that a weight of 0 can be used to effectively 772 mask out and ignore specific entries. In addition to performing the updates, 773 `update_op` also returns the `auc`. 774 775 Args: 776 predictions: A floating point `Tensor` of arbitrary shape and whose values 777 are in the range `[0, 1]`. 778 labels: A binary `Tensor` whose shape matches `predictions`. 779 weights: An optional, floating point `Tensor` of same shape as 780 `predictions`. 781 num_thresholds: The number of thresholds to use when discretizing the roc 782 curve. 783 metrics_collections: An optional list of collections that `auc` should be 784 added to. 785 updates_collections: An optional list of collections that `update_op` should 786 be added to. 787 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 788 'PR' for the Precision-Recall-curve. 789 name: An optional variable_scope name. 790 791 Returns: 792 auc: A scalar tensor representing the current area-under-curve. 793 update_op: An operation that increments the `true_positives`, 794 `true_negatives`, `false_positives` and `false_negatives` variables 795 appropriately and whose value matches `auc`. 796 797 Raises: 798 ValueError: If the shape of `predictions` and `labels` do not match or if 799 `weights` is not `None` and its shape doesn't match `predictions` or 800 if either `metrics_collections` or `updates_collections` are not a list or 801 tuple. 802 """ 803 with variable_scope.variable_scope(name, 'auc', [predictions, labels]): 804 if curve != 'ROC' and curve != 'PR': 805 raise ValueError('curve must be either ROC or PR, %s unknown' % 806 (curve)) 807 kepsilon = 1e-7 # to account for floating point imprecisions 808 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 809 for i in range(num_thresholds-2)] 810 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 811 812 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 813 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 814 815 # Add epsilons to avoid dividing by 0. 816 epsilon = 1.0e-6 817 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 818 819 def compute_auc(tp, fn, tn, fp, name): 820 """Computes the roc-auc or pr-auc based on confusion counts.""" 821 recall = math_ops.div(tp + epsilon, tp + fn + epsilon) 822 if curve == 'ROC': 823 fp_rate = math_ops.div(fp, fp + tn + epsilon) 824 x = fp_rate 825 y = recall 826 else: # curve == 'PR'. 827 precision = math_ops.div(tp + epsilon, tp + fp + epsilon) 828 x = recall 829 y = precision 830 return math_ops.reduce_sum(math_ops.mul( 831 x[:num_thresholds - 1] - x[1:], 832 (y[:num_thresholds - 1] + y[1:]) / 2.), name=name) 833 834 # sum up the areas of all the trapeziums 835 auc = compute_auc(tp, fn, tn, fp, 'value') 836 update_op = compute_auc( 837 tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op') 838 839 if metrics_collections: 840 ops.add_to_collections(metrics_collections, auc) 841 842 if updates_collections: 843 ops.add_to_collections(updates_collections, update_op) 844 845 return auc, update_op 846 847 848def streaming_specificity_at_sensitivity( 849 predictions, labels, sensitivity, weights=None, num_thresholds=200, 850 metrics_collections=None, updates_collections=None, name=None): 851 """Computes the the specificity at a given sensitivity. 852 853 The `streaming_specificity_at_sensitivity` function creates four local 854 variables, `true_positives`, `true_negatives`, `false_positives` and 855 `false_negatives` that are used to compute the specificity at the given 856 sensitivity value. The threshold for the given sensitivity value is computed 857 and used to evaluate the corresponding specificity. 858 859 To faciliate the estimation of the metric over a stream of data, the function 860 creates an `update_op` operation. `update_op` increments the 861 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 862 counts with the weighted number of each found in the current `predictions` 863 and `labels` `Tensors`. If `weights` is `None`, it is assumed that all 864 entries have weight 1. Note that a weight of 0 can be used to effectively 865 mask out and ignore specific entries. In addition to performing the updates, 866 `update_op` also returns the `specificity`. 867 868 For additional information about specificity and sensitivity, see the 869 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 870 871 Args: 872 predictions: A floating point `Tensor` of arbitrary shape and whose values 873 are in the range `[0, 1]`. 874 labels: A binary `Tensor` whose shape matches `predictions`. 875 sensitivity: A scalar value in range `[0, 1]`. 876 weights: An optional, floating point `Tensor` of same shape as 877 `predictions`. 878 num_thresholds: The number of thresholds to use for matching the given 879 sensitivity. 880 metrics_collections: An optional list of collections that `specificity` 881 should be added to. 882 updates_collections: An optional list of collections that `update_op` should 883 be added to. 884 name: An optional variable_scope name. 885 886 Returns: 887 specificity: A scalar tensor representing the specificity at the given 888 `specificity` value. 889 update_op: An operation that increments the `true_positives`, 890 `true_negatives`, `false_positives` and `false_negatives` variables 891 appropriately and whose value matches `specificity`. 892 893 Raises: 894 ValueError: If the shape of `predictions` and `labels` do not match or if 895 `weights` is not `None` and its shape doesn't match `predictions` or 896 `sensitivity` is not between 0 and 1 or if either `metrics_collections` or 897 `updates_collections` are not a list or tuple. 898 """ 899 if sensitivity < 0 or sensitivity > 1: 900 raise ValueError('`sensitivity` must be in the range [0, 1].') 901 902 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 903 [predictions, labels]): 904 kepsilon = 1e-7 # to account for floating point imprecisions 905 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 906 for i in range(num_thresholds-2)] 907 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 908 909 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 910 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 911 912 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 913 914 def compute_specificity_at_sensitivity(name): 915 """Computes the specificity at the given sensitivity. 916 917 Args: 918 name: The name of the operation. 919 920 Returns: 921 The specificity using the aggregated values. 922 """ 923 sensitivities = math_ops.div(tp, tp + fn + kepsilon) 924 925 # We'll need to use this trick until tf.argmax allows us to specify 926 # whether we should use the first or last index in case of ties. 927 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 928 indices_at_minval = math_ops.equal( 929 math_ops.abs(sensitivities - sensitivity), min_val) 930 indices_at_minval = math_ops.to_int64(indices_at_minval) 931 indices_at_minval = math_ops.cumsum(indices_at_minval) 932 tf_index = math_ops.argmax(indices_at_minval, 0) 933 tf_index = math_ops.cast(tf_index, dtypes.int32) 934 935 # Now, we have the implicit threshold, so compute the specificity: 936 return math_ops.div(tn[tf_index], 937 tn[tf_index] + fp[tf_index] + kepsilon, 938 name) 939 940 specificity = compute_specificity_at_sensitivity('value') 941 with ops.control_dependencies( 942 [tp_update_op, fn_update_op, tn_update_op, fp_update_op]): 943 update_op = compute_specificity_at_sensitivity('update_op') 944 945 if metrics_collections: 946 ops.add_to_collections(metrics_collections, specificity) 947 948 if updates_collections: 949 ops.add_to_collections(updates_collections, update_op) 950 951 return specificity, update_op 952 953 954def streaming_sensitivity_at_specificity( 955 predictions, labels, specificity, weights=None, num_thresholds=200, 956 metrics_collections=None, updates_collections=None, name=None): 957 """Computes the the specificity at a given sensitivity. 958 959 The `streaming_sensitivity_at_specificity` function creates four local 960 variables, `true_positives`, `true_negatives`, `false_positives` and 961 `false_negatives` that are used to compute the sensitivity at the given 962 specificity value. The threshold for the given specificity value is computed 963 and used to evaluate the corresponding sensitivity. 964 965 To faciliate the estimation of the metric over a stream of data, the function 966 creates an `update_op` operation. `update_op` increments the 967 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 968 counts with the weighted number of each found in the current `predictions` 969 and `labels` `Tensors`. If `weights` is `None`, it is assumed that all 970 entries have weight 1. Note that a weight of 0 can be used to effectively 971 mask out and ignore specific entries. In addition to performing the updates, 972 `update_op` also returns the `sensitivity`. 973 974 For additional information about specificity and sensitivity, see the 975 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 976 977 Args: 978 predictions: A floating point `Tensor` of arbitrary shape and whose values 979 are in the range `[0, 1]`. 980 labels: A binary `Tensor` whose shape matches `predictions`. 981 specificity: A scalar value in range `[0, 1]`. 982 weights: An optional, floating point `Tensor` of same shape as 983 `predictions`. 984 num_thresholds: The number of thresholds to use for matching the given 985 specificity. 986 metrics_collections: An optional list of collections that `sensitivity` 987 should be added to. 988 updates_collections: An optional list of collections that `update_op` should 989 be added to. 990 name: An optional variable_scope name. 991 992 Returns: 993 sensitivity: A scalar tensor representing the sensitivity at the given 994 `specificity` value. 995 update_op: An operation that increments the `true_positives`, 996 `true_negatives`, `false_positives` and `false_negatives` variables 997 appropriately and whose value matches `sensitivity`. 998 999 Raises: 1000 ValueError: If the shape of `predictions` and `labels` do not match or if 1001 `weights` is not `None` and its shape doesn't match `predictions` or 1002 `specificity` is not between 0 and 1 or if either `metrics_collections` or 1003 `updates_collections` are not a list or tuple. 1004 """ 1005 if specificity < 0 or specificity > 1: 1006 raise ValueError('`specificity` must be in the range [0, 1].') 1007 1008 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 1009 [predictions, labels]): 1010 kepsilon = 1e-7 # to account for floating point imprecisions 1011 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1012 for i in range(num_thresholds-2)] 1013 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 1014 1015 (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, 1016 fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights) 1017 assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds 1018 1019 def compute_sensitivity_at_specificity(name): 1020 specificities = math_ops.div(tn, tn + fp + kepsilon) 1021 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 1022 tf_index = math_ops.cast(tf_index, dtypes.int32) 1023 1024 # Now, we have the implicit threshold, so compute the sensitivity: 1025 return math_ops.div(tp[tf_index], 1026 tp[tf_index] + fn[tf_index] + kepsilon, 1027 name) 1028 1029 sensitivity = compute_sensitivity_at_specificity('value') 1030 with ops.control_dependencies( 1031 [tp_update_op, fn_update_op, tn_update_op, fp_update_op]): 1032 update_op = compute_sensitivity_at_specificity('update_op') 1033 1034 if metrics_collections: 1035 ops.add_to_collections(metrics_collections, sensitivity) 1036 1037 if updates_collections: 1038 ops.add_to_collections(updates_collections, update_op) 1039 1040 return sensitivity, update_op 1041 1042 1043def streaming_precision_at_thresholds(predictions, labels, thresholds, 1044 weights=None, 1045 metrics_collections=None, 1046 updates_collections=None, name=None): 1047 """Computes precision values for different `thresholds` on `predictions`. 1048 1049 The `streaming_precision_at_thresholds` function creates four local variables, 1050 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1051 for various values of thresholds. 1052 `precision[i]` is defined as the total weight of values in `predictions` above 1053 `thresholds[i]` whose corresponding entry in `labels` is `True` 1054 (`true_positives[i]`) divided by the total weight of values in `predictions` 1055 above `thresholds[i]` (`true_positives[i] + false_positives[i]`). 1056 1057 If `weights` is `None` then all entries are assumed to have equal weight 1. 1058 1059 `precision` is returned along with an `update_op` whose value equals that of 1060 `precision`. 1061 1062 Args: 1063 predictions: A floating point `Tensor` of arbitrary shape and whose values 1064 are in the range `[0, 1]`. 1065 labels: A binary `Tensor` whose shape matches `predictions`. 1066 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1067 weights: An optional, floating point `Tensor` of same shape as 1068 `predictions`. 1069 metrics_collections: An optional list of collections that `auc` should be 1070 added to. 1071 updates_collections: An optional list of collections that `update_op` should 1072 be added to. 1073 name: An optional variable_scope name. 1074 1075 Returns: 1076 precision: A float tensor of shape [len(thresholds)]. 1077 update_op: An operation that increments the `true_positives`, 1078 `true_negatives`, `false_positives` and `false_negatives` variables that 1079 are used in the computation of `precision`. 1080 1081 Raises: 1082 ValueError: If the shape of `predictions` and `labels` do not match or if 1083 `weights` is not `None` and its shape doesn't match `predictions` 1084 or if either `metrics_collections` or `updates_collections` are not a list 1085 or tuple. 1086 """ 1087 with variable_scope.variable_scope(name, 'precision_at_thresholds', 1088 [predictions, labels]): 1089 1090 # TODO(nsilberman): Replace with only tp and fp, this results in unnecessary 1091 # variable creation. b/30842882 1092 (true_positives, _, _, false_positives, true_positives_compute_op, _, _, 1093 false_positives_compute_op,) = _tp_fn_tn_fp( 1094 predictions, labels, thresholds, weights) 1095 1096 # avoid division by zero 1097 epsilon = 1e-7 1098 def compute_precision(name): 1099 precision = math_ops.div(true_positives, 1100 epsilon + true_positives + false_positives, 1101 name='precision_' + name) 1102 return precision 1103 1104 precision = compute_precision('value') 1105 with ops.control_dependencies([true_positives_compute_op, 1106 false_positives_compute_op]): 1107 update_op = compute_precision('update_op') 1108 1109 if metrics_collections: 1110 ops.add_to_collections(metrics_collections, precision) 1111 1112 if updates_collections: 1113 ops.add_to_collections(updates_collections, update_op) 1114 1115 return precision, update_op 1116 1117 1118def streaming_recall_at_thresholds(predictions, labels, thresholds, 1119 weights=None, metrics_collections=None, 1120 updates_collections=None, name=None): 1121 """Computes various recall values for different `thresholds` on `predictions`. 1122 1123 The `streaming_recall_at_thresholds` function creates four local variables, 1124 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1125 for various values of thresholds. 1126 `recall[i]` is defined as the total weight of values in `predictions` above 1127 `thresholds[i]` whose corresponding entry in `labels` is `True` 1128 (`true_positives[i]`) divided by the total weight of True values in `labels` 1129 (`true_positives[i] + false_negatives[i]`). 1130 1131 If `weights` is `None` then all entries are assumed to have equal weight 1. 1132 1133 `recall` are returned along with an `update_op` whose value equals that of 1134 `recall`. 1135 1136 Args: 1137 predictions: A floating point `Tensor` of arbitrary shape and whose values 1138 are in the range `[0, 1]`. 1139 labels: A binary `Tensor` whose shape matches `predictions`. 1140 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1141 weights: An optional, floating point `Tensor` of same shape as 1142 `predictions`. 1143 metrics_collections: An optional list of collections that `recall` should be 1144 added to. 1145 updates_collections: An optional list of collections that `update_op` should 1146 be added to. 1147 name: An optional variable_scope name. 1148 1149 Returns: 1150 recall: A float tensor of shape [len(thresholds)]. 1151 update_op: An operation that increments the `true_positives`, 1152 `true_negatives`, `false_positives` and `false_negatives` variables that 1153 are used in the computation of `recall`. 1154 1155 Raises: 1156 ValueError: If the shape of `predictions` and `labels` do not match or if 1157 `weights` is not `None` and its shape doesn't match `predictions` 1158 or if either `metrics_collections` or `updates_collections` are not a list 1159 or tuple. 1160 """ 1161 with variable_scope.variable_scope(name, 'recall_at_thresholds', 1162 [predictions, labels]): 1163 (true_positives, false_negatives, _, _, true_positives_compute_op, 1164 false_negatives_compute_op, _, _,) = _tp_fn_tn_fp( 1165 predictions, labels, thresholds, weights) 1166 1167 # avoid division by zero 1168 epsilon = 1e-7 1169 def compute_recall(name): 1170 recall = math_ops.div(true_positives, 1171 epsilon + true_positives + false_negatives, 1172 name='recall_' + name) 1173 return recall 1174 1175 recall = compute_recall('value') 1176 with ops.control_dependencies([true_positives_compute_op, 1177 false_negatives_compute_op]): 1178 update_op = compute_recall('update_op') 1179 1180 if metrics_collections: 1181 ops.add_to_collections(metrics_collections, recall) 1182 1183 if updates_collections: 1184 ops.add_to_collections(updates_collections, update_op) 1185 1186 return recall, update_op 1187 1188 1189def streaming_recall_at_k(predictions, labels, k, ignore_mask=None, 1190 metrics_collections=None, updates_collections=None, 1191 name=None): 1192 """Computes the recall@k of the predictions with respect to dense labels. 1193 1194 The `streaming_recall_at_k` function creates two local variables, `total` and 1195 `count`, that are used to compute the recall@k frequency. This frequency is 1196 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1197 divides `total` by `count`. To facilitate the estimation of recall@k over a 1198 stream of data, the function utilizes two operations. First, an `in_top_k` 1199 operation computes a tensor with shape [batch_size] whose elements indicate 1200 whether or not the corresponding label is in the top `k` predictions of the 1201 `predictions` `Tensor`. Second, an `update_op` operation whose behavior is 1202 dependent on the value of `ignore_mask`. If `ignore_mask` is None, then 1203 `update_op` increments `total` with the number of elements of `in_top_k` that 1204 are set to `True` and increments `count` with the batch size. If `ignore_mask` 1205 is not `None`, then `update_op` increments `total` with the number of elements 1206 in `in_top_k` that are `True` whose corresponding element in `ignore_mask` is 1207 `False`. In addition to performing the updates, `update_op` also returns the 1208 recall value. 1209 1210 Args: 1211 predictions: A floating point tensor of dimension [batch_size, num_classes] 1212 labels: A tensor of dimension [batch_size] whose type is in `int32`, 1213 `int64`. 1214 k: The number of top elements to look at for computing recall. 1215 ignore_mask: An optional, binary tensor whose size matches `labels`. If an 1216 element of `ignore_mask` is True, the corresponding prediction and label 1217 pair is used to compute the metrics. Otherwise, the pair is ignored. 1218 metrics_collections: An optional list of collections that `recall_at_k` 1219 should be added to. 1220 updates_collections: An optional list of collections `update_op` should be 1221 added to. 1222 name: An optional variable_scope name. 1223 1224 Returns: 1225 recall_at_k: A tensor representing the recall@k, the fraction of labels 1226 which fall into the top `k` predictions. 1227 update_op: An operation that increments the `total` and `count` variables 1228 appropriately and whose value matches `recall_at_k`. 1229 1230 Raises: 1231 ValueError: If the dimensions of `predictions` and `labels` don't match or 1232 if `ignore_mask` is not `None` and its shape doesn't match `predictions` 1233 or if either `metrics_collections` or `updates_collections` are not a list 1234 or tuple. 1235 """ 1236 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1237 return streaming_mean(in_top_k, _mask_to_weights(ignore_mask), 1238 metrics_collections, 1239 updates_collections, 1240 name or ('recall_at_%d' % k)) 1241 1242 1243# TODO(ptucker): Validate range of values in labels? 1244def streaming_sparse_recall_at_k(predictions, 1245 labels, 1246 k, 1247 class_id=None, 1248 ignore_mask=None, 1249 metrics_collections=None, 1250 updates_collections=None, 1251 name=None): 1252 """Computes recall@k of the predictions with respect to sparse labels. 1253 1254 If `class_id` is specified, we calculate recall by considering only the 1255 entries in the batch for which `class_id` is in the label, and computing 1256 the fraction of them for which `class_id` is in the top-k `predictions`. 1257 If `class_id` is not specified, we'll calculate recall as how often on 1258 average a class among the labels of a batch entry is in the top-k 1259 `predictions`. 1260 1261 `streaming_sparse_recall_at_k` creates two local variables, 1262 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1263 the recall_at_k frequency. This frequency is ultimately returned as 1264 `recall_at_<k>`: an idempotent operation that simply divides 1265 `true_positive_at_<k>` by total (`true_positive_at_<k>` + `recall_at_<k>`). To 1266 facilitate the estimation of recall@k over a stream of data, the function 1267 utilizes three steps. 1268 * A `top_k` operation computes a tensor whose elements indicate the top `k` 1269 predictions of the `predictions` `Tensor`. 1270 * Set operations are applied to `top_k` and `labels` to calculate true 1271 positives and false negatives. 1272 * An `update_op` operation increments `true_positive_at_<k>` and 1273 `false_negative_at_<k>`. It also returns the recall value. 1274 1275 Args: 1276 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1277 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1278 The final dimension contains the logit values for each class. [D1, ... DN] 1279 must match `labels`. 1280 labels: `int64` `Tensor` or `SparseTensor` with shape 1281 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1282 target classes for the associated prediction. Commonly, N=1 and `labels` 1283 has shape [batch_size, num_labels]. [D1, ... DN] must match `labels`. 1284 Values should be in range [0, num_classes], where num_classes is the last 1285 dimension of `predictions`. 1286 k: Integer, k for @k metric. 1287 class_id: Integer class ID for which we want binary metrics. This should be 1288 in range [0, num_classes], where num_classes is the last dimension of 1289 `predictions`. 1290 ignore_mask: An optional, binary tensor whose shape is broadcastable to the 1291 the first [D1, ... DN] dimensions of `predictions_idx` and `labels`. 1292 metrics_collections: An optional list of collections that values should 1293 be added to. 1294 updates_collections: An optional list of collections that updates should 1295 be added to. 1296 name: Name of new update operation, and namespace for other dependant ops. 1297 1298 Returns: 1299 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1300 by the sum of `true_positives` and `false_negatives`. 1301 update_op: `Operation` that increments `true_positives` and 1302 `false_negatives` variables appropriately, and whose value matches 1303 `recall`. 1304 """ 1305 default_name = 'recall_at_%d' % k 1306 if class_id is not None: 1307 default_name = '%s_class%d' % (default_name, class_id) 1308 1309 with ops.name_scope(name, default_name, [predictions, labels]) as scope: 1310 _, top_k_idx = nn.top_k(predictions, k) 1311 top_k_idx = math_ops.to_int64(top_k_idx) 1312 tp, tp_update = _streaming_sparse_true_positive_at_k( 1313 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1314 ignore_mask=ignore_mask) 1315 fn, fn_update = _streaming_sparse_false_negative_at_k( 1316 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1317 ignore_mask=ignore_mask) 1318 1319 metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) 1320 update = math_ops.div( 1321 tp_update, math_ops.add(tp_update, fn_update), name='update') 1322 if metrics_collections: 1323 ops.add_to_collections(metrics_collections, metric) 1324 if updates_collections: 1325 ops.add_to_collections(updates_collections, update) 1326 return metric, update 1327 1328 1329# TODO(ptucker): Validate range of values in labels? 1330def streaming_sparse_precision_at_k(predictions, 1331 labels, 1332 k, 1333 class_id=None, 1334 ignore_mask=None, 1335 metrics_collections=None, 1336 updates_collections=None, 1337 name=None): 1338 """Computes precision@k of the predictions with respect to sparse labels. 1339 1340 If `class_id` is specified, we calculate precision by considering only the 1341 entries in the batch for which `class_id` is in the top-k highest 1342 `predictions`, and computing the fraction of them for which `class_id` is 1343 indeed a correct label. 1344 If `class_id` is not specified, we'll calculate precision as how often on 1345 average a class among the top-k classes with the highest predicted values 1346 of a batch entry is correct and can be found in the label for that entry. 1347 1348 `streaming_sparse_precision_at_k` creates two local variables, 1349 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1350 the precision@k frequency. This frequency is ultimately returned as 1351 `precision_at_<k>`: an idempotent operation that simply divides 1352 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1353 `false_positive_at_<k>`). To facilitate the estimation of 1354 precision@k over a stream of data, the function utilizes three 1355 steps. 1356 * A `top_k` operation computes a tensor whose elements indicate the top `k` 1357 predictions of the `predictions` `Tensor`. 1358 * Set operations are applied to `top_k` and `labels` to calculate true 1359 positives and false positives. 1360 * An `update_op` operation increments `true_positive_at_<k>` and 1361 `false_positive_at_<k>`. It also returns the precision value. 1362 1363 Args: 1364 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1365 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1366 The final dimension contains the logit values for each class. [D1, ... DN] 1367 must match `labels`. 1368 labels: `int64` `Tensor` or `SparseTensor` with shape 1369 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1370 target classes for the associated prediction. Commonly, N=1 and `labels` 1371 has shape [batch_size, num_labels]. [D1, ... DN] must match 1372 `predictions_idx`. Values should be in range [0, num_classes], where 1373 num_classes is the last dimension of `predictions`. 1374 k: Integer, k for @k metric. 1375 class_id: Integer class ID for which we want binary metrics. This should be 1376 in range [0, num_classes], where num_classes is the last dimension of 1377 `predictions`. 1378 ignore_mask: An optional, binary tensor whose shape is broadcastable to the 1379 the first [D1, ... DN] dimensions of `predictions_idx` and `labels`. 1380 metrics_collections: An optional list of collections that values should 1381 be added to. 1382 updates_collections: An optional list of collections that updates should 1383 be added to. 1384 name: Name of new update operation, and namespace for other dependant ops. 1385 1386 Returns: 1387 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1388 divided by the sum of `true_positives` and `false_positives`. 1389 update_op: `Operation` that increments `true_positives` and 1390 `false_positives` variables appropriately, and whose value matches 1391 `precision`. 1392 """ 1393 default_name = 'precision_at_%d' % k 1394 if class_id is not None: 1395 default_name = '%s_class%d' % (default_name, class_id) 1396 with ops.name_scope(name, default_name, [predictions, labels]) as scope: 1397 _, top_k_idx = nn.top_k(predictions, k) 1398 top_k_idx = math_ops.to_int64(top_k_idx) 1399 tp, tp_update = _streaming_sparse_true_positive_at_k( 1400 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1401 ignore_mask=ignore_mask) 1402 fp, fp_update = _streaming_sparse_false_positive_at_k( 1403 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1404 ignore_mask=ignore_mask) 1405 1406 metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) 1407 update = math_ops.div( 1408 tp_update, math_ops.add(tp_update, fp_update), name='update') 1409 if metrics_collections: 1410 ops.add_to_collections(metrics_collections, metric) 1411 if updates_collections: 1412 ops.add_to_collections(updates_collections, update) 1413 return metric, update 1414 1415 1416def _select_class_id(ids, selected_id): 1417 """Filter all but `selected_id` out of `ids`. 1418 1419 Args: 1420 ids: `int64` `Tensor` or `SparseTensor` of IDs. 1421 selected_id: Int id to select. 1422 1423 Returns: 1424 `SparseTensor` of same dimensions as `ids`, except for the last dimension, 1425 which might be smaller. This contains only the entries equal to 1426 `selected_id`. 1427 """ 1428 if isinstance(ids, ops.SparseTensor): 1429 return sparse_ops.sparse_retain( 1430 ids, math_ops.equal(ids.values, selected_id)) 1431 1432 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 1433 # tf.equal and tf.reduce_any? 1434 1435 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 1436 ids_shape = array_ops.shape(ids) 1437 ids_last_dim = array_ops.size(ids_shape) - 1 1438 filled_selected_id_shape = math_ops.reduced_shape( 1439 ids_shape, array_ops.reshape(ids_last_dim, [1])) 1440 1441 # Intersect `ids` with the selected ID. 1442 filled_selected_id = array_ops.fill( 1443 filled_selected_id_shape, math_ops.to_int64(selected_id)) 1444 return set_ops.set_intersection(filled_selected_id, ids) 1445 1446 1447def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 1448 """If class ID is specified, filter all other classes. 1449 1450 Args: 1451 labels: `int64` `Tensor` or `SparseTensor` with shape 1452 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1453 target classes for the associated prediction. Commonly, N=1 and `labels` 1454 has shape [batch_size, num_labels]. [D1, ... DN] must match 1455 `predictions_idx`. 1456 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 1457 where N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 1458 selected_id: Int id to select. 1459 1460 Returns: 1461 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 1462 """ 1463 if selected_id is None: 1464 return labels, predictions_idx 1465 return (_select_class_id(labels, selected_id), 1466 _select_class_id(predictions_idx, selected_id)) 1467 1468 1469def _streaming_sparse_true_positive_at_k(predictions_idx, 1470 labels, 1471 k, 1472 class_id=None, 1473 ignore_mask=None, 1474 name=None): 1475 """Calculates per step true positives for recall@k and precision@k. 1476 1477 If `class_id` is specified, calculate binary true positives for `class_id` 1478 only. 1479 If `class_id` is not specified, calculate metrics for `k` predicted vs 1480 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1481 1482 Args: 1483 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1484 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1485 match `labels`. 1486 labels: `int64` `Tensor` or `SparseTensor` with shape 1487 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1488 target classes for the associated prediction. Commonly, N=1 and `labels` 1489 has shape [batch_size, num_labels]. [D1, ... DN] must match 1490 `predictions_idx`. 1491 k: Integer, k for @k metric. This is only used for default op name. 1492 class_id: Class for which we want binary metrics. 1493 ignore_mask: An optional, binary tensor whose shape is broadcastable to the 1494 the first [D1, ... DN] dimensions of `predictions_idx` and `labels`. 1495 name: Name of new variable, and namespace for other dependant ops. 1496 1497 Returns: 1498 A tuple of `Variable` and update `Operation`. 1499 """ 1500 default_name = 'true_positive_at_%d' % k 1501 if class_id is not None: 1502 default_name = '%s_class%d' % (default_name, class_id) 1503 with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope: 1504 labels, predictions_idx = _maybe_select_class_id(labels, 1505 predictions_idx, 1506 class_id) 1507 tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels)) 1508 if ignore_mask is not None: 1509 tp = math_ops.select(ignore_mask, array_ops.zeros_like(tp), tp) 1510 batch_total_tp = math_ops.cast( 1511 math_ops.reduce_sum(tp), dtype=dtypes.float64) 1512 1513 var = contrib_variables.local_variable( 1514 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1515 return var, state_ops.assign_add(var, batch_total_tp, name='update') 1516 1517 1518def _streaming_sparse_false_positive_at_k(predictions_idx, 1519 labels, 1520 k, 1521 class_id=None, 1522 ignore_mask=None, 1523 name=None): 1524 """Calculates per step false positives for precision@k. 1525 1526 If `class_id` is specified, calculate binary true positives for `class_id` 1527 only. 1528 If `class_id` is not specified, calculate metrics for `k` predicted vs 1529 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1530 1531 Args: 1532 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1533 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1534 match `labels`. 1535 labels: `int64` `Tensor` or `SparseTensor` with shape 1536 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1537 target classes for the associated prediction. Commonly, N=1 and `labels` 1538 has shape [batch_size, num_labels]. [D1, ... DN] must match 1539 `predictions_idx`. 1540 k: Integer, k for @k metric. This is only used for default op name. 1541 class_id: Class for which we want binary metrics. 1542 ignore_mask: An optional, binary tensor whose shape is broadcastable to the 1543 the first [D1, ... DN] dimensions of `predictions_idx` and `labels`. 1544 name: Name of new variable, and namespace for other dependant ops. 1545 1546 Returns: 1547 A tuple of `Variable` and update `Operation`. 1548 """ 1549 default_name = 'false_positive_at_%d' % k 1550 if class_id is not None: 1551 default_name = '%s_class%d' % (default_name, class_id) 1552 with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope: 1553 labels, predictions_idx = _maybe_select_class_id(labels, 1554 predictions_idx, 1555 class_id) 1556 fp = set_ops.set_size(set_ops.set_difference(predictions_idx, 1557 labels, 1558 aminusb=True)) 1559 if ignore_mask is not None: 1560 fp = math_ops.select(ignore_mask, array_ops.zeros_like(fp), fp) 1561 batch_total_fp = math_ops.cast( 1562 math_ops.reduce_sum(fp), dtype=dtypes.float64) 1563 1564 var = contrib_variables.local_variable( 1565 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1566 return var, state_ops.assign_add(var, batch_total_fp, name='update') 1567 1568 1569def _streaming_sparse_false_negative_at_k(predictions_idx, 1570 labels, 1571 k, 1572 class_id=None, 1573 ignore_mask=None, 1574 name=None): 1575 """Calculates per step false negatives for recall@k. 1576 1577 If `class_id` is specified, calculate binary true positives for `class_id` 1578 only. 1579 If `class_id` is not specified, calculate metrics for `k` predicted vs 1580 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1581 1582 Args: 1583 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1584 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1585 match `labels`. 1586 labels: `int64` `Tensor` or `SparseTensor` with shape 1587 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1588 target classes for the associated prediction. Commonly, N=1 and `labels` 1589 has shape [batch_size, num_labels]. [D1, ... DN] must match 1590 `predictions_idx`. 1591 k: Integer, k for @k metric. This is only used for default op name. 1592 class_id: Class for which we want binary metrics. 1593 ignore_mask: An optional, binary tensor whose shape is broadcastable to the 1594 the first [D1, ... DN] dimensions of `predictions_idx` and `labels`. 1595 name: Name of new variable, and namespace for other dependant ops. 1596 1597 Returns: 1598 A tuple of `Variable` and update `Operation`. 1599 """ 1600 default_name = 'false_negative_at_%d' % k 1601 if class_id is not None: 1602 default_name = '%s_class%d' % (default_name, class_id) 1603 with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope: 1604 labels, predictions_idx = _maybe_select_class_id(labels, 1605 predictions_idx, 1606 class_id) 1607 fn = set_ops.set_size(set_ops.set_difference(predictions_idx, 1608 labels, 1609 aminusb=False)) 1610 if ignore_mask is not None: 1611 fn = math_ops.select(ignore_mask, array_ops.zeros_like(fn), fn) 1612 batch_total_fn = math_ops.cast( 1613 math_ops.reduce_sum(fn), dtype=dtypes.float64) 1614 1615 var = contrib_variables.local_variable( 1616 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1617 return var, state_ops.assign_add(var, batch_total_fn, name='update') 1618 1619 1620def streaming_mean_absolute_error(predictions, labels, weights=None, 1621 metrics_collections=None, 1622 updates_collections=None, 1623 name=None): 1624 """Computes the mean absolute error between the labels and predictions. 1625 1626 The `streaming_mean_absolute_error` function creates two local variables, 1627 `total` and `count` that are used to compute the mean absolute error. This 1628 average is ultimately returned as `mean_absolute_error`: an idempotent 1629 operation that simply divides `total` by `count`. To facilitate the estimation 1630 of the mean absolute error over a stream of data, the function utilizes two 1631 operations. First, an `absolute_errors` operation computes the absolute value 1632 of the differences between `predictions` and `labels`. Second, an `update_op` 1633 operation whose behavior is dependent on the value of `weights`. If `weights` 1634 is None, then `update_op` increments `total` with the reduced sum of 1635 `absolute_errors` and increments `count` with the number of elements in 1636 `absolute_errors`. If `weights` is not `None`, then `update_op` increments 1637 `total` with the reduced sum of the product of `weights` and `absolute_errors` 1638 and increments `count` with the reduced sum of `weights`. In addition to 1639 performing the updates, `update_op` also returns the `mean_absolute_error` 1640 value. 1641 1642 Args: 1643 predictions: A `Tensor` of arbitrary shape. 1644 labels: A `Tensor` of the same shape as `predictions`. 1645 weights: An optional set of weights of the same shape as `predictions`. If 1646 `weights` is not None, the function computes a weighted mean. 1647 metrics_collections: An optional list of collections that 1648 `mean_absolute_error` should be added to. 1649 updates_collections: An optional list of collections that `update_op` should 1650 be added to. 1651 name: An optional variable_scope name. 1652 1653 Returns: 1654 mean_absolute_error: A tensor representing the current mean, the value of 1655 `total` divided by `count`. 1656 update_op: An operation that increments the `total` and `count` variables 1657 appropriately and whose value matches `mean_absolute_error`. 1658 1659 Raises: 1660 ValueError: If `weights` is not `None` and its shape doesn't match 1661 `predictions` or if either `metrics_collections` or `updates_collections` 1662 are not a list or tuple. 1663 """ 1664 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1665 predictions, labels) 1666 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1667 absolute_errors = math_ops.abs(predictions - labels) 1668 return streaming_mean(absolute_errors, weights, metrics_collections, 1669 updates_collections, name or 'mean_absolute_error') 1670 1671 1672def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, 1673 metrics_collections=None, 1674 updates_collections=None, 1675 name=None): 1676 """Computes the mean relative error by normalizing with the given values. 1677 1678 The `streaming_mean_relative_error` function creates two local variables, 1679 `total` and `count` that are used to compute the mean relative absolute error. 1680 This average is ultimately returned as `mean_relative_error`: an idempotent 1681 operation that simply divides `total` by `count`. To facilitate the estimation 1682 of the mean relative error over a stream of data, the function utilizes two 1683 operations. First, a `relative_errors` operation divides the absolute value 1684 of the differences between `predictions` and `labels` by the `normalizer`. 1685 Second, an `update_op` operation whose behavior is dependent on the value of 1686 `weights`. If `weights` is None, then `update_op` increments `total` with the 1687 reduced sum of `relative_errors` and increments `count` with the number of 1688 elements in `relative_errors`. If `weights` is not `None`, then `update_op` 1689 increments `total` with the reduced sum of the product of `weights` and 1690 `relative_errors` and increments `count` with the reduced sum of `weights`. In 1691 addition to performing the updates, `update_op` also returns the 1692 `mean_relative_error` value. 1693 1694 Args: 1695 predictions: A `Tensor` of arbitrary shape. 1696 labels: A `Tensor` of the same shape as `predictions`. 1697 normalizer: A `Tensor` of the same shape as `predictions`. 1698 weights: An optional set of weights of the same shape as `predictions`. If 1699 `weights` is not None, the function computes a weighted mean. 1700 metrics_collections: An optional list of collections that 1701 `mean_relative_error` should be added to. 1702 updates_collections: An optional list of collections that `update_op` should 1703 be added to. 1704 name: An optional variable_scope name. 1705 1706 Returns: 1707 mean_relative_error: A tensor representing the current mean, the value of 1708 `total` divided by `count`. 1709 update_op: An operation that increments the `total` and `count` variables 1710 appropriately and whose value matches `mean_relative_error`. 1711 1712 Raises: 1713 ValueError: If `weights` is not `None` and its shape doesn't match 1714 `predictions` or if either `metrics_collections` or `updates_collections` 1715 are not a list or tuple. 1716 """ 1717 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1718 predictions, labels) 1719 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1720 1721 predictions, normalizer = metric_ops_util.remove_squeezable_dimensions( 1722 predictions, normalizer) 1723 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1724 relative_errors = math_ops.select( 1725 math_ops.equal(normalizer, 0.0), 1726 array_ops.zeros_like(labels), 1727 math_ops.div(math_ops.abs(labels - predictions), normalizer)) 1728 return streaming_mean(relative_errors, weights, metrics_collections, 1729 updates_collections, name or 'mean_relative_error') 1730 1731 1732def streaming_mean_squared_error(predictions, labels, weights=None, 1733 metrics_collections=None, 1734 updates_collections=None, 1735 name=None): 1736 """Computes the mean squared error between the labels and predictions. 1737 1738 The `streaming_mean_squared_error` function creates two local variables, 1739 `total` and `count` that are used to compute the mean squared error. 1740 This average is ultimately returned as `mean_squared_error`: an idempotent 1741 operation that simply divides `total` by `count`. To facilitate the estimation 1742 of the mean squared error over a stream of data, the function utilizes two 1743 operations. First, a `squared_error` operation computes the element-wise 1744 square of the difference between `predictions` and `labels`. Second, an 1745 `update_op` operation whose behavior is dependent on the value of `weights`. 1746 If `weights` is None, then `update_op` increments `total` with the 1747 reduced sum of `squared_error` and increments `count` with the number of 1748 elements in `squared_error`. If `weights` is not `None`, then `update_op` 1749 increments `total` with the reduced sum of the product of `weights` and 1750 `squared_error` and increments `count` with the reduced sum of `weights`. In 1751 addition to performing the updates, `update_op` also returns the 1752 `mean_squared_error` value. 1753 1754 Args: 1755 predictions: A `Tensor` of arbitrary shape. 1756 labels: A `Tensor` of the same shape as `predictions`. 1757 weights: An optional set of weights of the same shape as `predictions`. If 1758 `weights` is not None, the function computes a weighted mean. 1759 metrics_collections: An optional list of collections that 1760 `mean_squared_error` should be added to. 1761 updates_collections: An optional list of collections that `update_op` should 1762 be added to. 1763 name: An optional variable_scope name. 1764 1765 Returns: 1766 mean_squared_error: A tensor representing the current mean, the value of 1767 `total` divided by `count`. 1768 update_op: An operation that increments the `total` and `count` variables 1769 appropriately and whose value matches `mean_squared_error`. 1770 1771 Raises: 1772 ValueError: If `weights` is not `None` and its shape doesn't match 1773 `predictions` or if either `metrics_collections` or `updates_collections` 1774 are not a list or tuple. 1775 """ 1776 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1777 predictions, labels) 1778 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1779 squared_error = math_ops.square(labels - predictions) 1780 return streaming_mean(squared_error, weights, metrics_collections, 1781 updates_collections, name or 'mean_squared_error') 1782 1783 1784def streaming_root_mean_squared_error(predictions, labels, weights=None, 1785 metrics_collections=None, 1786 updates_collections=None, 1787 name=None): 1788 """Computes the root mean squared error between the labels and predictions. 1789 1790 The `streaming_root_mean_squared_error` function creates two local variables, 1791 `total` and `count` that are used to compute the root mean squared error. 1792 This average is ultimately returned as `root_mean_squared_error`: an 1793 idempotent operation that takes the square root of the division of `total` 1794 by `count`. To facilitate the estimation of the root mean squared error over a 1795 stream of data, the function utilizes two operations. First, a `squared_error` 1796 operation computes the element-wise square of the difference between 1797 `predictions` and `labels`. Second, an `update_op` operation whose behavior is 1798 dependent on the value of `weights`. If `weights` is None, then `update_op` 1799 increments `total` with the reduced sum of `squared_error` and increments 1800 `count` with the number of elements in `squared_error`. If `weights` is not 1801 `None`, then `update_op` increments `total` with the reduced sum of the 1802 product of `weights` and `squared_error` and increments `count` with the 1803 reduced sum of `weights`. In addition to performing the updates, `update_op` 1804 also returns the `root_mean_squared_error` value. 1805 1806 Args: 1807 predictions: A `Tensor` of arbitrary shape. 1808 labels: A `Tensor` of the same shape as `predictions`. 1809 weights: An optional set of weights of the same shape as `predictions`. If 1810 `weights` is not None, the function computes a weighted mean. 1811 metrics_collections: An optional list of collections that 1812 `root_mean_squared_error` should be added to. 1813 updates_collections: An optional list of collections that `update_op` should 1814 be added to. 1815 name: An optional variable_scope name. 1816 1817 Returns: 1818 root_mean_squared_error: A tensor representing the current mean, the value 1819 of `total` divided by `count`. 1820 update_op: An operation that increments the `total` and `count` variables 1821 appropriately and whose value matches `root_mean_squared_error`. 1822 1823 Raises: 1824 ValueError: If `weights` is not `None` and its shape doesn't match 1825 `predictions` or if either `metrics_collections` or `updates_collections` 1826 are not a list or tuple. 1827 """ 1828 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1829 predictions, labels) 1830 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1831 value_tensor, update_op = streaming_mean_squared_error( 1832 predictions, labels, weights, None, None, 1833 name or 'root_mean_squared_error') 1834 1835 root_mean_squared_error = math_ops.sqrt(value_tensor) 1836 with ops.control_dependencies([update_op]): 1837 update_op = math_ops.sqrt(update_op) 1838 1839 if metrics_collections: 1840 ops.add_to_collections(metrics_collections, root_mean_squared_error) 1841 1842 if updates_collections: 1843 ops.add_to_collections(updates_collections, update_op) 1844 1845 return root_mean_squared_error, update_op 1846 1847 1848# TODO(nsilberman): add a 'normalized' flag so that the user can request 1849# normalization if the inputs are not normalized. 1850def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, 1851 metrics_collections=None, 1852 updates_collections=None, 1853 name=None): 1854 """Computes the cosine distance between the labels and predictions. 1855 1856 The `streaming_mean_cosine_distance` function creates two local variables, 1857 `total` and `count` that are used to compute the average cosine distance 1858 between `predictions` and `labels`. This average is ultimately returned as 1859 `mean_distance` which is an idempotent operation that simply divides `total` 1860 by `count. To facilitate the estimation of a mean over multiple batches 1861 of data, the function creates an `update_op` operation whose behavior is 1862 dependent on the value of `weights`. If `weights` is None, then `update_op` 1863 increments `total` with the reduced sum of `values and increments `count` with 1864 the number of elements in `values`. If `weights` is not `None`, then 1865 `update_op` increments `total` with the reduced sum of the product of `values` 1866 and `weights` and increments `count` with the reduced sum of weights. 1867 1868 Args: 1869 predictions: A tensor of the same size as labels. 1870 labels: A tensor of arbitrary size. 1871 dim: The dimension along which the cosine distance is computed. 1872 weights: An optional set of weights which indicates which predictions to 1873 ignore during metric computation. Its size matches that of labels except 1874 for the value of 'dim' which should be 1. For example if labels has 1875 dimensions [32, 100, 200, 3], then `weights` should have dimensions 1876 [32, 100, 200, 1]. 1877 metrics_collections: An optional list of collections that the metric 1878 value variable should be added to. 1879 updates_collections: An optional list of collections that the metric update 1880 ops should be added to. 1881 name: An optional variable_scope name. 1882 1883 Returns: 1884 mean_distance: A tensor representing the current mean, the value of `total` 1885 divided by `count`. 1886 update_op: An operation that increments the `total` and `count` variables 1887 appropriately. 1888 1889 Raises: 1890 ValueError: If labels and predictions are of different sizes or if the 1891 ignore_mask is of the wrong size or if either `metrics_collections` or 1892 `updates_collections` are not a list or tuple. 1893 """ 1894 predictions, labels = metric_ops_util.remove_squeezable_dimensions( 1895 predictions, labels) 1896 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1897 radial_diffs = math_ops.mul(predictions, labels) 1898 radial_diffs = math_ops.reduce_sum(radial_diffs, 1899 reduction_indices=[dim,], 1900 keep_dims=True) 1901 mean_distance, update_op = streaming_mean(radial_diffs, weights, 1902 None, 1903 None, 1904 name or 'mean_cosine_distance') 1905 mean_distance = math_ops.sub(1.0, mean_distance) 1906 update_op = math_ops.sub(1.0, update_op) 1907 1908 if metrics_collections: 1909 ops.add_to_collections(metrics_collections, mean_distance) 1910 1911 if updates_collections: 1912 ops.add_to_collections(updates_collections, update_op) 1913 1914 return mean_distance, update_op 1915 1916 1917def streaming_percentage_less(values, threshold, ignore_mask=None, 1918 metrics_collections=None, 1919 updates_collections=None, 1920 name=None): 1921 """Computes the percentage of values less than the given threshold. 1922 1923 The `streaming_percentage_less` function creates two local variables, 1924 `total` and `count` that are used to compute the percentage of `values` that 1925 fall below `threshold`. This rate is ultimately returned as `percentage` 1926 which is an idempotent operation that simply divides `total` by `count. 1927 To facilitate the estimation of the percentage of values that fall under 1928 `threshold` over multiple batches of data, the function creates an 1929 `update_op` operation whose behavior is dependent on the value of 1930 `ignore_mask`. If `ignore_mask` is None, then `update_op` 1931 increments `total` with the number of elements of `values` that are less 1932 than `threshold` and `count` with the number of elements in `values`. If 1933 `ignore_mask` is not `None`, then `update_op` increments `total` with the 1934 number of elements of `values` that are less than `threshold` and whose 1935 corresponding entries in `ignore_mask` are False, and `count` is incremented 1936 with the number of elements of `ignore_mask` that are False. 1937 1938 Args: 1939 values: A numeric `Tensor` of arbitrary size. 1940 threshold: A scalar threshold. 1941 ignore_mask: An optional mask of the same shape as 'values' which indicates 1942 which elements to ignore during metric computation. 1943 metrics_collections: An optional list of collections that the metric 1944 value variable should be added to. 1945 updates_collections: An optional list of collections that the metric update 1946 ops should be added to. 1947 name: An optional variable_scope name. 1948 1949 Returns: 1950 percentage: A tensor representing the current mean, the value of `total` 1951 divided by `count`. 1952 update_op: An operation that increments the `total` and `count` variables 1953 appropriately. 1954 1955 Raises: 1956 ValueError: If `ignore_mask` is not None and its shape doesn't match `values 1957 or if either `metrics_collections` or `updates_collections` are supplied 1958 but are not a list or tuple. 1959 """ 1960 is_below_threshold = math_ops.to_float(math_ops.less(values, threshold)) 1961 return streaming_mean(is_below_threshold, _mask_to_weights(ignore_mask), 1962 metrics_collections, updates_collections, 1963 name or 'percentage_below_threshold') 1964 1965 1966def streaming_mean_iou(predictions, 1967 labels, 1968 num_classes, 1969 ignore_mask=None, 1970 metrics_collections=None, 1971 updates_collections=None, 1972 name=None): 1973 """Calculate per-step mean Intersection-Over-Union (mIOU). 1974 1975 Mean Intersection-Over-Union is a common evaluation metric for 1976 semantic image segmentation, which first computes the IOU for each 1977 semantic class and then computes the average over classes. 1978 IOU is defined as follows: 1979 IOU = true_positive / (true_positive + false_positive + false_negative). 1980 The predictions are accumulated in a confusion matrix, and mIOU is then 1981 calculated from it. 1982 1983 Args: 1984 predictions: A tensor of prediction results for semantic labels, whose 1985 shape is [batch size] and type `int32` or `int64`. The tensor will be 1986 flattened, if its rank > 1. 1987 labels: A tensor of ground truth labels with shape [batch size] and of 1988 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 1989 num_classes: The possible number of labels the prediction task can 1990 have. This value must be provided, since a confusion matrix of 1991 dimension = [num_classes, num_classes] will be allocated. 1992 ignore_mask: An optional, boolean tensor whose size matches `labels`. If an 1993 element of `ignore_mask` is True, the corresponding prediction and label 1994 pair is NOT used to compute the metrics. Otherwise, the pair is included. 1995 metrics_collections: An optional list of collections that `mean_iou` 1996 should be added to. 1997 updates_collections: An optional list of collections `update_op` should be 1998 added to. 1999 name: An optional variable_scope name. 2000 2001 Returns: 2002 mean_iou: A tensor representing the mean intersection-over-union. 2003 update_op: An operation that increments the confusion matrix. 2004 2005 Raises: 2006 ValueError: If the dimensions of `predictions` and `labels` don't match or 2007 if `ignore_mask` is not `None` and its shape doesn't match `labels` 2008 or if either `metrics_collections` or `updates_collections` are not a list 2009 or tuple. 2010 """ 2011 with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]): 2012 # Check if shape is compatible. 2013 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2014 if ignore_mask is not None: 2015 labels.get_shape().assert_is_compatible_with(ignore_mask.get_shape()) 2016 2017 # Local variable to accumulate the predictions in the confusion matrix. 2018 total_cm = _create_local('total_confusion_matrix', 2019 shape=[num_classes, num_classes], 2020 dtype=dtypes.int64) 2021 2022 # Cast the type to int64 required by confusion_matrix_ops. 2023 predictions = math_ops.to_int64(predictions) 2024 labels = math_ops.to_int64(labels) 2025 num_classes = math_ops.to_int64(num_classes) 2026 2027 # Flatten the input if its rank > 1. 2028 predictions_rank = predictions.get_shape().ndims 2029 if predictions_rank > 1: 2030 predictions = array_ops.reshape(predictions, [-1]) 2031 2032 labels_rank = labels.get_shape().ndims 2033 if labels_rank > 1: 2034 labels = array_ops.reshape(labels, [-1]) 2035 2036 if ignore_mask is not None: 2037 ignore_mask_rank = ignore_mask.get_shape().ndims 2038 if ignore_mask_rank > 1: 2039 ignore_mask = array_ops.reshape(ignore_mask, [-1]) 2040 2041 check_ops.assert_type(ignore_mask, dtypes.bool) 2042 not_ignore_mask = math_ops.logical_not(ignore_mask) 2043 predictions = array_ops.boolean_mask(predictions, not_ignore_mask) 2044 labels = array_ops.boolean_mask(labels, not_ignore_mask) 2045 2046 # Accumulate the prediction to current confusion matrix. 2047 current_cm = confusion_matrix_ops.confusion_matrix( 2048 predictions, labels, num_classes, dtype=dtypes.int64) 2049 update_op = state_ops.assign_add(total_cm, current_cm) 2050 2051 def compute_mean_iou(name): 2052 """Compute the mean intersection-over-union via the confusion matrix.""" 2053 sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) 2054 sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) 2055 cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) 2056 denominator = sum_over_row + sum_over_col - cm_diag 2057 2058 # If the value of the denominator is 0, set it to 1 to avoid 2059 # zero division. 2060 denominator = math_ops.select( 2061 math_ops.greater(denominator, 0), 2062 denominator, 2063 array_ops.ones_like(denominator)) 2064 iou = math_ops.div(cm_diag, denominator) 2065 return math_ops.reduce_mean(iou, name=name) 2066 2067 mean_iou = compute_mean_iou('mean_iou') 2068 2069 if metrics_collections: 2070 ops.add_to_collections(metrics_collections, mean_iou) 2071 2072 if updates_collections: 2073 ops.add_to_collections(updates_collections, update_op) 2074 2075 return mean_iou, update_op 2076 2077 2078def aggregate_metrics(*value_update_tuples): 2079 """Aggregates the metric value tensors and update ops into two lists. 2080 2081 Args: 2082 *value_update_tuples: a variable number of tuples, each of which contain the 2083 pair of (value_tensor, update_op) from a streaming metric. 2084 2085 Returns: 2086 a list of value tensors and a list of update ops. 2087 2088 Raises: 2089 ValueError: if `value_update_tuples` is empty. 2090 """ 2091 if not value_update_tuples: 2092 raise ValueError('Expected at least one value_tensor/update_op pair') 2093 value_ops, update_ops = zip(*value_update_tuples) 2094 return list(value_ops), list(update_ops) 2095 2096 2097def aggregate_metric_map(names_to_tuples): 2098 """Aggregates the metric names to tuple dictionary. 2099 2100 This function is useful for pairing metric names with their associated value 2101 and update ops when the list of metrics is long. For example: 2102 2103 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 2104 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 2105 predictions, labels, weights), 2106 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 2107 predictions, labels, labels, weights), 2108 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 2109 predictions, labels, weights), 2110 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 2111 predictions, labels, weights), 2112 }) 2113 2114 Args: 2115 names_to_tuples: a map of metric names to tuples, each of which contain the 2116 pair of (value_tensor, update_op) from a streaming metric. 2117 2118 Returns: 2119 A dictionary from metric names to value ops and a dictionary from metric 2120 names to update ops. 2121 """ 2122 metric_names = names_to_tuples.keys() 2123 value_ops, update_ops = zip(*names_to_tuples.values()) 2124 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 2125 2126 2127__all__ = make_all(__name__) 2128