metric_ops.py revision be270ecb79c8548a0caddf67908189e6169b1472
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.framework import deprecated 26from tensorflow.contrib.framework import tensor_util 27from tensorflow.contrib.framework.python.ops import variables as contrib_variables 28from tensorflow.contrib.metrics.python.ops import set_ops 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import metrics 37from tensorflow.python.ops import nn 38from tensorflow.python.ops import sparse_ops 39from tensorflow.python.ops import state_ops 40from tensorflow.python.ops import variable_scope 41from tensorflow.python.ops import variables 42 43 44def _safe_div(numerator, denominator, name): 45 """Divides two values, returning 0 if the denominator is <= 0. 46 47 Args: 48 numerator: A real `Tensor`. 49 denominator: A real `Tensor`, with dtype matching `numerator`. 50 name: Name for the returned op. 51 52 Returns: 53 0 if `denominator` <= 0, else `numerator` / `denominator` 54 """ 55 return array_ops.where( 56 math_ops.greater(denominator, 0), 57 math_ops.truediv(numerator, denominator), 58 0, 59 name=name) 60 61 62def _safe_scalar_div(numerator, denominator, name): 63 """Divides two values, returning 0 if the denominator is 0. 64 65 Args: 66 numerator: A scalar `float64` `Tensor`. 67 denominator: A scalar `float64` `Tensor`. 68 name: Name for the returned op. 69 70 Returns: 71 0 if `denominator` == 0, else `numerator` / `denominator` 72 """ 73 numerator.get_shape().with_rank_at_most(1) 74 denominator.get_shape().with_rank_at_most(1) 75 return control_flow_ops.cond( 76 math_ops.equal( 77 array_ops.constant(0.0, dtype=dtypes.float64), denominator), 78 lambda: array_ops.constant(0.0, dtype=dtypes.float64), 79 lambda: math_ops.div(numerator, denominator), 80 name=name) 81 82 83def _create_local(name, shape, collections=None, validate_shape=True, 84 dtype=dtypes.float32): 85 """Creates a new local variable. 86 87 Args: 88 name: The name of the new or existing variable. 89 shape: Shape of the new or existing variable. 90 collections: A list of collection names to which the Variable will be added. 91 validate_shape: Whether to validate the shape of the variable. 92 dtype: Data type of the variables. 93 94 Returns: 95 The created variable. 96 """ 97 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 98 collections = list(collections or []) 99 collections += [ops.GraphKeys.LOCAL_VARIABLES] 100 return variables.Variable( 101 initial_value=array_ops.zeros(shape, dtype=dtype), 102 name=name, 103 trainable=False, 104 collections=collections, 105 validate_shape=validate_shape) 106 107 108# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. 109def _assert_weights_rank(weights, values): 110 """`weights` rank must be either `0`, or the same as 'values'.""" 111 return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) 112 113 114def _count_condition(values, weights=None, metrics_collections=None, 115 updates_collections=None): 116 """Sums the weights of cases where the given values are True. 117 118 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 119 120 Args: 121 values: A `bool` `Tensor` of arbitrary size. 122 weights: Optional `Tensor` whose rank is either 0, or the same rank as 123 `values`, and must be broadcastable to `values` (i.e., all dimensions 124 must be either `1`, or the same as the corresponding `values` 125 dimension). 126 metrics_collections: An optional list of collections that the metric 127 value variable should be added to. 128 updates_collections: An optional list of collections that the metric update 129 ops should be added to. 130 131 Returns: 132 value_tensor: A `Tensor` representing the current value of the metric. 133 update_op: An operation that accumulates the error from a batch of data. 134 135 Raises: 136 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 137 or if either `metrics_collections` or `updates_collections` are not a list 138 or tuple. 139 """ 140 check_ops.assert_type(values, dtypes.bool) 141 count = _create_local('count', shape=[]) 142 143 values = math_ops.to_float(values) 144 if weights is not None: 145 weights = math_ops.to_float(weights) 146 with ops.control_dependencies((_assert_weights_rank(weights, values),)): 147 values = math_ops.multiply(values, weights) 148 149 value_tensor = array_ops.identity(count) 150 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 151 152 if metrics_collections: 153 ops.add_to_collections(metrics_collections, value_tensor) 154 155 if updates_collections: 156 ops.add_to_collections(updates_collections, update_op) 157 158 return value_tensor, update_op 159 160 161def streaming_true_positives(predictions, labels, weights=None, 162 metrics_collections=None, 163 updates_collections=None, 164 name=None): 165 """Sum the weights of true_positives. 166 167 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 168 169 Args: 170 predictions: The predicted values, a `bool` `Tensor` of arbitrary 171 dimensions. 172 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 173 match `predictions`. 174 weights: Optional `Tensor` whose rank is either 0, or the same rank as 175 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 176 must be either `1`, or the same as the corresponding `labels` 177 dimension). 178 metrics_collections: An optional list of collections that the metric 179 value variable should be added to. 180 updates_collections: An optional list of collections that the metric update 181 ops should be added to. 182 name: An optional variable_scope name. 183 184 Returns: 185 value_tensor: A `Tensor` representing the current value of the metric. 186 update_op: An operation that accumulates the error from a batch of data. 187 188 Raises: 189 ValueError: If `predictions` and `labels` have mismatched shapes, or if 190 `weights` is not `None` and its shape doesn't match `predictions`, or if 191 either `metrics_collections` or `updates_collections` are not a list or 192 tuple. 193 """ 194 return metrics.true_positives( 195 predictions=predictions, labels=labels, weights=weights, 196 metrics_collections=metrics_collections, 197 updates_collections=updates_collections, name=name) 198 199 200def streaming_true_negatives(predictions, labels, weights=None, 201 metrics_collections=None, 202 updates_collections=None, 203 name=None): 204 """Sum the weights of true_negatives. 205 206 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 207 208 Args: 209 predictions: The predicted values, a `bool` `Tensor` of arbitrary 210 dimensions. 211 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 212 match `predictions`. 213 weights: Optional `Tensor` whose rank is either 0, or the same rank as 214 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 215 must be either `1`, or the same as the corresponding `labels` 216 dimension). 217 metrics_collections: An optional list of collections that the metric 218 value variable should be added to. 219 updates_collections: An optional list of collections that the metric update 220 ops should be added to. 221 name: An optional variable_scope name. 222 223 Returns: 224 value_tensor: A `Tensor` representing the current value of the metric. 225 update_op: An operation that accumulates the error from a batch of data. 226 227 Raises: 228 ValueError: If `predictions` and `labels` have mismatched shapes, or if 229 `weights` is not `None` and its shape doesn't match `predictions`, or if 230 either `metrics_collections` or `updates_collections` are not a list or 231 tuple. 232 """ 233 with variable_scope.variable_scope( 234 name, 'true_negatives', (predictions, labels, weights)): 235 236 predictions = ops.convert_to_tensor(predictions) 237 labels = ops.convert_to_tensor(labels) 238 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 239 is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0), 240 math_ops.equal(predictions, 0)) 241 return _count_condition(is_true_negative, weights, metrics_collections, 242 updates_collections) 243 244 245def streaming_false_positives(predictions, labels, weights=None, 246 metrics_collections=None, 247 updates_collections=None, 248 name=None): 249 """Sum the weights of false positives. 250 251 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 252 253 Args: 254 predictions: The predicted values, a `bool` `Tensor` of arbitrary 255 dimensions. 256 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 257 match `predictions`. 258 weights: Optional `Tensor` whose rank is either 0, or the same rank as 259 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 260 must be either `1`, or the same as the corresponding `labels` 261 dimension). 262 metrics_collections: An optional list of collections that the metric 263 value variable should be added to. 264 updates_collections: An optional list of collections that the metric update 265 ops should be added to. 266 name: An optional variable_scope name. 267 268 Returns: 269 value_tensor: A `Tensor` representing the current value of the metric. 270 update_op: An operation that accumulates the error from a batch of data. 271 272 Raises: 273 ValueError: If `predictions` and `labels` have mismatched shapes, or if 274 `weights` is not `None` and its shape doesn't match `predictions`, or if 275 either `metrics_collections` or `updates_collections` are not a list or 276 tuple. 277 """ 278 return metrics.false_positives( 279 predictions=predictions, labels=labels, weights=weights, 280 metrics_collections=metrics_collections, 281 updates_collections=updates_collections, name=name) 282 283 284def streaming_false_negatives(predictions, labels, weights=None, 285 metrics_collections=None, 286 updates_collections=None, 287 name=None): 288 """Computes the total number of false positives. 289 290 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 291 292 Args: 293 predictions: The predicted values, a `bool` `Tensor` of arbitrary 294 dimensions. 295 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 296 match `predictions`. 297 weights: Optional `Tensor` whose rank is either 0, or the same rank as 298 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 299 must be either `1`, or the same as the corresponding `labels` 300 dimension). 301 metrics_collections: An optional list of collections that the metric 302 value variable should be added to. 303 updates_collections: An optional list of collections that the metric update 304 ops should be added to. 305 name: An optional variable_scope name. 306 307 Returns: 308 value_tensor: A `Tensor` representing the current value of the metric. 309 update_op: An operation that accumulates the error from a batch of data. 310 311 Raises: 312 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 313 or if either `metrics_collections` or `updates_collections` are not a list 314 or tuple. 315 """ 316 return metrics.false_negatives( 317 predictions=predictions, labels=labels, weights=weights, 318 metrics_collections=metrics_collections, 319 updates_collections=updates_collections, name=name) 320 321 322# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. 323def _broadcast_weights(weights, values): 324 """Broadcast `weights` to the same shape as `values`. 325 326 This returns a version of `weights` following the same broadcast rules as 327 `mul(weights, values)`. When computing a weighted average, use this function 328 to broadcast `weights` before summing them; e.g., 329 `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 330 331 Args: 332 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 333 must be broadcastable to `values` (i.e., all dimensions must be either 334 `1`, or the same as the corresponding `values` dimension). 335 values: `Tensor` of any shape. 336 337 Returns: 338 `weights` broadcast to `values` shape. 339 """ 340 with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: 341 weights_shape = weights.get_shape() 342 values_shape = values.get_shape() 343 if (weights_shape.is_fully_defined() and 344 values_shape.is_fully_defined() and 345 weights_shape.is_compatible_with(values_shape)): 346 return weights 347 with ops.control_dependencies((_assert_weights_rank(weights, values),)): 348 return math_ops.multiply( 349 weights, array_ops.ones_like(values), name=scope) 350 351 352def streaming_mean(values, weights=None, metrics_collections=None, 353 updates_collections=None, name=None): 354 """Computes the (weighted) mean of the given values. 355 356 The `streaming_mean` function creates two local variables, `total` and `count` 357 that are used to compute the average of `values`. This average is ultimately 358 returned as `mean` which is an idempotent operation that simply divides 359 `total` by `count`. 360 361 For estimation of the metric over a stream of data, the function creates an 362 `update_op` operation that updates these variables and returns the `mean`. 363 `update_op` increments `total` with the reduced sum of the product of `values` 364 and `weights`, and it increments `count` with the reduced sum of `weights`. 365 366 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 367 368 Args: 369 values: A `Tensor` of arbitrary dimensions. 370 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 371 must be broadcastable to `values` (i.e., all dimensions must be either 372 `1`, or the same as the corresponding `values` dimension). 373 metrics_collections: An optional list of collections that `mean` 374 should be added to. 375 updates_collections: An optional list of collections that `update_op` 376 should be added to. 377 name: An optional variable_scope name. 378 379 Returns: 380 mean: A `Tensor` representing the current mean, the value of `total` divided 381 by `count`. 382 update_op: An operation that increments the `total` and `count` variables 383 appropriately and whose value matches `mean_value`. 384 385 Raises: 386 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 387 or if either `metrics_collections` or `updates_collections` are not a list 388 or tuple. 389 """ 390 return metrics.mean( 391 values=values, weights=weights, metrics_collections=metrics_collections, 392 updates_collections=updates_collections, name=name) 393 394 395def streaming_mean_tensor(values, weights=None, metrics_collections=None, 396 updates_collections=None, name=None): 397 """Computes the element-wise (weighted) mean of the given tensors. 398 399 In contrast to the `streaming_mean` function which returns a scalar with the 400 mean, this function returns an average tensor with the same shape as the 401 input tensors. 402 403 The `streaming_mean_tensor` function creates two local variables, 404 `total_tensor` and `count_tensor` that are used to compute the average of 405 `values`. This average is ultimately returned as `mean` which is an idempotent 406 operation that simply divides `total` by `count`. 407 408 For estimation of the metric over a stream of data, the function creates an 409 `update_op` operation that updates these variables and returns the `mean`. 410 `update_op` increments `total` with the reduced sum of the product of `values` 411 and `weights`, and it increments `count` with the reduced sum of `weights`. 412 413 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 414 415 Args: 416 values: A `Tensor` of arbitrary dimensions. 417 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 418 must be broadcastable to `values` (i.e., all dimensions must be either 419 `1`, or the same as the corresponding `values` dimension). 420 metrics_collections: An optional list of collections that `mean` 421 should be added to. 422 updates_collections: An optional list of collections that `update_op` 423 should be added to. 424 name: An optional variable_scope name. 425 426 Returns: 427 mean: A float `Tensor` representing the current mean, the value of `total` 428 divided by `count`. 429 update_op: An operation that increments the `total` and `count` variables 430 appropriately and whose value matches `mean_value`. 431 432 Raises: 433 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 434 or if either `metrics_collections` or `updates_collections` are not a list 435 or tuple. 436 """ 437 return metrics.mean_tensor( 438 values=values, weights=weights, metrics_collections=metrics_collections, 439 updates_collections=updates_collections, name=name) 440 441 442def streaming_accuracy(predictions, labels, weights=None, 443 metrics_collections=None, updates_collections=None, 444 name=None): 445 """Calculates how often `predictions` matches `labels`. 446 447 The `streaming_accuracy` function creates two local variables, `total` and 448 `count` that are used to compute the frequency with which `predictions` 449 matches `labels`. This frequency is ultimately returned as `accuracy`: an 450 idempotent operation that simply divides `total` by `count`. 451 452 For estimation of the metric over a stream of data, the function creates an 453 `update_op` operation that updates these variables and returns the `accuracy`. 454 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 455 where the corresponding elements of `predictions` and `labels` match and 0.0 456 otherwise. Then `update_op` increments `total` with the reduced sum of the 457 product of `weights` and `is_correct`, and it increments `count` with the 458 reduced sum of `weights`. 459 460 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 461 462 Args: 463 predictions: The predicted values, a `Tensor` of any shape. 464 labels: The ground truth values, a `Tensor` whose shape matches 465 `predictions`. 466 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 467 must be broadcastable to `labels` (i.e., all dimensions must be either 468 `1`, or the same as the corresponding `labels` dimension). 469 metrics_collections: An optional list of collections that `accuracy` should 470 be added to. 471 updates_collections: An optional list of collections that `update_op` should 472 be added to. 473 name: An optional variable_scope name. 474 475 Returns: 476 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 477 by `count`. 478 update_op: An operation that increments the `total` and `count` variables 479 appropriately and whose value matches `accuracy`. 480 481 Raises: 482 ValueError: If `predictions` and `labels` have mismatched shapes, or if 483 `weights` is not `None` and its shape doesn't match `predictions`, or if 484 either `metrics_collections` or `updates_collections` are not a list or 485 tuple. 486 """ 487 return metrics.accuracy( 488 predictions=predictions, labels=labels, weights=weights, 489 metrics_collections=metrics_collections, 490 updates_collections=updates_collections, name=name) 491 492 493def streaming_precision(predictions, labels, weights=None, 494 metrics_collections=None, updates_collections=None, 495 name=None): 496 """Computes the precision of the predictions with respect to the labels. 497 498 The `streaming_precision` function creates two local variables, 499 `true_positives` and `false_positives`, that are used to compute the 500 precision. This value is ultimately returned as `precision`, an idempotent 501 operation that simply divides `true_positives` by the sum of `true_positives` 502 and `false_positives`. 503 504 For estimation of the metric over a stream of data, the function creates an 505 `update_op` operation that updates these variables and returns the 506 `precision`. `update_op` weights each prediction by the corresponding value in 507 `weights`. 508 509 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 510 511 Args: 512 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 513 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 514 match `predictions`. 515 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 516 must be broadcastable to `labels` (i.e., all dimensions must be either 517 `1`, or the same as the corresponding `labels` dimension). 518 metrics_collections: An optional list of collections that `precision` should 519 be added to. 520 updates_collections: An optional list of collections that `update_op` should 521 be added to. 522 name: An optional variable_scope name. 523 524 Returns: 525 precision: Scalar float `Tensor` with the value of `true_positives` 526 divided by the sum of `true_positives` and `false_positives`. 527 update_op: `Operation` that increments `true_positives` and 528 `false_positives` variables appropriately and whose value matches 529 `precision`. 530 531 Raises: 532 ValueError: If `predictions` and `labels` have mismatched shapes, or if 533 `weights` is not `None` and its shape doesn't match `predictions`, or if 534 either `metrics_collections` or `updates_collections` are not a list or 535 tuple. 536 """ 537 return metrics.precision( 538 predictions=predictions, labels=labels, weights=weights, 539 metrics_collections=metrics_collections, 540 updates_collections=updates_collections, name=name) 541 542 543def streaming_recall(predictions, labels, weights=None, 544 metrics_collections=None, updates_collections=None, 545 name=None): 546 """Computes the recall of the predictions with respect to the labels. 547 548 The `streaming_recall` function creates two local variables, `true_positives` 549 and `false_negatives`, that are used to compute the recall. This value is 550 ultimately returned as `recall`, an idempotent operation that simply divides 551 `true_positives` by the sum of `true_positives` and `false_negatives`. 552 553 For estimation of the metric over a stream of data, the function creates an 554 `update_op` that updates these variables and returns the `recall`. `update_op` 555 weights each prediction by the corresponding value in `weights`. 556 557 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 558 559 Args: 560 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 561 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 562 match `predictions`. 563 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 564 must be broadcastable to `labels` (i.e., all dimensions must be either 565 `1`, or the same as the corresponding `labels` dimension). 566 metrics_collections: An optional list of collections that `recall` should 567 be added to. 568 updates_collections: An optional list of collections that `update_op` should 569 be added to. 570 name: An optional variable_scope name. 571 572 Returns: 573 recall: Scalar float `Tensor` with the value of `true_positives` divided 574 by the sum of `true_positives` and `false_negatives`. 575 update_op: `Operation` that increments `true_positives` and 576 `false_negatives` variables appropriately and whose value matches 577 `recall`. 578 579 Raises: 580 ValueError: If `predictions` and `labels` have mismatched shapes, or if 581 `weights` is not `None` and its shape doesn't match `predictions`, or if 582 either `metrics_collections` or `updates_collections` are not a list or 583 tuple. 584 """ 585 return metrics.recall( 586 predictions=predictions, labels=labels, weights=weights, 587 metrics_collections=metrics_collections, 588 updates_collections=updates_collections, name=name) 589 590 591def _streaming_confusion_matrix_at_thresholds( 592 predictions, labels, thresholds, weights=None, includes=None): 593 """Computes true_positives, false_negatives, true_negatives, false_positives. 594 595 This function creates up to four local variables, `true_positives`, 596 `true_negatives`, `false_positives` and `false_negatives`. 597 `true_positive[i]` is defined as the total weight of values in `predictions` 598 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 599 `false_negatives[i]` is defined as the total weight of values in `predictions` 600 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 601 `true_negatives[i]` is defined as the total weight of values in `predictions` 602 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 603 `false_positives[i]` is defined as the total weight of values in `predictions` 604 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 605 606 For estimation of these metrics over a stream of data, for each metric the 607 function respectively creates an `update_op` operation that updates the 608 variable and returns its value. 609 610 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 611 612 Args: 613 predictions: A floating point `Tensor` of arbitrary shape and whose values 614 are in the range `[0, 1]`. 615 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 616 to `bool`. 617 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 618 weights: Optional `Tensor` whose rank is either 0, or the same rank as 619 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 620 must be either `1`, or the same as the corresponding `labels` 621 dimension). 622 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 623 default to all four. 624 625 Returns: 626 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 627 `includes`. 628 update_ops: Dict of operations that increments the `values`. Keys are from 629 `includes`. 630 631 Raises: 632 ValueError: If `predictions` and `labels` have mismatched shapes, or if 633 `weights` is not `None` and its shape doesn't match `predictions`, or if 634 `includes` contains invalid keys. 635 """ 636 all_includes = ('tp', 'fn', 'tn', 'fp') 637 if includes is None: 638 includes = all_includes 639 else: 640 for include in includes: 641 if include not in all_includes: 642 raise ValueError('Invaild key: %s.' % include) 643 644 predictions, labels, weights = _remove_squeezable_dimensions( 645 predictions, labels, weights) 646 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 647 648 num_thresholds = len(thresholds) 649 650 # Reshape predictions and labels. 651 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 652 labels_2d = array_ops.reshape( 653 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 654 655 # Use static shape if known. 656 num_predictions = predictions_2d.get_shape().as_list()[0] 657 658 # Otherwise use dynamic shape. 659 if num_predictions is None: 660 num_predictions = array_ops.shape(predictions_2d)[0] 661 thresh_tiled = array_ops.tile( 662 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 663 array_ops.stack([1, num_predictions])) 664 665 # Tile the predictions after thresholding them across different thresholds. 666 pred_is_pos = math_ops.greater( 667 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 668 thresh_tiled) 669 if ('fn' in includes) or ('tn' in includes): 670 pred_is_neg = math_ops.logical_not(pred_is_pos) 671 672 # Tile labels by number of thresholds 673 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 674 if ('fp' in includes) or ('tn' in includes): 675 label_is_neg = math_ops.logical_not(label_is_pos) 676 677 if weights is not None: 678 broadcast_weights = _broadcast_weights( 679 math_ops.to_float(weights), predictions) 680 weights_tiled = array_ops.tile(array_ops.reshape( 681 broadcast_weights, [1, -1]), [num_thresholds, 1]) 682 thresh_tiled.get_shape().assert_is_compatible_with( 683 weights_tiled.get_shape()) 684 else: 685 weights_tiled = None 686 687 values = {} 688 update_ops = {} 689 690 if 'tp' in includes: 691 true_positives = _create_local('true_positives', shape=[num_thresholds]) 692 is_true_positive = math_ops.to_float( 693 math_ops.logical_and(label_is_pos, pred_is_pos)) 694 if weights_tiled is not None: 695 is_true_positive *= weights_tiled 696 update_ops['tp'] = state_ops.assign_add( 697 true_positives, math_ops.reduce_sum(is_true_positive, 1)) 698 values['tp'] = true_positives 699 700 if 'fn' in includes: 701 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 702 is_false_negative = math_ops.to_float( 703 math_ops.logical_and(label_is_pos, pred_is_neg)) 704 if weights_tiled is not None: 705 is_false_negative *= weights_tiled 706 update_ops['fn'] = state_ops.assign_add( 707 false_negatives, math_ops.reduce_sum(is_false_negative, 1)) 708 values['fn'] = false_negatives 709 710 if 'tn' in includes: 711 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 712 is_true_negative = math_ops.to_float( 713 math_ops.logical_and(label_is_neg, pred_is_neg)) 714 if weights_tiled is not None: 715 is_true_negative *= weights_tiled 716 update_ops['tn'] = state_ops.assign_add( 717 true_negatives, math_ops.reduce_sum(is_true_negative, 1)) 718 values['tn'] = true_negatives 719 720 if 'fp' in includes: 721 false_positives = _create_local('false_positives', shape=[num_thresholds]) 722 is_false_positive = math_ops.to_float( 723 math_ops.logical_and(label_is_neg, pred_is_pos)) 724 if weights_tiled is not None: 725 is_false_positive *= weights_tiled 726 update_ops['fp'] = state_ops.assign_add( 727 false_positives, math_ops.reduce_sum(is_false_positive, 1)) 728 values['fp'] = false_positives 729 730 return values, update_ops 731 732 733def streaming_true_positives_at_thresholds( 734 predictions, labels, thresholds, weights=None): 735 values, update_ops = _streaming_confusion_matrix_at_thresholds( 736 predictions, labels, thresholds, weights=weights, includes=('tp',)) 737 return values['tp'], update_ops['tp'] 738 739 740def streaming_false_negatives_at_thresholds( 741 predictions, labels, thresholds, weights=None): 742 values, update_ops = _streaming_confusion_matrix_at_thresholds( 743 predictions, labels, thresholds, weights=weights, includes=('fn',)) 744 return values['fn'], update_ops['fn'] 745 746 747def streaming_false_positives_at_thresholds( 748 predictions, labels, thresholds, weights=None): 749 values, update_ops = _streaming_confusion_matrix_at_thresholds( 750 predictions, labels, thresholds, weights=weights, includes=('fp',)) 751 return values['fp'], update_ops['fp'] 752 753 754def streaming_true_negatives_at_thresholds( 755 predictions, labels, thresholds, weights=None): 756 values, update_ops = _streaming_confusion_matrix_at_thresholds( 757 predictions, labels, thresholds, weights=weights, includes=('tn',)) 758 return values['tn'], update_ops['tn'] 759 760 761def streaming_auc(predictions, labels, weights=None, num_thresholds=200, 762 metrics_collections=None, updates_collections=None, 763 curve='ROC', name=None): 764 """Computes the approximate AUC via a Riemann sum. 765 766 The `streaming_auc` function creates four local variables, `true_positives`, 767 `true_negatives`, `false_positives` and `false_negatives` that are used to 768 compute the AUC. To discretize the AUC curve, a linearly spaced set of 769 thresholds is used to compute pairs of recall and precision values. The area 770 under the ROC-curve is therefore computed using the height of the recall 771 values by the false positive rate, while the area under the PR-curve is the 772 computed using the height of the precision values by the recall. 773 774 This value is ultimately returned as `auc`, an idempotent operation that 775 computes the area under a discretized curve of precision versus recall values 776 (computed using the aforementioned variables). The `num_thresholds` variable 777 controls the degree of discretization with larger numbers of thresholds more 778 closely approximating the true AUC. The quality of the approximation may vary 779 dramatically depending on `num_thresholds`. 780 781 For best results, `predictions` should be distributed approximately uniformly 782 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 783 approximation may be poor if this is not the case. 784 785 For estimation of the metric over a stream of data, the function creates an 786 `update_op` operation that updates these variables and returns the `auc`. 787 788 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 789 790 Args: 791 predictions: A floating point `Tensor` of arbitrary shape and whose values 792 are in the range `[0, 1]`. 793 labels: A `bool` `Tensor` whose shape matches `predictions`. 794 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 795 must be broadcastable to `labels` (i.e., all dimensions must be either 796 `1`, or the same as the corresponding `labels` dimension). 797 num_thresholds: The number of thresholds to use when discretizing the roc 798 curve. 799 metrics_collections: An optional list of collections that `auc` should be 800 added to. 801 updates_collections: An optional list of collections that `update_op` should 802 be added to. 803 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 804 'PR' for the Precision-Recall-curve. 805 name: An optional variable_scope name. 806 807 Returns: 808 auc: A scalar `Tensor` representing the current area-under-curve. 809 update_op: An operation that increments the `true_positives`, 810 `true_negatives`, `false_positives` and `false_negatives` variables 811 appropriately and whose value matches `auc`. 812 813 Raises: 814 ValueError: If `predictions` and `labels` have mismatched shapes, or if 815 `weights` is not `None` and its shape doesn't match `predictions`, or if 816 either `metrics_collections` or `updates_collections` are not a list or 817 tuple. 818 """ 819 return metrics.auc( 820 predictions=predictions, labels=labels, weights=weights, 821 metrics_collections=metrics_collections, num_thresholds=num_thresholds, 822 curve=curve, updates_collections=updates_collections, name=name) 823 824 825def streaming_specificity_at_sensitivity( 826 predictions, labels, sensitivity, weights=None, num_thresholds=200, 827 metrics_collections=None, updates_collections=None, name=None): 828 """Computes the specificity at a given sensitivity. 829 830 The `streaming_specificity_at_sensitivity` function creates four local 831 variables, `true_positives`, `true_negatives`, `false_positives` and 832 `false_negatives` that are used to compute the specificity at the given 833 sensitivity value. The threshold for the given sensitivity value is computed 834 and used to evaluate the corresponding specificity. 835 836 For estimation of the metric over a stream of data, the function creates an 837 `update_op` operation that updates these variables and returns the 838 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 839 `false_positives` and `false_negatives` counts with the weight of each case 840 found in the `predictions` and `labels`. 841 842 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 843 844 For additional information about specificity and sensitivity, see the 845 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 846 847 Args: 848 predictions: A floating point `Tensor` of arbitrary shape and whose values 849 are in the range `[0, 1]`. 850 labels: A `bool` `Tensor` whose shape matches `predictions`. 851 sensitivity: A scalar value in range `[0, 1]`. 852 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 853 must be broadcastable to `labels` (i.e., all dimensions must be either 854 `1`, or the same as the corresponding `labels` dimension). 855 num_thresholds: The number of thresholds to use for matching the given 856 sensitivity. 857 metrics_collections: An optional list of collections that `specificity` 858 should be added to. 859 updates_collections: An optional list of collections that `update_op` should 860 be added to. 861 name: An optional variable_scope name. 862 863 Returns: 864 specificity: A scalar `Tensor` representing the specificity at the given 865 `specificity` value. 866 update_op: An operation that increments the `true_positives`, 867 `true_negatives`, `false_positives` and `false_negatives` variables 868 appropriately and whose value matches `specificity`. 869 870 Raises: 871 ValueError: If `predictions` and `labels` have mismatched shapes, if 872 `weights` is not `None` and its shape doesn't match `predictions`, or if 873 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 874 or `updates_collections` are not a list or tuple. 875 """ 876 return metrics.specificity_at_sensitivity( 877 sensitivity=sensitivity, num_thresholds=num_thresholds, 878 predictions=predictions, labels=labels, weights=weights, 879 metrics_collections=metrics_collections, 880 updates_collections=updates_collections, name=name) 881 882 883def streaming_sensitivity_at_specificity( 884 predictions, labels, specificity, weights=None, num_thresholds=200, 885 metrics_collections=None, updates_collections=None, name=None): 886 """Computes the specificity at a given sensitivity. 887 888 The `streaming_sensitivity_at_specificity` function creates four local 889 variables, `true_positives`, `true_negatives`, `false_positives` and 890 `false_negatives` that are used to compute the sensitivity at the given 891 specificity value. The threshold for the given specificity value is computed 892 and used to evaluate the corresponding sensitivity. 893 894 For estimation of the metric over a stream of data, the function creates an 895 `update_op` operation that updates these variables and returns the 896 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 897 `false_positives` and `false_negatives` counts with the weight of each case 898 found in the `predictions` and `labels`. 899 900 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 901 902 For additional information about specificity and sensitivity, see the 903 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 904 905 Args: 906 predictions: A floating point `Tensor` of arbitrary shape and whose values 907 are in the range `[0, 1]`. 908 labels: A `bool` `Tensor` whose shape matches `predictions`. 909 specificity: A scalar value in range `[0, 1]`. 910 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 911 must be broadcastable to `labels` (i.e., all dimensions must be either 912 `1`, or the same as the corresponding `labels` dimension). 913 num_thresholds: The number of thresholds to use for matching the given 914 specificity. 915 metrics_collections: An optional list of collections that `sensitivity` 916 should be added to. 917 updates_collections: An optional list of collections that `update_op` should 918 be added to. 919 name: An optional variable_scope name. 920 921 Returns: 922 sensitivity: A scalar `Tensor` representing the sensitivity at the given 923 `specificity` value. 924 update_op: An operation that increments the `true_positives`, 925 `true_negatives`, `false_positives` and `false_negatives` variables 926 appropriately and whose value matches `sensitivity`. 927 928 Raises: 929 ValueError: If `predictions` and `labels` have mismatched shapes, if 930 `weights` is not `None` and its shape doesn't match `predictions`, or if 931 `specificity` is not between 0 and 1, or if either `metrics_collections` 932 or `updates_collections` are not a list or tuple. 933 """ 934 return metrics.sensitivity_at_specificity( 935 specificity=specificity, num_thresholds=num_thresholds, 936 predictions=predictions, labels=labels, weights=weights, 937 metrics_collections=metrics_collections, 938 updates_collections=updates_collections, name=name) 939 940 941def streaming_precision_at_thresholds(predictions, labels, thresholds, 942 weights=None, 943 metrics_collections=None, 944 updates_collections=None, name=None): 945 """Computes precision values for different `thresholds` on `predictions`. 946 947 The `streaming_precision_at_thresholds` function creates four local variables, 948 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 949 for various values of thresholds. `precision[i]` is defined as the total 950 weight of values in `predictions` above `thresholds[i]` whose corresponding 951 entry in `labels` is `True`, divided by the total weight of values in 952 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 953 false_positives[i])`). 954 955 For estimation of the metric over a stream of data, the function creates an 956 `update_op` operation that updates these variables and returns the 957 `precision`. 958 959 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 960 961 Args: 962 predictions: A floating point `Tensor` of arbitrary shape and whose values 963 are in the range `[0, 1]`. 964 labels: A `bool` `Tensor` whose shape matches `predictions`. 965 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 966 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 967 must be broadcastable to `labels` (i.e., all dimensions must be either 968 `1`, or the same as the corresponding `labels` dimension). 969 metrics_collections: An optional list of collections that `auc` should be 970 added to. 971 updates_collections: An optional list of collections that `update_op` should 972 be added to. 973 name: An optional variable_scope name. 974 975 Returns: 976 precision: A float `Tensor` of shape `[len(thresholds)]`. 977 update_op: An operation that increments the `true_positives`, 978 `true_negatives`, `false_positives` and `false_negatives` variables that 979 are used in the computation of `precision`. 980 981 Raises: 982 ValueError: If `predictions` and `labels` have mismatched shapes, or if 983 `weights` is not `None` and its shape doesn't match `predictions`, or if 984 either `metrics_collections` or `updates_collections` are not a list or 985 tuple. 986 """ 987 return metrics.precision_at_thresholds( 988 thresholds=thresholds, 989 predictions=predictions, labels=labels, weights=weights, 990 metrics_collections=metrics_collections, 991 updates_collections=updates_collections, name=name) 992 993 994def streaming_recall_at_thresholds(predictions, labels, thresholds, 995 weights=None, metrics_collections=None, 996 updates_collections=None, name=None): 997 """Computes various recall values for different `thresholds` on `predictions`. 998 999 The `streaming_recall_at_thresholds` function creates four local variables, 1000 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1001 for various values of thresholds. `recall[i]` is defined as the total weight 1002 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1003 `labels` is `True`, divided by the total weight of `True` values in `labels` 1004 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 1005 1006 For estimation of the metric over a stream of data, the function creates an 1007 `update_op` operation that updates these variables and returns the `recall`. 1008 1009 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1010 1011 Args: 1012 predictions: A floating point `Tensor` of arbitrary shape and whose values 1013 are in the range `[0, 1]`. 1014 labels: A `bool` `Tensor` whose shape matches `predictions`. 1015 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1016 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1017 must be broadcastable to `labels` (i.e., all dimensions must be either 1018 `1`, or the same as the corresponding `labels` dimension). 1019 metrics_collections: An optional list of collections that `recall` should be 1020 added to. 1021 updates_collections: An optional list of collections that `update_op` should 1022 be added to. 1023 name: An optional variable_scope name. 1024 1025 Returns: 1026 recall: A float `Tensor` of shape `[len(thresholds)]`. 1027 update_op: An operation that increments the `true_positives`, 1028 `true_negatives`, `false_positives` and `false_negatives` variables that 1029 are used in the computation of `recall`. 1030 1031 Raises: 1032 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1033 `weights` is not `None` and its shape doesn't match `predictions`, or if 1034 either `metrics_collections` or `updates_collections` are not a list or 1035 tuple. 1036 """ 1037 return metrics.recall_at_thresholds( 1038 thresholds=thresholds, 1039 predictions=predictions, labels=labels, weights=weights, 1040 metrics_collections=metrics_collections, 1041 updates_collections=updates_collections, name=name) 1042 1043 1044def _at_k_name(name, k=None, class_id=None): 1045 if k is not None: 1046 name = '%s_at_%d' % (name, k) 1047 else: 1048 name = '%s_at_k' % (name) 1049 if class_id is not None: 1050 name = '%s_class%d' % (name, class_id) 1051 return name 1052 1053 1054@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 1055 'and reshape labels from [batch_size] to [batch_size, 1].') 1056def streaming_recall_at_k(predictions, labels, k, weights=None, 1057 metrics_collections=None, updates_collections=None, 1058 name=None): 1059 """Computes the recall@k of the predictions with respect to dense labels. 1060 1061 The `streaming_recall_at_k` function creates two local variables, `total` and 1062 `count`, that are used to compute the recall@k frequency. This frequency is 1063 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1064 divides `total` by `count`. 1065 1066 For estimation of the metric over a stream of data, the function creates an 1067 `update_op` operation that updates these variables and returns the 1068 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 1069 shape [batch_size] whose elements indicate whether or not the corresponding 1070 label is in the top `k` `predictions`. Then `update_op` increments `total` 1071 with the reduced sum of `weights` where `in_top_k` is `True`, and it 1072 increments `count` with the reduced sum of `weights`. 1073 1074 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1075 1076 Args: 1077 predictions: A float `Tensor` of dimension [batch_size, num_classes]. 1078 labels: A `Tensor` of dimension [batch_size] whose type is in `int32`, 1079 `int64`. 1080 k: The number of top elements to look at for computing recall. 1081 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1082 must be broadcastable to `labels` (i.e., all dimensions must be either 1083 `1`, or the same as the corresponding `labels` dimension). 1084 metrics_collections: An optional list of collections that `recall_at_k` 1085 should be added to. 1086 updates_collections: An optional list of collections `update_op` should be 1087 added to. 1088 name: An optional variable_scope name. 1089 1090 Returns: 1091 recall_at_k: A `Tensor` representing the recall@k, the fraction of labels 1092 which fall into the top `k` predictions. 1093 update_op: An operation that increments the `total` and `count` variables 1094 appropriately and whose value matches `recall_at_k`. 1095 1096 Raises: 1097 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1098 `weights` is not `None` and its shape doesn't match `predictions`, or if 1099 either `metrics_collections` or `updates_collections` are not a list or 1100 tuple. 1101 """ 1102 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1103 return streaming_mean(in_top_k, 1104 weights, 1105 metrics_collections, 1106 updates_collections, 1107 name or _at_k_name('recall', k)) 1108 1109 1110# TODO(ptucker): Validate range of values in labels? 1111def streaming_sparse_recall_at_k(predictions, 1112 labels, 1113 k, 1114 class_id=None, 1115 weights=None, 1116 metrics_collections=None, 1117 updates_collections=None, 1118 name=None): 1119 """Computes recall@k of the predictions with respect to sparse labels. 1120 1121 If `class_id` is not specified, we'll calculate recall as the ratio of true 1122 positives (i.e., correct predictions, items in the top `k` highest 1123 `predictions` that are found in the corresponding row in `labels`) to 1124 actual positives (the full `labels` row). 1125 If `class_id` is specified, we calculate recall by considering only the rows 1126 in the batch for which `class_id` is in `labels`, and computing the 1127 fraction of them for which `class_id` is in the corresponding row in 1128 `labels`. 1129 1130 `streaming_sparse_recall_at_k` creates two local variables, 1131 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1132 the recall_at_k frequency. This frequency is ultimately returned as 1133 `recall_at_<k>`: an idempotent operation that simply divides 1134 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1135 `false_negative_at_<k>`). 1136 1137 For estimation of the metric over a stream of data, the function creates an 1138 `update_op` operation that updates these variables and returns the 1139 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1140 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1141 `labels` calculate the true positives and false negatives weighted by 1142 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1143 `false_negative_at_<k>` using these values. 1144 1145 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1146 1147 Args: 1148 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1149 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1150 The final dimension contains the logit values for each class. [D1, ... DN] 1151 must match `labels`. 1152 labels: `int64` `Tensor` or `SparseTensor` with shape 1153 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1154 target classes for the associated prediction. Commonly, N=1 and `labels` 1155 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 1156 Values should be in range [0, num_classes), where num_classes is the last 1157 dimension of `predictions`. Values outside this range always count 1158 towards `false_negative_at_<k>`. 1159 k: Integer, k for @k metric. 1160 class_id: Integer class ID for which we want binary metrics. This should be 1161 in range [0, num_classes), where num_classes is the last dimension of 1162 `predictions`. If class_id is outside this range, the method returns NAN. 1163 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1164 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1165 dimensions must be either `1`, or the same as the corresponding `labels` 1166 dimension). 1167 metrics_collections: An optional list of collections that values should 1168 be added to. 1169 updates_collections: An optional list of collections that updates should 1170 be added to. 1171 name: Name of new update operation, and namespace for other dependent ops. 1172 1173 Returns: 1174 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1175 by the sum of `true_positives` and `false_negatives`. 1176 update_op: `Operation` that increments `true_positives` and 1177 `false_negatives` variables appropriately, and whose value matches 1178 `recall`. 1179 1180 Raises: 1181 ValueError: If `weights` is not `None` and its shape doesn't match 1182 `predictions`, or if either `metrics_collections` or `updates_collections` 1183 are not a list or tuple. 1184 """ 1185 return metrics.recall_at_k( 1186 k=k, class_id=class_id, 1187 predictions=predictions, labels=labels, weights=weights, 1188 metrics_collections=metrics_collections, 1189 updates_collections=updates_collections, name=name) 1190 1191 1192def _streaming_sparse_precision_at_k(top_k_idx, 1193 labels, 1194 k=None, 1195 class_id=None, 1196 weights=None, 1197 metrics_collections=None, 1198 updates_collections=None, 1199 name=None): 1200 """Computes precision@k of the top-k indices with respect to sparse labels. 1201 1202 This method contains the code shared by streaming_sparse_precision_at_k and 1203 streaming_sparse_precision_at_top_k. Refer to those methods for more details. 1204 1205 Args: 1206 top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where 1207 N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k]. 1208 The final dimension contains the indices of top-k labels. [D1, ... DN] 1209 must match `labels`. 1210 labels: `int64` `Tensor` or `SparseTensor` with shape 1211 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1212 target classes for the associated prediction. Commonly, N=1 and `labels` 1213 has shape [batch_size, num_labels]. [D1, ... DN] must match 1214 `predictions_idx`. Values should be in range [0, num_classes), where 1215 num_classes is the last dimension of `predictions`. Values outside this 1216 range are ignored. 1217 k: Integer, k for @k metric or `None`. Only used for default op name. 1218 class_id: Integer class ID for which we want binary metrics. This should be 1219 in range [0, num_classes), where num_classes is the last dimension of 1220 `predictions`. If `class_id` is outside this range, the method returns 1221 NAN. 1222 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1223 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1224 dimensions must be either `1`, or the same as the corresponding `labels` 1225 dimension). 1226 metrics_collections: An optional list of collections that values should 1227 be added to. 1228 updates_collections: An optional list of collections that updates should 1229 be added to. 1230 name: Name of the metric and of the enclosing scope. 1231 1232 Returns: 1233 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1234 divided by the sum of `true_positives` and `false_positives`. 1235 update_op: `Operation` that increments `true_positives` and 1236 `false_positives` variables appropriately, and whose value matches 1237 `precision`. 1238 1239 Raises: 1240 ValueError: If `weights` is not `None` and its shape doesn't match 1241 `predictions`, or if either `metrics_collections` or `updates_collections` 1242 are not a list or tuple. 1243 """ 1244 top_k_idx = math_ops.to_int64(top_k_idx) 1245 tp, tp_update = _streaming_sparse_true_positive_at_k( 1246 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1247 weights=weights) 1248 fp, fp_update = _streaming_sparse_false_positive_at_k( 1249 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, 1250 weights=weights) 1251 1252 metric = math_ops.div(tp, math_ops.add(tp, fp), name=name) 1253 update = math_ops.div( 1254 tp_update, math_ops.add(tp_update, fp_update), name='update') 1255 if metrics_collections: 1256 ops.add_to_collections(metrics_collections, metric) 1257 if updates_collections: 1258 ops.add_to_collections(updates_collections, update) 1259 return metric, update 1260 1261 1262# TODO(ptucker): Validate range of values in labels? 1263def streaming_sparse_precision_at_k(predictions, 1264 labels, 1265 k, 1266 class_id=None, 1267 weights=None, 1268 metrics_collections=None, 1269 updates_collections=None, 1270 name=None): 1271 """Computes precision@k of the predictions with respect to sparse labels. 1272 1273 If `class_id` is not specified, we calculate precision as the ratio of true 1274 positives (i.e., correct predictions, items in the top `k` highest 1275 `predictions` that are found in the corresponding row in `labels`) to 1276 positives (all top `k` `predictions`). 1277 If `class_id` is specified, we calculate precision by considering only the 1278 rows in the batch for which `class_id` is in the top `k` highest 1279 `predictions`, and computing the fraction of them for which `class_id` is 1280 in the corresponding row in `labels`. 1281 1282 We expect precision to decrease as `k` increases. 1283 1284 `streaming_sparse_precision_at_k` creates two local variables, 1285 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1286 the precision@k frequency. This frequency is ultimately returned as 1287 `precision_at_<k>`: an idempotent operation that simply divides 1288 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1289 `false_positive_at_<k>`). 1290 1291 For estimation of the metric over a stream of data, the function creates an 1292 `update_op` operation that updates these variables and returns the 1293 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1294 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1295 `labels` calculate the true positives and false positives weighted by 1296 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1297 `false_positive_at_<k>` using these values. 1298 1299 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1300 1301 Args: 1302 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1303 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1304 The final dimension contains the logit values for each class. [D1, ... DN] 1305 must match `labels`. 1306 labels: `int64` `Tensor` or `SparseTensor` with shape 1307 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1308 target classes for the associated prediction. Commonly, N=1 and `labels` 1309 has shape [batch_size, num_labels]. [D1, ... DN] must match 1310 `predictions`. Values should be in range [0, num_classes), where 1311 num_classes is the last dimension of `predictions`. Values outside this 1312 range are ignored. 1313 k: Integer, k for @k metric. 1314 class_id: Integer class ID for which we want binary metrics. This should be 1315 in range [0, num_classes], where num_classes is the last dimension of 1316 `predictions`. If `class_id` is outside this range, the method returns 1317 NAN. 1318 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1319 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1320 dimensions must be either `1`, or the same as the corresponding `labels` 1321 dimension). 1322 metrics_collections: An optional list of collections that values should 1323 be added to. 1324 updates_collections: An optional list of collections that updates should 1325 be added to. 1326 name: Name of new update operation, and namespace for other dependent ops. 1327 1328 Returns: 1329 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1330 divided by the sum of `true_positives` and `false_positives`. 1331 update_op: `Operation` that increments `true_positives` and 1332 `false_positives` variables appropriately, and whose value matches 1333 `precision`. 1334 1335 Raises: 1336 ValueError: If `weights` is not `None` and its shape doesn't match 1337 `predictions`, or if either `metrics_collections` or `updates_collections` 1338 are not a list or tuple. 1339 """ 1340 return metrics.sparse_precision_at_k( 1341 k=k, class_id=class_id, 1342 predictions=predictions, labels=labels, weights=weights, 1343 metrics_collections=metrics_collections, 1344 updates_collections=updates_collections, name=name) 1345 1346 1347# TODO(ptucker): Validate range of values in labels? 1348def streaming_sparse_precision_at_top_k(top_k_predictions, 1349 labels, 1350 class_id=None, 1351 weights=None, 1352 metrics_collections=None, 1353 updates_collections=None, 1354 name=None): 1355 """Computes precision@k of top-k predictions with respect to sparse labels. 1356 1357 If `class_id` is not specified, we calculate precision as the ratio of 1358 true positives (i.e., correct predictions, items in `top_k_predictions` 1359 that are found in the corresponding row in `labels`) to positives (all 1360 `top_k_predictions`). 1361 If `class_id` is specified, we calculate precision by considering only the 1362 rows in the batch for which `class_id` is in the top `k` highest 1363 `predictions`, and computing the fraction of them for which `class_id` is 1364 in the corresponding row in `labels`. 1365 1366 We expect precision to decrease as `k` increases. 1367 1368 `streaming_sparse_precision_at_top_k` creates two local variables, 1369 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 1370 the precision@k frequency. This frequency is ultimately returned as 1371 `precision_at_k`: an idempotent operation that simply divides 1372 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 1373 1374 For estimation of the metric over a stream of data, the function creates an 1375 `update_op` operation that updates these variables and returns the 1376 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 1377 and `labels` calculate the true positives and false positives weighted by 1378 `weights`. Then `update_op` increments `true_positive_at_k` and 1379 `false_positive_at_k` using these values. 1380 1381 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1382 1383 Args: 1384 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 1385 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 1386 The final dimension contains the indices of top-k labels. [D1, ... DN] 1387 must match `labels`. 1388 labels: `int64` `Tensor` or `SparseTensor` with shape 1389 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1390 target classes for the associated prediction. Commonly, N=1 and `labels` 1391 has shape [batch_size, num_labels]. [D1, ... DN] must match 1392 `top_k_predictions`. Values should be in range [0, num_classes), where 1393 num_classes is the last dimension of `predictions`. Values outside this 1394 range are ignored. 1395 class_id: Integer class ID for which we want binary metrics. This should be 1396 in range [0, num_classes), where num_classes is the last dimension of 1397 `predictions`. If `class_id` is outside this range, the method returns 1398 NAN. 1399 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1400 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1401 dimensions must be either `1`, or the same as the corresponding `labels` 1402 dimension). 1403 metrics_collections: An optional list of collections that values should 1404 be added to. 1405 updates_collections: An optional list of collections that updates should 1406 be added to. 1407 name: Name of new update operation, and namespace for other dependent ops. 1408 1409 Returns: 1410 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1411 divided by the sum of `true_positives` and `false_positives`. 1412 update_op: `Operation` that increments `true_positives` and 1413 `false_positives` variables appropriately, and whose value matches 1414 `precision`. 1415 1416 Raises: 1417 ValueError: If `weights` is not `None` and its shape doesn't match 1418 `predictions`, or if either `metrics_collections` or `updates_collections` 1419 are not a list or tuple. 1420 ValueError: If `top_k_predictions` has rank < 2. 1421 """ 1422 default_name = _at_k_name('precision', class_id=class_id) 1423 with ops.name_scope( 1424 name, default_name, 1425 (top_k_predictions, labels, weights)) as name_scope: 1426 return _streaming_sparse_precision_at_k( 1427 top_k_idx=top_k_predictions, 1428 labels=labels, 1429 class_id=class_id, 1430 weights=weights, 1431 metrics_collections=metrics_collections, 1432 updates_collections=updates_collections, 1433 name=name_scope) 1434 1435 1436def num_relevant(labels, k): 1437 """Computes number of relevant values for each row in labels. 1438 1439 For labels with shape [D1, ... DN, num_labels], this is the minimum of 1440 `num_labels` and `k`. 1441 1442 Args: 1443 labels: `int64` `Tensor` or `SparseTensor` with shape 1444 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1445 target classes for the associated prediction. Commonly, N=1 and `labels` 1446 has shape [batch_size, num_labels]. 1447 k: Integer, k for @k metric. 1448 1449 Returns: 1450 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 1451 relevant values for that row. 1452 1453 Raises: 1454 ValueError: if inputs have invalid dtypes or values. 1455 """ 1456 if k < 1: 1457 raise ValueError('Invalid k=%s.' % k) 1458 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 1459 # For SparseTensor, calculate separate count for each row. 1460 if isinstance( 1461 labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 1462 labels_sizes = set_ops.set_size(labels) 1463 return math_ops.minimum(labels_sizes, k, name=scope) 1464 1465 # For dense Tensor, calculate scalar count based on last dimension, and 1466 # tile across labels shape. 1467 labels_shape = array_ops.shape(labels) 1468 labels_size = labels_shape[-1] 1469 num_relevant_scalar = math_ops.minimum(labels_size, k) 1470 return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope) 1471 1472 1473def expand_and_tile(tensor, multiple, dim=0, name=None): 1474 """Slice `tensor` shape in 2, then tile along the sliced dimension. 1475 1476 A new dimension is inserted in shape of `tensor` before `dim`, then values are 1477 tiled `multiple` times along the new dimension. 1478 1479 Args: 1480 tensor: Input `Tensor` or `SparseTensor`. 1481 multiple: Integer, number of times to tile. 1482 dim: Integer, dimension along which to tile. 1483 name: Name of operation. 1484 1485 Returns: 1486 `Tensor` result of expanding and tiling `tensor`. 1487 1488 Raises: 1489 ValueError: if `multiple` is less than 1, or `dim` is not in 1490 `[-rank(tensor), rank(tensor)]`. 1491 """ 1492 if multiple < 1: 1493 raise ValueError('Invalid multiple %s, must be > 0.' % multiple) 1494 with ops.name_scope( 1495 name, 'expand_and_tile', (tensor, multiple, dim)) as scope: 1496 # Sparse. 1497 if isinstance(tensor, sparse_tensor.SparseTensorValue): 1498 tensor = sparse_tensor.SparseTensor.from_value(tensor) 1499 if isinstance(tensor, sparse_tensor.SparseTensor): 1500 if dim < 0: 1501 expand_dims = array_ops.reshape( 1502 array_ops.size(tensor.dense_shape) + dim, [1]) 1503 else: 1504 expand_dims = [dim] 1505 expanded_shape = array_ops.concat( 1506 (array_ops.strided_slice(tensor.dense_shape, [0], expand_dims), [1], 1507 array_ops.strided_slice( 1508 tensor.dense_shape, expand_dims, [-1], end_mask=1 << 0)), 1509 0, 1510 name='expanded_shape') 1511 expanded = sparse_ops.sparse_reshape( 1512 tensor, shape=expanded_shape, name='expand') 1513 if multiple == 1: 1514 return expanded 1515 return sparse_ops.sparse_concat( 1516 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 1517 1518 # Dense. 1519 expanded = array_ops.expand_dims( 1520 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 1521 if multiple == 1: 1522 return expanded 1523 ones = array_ops.ones_like(array_ops.shape(tensor)) 1524 tile_multiples = array_ops.concat( 1525 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 1526 return array_ops.tile(expanded, tile_multiples, name=scope) 1527 1528 1529def sparse_average_precision_at_k(predictions, labels, k): 1530 """Computes average precision@k of predictions with respect to sparse labels. 1531 1532 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 1533 for each row is: 1534 1535 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 1536 1537 A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`, 1538 and the result `Tensors`. In the common case, this is [batch_size]. Each row 1539 of the results contains the average precision for that row. 1540 1541 Internally, a `top_k` operation computes a `Tensor` indicating the top `k` 1542 `predictions`. Set operations applied to `top_k` and `labels` calculate the 1543 true positives, which are used to calculate the precision ("P_{i}" term, 1544 above). 1545 1546 Args: 1547 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1548 N >= 1. Commonly, N=1 and `predictions` has shape 1549 [batch size, num_classes]. The final dimension contains the logit values 1550 for each class. [D1, ... DN] must match `labels`. 1551 labels: `int64` `Tensor` or `SparseTensor` with shape 1552 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1553 target classes for the associated prediction. Commonly, N=1 and `labels` 1554 has shape [batch_size, num_labels]. [D1, ... DN] must match 1555 `predictions`. Values should be in range [0, num_classes), where 1556 num_classes is the last dimension of `predictions`. Values outside this 1557 range are ignored. 1558 k: Integer, k for @k metric. This will calculate an average precision for 1559 range `[1,k]`, as documented above. 1560 1561 Returns: 1562 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 1563 precision for that row. 1564 1565 Raises: 1566 ValueError: if k is invalid. 1567 """ 1568 if k < 1: 1569 raise ValueError('Invalid k=%s.' % k) 1570 with ops.name_scope( 1571 None, 'average_precision', (predictions, labels, k)) as scope: 1572 # Calculate top k indices to produce [D1, ... DN, k] tensor. 1573 _, predictions_idx = nn.top_k(predictions, k) 1574 predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx') 1575 1576 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 1577 # prediction for each k, so we can calculate separate true positive values 1578 # for each k. 1579 predictions_idx_per_k = array_ops.expand_dims( 1580 predictions_idx, -1, name='predictions_idx_per_k') 1581 1582 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 1583 labels_per_k = expand_and_tile( 1584 labels, multiple=k, dim=-1, name='labels_per_k') 1585 1586 # The following tensors are all of shape [D1, ... DN, k], containing values 1587 # per row, per k value. 1588 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 1589 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 1590 # the formula above. 1591 # `tp_per_k` (int32) - True positive counts. 1592 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 1593 # the precision denominator. 1594 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 1595 # term from the formula above. 1596 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 1597 # precisions at all k for which relevance indicator is true. 1598 relevant_per_k = _sparse_true_positive_at_k( 1599 predictions_idx_per_k, labels_per_k, name='relevant_per_k') 1600 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 1601 retrieved_per_k = math_ops.cumsum( 1602 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 1603 precision_per_k = math_ops.div( 1604 math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k), 1605 name='precision_per_k') 1606 relevant_precision_per_k = math_ops.multiply( 1607 precision_per_k, math_ops.to_double(relevant_per_k), 1608 name='relevant_precision_per_k') 1609 1610 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 1611 precision_sum = math_ops.reduce_sum( 1612 relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum') 1613 1614 # Divide by number of relevant items to get average precision. These are 1615 # the "num_relevant_items" and "AveP" terms from the formula above. 1616 num_relevant_items = math_ops.to_double(num_relevant(labels, k)) 1617 return math_ops.div(precision_sum, num_relevant_items, name=scope) 1618 1619 1620def streaming_sparse_average_precision_at_k(predictions, 1621 labels, 1622 k, 1623 weights=None, 1624 metrics_collections=None, 1625 updates_collections=None, 1626 name=None): 1627 """Computes average precision@k of predictions with respect to sparse labels. 1628 1629 See `sparse_average_precision_at_k` for details on formula. `weights` are 1630 applied to the result of `sparse_average_precision_at_k` 1631 1632 `streaming_sparse_average_precision_at_k` creates two local variables, 1633 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 1634 are used to compute the frequency. This frequency is ultimately returned as 1635 `average_precision_at_<k>`: an idempotent operation that simply divides 1636 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 1637 1638 For estimation of the metric over a stream of data, the function creates an 1639 `update_op` operation that updates these variables and returns the 1640 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1641 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1642 `labels` calculate the true positives and false positives weighted by 1643 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1644 `false_positive_at_<k>` using these values. 1645 1646 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1647 1648 Args: 1649 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1650 N >= 1. Commonly, N=1 and `predictions` has shape 1651 [batch size, num_classes]. The final dimension contains the logit values 1652 for each class. [D1, ... DN] must match `labels`. 1653 labels: `int64` `Tensor` or `SparseTensor` with shape 1654 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1655 target classes for the associated prediction. Commonly, N=1 and `labels` 1656 has shape [batch_size, num_labels]. [D1, ... DN] must match 1657 `predictions_`. Values should be in range [0, num_classes), where 1658 num_classes is the last dimension of `predictions`. Values outside this 1659 range are ignored. 1660 k: Integer, k for @k metric. This will calculate an average precision for 1661 range `[1,k]`, as documented above. 1662 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1663 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1664 dimensions must be either `1`, or the same as the corresponding `labels` 1665 dimension). 1666 metrics_collections: An optional list of collections that values should 1667 be added to. 1668 updates_collections: An optional list of collections that updates should 1669 be added to. 1670 name: Name of new update operation, and namespace for other dependent ops. 1671 1672 Returns: 1673 mean_average_precision: Scalar `float64` `Tensor` with the mean average 1674 precision values. 1675 update: `Operation` that increments variables appropriately, and whose 1676 value matches `metric`. 1677 """ 1678 return metrics.sparse_average_precision_at_k( 1679 k=k, predictions=predictions, labels=labels, weights=weights, 1680 metrics_collections=metrics_collections, 1681 updates_collections=updates_collections, name=name) 1682 1683 1684def _select_class_id(ids, selected_id): 1685 """Filter all but `selected_id` out of `ids`. 1686 1687 Args: 1688 ids: `int64` `Tensor` or `SparseTensor` of IDs. 1689 selected_id: Int id to select. 1690 1691 Returns: 1692 `SparseTensor` of same dimensions as `ids`. This contains only the entries 1693 equal to `selected_id`. 1694 """ 1695 if isinstance( 1696 ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 1697 return sparse_ops.sparse_retain( 1698 ids, math_ops.equal(ids.values, selected_id)) 1699 1700 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 1701 # tf.equal and tf.reduce_any? 1702 1703 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 1704 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 1705 ids_last_dim = array_ops.size(ids_shape) - 1 1706 filled_selected_id_shape = math_ops.reduced_shape( 1707 ids_shape, array_ops.reshape(ids_last_dim, [1])) 1708 1709 # Intersect `ids` with the selected ID. 1710 filled_selected_id = array_ops.fill( 1711 filled_selected_id_shape, math_ops.to_int64(selected_id)) 1712 result = set_ops.set_intersection(filled_selected_id, ids) 1713 return sparse_tensor.SparseTensor( 1714 indices=result.indices, values=result.values, dense_shape=ids_shape) 1715 1716 1717def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 1718 """If class ID is specified, filter all other classes. 1719 1720 Args: 1721 labels: `int64` `Tensor` or `SparseTensor` with shape 1722 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1723 target classes for the associated prediction. Commonly, N=1 and `labels` 1724 has shape [batch_size, num_labels]. [D1, ... DN] must match 1725 `predictions_idx`. 1726 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 1727 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 1728 [batch size, k]. 1729 selected_id: Int id to select. 1730 1731 Returns: 1732 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 1733 """ 1734 if selected_id is None: 1735 return labels, predictions_idx 1736 return (_select_class_id(labels, selected_id), 1737 _select_class_id(predictions_idx, selected_id)) 1738 1739 1740def _sparse_true_positive_at_k(predictions_idx, 1741 labels, 1742 class_id=None, 1743 weights=None, 1744 name=None): 1745 """Calculates true positives for recall@k and precision@k. 1746 1747 If `class_id` is specified, calculate binary true positives for `class_id` 1748 only. 1749 If `class_id` is not specified, calculate metrics for `k` predicted vs 1750 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1751 1752 Args: 1753 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1754 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1755 match `labels`. 1756 labels: `int64` `Tensor` or `SparseTensor` with shape 1757 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1758 target classes for the associated prediction. Commonly, N=1 and `labels` 1759 has shape [batch_size, num_labels]. [D1, ... DN] must match 1760 `predictions_idx`. 1761 class_id: Class for which we want binary metrics. 1762 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1763 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1764 dimensions must be either `1`, or the same as the corresponding `labels` 1765 dimension). 1766 name: Name of operation. 1767 1768 Returns: 1769 A [D1, ... DN] `Tensor` of true positive counts. 1770 """ 1771 with ops.name_scope( 1772 name, 'true_positives', (predictions_idx, labels, weights)): 1773 labels, predictions_idx = _maybe_select_class_id( 1774 labels, predictions_idx, class_id) 1775 tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels)) 1776 tp = math_ops.to_double(tp) 1777 if weights is not None: 1778 weights = math_ops.to_double(weights) 1779 with ops.control_dependencies((_assert_weights_rank(weights, tp),)): 1780 tp = math_ops.multiply(tp, weights) 1781 return tp 1782 1783 1784def _streaming_sparse_true_positive_at_k(predictions_idx, 1785 labels, 1786 k=None, 1787 class_id=None, 1788 weights=None, 1789 name=None): 1790 """Calculates weighted per step true positives for recall@k and precision@k. 1791 1792 If `class_id` is specified, calculate binary true positives for `class_id` 1793 only. 1794 If `class_id` is not specified, calculate metrics for `k` predicted vs 1795 `n` label classes, where `n` is the 2nd dimension of `labels`. 1796 1797 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1798 1799 Args: 1800 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1801 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1802 match `labels`. 1803 labels: `int64` `Tensor` or `SparseTensor` with shape 1804 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1805 target classes for the associated prediction. Commonly, N=1 and `labels` 1806 has shape [batch_size, num_labels]. [D1, ... DN] must match 1807 `predictions_idx`. 1808 k: Integer, k for @k metric. This is only used for default op name. 1809 class_id: Class for which we want binary metrics. 1810 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1811 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1812 dimensions must be either `1`, or the same as the corresponding `labels` 1813 dimension). 1814 name: Name of new variable, and namespace for other dependent ops. 1815 1816 Returns: 1817 A tuple of `Variable` and update `Operation`. 1818 1819 Raises: 1820 ValueError: If `weights` is not `None` and has an incomptable shape. 1821 """ 1822 default_name = _at_k_name('true_positive', k, class_id=class_id) 1823 with ops.name_scope( 1824 name, default_name, (predictions_idx, labels, weights)) as scope: 1825 tp = _sparse_true_positive_at_k( 1826 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 1827 weights=weights) 1828 batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp)) 1829 1830 var = contrib_variables.local_variable( 1831 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1832 return var, state_ops.assign_add(var, batch_total_tp, name='update') 1833 1834 1835def _sparse_false_positive_at_k(predictions_idx, 1836 labels, 1837 class_id=None, 1838 weights=None): 1839 """Calculates false positives for precision@k. 1840 1841 If `class_id` is specified, calculate binary true positives for `class_id` 1842 only. 1843 If `class_id` is not specified, calculate metrics for `k` predicted vs 1844 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1845 1846 Args: 1847 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1848 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1849 match `labels`. 1850 labels: `int64` `Tensor` or `SparseTensor` with shape 1851 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1852 target classes for the associated prediction. Commonly, N=1 and `labels` 1853 has shape [batch_size, num_labels]. [D1, ... DN] must match 1854 `predictions_idx`. 1855 class_id: Class for which we want binary metrics. 1856 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1857 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1858 dimensions must be either `1`, or the same as the corresponding `labels` 1859 dimension). 1860 1861 Returns: 1862 A [D1, ... DN] `Tensor` of false positive counts. 1863 """ 1864 with ops.name_scope( 1865 None, 'false_positives', (predictions_idx, labels, weights)): 1866 labels, predictions_idx = _maybe_select_class_id(labels, 1867 predictions_idx, 1868 class_id) 1869 fp = set_ops.set_size(set_ops.set_difference( 1870 predictions_idx, labels, aminusb=True)) 1871 fp = math_ops.to_double(fp) 1872 if weights is not None: 1873 weights = math_ops.to_double(weights) 1874 with ops.control_dependencies((_assert_weights_rank(weights, fp),)): 1875 fp = math_ops.multiply(fp, weights) 1876 return fp 1877 1878 1879def _streaming_sparse_false_positive_at_k(predictions_idx, 1880 labels, 1881 k=None, 1882 class_id=None, 1883 weights=None, 1884 name=None): 1885 """Calculates weighted per step false positives for precision@k. 1886 1887 If `class_id` is specified, calculate binary true positives for `class_id` 1888 only. 1889 If `class_id` is not specified, calculate metrics for `k` predicted vs 1890 `n` label classes, where `n` is the 2nd dimension of `labels`. 1891 1892 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1893 1894 Args: 1895 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1896 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1897 match `labels`. 1898 labels: `int64` `Tensor` or `SparseTensor` with shape 1899 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1900 target classes for the associated prediction. Commonly, N=1 and `labels` 1901 has shape [batch_size, num_labels]. [D1, ... DN] must match 1902 `predictions_idx`. 1903 k: Integer, k for @k metric. This is only used for default op name. 1904 class_id: Class for which we want binary metrics. 1905 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1906 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1907 dimensions must be either `1`, or the same as the corresponding `labels` 1908 dimension). 1909 name: Name of new variable, and namespace for other dependent ops. 1910 1911 Returns: 1912 A tuple of `Variable` and update `Operation`. 1913 1914 Raises: 1915 ValueError: If `weights` is not `None` and has an incomptable shape. 1916 """ 1917 with ops.name_scope( 1918 name, _at_k_name('false_positive', k, class_id=class_id), 1919 (predictions_idx, labels, weights)) as scope: 1920 fp = _sparse_false_positive_at_k( 1921 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 1922 weights=weights) 1923 batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp)) 1924 1925 var = contrib_variables.local_variable( 1926 array_ops.zeros([], dtype=dtypes.float64), name=scope) 1927 return var, state_ops.assign_add(var, batch_total_fp, name='update') 1928 1929 1930def _sparse_false_negative_at_k(predictions_idx, 1931 labels, 1932 class_id=None, 1933 weights=None): 1934 """Calculates false negatives for recall@k. 1935 1936 If `class_id` is specified, calculate binary true positives for `class_id` 1937 only. 1938 If `class_id` is not specified, calculate metrics for `k` predicted vs 1939 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 1940 1941 Args: 1942 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1943 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1944 match `labels`. 1945 labels: `int64` `Tensor` or `SparseTensor` with shape 1946 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1947 target classes for the associated prediction. Commonly, N=1 and `labels` 1948 has shape [batch_size, num_labels]. [D1, ... DN] must match 1949 `predictions_idx`. 1950 class_id: Class for which we want binary metrics. 1951 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1952 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1953 dimensions must be either `1`, or the same as the corresponding `labels` 1954 dimension). 1955 1956 Returns: 1957 A [D1, ... DN] `Tensor` of false negative counts. 1958 """ 1959 with ops.name_scope( 1960 None, 'false_negatives', (predictions_idx, labels, weights)): 1961 labels, predictions_idx = _maybe_select_class_id(labels, 1962 predictions_idx, 1963 class_id) 1964 fn = set_ops.set_size(set_ops.set_difference(predictions_idx, 1965 labels, 1966 aminusb=False)) 1967 fn = math_ops.to_double(fn) 1968 if weights is not None: 1969 weights = math_ops.to_double(weights) 1970 with ops.control_dependencies((_assert_weights_rank(weights, fn),)): 1971 fn = math_ops.multiply(fn, weights) 1972 return fn 1973 1974 1975def _streaming_sparse_false_negative_at_k(predictions_idx, 1976 labels, 1977 k, 1978 class_id=None, 1979 weights=None, 1980 name=None): 1981 """Calculates weighted per step false negatives for recall@k. 1982 1983 If `class_id` is specified, calculate binary true positives for `class_id` 1984 only. 1985 If `class_id` is not specified, calculate metrics for `k` predicted vs 1986 `n` label classes, where `n` is the 2nd dimension of `labels`. 1987 1988 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1989 1990 Args: 1991 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 1992 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 1993 match `labels`. 1994 labels: `int64` `Tensor` or `SparseTensor` with shape 1995 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1996 target classes for the associated prediction. Commonly, N=1 and `labels` 1997 has shape [batch_size, num_labels]. [D1, ... DN] must match 1998 `predictions_idx`. 1999 k: Integer, k for @k metric. This is only used for default op name. 2000 class_id: Class for which we want binary metrics. 2001 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2002 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2003 dimensions must be either `1`, or the same as the corresponding `labels` 2004 dimension). 2005 name: Name of new variable, and namespace for other dependent ops. 2006 2007 Returns: 2008 A tuple of `Variable` and update `Operation`. 2009 2010 Raises: 2011 ValueError: If `weights` is not `None` and has an incomptable shape. 2012 """ 2013 with ops.name_scope( 2014 name, _at_k_name('false_negative', k, class_id=class_id), 2015 (predictions_idx, labels, weights)) as scope: 2016 fn = _sparse_false_negative_at_k( 2017 predictions_idx=predictions_idx, labels=labels, class_id=class_id, 2018 weights=weights) 2019 batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn)) 2020 2021 var = contrib_variables.local_variable( 2022 array_ops.zeros([], dtype=dtypes.float64), name=scope) 2023 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2024 2025 2026def streaming_mean_absolute_error(predictions, labels, weights=None, 2027 metrics_collections=None, 2028 updates_collections=None, 2029 name=None): 2030 """Computes the mean absolute error between the labels and predictions. 2031 2032 The `streaming_mean_absolute_error` function creates two local variables, 2033 `total` and `count` that are used to compute the mean absolute error. This 2034 average is weighted by `weights`, and it is ultimately returned as 2035 `mean_absolute_error`: an idempotent operation that simply divides `total` by 2036 `count`. 2037 2038 For estimation of the metric over a stream of data, the function creates an 2039 `update_op` operation that updates these variables and returns the 2040 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 2041 absolute value of the differences between `predictions` and `labels`. Then 2042 `update_op` increments `total` with the reduced sum of the product of 2043 `weights` and `absolute_errors`, and it increments `count` with the reduced 2044 sum of `weights` 2045 2046 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2047 2048 Args: 2049 predictions: A `Tensor` of arbitrary shape. 2050 labels: A `Tensor` of the same shape as `predictions`. 2051 weights: Optional `Tensor` indicating the frequency with which an example is 2052 sampled. Rank must be 0, or the same rank as `labels`, and must be 2053 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2054 the same as the corresponding `labels` dimension). 2055 metrics_collections: An optional list of collections that 2056 `mean_absolute_error` should be added to. 2057 updates_collections: An optional list of collections that `update_op` should 2058 be added to. 2059 name: An optional variable_scope name. 2060 2061 Returns: 2062 mean_absolute_error: A `Tensor` representing the current mean, the value of 2063 `total` divided by `count`. 2064 update_op: An operation that increments the `total` and `count` variables 2065 appropriately and whose value matches `mean_absolute_error`. 2066 2067 Raises: 2068 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2069 `weights` is not `None` and its shape doesn't match `predictions`, or if 2070 either `metrics_collections` or `updates_collections` are not a list or 2071 tuple. 2072 """ 2073 return metrics.mean_absolute_error( 2074 predictions=predictions, labels=labels, weights=weights, 2075 metrics_collections=metrics_collections, 2076 updates_collections=updates_collections, name=name) 2077 2078 2079def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, 2080 metrics_collections=None, 2081 updates_collections=None, 2082 name=None): 2083 """Computes the mean relative error by normalizing with the given values. 2084 2085 The `streaming_mean_relative_error` function creates two local variables, 2086 `total` and `count` that are used to compute the mean relative absolute error. 2087 This average is weighted by `weights`, and it is ultimately returned as 2088 `mean_relative_error`: an idempotent operation that simply divides `total` by 2089 `count`. 2090 2091 For estimation of the metric over a stream of data, the function creates an 2092 `update_op` operation that updates these variables and returns the 2093 `mean_reative_error`. Internally, a `relative_errors` operation divides the 2094 absolute value of the differences between `predictions` and `labels` by the 2095 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 2096 product of `weights` and `relative_errors`, and it increments `count` with the 2097 reduced sum of `weights`. 2098 2099 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2100 2101 Args: 2102 predictions: A `Tensor` of arbitrary shape. 2103 labels: A `Tensor` of the same shape as `predictions`. 2104 normalizer: A `Tensor` of the same shape as `predictions`. 2105 weights: Optional `Tensor` indicating the frequency with which an example is 2106 sampled. Rank must be 0, or the same rank as `labels`, and must be 2107 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2108 the same as the corresponding `labels` dimension). 2109 metrics_collections: An optional list of collections that 2110 `mean_relative_error` should be added to. 2111 updates_collections: An optional list of collections that `update_op` should 2112 be added to. 2113 name: An optional variable_scope name. 2114 2115 Returns: 2116 mean_relative_error: A `Tensor` representing the current mean, the value of 2117 `total` divided by `count`. 2118 update_op: An operation that increments the `total` and `count` variables 2119 appropriately and whose value matches `mean_relative_error`. 2120 2121 Raises: 2122 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2123 `weights` is not `None` and its shape doesn't match `predictions`, or if 2124 either `metrics_collections` or `updates_collections` are not a list or 2125 tuple. 2126 """ 2127 return metrics.mean_relative_error( 2128 normalizer=normalizer, predictions=predictions, labels=labels, 2129 weights=weights, metrics_collections=metrics_collections, 2130 updates_collections=updates_collections, name=name) 2131 2132 2133def streaming_mean_squared_error(predictions, labels, weights=None, 2134 metrics_collections=None, 2135 updates_collections=None, 2136 name=None): 2137 """Computes the mean squared error between the labels and predictions. 2138 2139 The `streaming_mean_squared_error` function creates two local variables, 2140 `total` and `count` that are used to compute the mean squared error. 2141 This average is weighted by `weights`, and it is ultimately returned as 2142 `mean_squared_error`: an idempotent operation that simply divides `total` by 2143 `count`. 2144 2145 For estimation of the metric over a stream of data, the function creates an 2146 `update_op` operation that updates these variables and returns the 2147 `mean_squared_error`. Internally, a `squared_error` operation computes the 2148 element-wise square of the difference between `predictions` and `labels`. Then 2149 `update_op` increments `total` with the reduced sum of the product of 2150 `weights` and `squared_error`, and it increments `count` with the reduced sum 2151 of `weights`. 2152 2153 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2154 2155 Args: 2156 predictions: A `Tensor` of arbitrary shape. 2157 labels: A `Tensor` of the same shape as `predictions`. 2158 weights: Optional `Tensor` indicating the frequency with which an example is 2159 sampled. Rank must be 0, or the same rank as `labels`, and must be 2160 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2161 the same as the corresponding `labels` dimension). 2162 metrics_collections: An optional list of collections that 2163 `mean_squared_error` should be added to. 2164 updates_collections: An optional list of collections that `update_op` should 2165 be added to. 2166 name: An optional variable_scope name. 2167 2168 Returns: 2169 mean_squared_error: A `Tensor` representing the current mean, the value of 2170 `total` divided by `count`. 2171 update_op: An operation that increments the `total` and `count` variables 2172 appropriately and whose value matches `mean_squared_error`. 2173 2174 Raises: 2175 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2176 `weights` is not `None` and its shape doesn't match `predictions`, or if 2177 either `metrics_collections` or `updates_collections` are not a list or 2178 tuple. 2179 """ 2180 return metrics.mean_squared_error( 2181 predictions=predictions, labels=labels, weights=weights, 2182 metrics_collections=metrics_collections, 2183 updates_collections=updates_collections, name=name) 2184 2185 2186def streaming_root_mean_squared_error(predictions, labels, weights=None, 2187 metrics_collections=None, 2188 updates_collections=None, 2189 name=None): 2190 """Computes the root mean squared error between the labels and predictions. 2191 2192 The `streaming_root_mean_squared_error` function creates two local variables, 2193 `total` and `count` that are used to compute the root mean squared error. 2194 This average is weighted by `weights`, and it is ultimately returned as 2195 `root_mean_squared_error`: an idempotent operation that takes the square root 2196 of the division of `total` by `count`. 2197 2198 For estimation of the metric over a stream of data, the function creates an 2199 `update_op` operation that updates these variables and returns the 2200 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2201 the element-wise square of the difference between `predictions` and `labels`. 2202 Then `update_op` increments `total` with the reduced sum of the product of 2203 `weights` and `squared_error`, and it increments `count` with the reduced sum 2204 of `weights`. 2205 2206 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2207 2208 Args: 2209 predictions: A `Tensor` of arbitrary shape. 2210 labels: A `Tensor` of the same shape as `predictions`. 2211 weights: Optional `Tensor` indicating the frequency with which an example is 2212 sampled. Rank must be 0, or the same rank as `labels`, and must be 2213 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2214 the same as the corresponding `labels` dimension). 2215 metrics_collections: An optional list of collections that 2216 `root_mean_squared_error` should be added to. 2217 updates_collections: An optional list of collections that `update_op` should 2218 be added to. 2219 name: An optional variable_scope name. 2220 2221 Returns: 2222 root_mean_squared_error: A `Tensor` representing the current mean, the value 2223 of `total` divided by `count`. 2224 update_op: An operation that increments the `total` and `count` variables 2225 appropriately and whose value matches `root_mean_squared_error`. 2226 2227 Raises: 2228 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2229 `weights` is not `None` and its shape doesn't match `predictions`, or if 2230 either `metrics_collections` or `updates_collections` are not a list or 2231 tuple. 2232 """ 2233 return metrics.root_mean_squared_error( 2234 predictions=predictions, labels=labels, weights=weights, 2235 metrics_collections=metrics_collections, 2236 updates_collections=updates_collections, name=name) 2237 2238 2239def streaming_covariance(predictions, 2240 labels, 2241 weights=None, 2242 metrics_collections=None, 2243 updates_collections=None, 2244 name=None): 2245 """Computes the unbiased sample covariance between `predictions` and `labels`. 2246 2247 The `streaming_covariance` function creates four local variables, 2248 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 2249 compute the sample covariance between predictions and labels across multiple 2250 batches of data. The covariance is ultimately returned as an idempotent 2251 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 2252 in order to get an unbiased estimate. 2253 2254 The algorithm used for this online computation is described in 2255 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 2256 Specifically, the formula used to combine two sample comoments is 2257 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 2258 The comoment for a single batch of data is simply 2259 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 2260 2261 If `weights` is not None, then it is used to compute weighted comoments, 2262 means, and count. NOTE: these weights are treated as "frequency weights", as 2263 opposed to "reliability weights". See discussion of the difference on 2264 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2265 2266 To facilitate the computation of covariance across multiple batches of data, 2267 the function creates an `update_op` operation, which updates underlying 2268 variables and returns the updated covariance. 2269 2270 Args: 2271 predictions: A `Tensor` of arbitrary size. 2272 labels: A `Tensor` of the same size as `predictions`. 2273 weights: Optional `Tensor` indicating the frequency with which an example is 2274 sampled. Rank must be 0, or the same rank as `labels`, and must be 2275 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2276 the same as the corresponding `labels` dimension). 2277 metrics_collections: An optional list of collections that the metric 2278 value variable should be added to. 2279 updates_collections: An optional list of collections that the metric update 2280 ops should be added to. 2281 name: An optional variable_scope name. 2282 2283 Returns: 2284 covariance: A `Tensor` representing the current unbiased sample covariance, 2285 `comoment` / (`count` - 1). 2286 update_op: An operation that updates the local variables appropriately. 2287 2288 Raises: 2289 ValueError: If labels and predictions are of different sizes or if either 2290 `metrics_collections` or `updates_collections` are not a list or tuple. 2291 """ 2292 with variable_scope.variable_scope( 2293 name, 'covariance', (predictions, labels, weights)): 2294 predictions, labels, weights = _remove_squeezable_dimensions( 2295 predictions, labels, weights) 2296 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2297 count = _create_local('count', []) 2298 mean_prediction = _create_local('mean_prediction', []) 2299 mean_label = _create_local('mean_label', []) 2300 comoment = _create_local('comoment', []) # C_A in update equation 2301 2302 if weights is None: 2303 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 2304 weighted_predictions = predictions 2305 weighted_labels = labels 2306 else: 2307 weights = _broadcast_weights(weights, labels) 2308 batch_count = math_ops.reduce_sum(weights) # n_B in eqn 2309 weighted_predictions = math_ops.multiply(predictions, weights) 2310 weighted_labels = math_ops.multiply(labels, weights) 2311 2312 update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn 2313 prev_count = update_count - batch_count # n_A in update equation 2314 2315 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 2316 # batch_mean_prediction is E[x_B] in the update equation 2317 batch_mean_prediction = _safe_div( 2318 math_ops.reduce_sum(weighted_predictions), batch_count, 2319 'batch_mean_prediction') 2320 delta_mean_prediction = _safe_div( 2321 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 2322 'delta_mean_prediction') 2323 update_mean_prediction = state_ops.assign_add(mean_prediction, 2324 delta_mean_prediction) 2325 # prev_mean_prediction is E[x_A] in the update equation 2326 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 2327 2328 # batch_mean_label is E[y_B] in the update equation 2329 batch_mean_label = _safe_div( 2330 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 2331 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 2332 update_count, 'delta_mean_label') 2333 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 2334 # prev_mean_label is E[y_A] in the update equation 2335 prev_mean_label = update_mean_label - delta_mean_label 2336 2337 unweighted_batch_coresiduals = ( 2338 (predictions - batch_mean_prediction) * (labels - batch_mean_label)) 2339 # batch_comoment is C_B in the update equation 2340 if weights is None: 2341 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 2342 else: 2343 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * 2344 weights) 2345 2346 # View delta_comoment as = C_AB - C_A in the update equation above. 2347 # Since C_A is stored in a var, by how much do we need to increment that var 2348 # to make the var = C_AB? 2349 delta_comoment = (batch_comoment + 2350 (prev_mean_prediction - batch_mean_prediction) * 2351 (prev_mean_label - batch_mean_label) * 2352 (prev_count * batch_count / update_count)) 2353 update_comoment = state_ops.assign_add(comoment, delta_comoment) 2354 2355 covariance = _safe_div(comoment, count - 1, 'covariance') 2356 with ops.control_dependencies([update_comoment]): 2357 update_op = _safe_div(comoment, count - 1, 'update_op') 2358 2359 if metrics_collections: 2360 ops.add_to_collections(metrics_collections, covariance) 2361 2362 if updates_collections: 2363 ops.add_to_collections(updates_collections, update_op) 2364 2365 return covariance, update_op 2366 2367 2368def streaming_pearson_correlation(predictions, 2369 labels, 2370 weights=None, 2371 metrics_collections=None, 2372 updates_collections=None, 2373 name=None): 2374 """Computes Pearson correlation coefficient between `predictions`, `labels`. 2375 2376 The `streaming_pearson_correlation` function delegates to 2377 `streaming_covariance` the tracking of three [co]variances: 2378 2379 - `streaming_covariance(predictions, labels)`, i.e. covariance 2380 - `streaming_covariance(predictions, predictions)`, i.e. variance 2381 - `streaming_covariance(labels, labels)`, i.e. variance 2382 2383 The product-moment correlation ultimately returned is an idempotent operation 2384 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 2385 facilitate correlation computation across multiple batches, the function 2386 groups the `update_op`s of the underlying streaming_covariance and returns an 2387 `update_op`. 2388 2389 If `weights` is not None, then it is used to compute a weighted correlation. 2390 NOTE: these weights are treated as "frequency weights", as opposed to 2391 "reliability weights". See discussion of the difference on 2392 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2393 2394 Args: 2395 predictions: A `Tensor` of arbitrary size. 2396 labels: A `Tensor` of the same size as predictions. 2397 weights: Optional `Tensor` indicating the frequency with which an example is 2398 sampled. Rank must be 0, or the same rank as `labels`, and must be 2399 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2400 the same as the corresponding `labels` dimension). 2401 metrics_collections: An optional list of collections that the metric 2402 value variable should be added to. 2403 updates_collections: An optional list of collections that the metric update 2404 ops should be added to. 2405 name: An optional variable_scope name. 2406 2407 Returns: 2408 pearson_r: A `Tensor` representing the current Pearson product-moment 2409 correlation coefficient, the value of 2410 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 2411 update_op: An operation that updates the underlying variables appropriately. 2412 2413 Raises: 2414 ValueError: If `labels` and `predictions` are of different sizes, or if 2415 `weights` is the wrong size, or if either `metrics_collections` or 2416 `updates_collections` are not a `list` or `tuple`. 2417 """ 2418 with variable_scope.variable_scope( 2419 name, 'pearson_r', (predictions, labels, weights)): 2420 predictions, labels, weights = _remove_squeezable_dimensions( 2421 predictions, labels, weights) 2422 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2423 # Broadcast weights here to avoid duplicate broadcasting in each call to 2424 # `streaming_covariance`. 2425 if weights is not None: 2426 weights = _broadcast_weights(weights, labels) 2427 cov, update_cov = streaming_covariance( 2428 predictions, labels, weights=weights, name='covariance') 2429 var_predictions, update_var_predictions = streaming_covariance( 2430 predictions, predictions, weights=weights, name='variance_predictions') 2431 var_labels, update_var_labels = streaming_covariance( 2432 labels, labels, weights=weights, name='variance_labels') 2433 2434 pearson_r = _safe_div( 2435 cov, 2436 math_ops.multiply(math_ops.sqrt(var_predictions), 2437 math_ops.sqrt(var_labels)), 2438 'pearson_r') 2439 with ops.control_dependencies( 2440 [update_cov, update_var_predictions, update_var_labels]): 2441 update_op = _safe_div(update_cov, math_ops.multiply( 2442 math_ops.sqrt(update_var_predictions), 2443 math_ops.sqrt(update_var_labels)), 'update_op') 2444 2445 if metrics_collections: 2446 ops.add_to_collections(metrics_collections, pearson_r) 2447 2448 if updates_collections: 2449 ops.add_to_collections(updates_collections, update_op) 2450 2451 return pearson_r, update_op 2452 2453 2454# TODO(nsilberman): add a 'normalized' flag so that the user can request 2455# normalization if the inputs are not normalized. 2456def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, 2457 metrics_collections=None, 2458 updates_collections=None, 2459 name=None): 2460 """Computes the cosine distance between the labels and predictions. 2461 2462 The `streaming_mean_cosine_distance` function creates two local variables, 2463 `total` and `count` that are used to compute the average cosine distance 2464 between `predictions` and `labels`. This average is weighted by `weights`, 2465 and it is ultimately returned as `mean_distance`, which is an idempotent 2466 operation that simply divides `total` by `count`. 2467 2468 For estimation of the metric over a stream of data, the function creates an 2469 `update_op` operation that updates these variables and returns the 2470 `mean_distance`. 2471 2472 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2473 2474 Args: 2475 predictions: A `Tensor` of the same shape as `labels`. 2476 labels: A `Tensor` of arbitrary shape. 2477 dim: The dimension along which the cosine distance is computed. 2478 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 2479 and whose dimension `dim` is 1. 2480 metrics_collections: An optional list of collections that the metric 2481 value variable should be added to. 2482 updates_collections: An optional list of collections that the metric update 2483 ops should be added to. 2484 name: An optional variable_scope name. 2485 2486 Returns: 2487 mean_distance: A `Tensor` representing the current mean, the value of 2488 `total` divided by `count`. 2489 update_op: An operation that increments the `total` and `count` variables 2490 appropriately. 2491 2492 Raises: 2493 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2494 `weights` is not `None` and its shape doesn't match `predictions`, or if 2495 either `metrics_collections` or `updates_collections` are not a list or 2496 tuple. 2497 """ 2498 predictions, labels, weights = _remove_squeezable_dimensions( 2499 predictions, labels, weights) 2500 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2501 radial_diffs = math_ops.multiply(predictions, labels) 2502 radial_diffs = math_ops.reduce_sum(radial_diffs, 2503 reduction_indices=[dim,], 2504 keep_dims=True) 2505 mean_distance, update_op = streaming_mean(radial_diffs, weights, 2506 None, 2507 None, 2508 name or 'mean_cosine_distance') 2509 mean_distance = math_ops.subtract(1.0, mean_distance) 2510 update_op = math_ops.subtract(1.0, update_op) 2511 2512 if metrics_collections: 2513 ops.add_to_collections(metrics_collections, mean_distance) 2514 2515 if updates_collections: 2516 ops.add_to_collections(updates_collections, update_op) 2517 2518 return mean_distance, update_op 2519 2520 2521def streaming_percentage_less(values, threshold, weights=None, 2522 metrics_collections=None, 2523 updates_collections=None, 2524 name=None): 2525 """Computes the percentage of values less than the given threshold. 2526 2527 The `streaming_percentage_less` function creates two local variables, 2528 `total` and `count` that are used to compute the percentage of `values` that 2529 fall below `threshold`. This rate is weighted by `weights`, and it is 2530 ultimately returned as `percentage` which is an idempotent operation that 2531 simply divides `total` by `count`. 2532 2533 For estimation of the metric over a stream of data, the function creates an 2534 `update_op` operation that updates these variables and returns the 2535 `percentage`. 2536 2537 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2538 2539 Args: 2540 values: A numeric `Tensor` of arbitrary size. 2541 threshold: A scalar threshold. 2542 weights: An optional `Tensor` whose shape is broadcastable to `values`. 2543 metrics_collections: An optional list of collections that the metric 2544 value variable should be added to. 2545 updates_collections: An optional list of collections that the metric update 2546 ops should be added to. 2547 name: An optional variable_scope name. 2548 2549 Returns: 2550 percentage: A `Tensor` representing the current mean, the value of `total` 2551 divided by `count`. 2552 update_op: An operation that increments the `total` and `count` variables 2553 appropriately. 2554 2555 Raises: 2556 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 2557 or if either `metrics_collections` or `updates_collections` are not a list 2558 or tuple. 2559 """ 2560 return metrics.percentage_below( 2561 values=values, threshold=threshold, weights=weights, 2562 metrics_collections=metrics_collections, 2563 updates_collections=updates_collections, name=name) 2564 2565 2566def streaming_mean_iou(predictions, 2567 labels, 2568 num_classes, 2569 weights=None, 2570 metrics_collections=None, 2571 updates_collections=None, 2572 name=None): 2573 """Calculate per-step mean Intersection-Over-Union (mIOU). 2574 2575 Mean Intersection-Over-Union is a common evaluation metric for 2576 semantic image segmentation, which first computes the IOU for each 2577 semantic class and then computes the average over classes. 2578 IOU is defined as follows: 2579 IOU = true_positive / (true_positive + false_positive + false_negative). 2580 The predictions are accumulated in a confusion matrix, weighted by `weights`, 2581 and mIOU is then calculated from it. 2582 2583 For estimation of the metric over a stream of data, the function creates an 2584 `update_op` operation that updates these variables and returns the `mean_iou`. 2585 2586 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2587 2588 Args: 2589 predictions: A `Tensor` of prediction results for semantic labels, whose 2590 shape is [batch size] and type `int32` or `int64`. The tensor will be 2591 flattened, if its rank > 1. 2592 labels: A `Tensor` of ground truth labels with shape [batch size] and of 2593 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 2594 num_classes: The possible number of labels the prediction task can 2595 have. This value must be provided, since a confusion matrix of 2596 dimension = [num_classes, num_classes] will be allocated. 2597 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2598 metrics_collections: An optional list of collections that `mean_iou` 2599 should be added to. 2600 updates_collections: An optional list of collections `update_op` should be 2601 added to. 2602 name: An optional variable_scope name. 2603 2604 Returns: 2605 mean_iou: A `Tensor` representing the mean intersection-over-union. 2606 update_op: An operation that increments the confusion matrix. 2607 2608 Raises: 2609 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2610 `weights` is not `None` and its shape doesn't match `predictions`, or if 2611 either `metrics_collections` or `updates_collections` are not a list or 2612 tuple. 2613 """ 2614 return metrics.mean_iou( 2615 num_classes=num_classes, predictions=predictions, labels=labels, 2616 weights=weights, metrics_collections=metrics_collections, 2617 updates_collections=updates_collections, name=name) 2618 2619 2620def _next_array_size(required_size, growth_factor=1.5): 2621 """Calculate the next size for reallocating a dynamic array. 2622 2623 Args: 2624 required_size: number or tf.Tensor specifying required array capacity. 2625 growth_factor: optional number or tf.Tensor specifying the growth factor 2626 between subsequent allocations. 2627 2628 Returns: 2629 tf.Tensor with dtype=int32 giving the next array size. 2630 """ 2631 exponent = math_ops.ceil( 2632 math_ops.log(math_ops.cast(required_size, dtypes.float32)) 2633 / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) 2634 return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) 2635 2636 2637def streaming_concat(values, 2638 axis=0, 2639 max_size=None, 2640 metrics_collections=None, 2641 updates_collections=None, 2642 name=None): 2643 """Concatenate values along an axis across batches. 2644 2645 The function `streaming_concat` creates two local variables, `array` and 2646 `size`, that are used to store concatenated values. Internally, `array` is 2647 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 2648 that updates can be run in amortized constant time. 2649 2650 For estimation of the metric over a stream of data, the function creates an 2651 `update_op` operation that appends the values of a tensor and returns the 2652 `value` of the concatenated tensors. 2653 2654 This op allows for evaluating metrics that cannot be updated incrementally 2655 using the same framework as other streaming metrics. 2656 2657 Args: 2658 values: `Tensor` to concatenate. Rank and the shape along all axes other 2659 than the axis to concatenate along must be statically known. 2660 axis: optional integer axis to concatenate along. 2661 max_size: optional integer maximum size of `value` along the given axis. 2662 Once the maximum size is reached, further updates are no-ops. By default, 2663 there is no maximum size: the array is resized as necessary. 2664 metrics_collections: An optional list of collections that `value` 2665 should be added to. 2666 updates_collections: An optional list of collections `update_op` should be 2667 added to. 2668 name: An optional variable_scope name. 2669 2670 Returns: 2671 value: A `Tensor` representing the concatenated values. 2672 update_op: An operation that concatenates the next values. 2673 2674 Raises: 2675 ValueError: if `values` does not have a statically known rank, `axis` is 2676 not in the valid range or the size of `values` is not statically known 2677 along any axis other than `axis`. 2678 """ 2679 with variable_scope.variable_scope(name, 'streaming_concat', (values,)): 2680 # pylint: disable=invalid-slice-index 2681 values_shape = values.get_shape() 2682 if values_shape.dims is None: 2683 raise ValueError('`values` must have known statically known rank') 2684 2685 ndim = len(values_shape) 2686 if axis < 0: 2687 axis += ndim 2688 if not 0 <= axis < ndim: 2689 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 2690 2691 fixed_shape = [dim.value for n, dim in enumerate(values_shape) 2692 if n != axis] 2693 if any(value is None for value in fixed_shape): 2694 raise ValueError('all dimensions of `values` other than the dimension to ' 2695 'concatenate along must have statically known size') 2696 2697 # We move `axis` to the front of the internal array so assign ops can be 2698 # applied to contiguous slices 2699 init_size = 0 if max_size is None else max_size 2700 init_shape = [init_size] + fixed_shape 2701 array = _create_local( 2702 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) 2703 size = _create_local('size', shape=[], dtype=dtypes.int32) 2704 2705 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 2706 valid_array = array[:size] 2707 valid_array.set_shape([None] + fixed_shape) 2708 value = array_ops.transpose(valid_array, perm, name='concat') 2709 2710 values_size = array_ops.shape(values)[axis] 2711 if max_size is None: 2712 batch_size = values_size 2713 else: 2714 batch_size = math_ops.minimum(values_size, max_size - size) 2715 2716 perm = [axis] + [n for n in range(ndim) if n != axis] 2717 batch_values = array_ops.transpose(values, perm)[:batch_size] 2718 2719 def reallocate(): 2720 next_size = _next_array_size(new_size) 2721 next_shape = array_ops.stack([next_size] + fixed_shape) 2722 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 2723 old_value = array.value() 2724 assign_op = state_ops.assign(array, new_value, validate_shape=False) 2725 with ops.control_dependencies([assign_op]): 2726 copy_op = array[:size].assign(old_value[:size]) 2727 # return value needs to be the same dtype as no_op() for cond 2728 with ops.control_dependencies([copy_op]): 2729 return control_flow_ops.no_op() 2730 2731 new_size = size + batch_size 2732 array_size = array_ops.shape_internal(array, optimize=False)[0] 2733 maybe_reallocate_op = control_flow_ops.cond( 2734 new_size > array_size, reallocate, control_flow_ops.no_op) 2735 with ops.control_dependencies([maybe_reallocate_op]): 2736 append_values_op = array[size:new_size].assign(batch_values) 2737 with ops.control_dependencies([append_values_op]): 2738 update_op = size.assign(new_size) 2739 2740 if metrics_collections: 2741 ops.add_to_collections(metrics_collections, value) 2742 2743 if updates_collections: 2744 ops.add_to_collections(updates_collections, update_op) 2745 2746 return value, update_op 2747 # pylint: enable=invalid-slice-index 2748 2749 2750def aggregate_metrics(*value_update_tuples): 2751 """Aggregates the metric value tensors and update ops into two lists. 2752 2753 Args: 2754 *value_update_tuples: a variable number of tuples, each of which contain the 2755 pair of (value_tensor, update_op) from a streaming metric. 2756 2757 Returns: 2758 A list of value `Tensor` objects and a list of update ops. 2759 2760 Raises: 2761 ValueError: if `value_update_tuples` is empty. 2762 """ 2763 if not value_update_tuples: 2764 raise ValueError('Expected at least one value_tensor/update_op pair') 2765 value_ops, update_ops = zip(*value_update_tuples) 2766 return list(value_ops), list(update_ops) 2767 2768 2769def aggregate_metric_map(names_to_tuples): 2770 """Aggregates the metric names to tuple dictionary. 2771 2772 This function is useful for pairing metric names with their associated value 2773 and update ops when the list of metrics is long. For example: 2774 2775 ```python 2776 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 2777 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 2778 predictions, labels, weights), 2779 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 2780 predictions, labels, labels, weights), 2781 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 2782 predictions, labels, weights), 2783 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 2784 predictions, labels, weights), 2785 }) 2786 ``` 2787 2788 Args: 2789 names_to_tuples: a map of metric names to tuples, each of which contain the 2790 pair of (value_tensor, update_op) from a streaming metric. 2791 2792 Returns: 2793 A dictionary from metric names to value ops and a dictionary from metric 2794 names to update ops. 2795 """ 2796 metric_names = names_to_tuples.keys() 2797 value_ops, update_ops = zip(*names_to_tuples.values()) 2798 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 2799 2800 2801def _remove_squeezable_dimensions(predictions, labels, weights): 2802 """Squeeze last dim if needed. 2803 2804 Squeezes `predictions` and `labels` if their rank differs by 1. 2805 Squeezes `weights` if its rank is 1 more than the new rank of `predictions` 2806 2807 This will use static shape if available. Otherwise, it will add graph 2808 operations, which could result in a performance hit. 2809 2810 Args: 2811 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 2812 labels: Label values, a `Tensor` whose dimensions match `predictions`. 2813 weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 2814 more than the new rank of `predictions` 2815 2816 Returns: 2817 Tuple of `predictions`, `labels` and `weights`, possibly with the last 2818 dimension squeezed. 2819 """ 2820 predictions, labels = tensor_util.remove_squeezable_dimensions( 2821 predictions, labels) 2822 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2823 2824 if weights is not None: 2825 weights = ops.convert_to_tensor(weights) 2826 predictions_shape = predictions.get_shape() 2827 predictions_rank = predictions_shape.ndims 2828 weights_shape = weights.get_shape() 2829 weights_rank = weights_shape.ndims 2830 2831 if (predictions_rank is not None) and (weights_rank is not None): 2832 # Use static rank. 2833 if weights_rank - predictions_rank == 1: 2834 weights = array_ops.squeeze(weights, [-1]) 2835 elif (weights_rank is None) or ( 2836 weights_shape.dims[-1].is_compatible_with(1)): 2837 # Use dynamic rank 2838 weights = control_flow_ops.cond( 2839 math_ops.equal(array_ops.rank(weights), 2840 math_ops.add(array_ops.rank(predictions), 1)), 2841 lambda: array_ops.squeeze(weights, [-1]), 2842 lambda: weights) 2843 return predictions, labels, weights 2844 2845 2846__all__ = [ 2847 'aggregate_metric_map', 2848 'aggregate_metrics', 2849 'streaming_accuracy', 2850 'streaming_auc', 2851 'streaming_false_negatives', 2852 'streaming_false_negatives_at_thresholds', 2853 'streaming_false_positives', 2854 'streaming_false_positives_at_thresholds', 2855 'streaming_mean', 2856 'streaming_mean_absolute_error', 2857 'streaming_mean_cosine_distance', 2858 'streaming_mean_iou', 2859 'streaming_mean_relative_error', 2860 'streaming_mean_squared_error', 2861 'streaming_mean_tensor', 2862 'streaming_percentage_less', 2863 'streaming_precision', 2864 'streaming_precision_at_thresholds', 2865 'streaming_recall', 2866 'streaming_recall_at_k', 2867 'streaming_recall_at_thresholds', 2868 'streaming_root_mean_squared_error', 2869 'streaming_sensitivity_at_specificity', 2870 'streaming_sparse_average_precision_at_k', 2871 'streaming_sparse_precision_at_k', 2872 'streaming_sparse_recall_at_k', 2873 'streaming_specificity_at_sensitivity', 2874 'streaming_true_negatives', 2875 'streaming_true_negatives_at_thresholds', 2876 'streaming_true_positives', 2877 'streaming_true_positives_at_thresholds', 2878] 2879