metric_ops.py revision 0845c9c976974ced632cb85e52808b22c1e7c669
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.framework import deprecated 26from tensorflow.contrib.framework.python.framework import tensor_util 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import check_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import metrics 34from tensorflow.python.ops import metrics_impl 35from tensorflow.python.ops import nn 36from tensorflow.python.ops import state_ops 37from tensorflow.python.ops import variable_scope 38 39 40def _safe_div(numerator, denominator, name): 41 """Divides two values, returning 0 if the denominator is <= 0. 42 43 Args: 44 numerator: A real `Tensor`. 45 denominator: A real `Tensor`, with dtype matching `numerator`. 46 name: Name for the returned op. 47 48 Returns: 49 0 if `denominator` <= 0, else `numerator` / `denominator` 50 """ 51 return array_ops.where( 52 math_ops.greater(denominator, 0), 53 math_ops.truediv(numerator, denominator), 54 0, 55 name=name) 56 57 58def _create_local(name, shape, collections=None, validate_shape=True, 59 dtype=dtypes.float32): 60 """Creates a new local variable. 61 62 Args: 63 name: The name of the new or existing variable. 64 shape: Shape of the new or existing variable. 65 collections: A list of collection names to which the Variable will be added. 66 validate_shape: Whether to validate the shape of the variable. 67 dtype: Data type of the variables. 68 69 Returns: 70 The created variable. 71 """ 72 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 73 collections = list(collections or []) 74 collections += [ops.GraphKeys.LOCAL_VARIABLES] 75 return variable_scope.variable( 76 initial_value=array_ops.zeros(shape, dtype=dtype), 77 name=name, 78 trainable=False, 79 collections=collections, 80 validate_shape=validate_shape) 81 82 83# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. 84def _assert_weights_rank(weights, values): 85 """`weights` rank must be either `0`, or the same as 'values'.""" 86 return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) 87 88 89def _count_condition(values, weights=None, metrics_collections=None, 90 updates_collections=None): 91 """Sums the weights of cases where the given values are True. 92 93 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 94 95 Args: 96 values: A `bool` `Tensor` of arbitrary size. 97 weights: Optional `Tensor` whose rank is either 0, or the same rank as 98 `values`, and must be broadcastable to `values` (i.e., all dimensions 99 must be either `1`, or the same as the corresponding `values` 100 dimension). 101 metrics_collections: An optional list of collections that the metric 102 value variable should be added to. 103 updates_collections: An optional list of collections that the metric update 104 ops should be added to. 105 106 Returns: 107 value_tensor: A `Tensor` representing the current value of the metric. 108 update_op: An operation that accumulates the error from a batch of data. 109 110 Raises: 111 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 112 or if either `metrics_collections` or `updates_collections` are not a list 113 or tuple. 114 """ 115 check_ops.assert_type(values, dtypes.bool) 116 count = _create_local('count', shape=[]) 117 118 values = math_ops.to_float(values) 119 if weights is not None: 120 weights = math_ops.to_float(weights) 121 with ops.control_dependencies((_assert_weights_rank(weights, values),)): 122 values = math_ops.multiply(values, weights) 123 124 value_tensor = array_ops.identity(count) 125 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 126 127 if metrics_collections: 128 ops.add_to_collections(metrics_collections, value_tensor) 129 130 if updates_collections: 131 ops.add_to_collections(updates_collections, update_op) 132 133 return value_tensor, update_op 134 135 136def streaming_true_positives(predictions, labels, weights=None, 137 metrics_collections=None, 138 updates_collections=None, 139 name=None): 140 """Sum the weights of true_positives. 141 142 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 143 144 Args: 145 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 146 be cast to `bool`. 147 labels: The ground truth values, a `Tensor` whose dimensions must match 148 `predictions`. Will be cast to `bool`. 149 weights: Optional `Tensor` whose rank is either 0, or the same rank as 150 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 151 must be either `1`, or the same as the corresponding `labels` 152 dimension). 153 metrics_collections: An optional list of collections that the metric 154 value variable should be added to. 155 updates_collections: An optional list of collections that the metric update 156 ops should be added to. 157 name: An optional variable_scope name. 158 159 Returns: 160 value_tensor: A `Tensor` representing the current value of the metric. 161 update_op: An operation that accumulates the error from a batch of data. 162 163 Raises: 164 ValueError: If `predictions` and `labels` have mismatched shapes, or if 165 `weights` is not `None` and its shape doesn't match `predictions`, or if 166 either `metrics_collections` or `updates_collections` are not a list or 167 tuple. 168 """ 169 return metrics.true_positives( 170 predictions=predictions, labels=labels, weights=weights, 171 metrics_collections=metrics_collections, 172 updates_collections=updates_collections, name=name) 173 174 175def streaming_true_negatives(predictions, labels, weights=None, 176 metrics_collections=None, 177 updates_collections=None, 178 name=None): 179 """Sum the weights of true_negatives. 180 181 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 182 183 Args: 184 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 185 be cast to `bool`. 186 labels: The ground truth values, a `Tensor` whose dimensions must match 187 `predictions`. Will be cast to `bool`. 188 weights: Optional `Tensor` whose rank is either 0, or the same rank as 189 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 190 must be either `1`, or the same as the corresponding `labels` 191 dimension). 192 metrics_collections: An optional list of collections that the metric 193 value variable should be added to. 194 updates_collections: An optional list of collections that the metric update 195 ops should be added to. 196 name: An optional variable_scope name. 197 198 Returns: 199 value_tensor: A `Tensor` representing the current value of the metric. 200 update_op: An operation that accumulates the error from a batch of data. 201 202 Raises: 203 ValueError: If `predictions` and `labels` have mismatched shapes, or if 204 `weights` is not `None` and its shape doesn't match `predictions`, or if 205 either `metrics_collections` or `updates_collections` are not a list or 206 tuple. 207 """ 208 with variable_scope.variable_scope( 209 name, 'true_negatives', (predictions, labels, weights)): 210 211 predictions = math_ops.cast(predictions, dtype=dtypes.bool) 212 labels = math_ops.cast(labels, dtype=dtypes.bool) 213 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 214 is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), 215 math_ops.equal(predictions, False)) 216 return _count_condition(is_true_negative, weights, metrics_collections, 217 updates_collections) 218 219 220def streaming_false_positives(predictions, labels, weights=None, 221 metrics_collections=None, 222 updates_collections=None, 223 name=None): 224 """Sum the weights of false positives. 225 226 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 227 228 Args: 229 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 230 be cast to `bool`. 231 labels: The ground truth values, a `Tensor` whose dimensions must match 232 `predictions`. Will be cast to `bool`. 233 weights: Optional `Tensor` whose rank is either 0, or the same rank as 234 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 235 must be either `1`, or the same as the corresponding `labels` 236 dimension). 237 metrics_collections: An optional list of collections that the metric 238 value variable should be added to. 239 updates_collections: An optional list of collections that the metric update 240 ops should be added to. 241 name: An optional variable_scope name. 242 243 Returns: 244 value_tensor: A `Tensor` representing the current value of the metric. 245 update_op: An operation that accumulates the error from a batch of data. 246 247 Raises: 248 ValueError: If `predictions` and `labels` have mismatched shapes, or if 249 `weights` is not `None` and its shape doesn't match `predictions`, or if 250 either `metrics_collections` or `updates_collections` are not a list or 251 tuple. 252 """ 253 return metrics.false_positives( 254 predictions=predictions, labels=labels, weights=weights, 255 metrics_collections=metrics_collections, 256 updates_collections=updates_collections, name=name) 257 258 259def streaming_false_negatives(predictions, labels, weights=None, 260 metrics_collections=None, 261 updates_collections=None, 262 name=None): 263 """Computes the total number of false negatives. 264 265 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 266 267 Args: 268 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 269 be cast to `bool`. 270 labels: The ground truth values, a `Tensor` whose dimensions must match 271 `predictions`. Will be cast to `bool`. 272 weights: Optional `Tensor` whose rank is either 0, or the same rank as 273 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 274 must be either `1`, or the same as the corresponding `labels` 275 dimension). 276 metrics_collections: An optional list of collections that the metric 277 value variable should be added to. 278 updates_collections: An optional list of collections that the metric update 279 ops should be added to. 280 name: An optional variable_scope name. 281 282 Returns: 283 value_tensor: A `Tensor` representing the current value of the metric. 284 update_op: An operation that accumulates the error from a batch of data. 285 286 Raises: 287 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 288 or if either `metrics_collections` or `updates_collections` are not a list 289 or tuple. 290 """ 291 return metrics.false_negatives( 292 predictions=predictions, labels=labels, weights=weights, 293 metrics_collections=metrics_collections, 294 updates_collections=updates_collections, name=name) 295 296 297# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. 298def _broadcast_weights(weights, values): 299 """Broadcast `weights` to the same shape as `values`. 300 301 This returns a version of `weights` following the same broadcast rules as 302 `mul(weights, values)`. When computing a weighted average, use this function 303 to broadcast `weights` before summing them; e.g., 304 `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 305 306 Args: 307 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 308 must be broadcastable to `values` (i.e., all dimensions must be either 309 `1`, or the same as the corresponding `values` dimension). 310 values: `Tensor` of any shape. 311 312 Returns: 313 `weights` broadcast to `values` shape. 314 """ 315 with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: 316 weights_shape = weights.get_shape() 317 values_shape = values.get_shape() 318 if (weights_shape.is_fully_defined() and 319 values_shape.is_fully_defined() and 320 weights_shape.is_compatible_with(values_shape)): 321 return weights 322 with ops.control_dependencies((_assert_weights_rank(weights, values),)): 323 return math_ops.multiply( 324 weights, array_ops.ones_like(values), name=scope) 325 326 327def streaming_mean(values, weights=None, metrics_collections=None, 328 updates_collections=None, name=None): 329 """Computes the (weighted) mean of the given values. 330 331 The `streaming_mean` function creates two local variables, `total` and `count` 332 that are used to compute the average of `values`. This average is ultimately 333 returned as `mean` which is an idempotent operation that simply divides 334 `total` by `count`. 335 336 For estimation of the metric over a stream of data, the function creates an 337 `update_op` operation that updates these variables and returns the `mean`. 338 `update_op` increments `total` with the reduced sum of the product of `values` 339 and `weights`, and it increments `count` with the reduced sum of `weights`. 340 341 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 342 343 Args: 344 values: A `Tensor` of arbitrary dimensions. 345 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 346 must be broadcastable to `values` (i.e., all dimensions must be either 347 `1`, or the same as the corresponding `values` dimension). 348 metrics_collections: An optional list of collections that `mean` 349 should be added to. 350 updates_collections: An optional list of collections that `update_op` 351 should be added to. 352 name: An optional variable_scope name. 353 354 Returns: 355 mean: A `Tensor` representing the current mean, the value of `total` divided 356 by `count`. 357 update_op: An operation that increments the `total` and `count` variables 358 appropriately and whose value matches `mean`. 359 360 Raises: 361 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 362 or if either `metrics_collections` or `updates_collections` are not a list 363 or tuple. 364 """ 365 return metrics.mean( 366 values=values, weights=weights, metrics_collections=metrics_collections, 367 updates_collections=updates_collections, name=name) 368 369 370def streaming_mean_tensor(values, weights=None, metrics_collections=None, 371 updates_collections=None, name=None): 372 """Computes the element-wise (weighted) mean of the given tensors. 373 374 In contrast to the `streaming_mean` function which returns a scalar with the 375 mean, this function returns an average tensor with the same shape as the 376 input tensors. 377 378 The `streaming_mean_tensor` function creates two local variables, 379 `total_tensor` and `count_tensor` that are used to compute the average of 380 `values`. This average is ultimately returned as `mean` which is an idempotent 381 operation that simply divides `total` by `count`. 382 383 For estimation of the metric over a stream of data, the function creates an 384 `update_op` operation that updates these variables and returns the `mean`. 385 `update_op` increments `total` with the reduced sum of the product of `values` 386 and `weights`, and it increments `count` with the reduced sum of `weights`. 387 388 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 389 390 Args: 391 values: A `Tensor` of arbitrary dimensions. 392 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 393 must be broadcastable to `values` (i.e., all dimensions must be either 394 `1`, or the same as the corresponding `values` dimension). 395 metrics_collections: An optional list of collections that `mean` 396 should be added to. 397 updates_collections: An optional list of collections that `update_op` 398 should be added to. 399 name: An optional variable_scope name. 400 401 Returns: 402 mean: A float `Tensor` representing the current mean, the value of `total` 403 divided by `count`. 404 update_op: An operation that increments the `total` and `count` variables 405 appropriately and whose value matches `mean`. 406 407 Raises: 408 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 409 or if either `metrics_collections` or `updates_collections` are not a list 410 or tuple. 411 """ 412 return metrics.mean_tensor( 413 values=values, weights=weights, metrics_collections=metrics_collections, 414 updates_collections=updates_collections, name=name) 415 416 417def streaming_accuracy(predictions, labels, weights=None, 418 metrics_collections=None, updates_collections=None, 419 name=None): 420 """Calculates how often `predictions` matches `labels`. 421 422 The `streaming_accuracy` function creates two local variables, `total` and 423 `count` that are used to compute the frequency with which `predictions` 424 matches `labels`. This frequency is ultimately returned as `accuracy`: an 425 idempotent operation that simply divides `total` by `count`. 426 427 For estimation of the metric over a stream of data, the function creates an 428 `update_op` operation that updates these variables and returns the `accuracy`. 429 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 430 where the corresponding elements of `predictions` and `labels` match and 0.0 431 otherwise. Then `update_op` increments `total` with the reduced sum of the 432 product of `weights` and `is_correct`, and it increments `count` with the 433 reduced sum of `weights`. 434 435 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 436 437 Args: 438 predictions: The predicted values, a `Tensor` of any shape. 439 labels: The ground truth values, a `Tensor` whose shape matches 440 `predictions`. 441 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 442 must be broadcastable to `labels` (i.e., all dimensions must be either 443 `1`, or the same as the corresponding `labels` dimension). 444 metrics_collections: An optional list of collections that `accuracy` should 445 be added to. 446 updates_collections: An optional list of collections that `update_op` should 447 be added to. 448 name: An optional variable_scope name. 449 450 Returns: 451 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 452 by `count`. 453 update_op: An operation that increments the `total` and `count` variables 454 appropriately and whose value matches `accuracy`. 455 456 Raises: 457 ValueError: If `predictions` and `labels` have mismatched shapes, or if 458 `weights` is not `None` and its shape doesn't match `predictions`, or if 459 either `metrics_collections` or `updates_collections` are not a list or 460 tuple. 461 """ 462 return metrics.accuracy( 463 predictions=predictions, labels=labels, weights=weights, 464 metrics_collections=metrics_collections, 465 updates_collections=updates_collections, name=name) 466 467 468def streaming_precision(predictions, labels, weights=None, 469 metrics_collections=None, updates_collections=None, 470 name=None): 471 """Computes the precision of the predictions with respect to the labels. 472 473 The `streaming_precision` function creates two local variables, 474 `true_positives` and `false_positives`, that are used to compute the 475 precision. This value is ultimately returned as `precision`, an idempotent 476 operation that simply divides `true_positives` by the sum of `true_positives` 477 and `false_positives`. 478 479 For estimation of the metric over a stream of data, the function creates an 480 `update_op` operation that updates these variables and returns the 481 `precision`. `update_op` weights each prediction by the corresponding value in 482 `weights`. 483 484 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 485 486 Args: 487 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 488 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 489 match `predictions`. 490 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 491 must be broadcastable to `labels` (i.e., all dimensions must be either 492 `1`, or the same as the corresponding `labels` dimension). 493 metrics_collections: An optional list of collections that `precision` should 494 be added to. 495 updates_collections: An optional list of collections that `update_op` should 496 be added to. 497 name: An optional variable_scope name. 498 499 Returns: 500 precision: Scalar float `Tensor` with the value of `true_positives` 501 divided by the sum of `true_positives` and `false_positives`. 502 update_op: `Operation` that increments `true_positives` and 503 `false_positives` variables appropriately and whose value matches 504 `precision`. 505 506 Raises: 507 ValueError: If `predictions` and `labels` have mismatched shapes, or if 508 `weights` is not `None` and its shape doesn't match `predictions`, or if 509 either `metrics_collections` or `updates_collections` are not a list or 510 tuple. 511 """ 512 return metrics.precision( 513 predictions=predictions, labels=labels, weights=weights, 514 metrics_collections=metrics_collections, 515 updates_collections=updates_collections, name=name) 516 517 518def streaming_recall(predictions, labels, weights=None, 519 metrics_collections=None, updates_collections=None, 520 name=None): 521 """Computes the recall of the predictions with respect to the labels. 522 523 The `streaming_recall` function creates two local variables, `true_positives` 524 and `false_negatives`, that are used to compute the recall. This value is 525 ultimately returned as `recall`, an idempotent operation that simply divides 526 `true_positives` by the sum of `true_positives` and `false_negatives`. 527 528 For estimation of the metric over a stream of data, the function creates an 529 `update_op` that updates these variables and returns the `recall`. `update_op` 530 weights each prediction by the corresponding value in `weights`. 531 532 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 533 534 Args: 535 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 536 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 537 match `predictions`. 538 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 539 must be broadcastable to `labels` (i.e., all dimensions must be either 540 `1`, or the same as the corresponding `labels` dimension). 541 metrics_collections: An optional list of collections that `recall` should 542 be added to. 543 updates_collections: An optional list of collections that `update_op` should 544 be added to. 545 name: An optional variable_scope name. 546 547 Returns: 548 recall: Scalar float `Tensor` with the value of `true_positives` divided 549 by the sum of `true_positives` and `false_negatives`. 550 update_op: `Operation` that increments `true_positives` and 551 `false_negatives` variables appropriately and whose value matches 552 `recall`. 553 554 Raises: 555 ValueError: If `predictions` and `labels` have mismatched shapes, or if 556 `weights` is not `None` and its shape doesn't match `predictions`, or if 557 either `metrics_collections` or `updates_collections` are not a list or 558 tuple. 559 """ 560 return metrics.recall( 561 predictions=predictions, labels=labels, weights=weights, 562 metrics_collections=metrics_collections, 563 updates_collections=updates_collections, name=name) 564 565 566def _streaming_confusion_matrix_at_thresholds( 567 predictions, labels, thresholds, weights=None, includes=None): 568 """Computes true_positives, false_negatives, true_negatives, false_positives. 569 570 This function creates up to four local variables, `true_positives`, 571 `true_negatives`, `false_positives` and `false_negatives`. 572 `true_positive[i]` is defined as the total weight of values in `predictions` 573 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 574 `false_negatives[i]` is defined as the total weight of values in `predictions` 575 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 576 `true_negatives[i]` is defined as the total weight of values in `predictions` 577 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 578 `false_positives[i]` is defined as the total weight of values in `predictions` 579 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 580 581 For estimation of these metrics over a stream of data, for each metric the 582 function respectively creates an `update_op` operation that updates the 583 variable and returns its value. 584 585 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 586 587 Args: 588 predictions: A floating point `Tensor` of arbitrary shape and whose values 589 are in the range `[0, 1]`. 590 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 591 to `bool`. 592 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 593 weights: Optional `Tensor` whose rank is either 0, or the same rank as 594 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 595 must be either `1`, or the same as the corresponding `labels` 596 dimension). 597 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 598 default to all four. 599 600 Returns: 601 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 602 `includes`. 603 update_ops: Dict of operations that increments the `values`. Keys are from 604 `includes`. 605 606 Raises: 607 ValueError: If `predictions` and `labels` have mismatched shapes, or if 608 `weights` is not `None` and its shape doesn't match `predictions`, or if 609 `includes` contains invalid keys. 610 """ 611 all_includes = ('tp', 'fn', 'tn', 'fp') 612 if includes is None: 613 includes = all_includes 614 else: 615 for include in includes: 616 if include not in all_includes: 617 raise ValueError('Invaild key: %s.' % include) 618 619 predictions, labels, weights = _remove_squeezable_dimensions( 620 predictions, labels, weights) 621 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 622 623 num_thresholds = len(thresholds) 624 625 # Reshape predictions and labels. 626 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 627 labels_2d = array_ops.reshape( 628 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 629 630 # Use static shape if known. 631 num_predictions = predictions_2d.get_shape().as_list()[0] 632 633 # Otherwise use dynamic shape. 634 if num_predictions is None: 635 num_predictions = array_ops.shape(predictions_2d)[0] 636 thresh_tiled = array_ops.tile( 637 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 638 array_ops.stack([1, num_predictions])) 639 640 # Tile the predictions after thresholding them across different thresholds. 641 pred_is_pos = math_ops.greater( 642 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 643 thresh_tiled) 644 if ('fn' in includes) or ('tn' in includes): 645 pred_is_neg = math_ops.logical_not(pred_is_pos) 646 647 # Tile labels by number of thresholds 648 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 649 if ('fp' in includes) or ('tn' in includes): 650 label_is_neg = math_ops.logical_not(label_is_pos) 651 652 if weights is not None: 653 broadcast_weights = _broadcast_weights( 654 math_ops.to_float(weights), predictions) 655 weights_tiled = array_ops.tile(array_ops.reshape( 656 broadcast_weights, [1, -1]), [num_thresholds, 1]) 657 thresh_tiled.get_shape().assert_is_compatible_with( 658 weights_tiled.get_shape()) 659 else: 660 weights_tiled = None 661 662 values = {} 663 update_ops = {} 664 665 if 'tp' in includes: 666 true_positives = _create_local('true_positives', shape=[num_thresholds]) 667 is_true_positive = math_ops.to_float( 668 math_ops.logical_and(label_is_pos, pred_is_pos)) 669 if weights_tiled is not None: 670 is_true_positive *= weights_tiled 671 update_ops['tp'] = state_ops.assign_add( 672 true_positives, math_ops.reduce_sum(is_true_positive, 1)) 673 values['tp'] = true_positives 674 675 if 'fn' in includes: 676 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 677 is_false_negative = math_ops.to_float( 678 math_ops.logical_and(label_is_pos, pred_is_neg)) 679 if weights_tiled is not None: 680 is_false_negative *= weights_tiled 681 update_ops['fn'] = state_ops.assign_add( 682 false_negatives, math_ops.reduce_sum(is_false_negative, 1)) 683 values['fn'] = false_negatives 684 685 if 'tn' in includes: 686 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 687 is_true_negative = math_ops.to_float( 688 math_ops.logical_and(label_is_neg, pred_is_neg)) 689 if weights_tiled is not None: 690 is_true_negative *= weights_tiled 691 update_ops['tn'] = state_ops.assign_add( 692 true_negatives, math_ops.reduce_sum(is_true_negative, 1)) 693 values['tn'] = true_negatives 694 695 if 'fp' in includes: 696 false_positives = _create_local('false_positives', shape=[num_thresholds]) 697 is_false_positive = math_ops.to_float( 698 math_ops.logical_and(label_is_neg, pred_is_pos)) 699 if weights_tiled is not None: 700 is_false_positive *= weights_tiled 701 update_ops['fp'] = state_ops.assign_add( 702 false_positives, math_ops.reduce_sum(is_false_positive, 1)) 703 values['fp'] = false_positives 704 705 return values, update_ops 706 707 708def streaming_true_positives_at_thresholds( 709 predictions, labels, thresholds, weights=None): 710 values, update_ops = _streaming_confusion_matrix_at_thresholds( 711 predictions, labels, thresholds, weights=weights, includes=('tp',)) 712 return values['tp'], update_ops['tp'] 713 714 715def streaming_false_negatives_at_thresholds( 716 predictions, labels, thresholds, weights=None): 717 values, update_ops = _streaming_confusion_matrix_at_thresholds( 718 predictions, labels, thresholds, weights=weights, includes=('fn',)) 719 return values['fn'], update_ops['fn'] 720 721 722def streaming_false_positives_at_thresholds( 723 predictions, labels, thresholds, weights=None): 724 values, update_ops = _streaming_confusion_matrix_at_thresholds( 725 predictions, labels, thresholds, weights=weights, includes=('fp',)) 726 return values['fp'], update_ops['fp'] 727 728 729def streaming_true_negatives_at_thresholds( 730 predictions, labels, thresholds, weights=None): 731 values, update_ops = _streaming_confusion_matrix_at_thresholds( 732 predictions, labels, thresholds, weights=weights, includes=('tn',)) 733 return values['tn'], update_ops['tn'] 734 735 736def streaming_auc(predictions, labels, weights=None, num_thresholds=200, 737 metrics_collections=None, updates_collections=None, 738 curve='ROC', name=None): 739 """Computes the approximate AUC via a Riemann sum. 740 741 The `streaming_auc` function creates four local variables, `true_positives`, 742 `true_negatives`, `false_positives` and `false_negatives` that are used to 743 compute the AUC. To discretize the AUC curve, a linearly spaced set of 744 thresholds is used to compute pairs of recall and precision values. The area 745 under the ROC-curve is therefore computed using the height of the recall 746 values by the false positive rate, while the area under the PR-curve is the 747 computed using the height of the precision values by the recall. 748 749 This value is ultimately returned as `auc`, an idempotent operation that 750 computes the area under a discretized curve of precision versus recall values 751 (computed using the aforementioned variables). The `num_thresholds` variable 752 controls the degree of discretization with larger numbers of thresholds more 753 closely approximating the true AUC. The quality of the approximation may vary 754 dramatically depending on `num_thresholds`. 755 756 For best results, `predictions` should be distributed approximately uniformly 757 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 758 approximation may be poor if this is not the case. 759 760 For estimation of the metric over a stream of data, the function creates an 761 `update_op` operation that updates these variables and returns the `auc`. 762 763 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 764 765 Args: 766 predictions: A floating point `Tensor` of arbitrary shape and whose values 767 are in the range `[0, 1]`. 768 labels: A `bool` `Tensor` whose shape matches `predictions`. 769 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 770 must be broadcastable to `labels` (i.e., all dimensions must be either 771 `1`, or the same as the corresponding `labels` dimension). 772 num_thresholds: The number of thresholds to use when discretizing the roc 773 curve. 774 metrics_collections: An optional list of collections that `auc` should be 775 added to. 776 updates_collections: An optional list of collections that `update_op` should 777 be added to. 778 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 779 'PR' for the Precision-Recall-curve. 780 name: An optional variable_scope name. 781 782 Returns: 783 auc: A scalar `Tensor` representing the current area-under-curve. 784 update_op: An operation that increments the `true_positives`, 785 `true_negatives`, `false_positives` and `false_negatives` variables 786 appropriately and whose value matches `auc`. 787 788 Raises: 789 ValueError: If `predictions` and `labels` have mismatched shapes, or if 790 `weights` is not `None` and its shape doesn't match `predictions`, or if 791 either `metrics_collections` or `updates_collections` are not a list or 792 tuple. 793 """ 794 return metrics.auc( 795 predictions=predictions, labels=labels, weights=weights, 796 metrics_collections=metrics_collections, num_thresholds=num_thresholds, 797 curve=curve, updates_collections=updates_collections, name=name) 798 799 800def streaming_specificity_at_sensitivity( 801 predictions, labels, sensitivity, weights=None, num_thresholds=200, 802 metrics_collections=None, updates_collections=None, name=None): 803 """Computes the specificity at a given sensitivity. 804 805 The `streaming_specificity_at_sensitivity` function creates four local 806 variables, `true_positives`, `true_negatives`, `false_positives` and 807 `false_negatives` that are used to compute the specificity at the given 808 sensitivity value. The threshold for the given sensitivity value is computed 809 and used to evaluate the corresponding specificity. 810 811 For estimation of the metric over a stream of data, the function creates an 812 `update_op` operation that updates these variables and returns the 813 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 814 `false_positives` and `false_negatives` counts with the weight of each case 815 found in the `predictions` and `labels`. 816 817 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 818 819 For additional information about specificity and sensitivity, see the 820 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 821 822 Args: 823 predictions: A floating point `Tensor` of arbitrary shape and whose values 824 are in the range `[0, 1]`. 825 labels: A `bool` `Tensor` whose shape matches `predictions`. 826 sensitivity: A scalar value in range `[0, 1]`. 827 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 828 must be broadcastable to `labels` (i.e., all dimensions must be either 829 `1`, or the same as the corresponding `labels` dimension). 830 num_thresholds: The number of thresholds to use for matching the given 831 sensitivity. 832 metrics_collections: An optional list of collections that `specificity` 833 should be added to. 834 updates_collections: An optional list of collections that `update_op` should 835 be added to. 836 name: An optional variable_scope name. 837 838 Returns: 839 specificity: A scalar `Tensor` representing the specificity at the given 840 `specificity` value. 841 update_op: An operation that increments the `true_positives`, 842 `true_negatives`, `false_positives` and `false_negatives` variables 843 appropriately and whose value matches `specificity`. 844 845 Raises: 846 ValueError: If `predictions` and `labels` have mismatched shapes, if 847 `weights` is not `None` and its shape doesn't match `predictions`, or if 848 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 849 or `updates_collections` are not a list or tuple. 850 """ 851 return metrics.specificity_at_sensitivity( 852 sensitivity=sensitivity, num_thresholds=num_thresholds, 853 predictions=predictions, labels=labels, weights=weights, 854 metrics_collections=metrics_collections, 855 updates_collections=updates_collections, name=name) 856 857 858def streaming_sensitivity_at_specificity( 859 predictions, labels, specificity, weights=None, num_thresholds=200, 860 metrics_collections=None, updates_collections=None, name=None): 861 """Computes the specificity at a given sensitivity. 862 863 The `streaming_sensitivity_at_specificity` function creates four local 864 variables, `true_positives`, `true_negatives`, `false_positives` and 865 `false_negatives` that are used to compute the sensitivity at the given 866 specificity value. The threshold for the given specificity value is computed 867 and used to evaluate the corresponding sensitivity. 868 869 For estimation of the metric over a stream of data, the function creates an 870 `update_op` operation that updates these variables and returns the 871 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 872 `false_positives` and `false_negatives` counts with the weight of each case 873 found in the `predictions` and `labels`. 874 875 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 876 877 For additional information about specificity and sensitivity, see the 878 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 879 880 Args: 881 predictions: A floating point `Tensor` of arbitrary shape and whose values 882 are in the range `[0, 1]`. 883 labels: A `bool` `Tensor` whose shape matches `predictions`. 884 specificity: A scalar value in range `[0, 1]`. 885 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 886 must be broadcastable to `labels` (i.e., all dimensions must be either 887 `1`, or the same as the corresponding `labels` dimension). 888 num_thresholds: The number of thresholds to use for matching the given 889 specificity. 890 metrics_collections: An optional list of collections that `sensitivity` 891 should be added to. 892 updates_collections: An optional list of collections that `update_op` should 893 be added to. 894 name: An optional variable_scope name. 895 896 Returns: 897 sensitivity: A scalar `Tensor` representing the sensitivity at the given 898 `specificity` value. 899 update_op: An operation that increments the `true_positives`, 900 `true_negatives`, `false_positives` and `false_negatives` variables 901 appropriately and whose value matches `sensitivity`. 902 903 Raises: 904 ValueError: If `predictions` and `labels` have mismatched shapes, if 905 `weights` is not `None` and its shape doesn't match `predictions`, or if 906 `specificity` is not between 0 and 1, or if either `metrics_collections` 907 or `updates_collections` are not a list or tuple. 908 """ 909 return metrics.sensitivity_at_specificity( 910 specificity=specificity, num_thresholds=num_thresholds, 911 predictions=predictions, labels=labels, weights=weights, 912 metrics_collections=metrics_collections, 913 updates_collections=updates_collections, name=name) 914 915 916def streaming_precision_at_thresholds(predictions, labels, thresholds, 917 weights=None, 918 metrics_collections=None, 919 updates_collections=None, name=None): 920 """Computes precision values for different `thresholds` on `predictions`. 921 922 The `streaming_precision_at_thresholds` function creates four local variables, 923 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 924 for various values of thresholds. `precision[i]` is defined as the total 925 weight of values in `predictions` above `thresholds[i]` whose corresponding 926 entry in `labels` is `True`, divided by the total weight of values in 927 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 928 false_positives[i])`). 929 930 For estimation of the metric over a stream of data, the function creates an 931 `update_op` operation that updates these variables and returns the 932 `precision`. 933 934 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 935 936 Args: 937 predictions: A floating point `Tensor` of arbitrary shape and whose values 938 are in the range `[0, 1]`. 939 labels: A `bool` `Tensor` whose shape matches `predictions`. 940 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 941 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 942 must be broadcastable to `labels` (i.e., all dimensions must be either 943 `1`, or the same as the corresponding `labels` dimension). 944 metrics_collections: An optional list of collections that `auc` should be 945 added to. 946 updates_collections: An optional list of collections that `update_op` should 947 be added to. 948 name: An optional variable_scope name. 949 950 Returns: 951 precision: A float `Tensor` of shape `[len(thresholds)]`. 952 update_op: An operation that increments the `true_positives`, 953 `true_negatives`, `false_positives` and `false_negatives` variables that 954 are used in the computation of `precision`. 955 956 Raises: 957 ValueError: If `predictions` and `labels` have mismatched shapes, or if 958 `weights` is not `None` and its shape doesn't match `predictions`, or if 959 either `metrics_collections` or `updates_collections` are not a list or 960 tuple. 961 """ 962 return metrics.precision_at_thresholds( 963 thresholds=thresholds, 964 predictions=predictions, labels=labels, weights=weights, 965 metrics_collections=metrics_collections, 966 updates_collections=updates_collections, name=name) 967 968 969def streaming_recall_at_thresholds(predictions, labels, thresholds, 970 weights=None, metrics_collections=None, 971 updates_collections=None, name=None): 972 """Computes various recall values for different `thresholds` on `predictions`. 973 974 The `streaming_recall_at_thresholds` function creates four local variables, 975 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 976 for various values of thresholds. `recall[i]` is defined as the total weight 977 of values in `predictions` above `thresholds[i]` whose corresponding entry in 978 `labels` is `True`, divided by the total weight of `True` values in `labels` 979 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 980 981 For estimation of the metric over a stream of data, the function creates an 982 `update_op` operation that updates these variables and returns the `recall`. 983 984 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 985 986 Args: 987 predictions: A floating point `Tensor` of arbitrary shape and whose values 988 are in the range `[0, 1]`. 989 labels: A `bool` `Tensor` whose shape matches `predictions`. 990 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 991 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 992 must be broadcastable to `labels` (i.e., all dimensions must be either 993 `1`, or the same as the corresponding `labels` dimension). 994 metrics_collections: An optional list of collections that `recall` should be 995 added to. 996 updates_collections: An optional list of collections that `update_op` should 997 be added to. 998 name: An optional variable_scope name. 999 1000 Returns: 1001 recall: A float `Tensor` of shape `[len(thresholds)]`. 1002 update_op: An operation that increments the `true_positives`, 1003 `true_negatives`, `false_positives` and `false_negatives` variables that 1004 are used in the computation of `recall`. 1005 1006 Raises: 1007 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1008 `weights` is not `None` and its shape doesn't match `predictions`, or if 1009 either `metrics_collections` or `updates_collections` are not a list or 1010 tuple. 1011 """ 1012 return metrics.recall_at_thresholds( 1013 thresholds=thresholds, 1014 predictions=predictions, labels=labels, weights=weights, 1015 metrics_collections=metrics_collections, 1016 updates_collections=updates_collections, name=name) 1017 1018 1019def _at_k_name(name, k=None, class_id=None): 1020 if k is not None: 1021 name = '%s_at_%d' % (name, k) 1022 else: 1023 name = '%s_at_k' % (name) 1024 if class_id is not None: 1025 name = '%s_class%d' % (name, class_id) 1026 return name 1027 1028 1029@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 1030 'and reshape labels from [batch_size] to [batch_size, 1].') 1031def streaming_recall_at_k(predictions, labels, k, weights=None, 1032 metrics_collections=None, updates_collections=None, 1033 name=None): 1034 """Computes the recall@k of the predictions with respect to dense labels. 1035 1036 The `streaming_recall_at_k` function creates two local variables, `total` and 1037 `count`, that are used to compute the recall@k frequency. This frequency is 1038 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1039 divides `total` by `count`. 1040 1041 For estimation of the metric over a stream of data, the function creates an 1042 `update_op` operation that updates these variables and returns the 1043 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 1044 shape [batch_size] whose elements indicate whether or not the corresponding 1045 label is in the top `k` `predictions`. Then `update_op` increments `total` 1046 with the reduced sum of `weights` where `in_top_k` is `True`, and it 1047 increments `count` with the reduced sum of `weights`. 1048 1049 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1050 1051 Args: 1052 predictions: A float `Tensor` of dimension [batch_size, num_classes]. 1053 labels: A `Tensor` of dimension [batch_size] whose type is in `int32`, 1054 `int64`. 1055 k: The number of top elements to look at for computing recall. 1056 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1057 must be broadcastable to `labels` (i.e., all dimensions must be either 1058 `1`, or the same as the corresponding `labels` dimension). 1059 metrics_collections: An optional list of collections that `recall_at_k` 1060 should be added to. 1061 updates_collections: An optional list of collections `update_op` should be 1062 added to. 1063 name: An optional variable_scope name. 1064 1065 Returns: 1066 recall_at_k: A `Tensor` representing the recall@k, the fraction of labels 1067 which fall into the top `k` predictions. 1068 update_op: An operation that increments the `total` and `count` variables 1069 appropriately and whose value matches `recall_at_k`. 1070 1071 Raises: 1072 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1073 `weights` is not `None` and its shape doesn't match `predictions`, or if 1074 either `metrics_collections` or `updates_collections` are not a list or 1075 tuple. 1076 """ 1077 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1078 return streaming_mean(in_top_k, 1079 weights, 1080 metrics_collections, 1081 updates_collections, 1082 name or _at_k_name('recall', k)) 1083 1084 1085# TODO(ptucker): Validate range of values in labels? 1086def streaming_sparse_recall_at_k(predictions, 1087 labels, 1088 k, 1089 class_id=None, 1090 weights=None, 1091 metrics_collections=None, 1092 updates_collections=None, 1093 name=None): 1094 """Computes recall@k of the predictions with respect to sparse labels. 1095 1096 If `class_id` is not specified, we'll calculate recall as the ratio of true 1097 positives (i.e., correct predictions, items in the top `k` highest 1098 `predictions` that are found in the corresponding row in `labels`) to 1099 actual positives (the full `labels` row). 1100 If `class_id` is specified, we calculate recall by considering only the rows 1101 in the batch for which `class_id` is in `labels`, and computing the 1102 fraction of them for which `class_id` is in the corresponding row in 1103 `labels`. 1104 1105 `streaming_sparse_recall_at_k` creates two local variables, 1106 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1107 the recall_at_k frequency. This frequency is ultimately returned as 1108 `recall_at_<k>`: an idempotent operation that simply divides 1109 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1110 `false_negative_at_<k>`). 1111 1112 For estimation of the metric over a stream of data, the function creates an 1113 `update_op` operation that updates these variables and returns the 1114 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1115 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1116 `labels` calculate the true positives and false negatives weighted by 1117 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1118 `false_negative_at_<k>` using these values. 1119 1120 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1121 1122 Args: 1123 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1124 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1125 The final dimension contains the logit values for each class. [D1, ... DN] 1126 must match `labels`. 1127 labels: `int64` `Tensor` or `SparseTensor` with shape 1128 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1129 target classes for the associated prediction. Commonly, N=1 and `labels` 1130 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 1131 Values should be in range [0, num_classes), where num_classes is the last 1132 dimension of `predictions`. Values outside this range always count 1133 towards `false_negative_at_<k>`. 1134 k: Integer, k for @k metric. 1135 class_id: Integer class ID for which we want binary metrics. This should be 1136 in range [0, num_classes), where num_classes is the last dimension of 1137 `predictions`. If class_id is outside this range, the method returns NAN. 1138 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1139 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1140 dimensions must be either `1`, or the same as the corresponding `labels` 1141 dimension). 1142 metrics_collections: An optional list of collections that values should 1143 be added to. 1144 updates_collections: An optional list of collections that updates should 1145 be added to. 1146 name: Name of new update operation, and namespace for other dependent ops. 1147 1148 Returns: 1149 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1150 by the sum of `true_positives` and `false_negatives`. 1151 update_op: `Operation` that increments `true_positives` and 1152 `false_negatives` variables appropriately, and whose value matches 1153 `recall`. 1154 1155 Raises: 1156 ValueError: If `weights` is not `None` and its shape doesn't match 1157 `predictions`, or if either `metrics_collections` or `updates_collections` 1158 are not a list or tuple. 1159 """ 1160 return metrics.recall_at_k( 1161 k=k, class_id=class_id, 1162 predictions=predictions, labels=labels, weights=weights, 1163 metrics_collections=metrics_collections, 1164 updates_collections=updates_collections, name=name) 1165 1166 1167# TODO(ptucker): Validate range of values in labels? 1168def streaming_sparse_precision_at_k(predictions, 1169 labels, 1170 k, 1171 class_id=None, 1172 weights=None, 1173 metrics_collections=None, 1174 updates_collections=None, 1175 name=None): 1176 """Computes precision@k of the predictions with respect to sparse labels. 1177 1178 If `class_id` is not specified, we calculate precision as the ratio of true 1179 positives (i.e., correct predictions, items in the top `k` highest 1180 `predictions` that are found in the corresponding row in `labels`) to 1181 positives (all top `k` `predictions`). 1182 If `class_id` is specified, we calculate precision by considering only the 1183 rows in the batch for which `class_id` is in the top `k` highest 1184 `predictions`, and computing the fraction of them for which `class_id` is 1185 in the corresponding row in `labels`. 1186 1187 We expect precision to decrease as `k` increases. 1188 1189 `streaming_sparse_precision_at_k` creates two local variables, 1190 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1191 the precision@k frequency. This frequency is ultimately returned as 1192 `precision_at_<k>`: an idempotent operation that simply divides 1193 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1194 `false_positive_at_<k>`). 1195 1196 For estimation of the metric over a stream of data, the function creates an 1197 `update_op` operation that updates these variables and returns the 1198 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1199 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1200 `labels` calculate the true positives and false positives weighted by 1201 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1202 `false_positive_at_<k>` using these values. 1203 1204 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1205 1206 Args: 1207 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1208 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1209 The final dimension contains the logit values for each class. [D1, ... DN] 1210 must match `labels`. 1211 labels: `int64` `Tensor` or `SparseTensor` with shape 1212 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1213 target classes for the associated prediction. Commonly, N=1 and `labels` 1214 has shape [batch_size, num_labels]. [D1, ... DN] must match 1215 `predictions`. Values should be in range [0, num_classes), where 1216 num_classes is the last dimension of `predictions`. Values outside this 1217 range are ignored. 1218 k: Integer, k for @k metric. 1219 class_id: Integer class ID for which we want binary metrics. This should be 1220 in range [0, num_classes], where num_classes is the last dimension of 1221 `predictions`. If `class_id` is outside this range, the method returns 1222 NAN. 1223 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1224 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1225 dimensions must be either `1`, or the same as the corresponding `labels` 1226 dimension). 1227 metrics_collections: An optional list of collections that values should 1228 be added to. 1229 updates_collections: An optional list of collections that updates should 1230 be added to. 1231 name: Name of new update operation, and namespace for other dependent ops. 1232 1233 Returns: 1234 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1235 divided by the sum of `true_positives` and `false_positives`. 1236 update_op: `Operation` that increments `true_positives` and 1237 `false_positives` variables appropriately, and whose value matches 1238 `precision`. 1239 1240 Raises: 1241 ValueError: If `weights` is not `None` and its shape doesn't match 1242 `predictions`, or if either `metrics_collections` or `updates_collections` 1243 are not a list or tuple. 1244 """ 1245 return metrics.sparse_precision_at_k( 1246 k=k, class_id=class_id, 1247 predictions=predictions, labels=labels, weights=weights, 1248 metrics_collections=metrics_collections, 1249 updates_collections=updates_collections, name=name) 1250 1251 1252# TODO(ptucker): Validate range of values in labels? 1253def streaming_sparse_precision_at_top_k(top_k_predictions, 1254 labels, 1255 class_id=None, 1256 weights=None, 1257 metrics_collections=None, 1258 updates_collections=None, 1259 name=None): 1260 """Computes precision@k of top-k predictions with respect to sparse labels. 1261 1262 If `class_id` is not specified, we calculate precision as the ratio of 1263 true positives (i.e., correct predictions, items in `top_k_predictions` 1264 that are found in the corresponding row in `labels`) to positives (all 1265 `top_k_predictions`). 1266 If `class_id` is specified, we calculate precision by considering only the 1267 rows in the batch for which `class_id` is in the top `k` highest 1268 `predictions`, and computing the fraction of them for which `class_id` is 1269 in the corresponding row in `labels`. 1270 1271 We expect precision to decrease as `k` increases. 1272 1273 `streaming_sparse_precision_at_top_k` creates two local variables, 1274 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 1275 the precision@k frequency. This frequency is ultimately returned as 1276 `precision_at_k`: an idempotent operation that simply divides 1277 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 1278 1279 For estimation of the metric over a stream of data, the function creates an 1280 `update_op` operation that updates these variables and returns the 1281 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 1282 and `labels` calculate the true positives and false positives weighted by 1283 `weights`. Then `update_op` increments `true_positive_at_k` and 1284 `false_positive_at_k` using these values. 1285 1286 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1287 1288 Args: 1289 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 1290 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 1291 The final dimension contains the indices of top-k labels. [D1, ... DN] 1292 must match `labels`. 1293 labels: `int64` `Tensor` or `SparseTensor` with shape 1294 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1295 target classes for the associated prediction. Commonly, N=1 and `labels` 1296 has shape [batch_size, num_labels]. [D1, ... DN] must match 1297 `top_k_predictions`. Values should be in range [0, num_classes), where 1298 num_classes is the last dimension of `predictions`. Values outside this 1299 range are ignored. 1300 class_id: Integer class ID for which we want binary metrics. This should be 1301 in range [0, num_classes), where num_classes is the last dimension of 1302 `predictions`. If `class_id` is outside this range, the method returns 1303 NAN. 1304 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1305 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1306 dimensions must be either `1`, or the same as the corresponding `labels` 1307 dimension). 1308 metrics_collections: An optional list of collections that values should 1309 be added to. 1310 updates_collections: An optional list of collections that updates should 1311 be added to. 1312 name: Name of new update operation, and namespace for other dependent ops. 1313 1314 Returns: 1315 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1316 divided by the sum of `true_positives` and `false_positives`. 1317 update_op: `Operation` that increments `true_positives` and 1318 `false_positives` variables appropriately, and whose value matches 1319 `precision`. 1320 1321 Raises: 1322 ValueError: If `weights` is not `None` and its shape doesn't match 1323 `predictions`, or if either `metrics_collections` or `updates_collections` 1324 are not a list or tuple. 1325 ValueError: If `top_k_predictions` has rank < 2. 1326 """ 1327 default_name = _at_k_name('precision', class_id=class_id) 1328 with ops.name_scope( 1329 name, default_name, 1330 (top_k_predictions, labels, weights)) as name_scope: 1331 return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access 1332 labels=labels, 1333 predictions_idx=top_k_predictions, 1334 class_id=class_id, 1335 weights=weights, 1336 metrics_collections=metrics_collections, 1337 updates_collections=updates_collections, 1338 name=name_scope) 1339 1340 1341def streaming_sparse_average_precision_at_k(predictions, 1342 labels, 1343 k, 1344 weights=None, 1345 metrics_collections=None, 1346 updates_collections=None, 1347 name=None): 1348 """Computes average precision@k of predictions with respect to sparse labels. 1349 1350 See `sparse_average_precision_at_k` for details on formula. `weights` are 1351 applied to the result of `sparse_average_precision_at_k` 1352 1353 `streaming_sparse_average_precision_at_k` creates two local variables, 1354 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 1355 are used to compute the frequency. This frequency is ultimately returned as 1356 `average_precision_at_<k>`: an idempotent operation that simply divides 1357 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 1358 1359 For estimation of the metric over a stream of data, the function creates an 1360 `update_op` operation that updates these variables and returns the 1361 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1362 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1363 `labels` calculate the true positives and false positives weighted by 1364 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1365 `false_positive_at_<k>` using these values. 1366 1367 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1368 1369 Args: 1370 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1371 N >= 1. Commonly, N=1 and `predictions` has shape 1372 [batch size, num_classes]. The final dimension contains the logit values 1373 for each class. [D1, ... DN] must match `labels`. 1374 labels: `int64` `Tensor` or `SparseTensor` with shape 1375 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1376 target classes for the associated prediction. Commonly, N=1 and `labels` 1377 has shape [batch_size, num_labels]. [D1, ... DN] must match 1378 `predictions_`. Values should be in range [0, num_classes), where 1379 num_classes is the last dimension of `predictions`. Values outside this 1380 range are ignored. 1381 k: Integer, k for @k metric. This will calculate an average precision for 1382 range `[1,k]`, as documented above. 1383 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1384 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1385 dimensions must be either `1`, or the same as the corresponding `labels` 1386 dimension). 1387 metrics_collections: An optional list of collections that values should 1388 be added to. 1389 updates_collections: An optional list of collections that updates should 1390 be added to. 1391 name: Name of new update operation, and namespace for other dependent ops. 1392 1393 Returns: 1394 mean_average_precision: Scalar `float64` `Tensor` with the mean average 1395 precision values. 1396 update: `Operation` that increments variables appropriately, and whose 1397 value matches `metric`. 1398 """ 1399 return metrics.sparse_average_precision_at_k( 1400 k=k, predictions=predictions, labels=labels, weights=weights, 1401 metrics_collections=metrics_collections, 1402 updates_collections=updates_collections, name=name) 1403 1404 1405def streaming_sparse_average_precision_at_top_k(top_k_predictions, 1406 labels, 1407 weights=None, 1408 metrics_collections=None, 1409 updates_collections=None, 1410 name=None): 1411 """Computes average precision@k of predictions with respect to sparse labels. 1412 1413 `streaming_sparse_average_precision_at_top_k` creates two local variables, 1414 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 1415 are used to compute the frequency. This frequency is ultimately returned as 1416 `average_precision_at_<k>`: an idempotent operation that simply divides 1417 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 1418 1419 For estimation of the metric over a stream of data, the function creates an 1420 `update_op` operation that updates these variables and returns the 1421 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 1422 the true positives and false positives weighted by `weights`. Then `update_op` 1423 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 1424 values. 1425 1426 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1427 1428 Args: 1429 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 1430 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 1431 dimension must be set and contains the top `k` predicted class indices. 1432 [D1, ... DN] must match `labels`. Values should be in range 1433 [0, num_classes). 1434 labels: `int64` `Tensor` or `SparseTensor` with shape 1435 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 1436 num_labels=1. N >= 1 and num_labels is the number of target classes for 1437 the associated prediction. Commonly, N=1 and `labels` has shape 1438 [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. 1439 Values should be in range [0, num_classes). 1440 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1441 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1442 dimensions must be either `1`, or the same as the corresponding `labels` 1443 dimension). 1444 metrics_collections: An optional list of collections that values should 1445 be added to. 1446 updates_collections: An optional list of collections that updates should 1447 be added to. 1448 name: Name of new update operation, and namespace for other dependent ops. 1449 1450 Returns: 1451 mean_average_precision: Scalar `float64` `Tensor` with the mean average 1452 precision values. 1453 update: `Operation` that increments variables appropriately, and whose 1454 value matches `metric`. 1455 1456 Raises: 1457 ValueError: if the last dimension of top_k_predictions is not set. 1458 """ 1459 return metrics_impl._streaming_sparse_average_precision_at_top_k( # pylint: disable=protected-access 1460 predictions_idx=top_k_predictions, 1461 labels=labels, 1462 weights=weights, 1463 metrics_collections=metrics_collections, 1464 updates_collections=updates_collections, 1465 name=name) 1466 1467 1468def streaming_mean_absolute_error(predictions, labels, weights=None, 1469 metrics_collections=None, 1470 updates_collections=None, 1471 name=None): 1472 """Computes the mean absolute error between the labels and predictions. 1473 1474 The `streaming_mean_absolute_error` function creates two local variables, 1475 `total` and `count` that are used to compute the mean absolute error. This 1476 average is weighted by `weights`, and it is ultimately returned as 1477 `mean_absolute_error`: an idempotent operation that simply divides `total` by 1478 `count`. 1479 1480 For estimation of the metric over a stream of data, the function creates an 1481 `update_op` operation that updates these variables and returns the 1482 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 1483 absolute value of the differences between `predictions` and `labels`. Then 1484 `update_op` increments `total` with the reduced sum of the product of 1485 `weights` and `absolute_errors`, and it increments `count` with the reduced 1486 sum of `weights` 1487 1488 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1489 1490 Args: 1491 predictions: A `Tensor` of arbitrary shape. 1492 labels: A `Tensor` of the same shape as `predictions`. 1493 weights: Optional `Tensor` indicating the frequency with which an example is 1494 sampled. Rank must be 0, or the same rank as `labels`, and must be 1495 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 1496 the same as the corresponding `labels` dimension). 1497 metrics_collections: An optional list of collections that 1498 `mean_absolute_error` should be added to. 1499 updates_collections: An optional list of collections that `update_op` should 1500 be added to. 1501 name: An optional variable_scope name. 1502 1503 Returns: 1504 mean_absolute_error: A `Tensor` representing the current mean, the value of 1505 `total` divided by `count`. 1506 update_op: An operation that increments the `total` and `count` variables 1507 appropriately and whose value matches `mean_absolute_error`. 1508 1509 Raises: 1510 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1511 `weights` is not `None` and its shape doesn't match `predictions`, or if 1512 either `metrics_collections` or `updates_collections` are not a list or 1513 tuple. 1514 """ 1515 return metrics.mean_absolute_error( 1516 predictions=predictions, labels=labels, weights=weights, 1517 metrics_collections=metrics_collections, 1518 updates_collections=updates_collections, name=name) 1519 1520 1521def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, 1522 metrics_collections=None, 1523 updates_collections=None, 1524 name=None): 1525 """Computes the mean relative error by normalizing with the given values. 1526 1527 The `streaming_mean_relative_error` function creates two local variables, 1528 `total` and `count` that are used to compute the mean relative absolute error. 1529 This average is weighted by `weights`, and it is ultimately returned as 1530 `mean_relative_error`: an idempotent operation that simply divides `total` by 1531 `count`. 1532 1533 For estimation of the metric over a stream of data, the function creates an 1534 `update_op` operation that updates these variables and returns the 1535 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1536 absolute value of the differences between `predictions` and `labels` by the 1537 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1538 product of `weights` and `relative_errors`, and it increments `count` with the 1539 reduced sum of `weights`. 1540 1541 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1542 1543 Args: 1544 predictions: A `Tensor` of arbitrary shape. 1545 labels: A `Tensor` of the same shape as `predictions`. 1546 normalizer: A `Tensor` of the same shape as `predictions`. 1547 weights: Optional `Tensor` indicating the frequency with which an example is 1548 sampled. Rank must be 0, or the same rank as `labels`, and must be 1549 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 1550 the same as the corresponding `labels` dimension). 1551 metrics_collections: An optional list of collections that 1552 `mean_relative_error` should be added to. 1553 updates_collections: An optional list of collections that `update_op` should 1554 be added to. 1555 name: An optional variable_scope name. 1556 1557 Returns: 1558 mean_relative_error: A `Tensor` representing the current mean, the value of 1559 `total` divided by `count`. 1560 update_op: An operation that increments the `total` and `count` variables 1561 appropriately and whose value matches `mean_relative_error`. 1562 1563 Raises: 1564 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1565 `weights` is not `None` and its shape doesn't match `predictions`, or if 1566 either `metrics_collections` or `updates_collections` are not a list or 1567 tuple. 1568 """ 1569 return metrics.mean_relative_error( 1570 normalizer=normalizer, predictions=predictions, labels=labels, 1571 weights=weights, metrics_collections=metrics_collections, 1572 updates_collections=updates_collections, name=name) 1573 1574 1575def streaming_mean_squared_error(predictions, labels, weights=None, 1576 metrics_collections=None, 1577 updates_collections=None, 1578 name=None): 1579 """Computes the mean squared error between the labels and predictions. 1580 1581 The `streaming_mean_squared_error` function creates two local variables, 1582 `total` and `count` that are used to compute the mean squared error. 1583 This average is weighted by `weights`, and it is ultimately returned as 1584 `mean_squared_error`: an idempotent operation that simply divides `total` by 1585 `count`. 1586 1587 For estimation of the metric over a stream of data, the function creates an 1588 `update_op` operation that updates these variables and returns the 1589 `mean_squared_error`. Internally, a `squared_error` operation computes the 1590 element-wise square of the difference between `predictions` and `labels`. Then 1591 `update_op` increments `total` with the reduced sum of the product of 1592 `weights` and `squared_error`, and it increments `count` with the reduced sum 1593 of `weights`. 1594 1595 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1596 1597 Args: 1598 predictions: A `Tensor` of arbitrary shape. 1599 labels: A `Tensor` of the same shape as `predictions`. 1600 weights: Optional `Tensor` indicating the frequency with which an example is 1601 sampled. Rank must be 0, or the same rank as `labels`, and must be 1602 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 1603 the same as the corresponding `labels` dimension). 1604 metrics_collections: An optional list of collections that 1605 `mean_squared_error` should be added to. 1606 updates_collections: An optional list of collections that `update_op` should 1607 be added to. 1608 name: An optional variable_scope name. 1609 1610 Returns: 1611 mean_squared_error: A `Tensor` representing the current mean, the value of 1612 `total` divided by `count`. 1613 update_op: An operation that increments the `total` and `count` variables 1614 appropriately and whose value matches `mean_squared_error`. 1615 1616 Raises: 1617 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1618 `weights` is not `None` and its shape doesn't match `predictions`, or if 1619 either `metrics_collections` or `updates_collections` are not a list or 1620 tuple. 1621 """ 1622 return metrics.mean_squared_error( 1623 predictions=predictions, labels=labels, weights=weights, 1624 metrics_collections=metrics_collections, 1625 updates_collections=updates_collections, name=name) 1626 1627 1628def streaming_root_mean_squared_error(predictions, labels, weights=None, 1629 metrics_collections=None, 1630 updates_collections=None, 1631 name=None): 1632 """Computes the root mean squared error between the labels and predictions. 1633 1634 The `streaming_root_mean_squared_error` function creates two local variables, 1635 `total` and `count` that are used to compute the root mean squared error. 1636 This average is weighted by `weights`, and it is ultimately returned as 1637 `root_mean_squared_error`: an idempotent operation that takes the square root 1638 of the division of `total` by `count`. 1639 1640 For estimation of the metric over a stream of data, the function creates an 1641 `update_op` operation that updates these variables and returns the 1642 `root_mean_squared_error`. Internally, a `squared_error` operation computes 1643 the element-wise square of the difference between `predictions` and `labels`. 1644 Then `update_op` increments `total` with the reduced sum of the product of 1645 `weights` and `squared_error`, and it increments `count` with the reduced sum 1646 of `weights`. 1647 1648 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1649 1650 Args: 1651 predictions: A `Tensor` of arbitrary shape. 1652 labels: A `Tensor` of the same shape as `predictions`. 1653 weights: Optional `Tensor` indicating the frequency with which an example is 1654 sampled. Rank must be 0, or the same rank as `labels`, and must be 1655 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 1656 the same as the corresponding `labels` dimension). 1657 metrics_collections: An optional list of collections that 1658 `root_mean_squared_error` should be added to. 1659 updates_collections: An optional list of collections that `update_op` should 1660 be added to. 1661 name: An optional variable_scope name. 1662 1663 Returns: 1664 root_mean_squared_error: A `Tensor` representing the current mean, the value 1665 of `total` divided by `count`. 1666 update_op: An operation that increments the `total` and `count` variables 1667 appropriately and whose value matches `root_mean_squared_error`. 1668 1669 Raises: 1670 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1671 `weights` is not `None` and its shape doesn't match `predictions`, or if 1672 either `metrics_collections` or `updates_collections` are not a list or 1673 tuple. 1674 """ 1675 return metrics.root_mean_squared_error( 1676 predictions=predictions, labels=labels, weights=weights, 1677 metrics_collections=metrics_collections, 1678 updates_collections=updates_collections, name=name) 1679 1680 1681def streaming_covariance(predictions, 1682 labels, 1683 weights=None, 1684 metrics_collections=None, 1685 updates_collections=None, 1686 name=None): 1687 """Computes the unbiased sample covariance between `predictions` and `labels`. 1688 1689 The `streaming_covariance` function creates four local variables, 1690 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 1691 compute the sample covariance between predictions and labels across multiple 1692 batches of data. The covariance is ultimately returned as an idempotent 1693 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 1694 in order to get an unbiased estimate. 1695 1696 The algorithm used for this online computation is described in 1697 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 1698 Specifically, the formula used to combine two sample comoments is 1699 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 1700 The comoment for a single batch of data is simply 1701 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 1702 1703 If `weights` is not None, then it is used to compute weighted comoments, 1704 means, and count. NOTE: these weights are treated as "frequency weights", as 1705 opposed to "reliability weights". See discussion of the difference on 1706 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 1707 1708 To facilitate the computation of covariance across multiple batches of data, 1709 the function creates an `update_op` operation, which updates underlying 1710 variables and returns the updated covariance. 1711 1712 Args: 1713 predictions: A `Tensor` of arbitrary size. 1714 labels: A `Tensor` of the same size as `predictions`. 1715 weights: Optional `Tensor` indicating the frequency with which an example is 1716 sampled. Rank must be 0, or the same rank as `labels`, and must be 1717 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 1718 the same as the corresponding `labels` dimension). 1719 metrics_collections: An optional list of collections that the metric 1720 value variable should be added to. 1721 updates_collections: An optional list of collections that the metric update 1722 ops should be added to. 1723 name: An optional variable_scope name. 1724 1725 Returns: 1726 covariance: A `Tensor` representing the current unbiased sample covariance, 1727 `comoment` / (`count` - 1). 1728 update_op: An operation that updates the local variables appropriately. 1729 1730 Raises: 1731 ValueError: If labels and predictions are of different sizes or if either 1732 `metrics_collections` or `updates_collections` are not a list or tuple. 1733 """ 1734 with variable_scope.variable_scope( 1735 name, 'covariance', (predictions, labels, weights)): 1736 predictions, labels, weights = _remove_squeezable_dimensions( 1737 predictions, labels, weights) 1738 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1739 count = _create_local('count', []) 1740 mean_prediction = _create_local('mean_prediction', []) 1741 mean_label = _create_local('mean_label', []) 1742 comoment = _create_local('comoment', []) # C_A in update equation 1743 1744 if weights is None: 1745 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 1746 weighted_predictions = predictions 1747 weighted_labels = labels 1748 else: 1749 weights = _broadcast_weights(weights, labels) 1750 batch_count = math_ops.reduce_sum(weights) # n_B in eqn 1751 weighted_predictions = math_ops.multiply(predictions, weights) 1752 weighted_labels = math_ops.multiply(labels, weights) 1753 1754 update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn 1755 prev_count = update_count - batch_count # n_A in update equation 1756 1757 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 1758 # batch_mean_prediction is E[x_B] in the update equation 1759 batch_mean_prediction = _safe_div( 1760 math_ops.reduce_sum(weighted_predictions), batch_count, 1761 'batch_mean_prediction') 1762 delta_mean_prediction = _safe_div( 1763 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 1764 'delta_mean_prediction') 1765 update_mean_prediction = state_ops.assign_add(mean_prediction, 1766 delta_mean_prediction) 1767 # prev_mean_prediction is E[x_A] in the update equation 1768 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 1769 1770 # batch_mean_label is E[y_B] in the update equation 1771 batch_mean_label = _safe_div( 1772 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 1773 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 1774 update_count, 'delta_mean_label') 1775 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 1776 # prev_mean_label is E[y_A] in the update equation 1777 prev_mean_label = update_mean_label - delta_mean_label 1778 1779 unweighted_batch_coresiduals = ( 1780 (predictions - batch_mean_prediction) * (labels - batch_mean_label)) 1781 # batch_comoment is C_B in the update equation 1782 if weights is None: 1783 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 1784 else: 1785 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * 1786 weights) 1787 1788 # View delta_comoment as = C_AB - C_A in the update equation above. 1789 # Since C_A is stored in a var, by how much do we need to increment that var 1790 # to make the var = C_AB? 1791 delta_comoment = (batch_comoment + 1792 (prev_mean_prediction - batch_mean_prediction) * 1793 (prev_mean_label - batch_mean_label) * 1794 (prev_count * batch_count / update_count)) 1795 update_comoment = state_ops.assign_add(comoment, delta_comoment) 1796 1797 covariance = _safe_div(comoment, count - 1, 'covariance') 1798 with ops.control_dependencies([update_comoment]): 1799 update_op = _safe_div(comoment, count - 1, 'update_op') 1800 1801 if metrics_collections: 1802 ops.add_to_collections(metrics_collections, covariance) 1803 1804 if updates_collections: 1805 ops.add_to_collections(updates_collections, update_op) 1806 1807 return covariance, update_op 1808 1809 1810def streaming_pearson_correlation(predictions, 1811 labels, 1812 weights=None, 1813 metrics_collections=None, 1814 updates_collections=None, 1815 name=None): 1816 """Computes Pearson correlation coefficient between `predictions`, `labels`. 1817 1818 The `streaming_pearson_correlation` function delegates to 1819 `streaming_covariance` the tracking of three [co]variances: 1820 1821 - `streaming_covariance(predictions, labels)`, i.e. covariance 1822 - `streaming_covariance(predictions, predictions)`, i.e. variance 1823 - `streaming_covariance(labels, labels)`, i.e. variance 1824 1825 The product-moment correlation ultimately returned is an idempotent operation 1826 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 1827 facilitate correlation computation across multiple batches, the function 1828 groups the `update_op`s of the underlying streaming_covariance and returns an 1829 `update_op`. 1830 1831 If `weights` is not None, then it is used to compute a weighted correlation. 1832 NOTE: these weights are treated as "frequency weights", as opposed to 1833 "reliability weights". See discussion of the difference on 1834 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 1835 1836 Args: 1837 predictions: A `Tensor` of arbitrary size. 1838 labels: A `Tensor` of the same size as predictions. 1839 weights: Optional `Tensor` indicating the frequency with which an example is 1840 sampled. Rank must be 0, or the same rank as `labels`, and must be 1841 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 1842 the same as the corresponding `labels` dimension). 1843 metrics_collections: An optional list of collections that the metric 1844 value variable should be added to. 1845 updates_collections: An optional list of collections that the metric update 1846 ops should be added to. 1847 name: An optional variable_scope name. 1848 1849 Returns: 1850 pearson_r: A `Tensor` representing the current Pearson product-moment 1851 correlation coefficient, the value of 1852 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 1853 update_op: An operation that updates the underlying variables appropriately. 1854 1855 Raises: 1856 ValueError: If `labels` and `predictions` are of different sizes, or if 1857 `weights` is the wrong size, or if either `metrics_collections` or 1858 `updates_collections` are not a `list` or `tuple`. 1859 """ 1860 with variable_scope.variable_scope( 1861 name, 'pearson_r', (predictions, labels, weights)): 1862 predictions, labels, weights = _remove_squeezable_dimensions( 1863 predictions, labels, weights) 1864 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1865 # Broadcast weights here to avoid duplicate broadcasting in each call to 1866 # `streaming_covariance`. 1867 if weights is not None: 1868 weights = _broadcast_weights(weights, labels) 1869 cov, update_cov = streaming_covariance( 1870 predictions, labels, weights=weights, name='covariance') 1871 var_predictions, update_var_predictions = streaming_covariance( 1872 predictions, predictions, weights=weights, name='variance_predictions') 1873 var_labels, update_var_labels = streaming_covariance( 1874 labels, labels, weights=weights, name='variance_labels') 1875 1876 pearson_r = _safe_div( 1877 cov, 1878 math_ops.multiply(math_ops.sqrt(var_predictions), 1879 math_ops.sqrt(var_labels)), 1880 'pearson_r') 1881 update_op = _safe_div( 1882 update_cov, 1883 math_ops.multiply(math_ops.sqrt(update_var_predictions), 1884 math_ops.sqrt(update_var_labels)), 1885 'update_op') 1886 1887 if metrics_collections: 1888 ops.add_to_collections(metrics_collections, pearson_r) 1889 1890 if updates_collections: 1891 ops.add_to_collections(updates_collections, update_op) 1892 1893 return pearson_r, update_op 1894 1895 1896# TODO(nsilberman): add a 'normalized' flag so that the user can request 1897# normalization if the inputs are not normalized. 1898def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, 1899 metrics_collections=None, 1900 updates_collections=None, 1901 name=None): 1902 """Computes the cosine distance between the labels and predictions. 1903 1904 The `streaming_mean_cosine_distance` function creates two local variables, 1905 `total` and `count` that are used to compute the average cosine distance 1906 between `predictions` and `labels`. This average is weighted by `weights`, 1907 and it is ultimately returned as `mean_distance`, which is an idempotent 1908 operation that simply divides `total` by `count`. 1909 1910 For estimation of the metric over a stream of data, the function creates an 1911 `update_op` operation that updates these variables and returns the 1912 `mean_distance`. 1913 1914 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1915 1916 Args: 1917 predictions: A `Tensor` of the same shape as `labels`. 1918 labels: A `Tensor` of arbitrary shape. 1919 dim: The dimension along which the cosine distance is computed. 1920 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 1921 and whose dimension `dim` is 1. 1922 metrics_collections: An optional list of collections that the metric 1923 value variable should be added to. 1924 updates_collections: An optional list of collections that the metric update 1925 ops should be added to. 1926 name: An optional variable_scope name. 1927 1928 Returns: 1929 mean_distance: A `Tensor` representing the current mean, the value of 1930 `total` divided by `count`. 1931 update_op: An operation that increments the `total` and `count` variables 1932 appropriately. 1933 1934 Raises: 1935 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1936 `weights` is not `None` and its shape doesn't match `predictions`, or if 1937 either `metrics_collections` or `updates_collections` are not a list or 1938 tuple. 1939 """ 1940 predictions, labels, weights = _remove_squeezable_dimensions( 1941 predictions, labels, weights) 1942 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1943 radial_diffs = math_ops.multiply(predictions, labels) 1944 radial_diffs = math_ops.reduce_sum(radial_diffs, 1945 reduction_indices=[dim,], 1946 keep_dims=True) 1947 mean_distance, update_op = streaming_mean(radial_diffs, weights, 1948 None, 1949 None, 1950 name or 'mean_cosine_distance') 1951 mean_distance = math_ops.subtract(1.0, mean_distance) 1952 update_op = math_ops.subtract(1.0, update_op) 1953 1954 if metrics_collections: 1955 ops.add_to_collections(metrics_collections, mean_distance) 1956 1957 if updates_collections: 1958 ops.add_to_collections(updates_collections, update_op) 1959 1960 return mean_distance, update_op 1961 1962 1963def streaming_percentage_less(values, threshold, weights=None, 1964 metrics_collections=None, 1965 updates_collections=None, 1966 name=None): 1967 """Computes the percentage of values less than the given threshold. 1968 1969 The `streaming_percentage_less` function creates two local variables, 1970 `total` and `count` that are used to compute the percentage of `values` that 1971 fall below `threshold`. This rate is weighted by `weights`, and it is 1972 ultimately returned as `percentage` which is an idempotent operation that 1973 simply divides `total` by `count`. 1974 1975 For estimation of the metric over a stream of data, the function creates an 1976 `update_op` operation that updates these variables and returns the 1977 `percentage`. 1978 1979 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1980 1981 Args: 1982 values: A numeric `Tensor` of arbitrary size. 1983 threshold: A scalar threshold. 1984 weights: An optional `Tensor` whose shape is broadcastable to `values`. 1985 metrics_collections: An optional list of collections that the metric 1986 value variable should be added to. 1987 updates_collections: An optional list of collections that the metric update 1988 ops should be added to. 1989 name: An optional variable_scope name. 1990 1991 Returns: 1992 percentage: A `Tensor` representing the current mean, the value of `total` 1993 divided by `count`. 1994 update_op: An operation that increments the `total` and `count` variables 1995 appropriately. 1996 1997 Raises: 1998 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1999 or if either `metrics_collections` or `updates_collections` are not a list 2000 or tuple. 2001 """ 2002 return metrics.percentage_below( 2003 values=values, threshold=threshold, weights=weights, 2004 metrics_collections=metrics_collections, 2005 updates_collections=updates_collections, name=name) 2006 2007 2008def streaming_mean_iou(predictions, 2009 labels, 2010 num_classes, 2011 weights=None, 2012 metrics_collections=None, 2013 updates_collections=None, 2014 name=None): 2015 """Calculate per-step mean Intersection-Over-Union (mIOU). 2016 2017 Mean Intersection-Over-Union is a common evaluation metric for 2018 semantic image segmentation, which first computes the IOU for each 2019 semantic class and then computes the average over classes. 2020 IOU is defined as follows: 2021 IOU = true_positive / (true_positive + false_positive + false_negative). 2022 The predictions are accumulated in a confusion matrix, weighted by `weights`, 2023 and mIOU is then calculated from it. 2024 2025 For estimation of the metric over a stream of data, the function creates an 2026 `update_op` operation that updates these variables and returns the `mean_iou`. 2027 2028 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2029 2030 Args: 2031 predictions: A `Tensor` of prediction results for semantic labels, whose 2032 shape is [batch size] and type `int32` or `int64`. The tensor will be 2033 flattened, if its rank > 1. 2034 labels: A `Tensor` of ground truth labels with shape [batch size] and of 2035 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 2036 num_classes: The possible number of labels the prediction task can 2037 have. This value must be provided, since a confusion matrix of 2038 dimension = [num_classes, num_classes] will be allocated. 2039 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2040 metrics_collections: An optional list of collections that `mean_iou` 2041 should be added to. 2042 updates_collections: An optional list of collections `update_op` should be 2043 added to. 2044 name: An optional variable_scope name. 2045 2046 Returns: 2047 mean_iou: A `Tensor` representing the mean intersection-over-union. 2048 update_op: An operation that increments the confusion matrix. 2049 2050 Raises: 2051 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2052 `weights` is not `None` and its shape doesn't match `predictions`, or if 2053 either `metrics_collections` or `updates_collections` are not a list or 2054 tuple. 2055 """ 2056 return metrics.mean_iou( 2057 num_classes=num_classes, predictions=predictions, labels=labels, 2058 weights=weights, metrics_collections=metrics_collections, 2059 updates_collections=updates_collections, name=name) 2060 2061 2062def _next_array_size(required_size, growth_factor=1.5): 2063 """Calculate the next size for reallocating a dynamic array. 2064 2065 Args: 2066 required_size: number or tf.Tensor specifying required array capacity. 2067 growth_factor: optional number or tf.Tensor specifying the growth factor 2068 between subsequent allocations. 2069 2070 Returns: 2071 tf.Tensor with dtype=int32 giving the next array size. 2072 """ 2073 exponent = math_ops.ceil( 2074 math_ops.log(math_ops.cast(required_size, dtypes.float32)) 2075 / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) 2076 return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) 2077 2078 2079def streaming_concat(values, 2080 axis=0, 2081 max_size=None, 2082 metrics_collections=None, 2083 updates_collections=None, 2084 name=None): 2085 """Concatenate values along an axis across batches. 2086 2087 The function `streaming_concat` creates two local variables, `array` and 2088 `size`, that are used to store concatenated values. Internally, `array` is 2089 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 2090 that updates can be run in amortized constant time. 2091 2092 For estimation of the metric over a stream of data, the function creates an 2093 `update_op` operation that appends the values of a tensor and returns the 2094 length of the concatenated axis. 2095 2096 This op allows for evaluating metrics that cannot be updated incrementally 2097 using the same framework as other streaming metrics. 2098 2099 Args: 2100 values: `Tensor` to concatenate. Rank and the shape along all axes other 2101 than the axis to concatenate along must be statically known. 2102 axis: optional integer axis to concatenate along. 2103 max_size: optional integer maximum size of `value` along the given axis. 2104 Once the maximum size is reached, further updates are no-ops. By default, 2105 there is no maximum size: the array is resized as necessary. 2106 metrics_collections: An optional list of collections that `value` 2107 should be added to. 2108 updates_collections: An optional list of collections `update_op` should be 2109 added to. 2110 name: An optional variable_scope name. 2111 2112 Returns: 2113 value: A `Tensor` representing the concatenated values. 2114 update_op: An operation that concatenates the next values. 2115 2116 Raises: 2117 ValueError: if `values` does not have a statically known rank, `axis` is 2118 not in the valid range or the size of `values` is not statically known 2119 along any axis other than `axis`. 2120 """ 2121 with variable_scope.variable_scope(name, 'streaming_concat', (values,)): 2122 # pylint: disable=invalid-slice-index 2123 values_shape = values.get_shape() 2124 if values_shape.dims is None: 2125 raise ValueError('`values` must have known statically known rank') 2126 2127 ndim = len(values_shape) 2128 if axis < 0: 2129 axis += ndim 2130 if not 0 <= axis < ndim: 2131 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 2132 2133 fixed_shape = [dim.value for n, dim in enumerate(values_shape) 2134 if n != axis] 2135 if any(value is None for value in fixed_shape): 2136 raise ValueError('all dimensions of `values` other than the dimension to ' 2137 'concatenate along must have statically known size') 2138 2139 # We move `axis` to the front of the internal array so assign ops can be 2140 # applied to contiguous slices 2141 init_size = 0 if max_size is None else max_size 2142 init_shape = [init_size] + fixed_shape 2143 array = _create_local( 2144 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) 2145 size = _create_local('size', shape=[], dtype=dtypes.int32) 2146 2147 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 2148 valid_array = array[:size] 2149 valid_array.set_shape([None] + fixed_shape) 2150 value = array_ops.transpose(valid_array, perm, name='concat') 2151 2152 values_size = array_ops.shape(values)[axis] 2153 if max_size is None: 2154 batch_size = values_size 2155 else: 2156 batch_size = math_ops.minimum(values_size, max_size - size) 2157 2158 perm = [axis] + [n for n in range(ndim) if n != axis] 2159 batch_values = array_ops.transpose(values, perm)[:batch_size] 2160 2161 def reallocate(): 2162 next_size = _next_array_size(new_size) 2163 next_shape = array_ops.stack([next_size] + fixed_shape) 2164 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 2165 old_value = array.value() 2166 assign_op = state_ops.assign(array, new_value, validate_shape=False) 2167 with ops.control_dependencies([assign_op]): 2168 copy_op = array[:size].assign(old_value[:size]) 2169 # return value needs to be the same dtype as no_op() for cond 2170 with ops.control_dependencies([copy_op]): 2171 return control_flow_ops.no_op() 2172 2173 new_size = size + batch_size 2174 array_size = array_ops.shape_internal(array, optimize=False)[0] 2175 maybe_reallocate_op = control_flow_ops.cond( 2176 new_size > array_size, reallocate, control_flow_ops.no_op) 2177 with ops.control_dependencies([maybe_reallocate_op]): 2178 append_values_op = array[size:new_size].assign(batch_values) 2179 with ops.control_dependencies([append_values_op]): 2180 update_op = size.assign(new_size) 2181 2182 if metrics_collections: 2183 ops.add_to_collections(metrics_collections, value) 2184 2185 if updates_collections: 2186 ops.add_to_collections(updates_collections, update_op) 2187 2188 return value, update_op 2189 # pylint: enable=invalid-slice-index 2190 2191 2192def aggregate_metrics(*value_update_tuples): 2193 """Aggregates the metric value tensors and update ops into two lists. 2194 2195 Args: 2196 *value_update_tuples: a variable number of tuples, each of which contain the 2197 pair of (value_tensor, update_op) from a streaming metric. 2198 2199 Returns: 2200 A list of value `Tensor` objects and a list of update ops. 2201 2202 Raises: 2203 ValueError: if `value_update_tuples` is empty. 2204 """ 2205 if not value_update_tuples: 2206 raise ValueError('Expected at least one value_tensor/update_op pair') 2207 value_ops, update_ops = zip(*value_update_tuples) 2208 return list(value_ops), list(update_ops) 2209 2210 2211def aggregate_metric_map(names_to_tuples): 2212 """Aggregates the metric names to tuple dictionary. 2213 2214 This function is useful for pairing metric names with their associated value 2215 and update ops when the list of metrics is long. For example: 2216 2217 ```python 2218 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 2219 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 2220 predictions, labels, weights), 2221 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 2222 predictions, labels, labels, weights), 2223 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 2224 predictions, labels, weights), 2225 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 2226 predictions, labels, weights), 2227 }) 2228 ``` 2229 2230 Args: 2231 names_to_tuples: a map of metric names to tuples, each of which contain the 2232 pair of (value_tensor, update_op) from a streaming metric. 2233 2234 Returns: 2235 A dictionary from metric names to value ops and a dictionary from metric 2236 names to update ops. 2237 """ 2238 metric_names = names_to_tuples.keys() 2239 value_ops, update_ops = zip(*names_to_tuples.values()) 2240 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 2241 2242 2243def _remove_squeezable_dimensions(predictions, labels, weights): 2244 """Squeeze last dim if needed. 2245 2246 Squeezes `predictions` and `labels` if their rank differs by 1. 2247 Squeezes `weights` if its rank is 1 more than the new rank of `predictions` 2248 2249 This will use static shape if available. Otherwise, it will add graph 2250 operations, which could result in a performance hit. 2251 2252 Args: 2253 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 2254 labels: Label values, a `Tensor` whose dimensions match `predictions`. 2255 weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 2256 more than the new rank of `predictions` 2257 2258 Returns: 2259 Tuple of `predictions`, `labels` and `weights`, possibly with the last 2260 dimension squeezed. 2261 """ 2262 predictions, labels = tensor_util.remove_squeezable_dimensions( 2263 predictions, labels) 2264 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2265 2266 if weights is not None: 2267 weights = ops.convert_to_tensor(weights) 2268 predictions_shape = predictions.get_shape() 2269 predictions_rank = predictions_shape.ndims 2270 weights_shape = weights.get_shape() 2271 weights_rank = weights_shape.ndims 2272 2273 if (predictions_rank is not None) and (weights_rank is not None): 2274 # Use static rank. 2275 if weights_rank - predictions_rank == 1: 2276 weights = array_ops.squeeze(weights, [-1]) 2277 elif (weights_rank is None) or ( 2278 weights_shape.dims[-1].is_compatible_with(1)): 2279 # Use dynamic rank 2280 weights = control_flow_ops.cond( 2281 math_ops.equal(array_ops.rank(weights), 2282 math_ops.add(array_ops.rank(predictions), 1)), 2283 lambda: array_ops.squeeze(weights, [-1]), 2284 lambda: weights) 2285 return predictions, labels, weights 2286 2287 2288__all__ = [ 2289 'aggregate_metric_map', 2290 'aggregate_metrics', 2291 'streaming_accuracy', 2292 'streaming_auc', 2293 'streaming_false_negatives', 2294 'streaming_false_negatives_at_thresholds', 2295 'streaming_false_positives', 2296 'streaming_false_positives_at_thresholds', 2297 'streaming_mean', 2298 'streaming_mean_absolute_error', 2299 'streaming_mean_cosine_distance', 2300 'streaming_mean_iou', 2301 'streaming_mean_relative_error', 2302 'streaming_mean_squared_error', 2303 'streaming_mean_tensor', 2304 'streaming_percentage_less', 2305 'streaming_precision', 2306 'streaming_precision_at_thresholds', 2307 'streaming_recall', 2308 'streaming_recall_at_k', 2309 'streaming_recall_at_thresholds', 2310 'streaming_root_mean_squared_error', 2311 'streaming_sensitivity_at_specificity', 2312 'streaming_sparse_average_precision_at_k', 2313 'streaming_sparse_precision_at_k', 2314 'streaming_sparse_recall_at_k', 2315 'streaming_specificity_at_sensitivity', 2316 'streaming_true_negatives', 2317 'streaming_true_negatives_at_thresholds', 2318 'streaming_true_positives', 2319 'streaming_true_positives_at_thresholds', 2320] 2321