metric_ops.py revision ebcae4a5e3bf5c840d73a0d90f1b5bf01a68f82c
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections as collections_lib 26 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import check_ops 31from tensorflow.python.ops import confusion_matrix 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import metrics 35from tensorflow.python.ops import metrics_impl 36from tensorflow.python.ops import nn 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.ops import weights_broadcast_ops 40from tensorflow.python.util.deprecation import deprecated 41 42 43def _safe_div(numerator, denominator, name): 44 """Divides two values, returning 0 if the denominator is <= 0. 45 46 Args: 47 numerator: A real `Tensor`. 48 denominator: A real `Tensor`, with dtype matching `numerator`. 49 name: Name for the returned op. 50 51 Returns: 52 0 if `denominator` <= 0, else `numerator` / `denominator` 53 """ 54 return array_ops.where( 55 math_ops.greater(denominator, 0), 56 math_ops.truediv(numerator, denominator), 57 0, 58 name=name) 59 60 61def _create_local(name, 62 shape, 63 collections=None, 64 validate_shape=True, 65 dtype=dtypes.float32): 66 """Creates a new local variable. 67 68 Args: 69 name: The name of the new or existing variable. 70 shape: Shape of the new or existing variable. 71 collections: A list of collection names to which the Variable will be added. 72 validate_shape: Whether to validate the shape of the variable. 73 dtype: Data type of the variables. 74 75 Returns: 76 The created variable. 77 """ 78 # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 79 collections = list(collections or []) 80 collections += [ops.GraphKeys.LOCAL_VARIABLES] 81 return variable_scope.variable( 82 initial_value=array_ops.zeros(shape, dtype=dtype), 83 name=name, 84 trainable=False, 85 collections=collections, 86 validate_shape=validate_shape) 87 88 89# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. 90def _assert_weights_rank(weights, values): 91 """`weights` rank must be either `0`, or the same as 'values'.""" 92 return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) 93 94 95def _count_condition(values, 96 weights=None, 97 metrics_collections=None, 98 updates_collections=None): 99 """Sums the weights of cases where the given values are True. 100 101 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 102 103 Args: 104 values: A `bool` `Tensor` of arbitrary size. 105 weights: Optional `Tensor` whose rank is either 0, or the same rank as 106 `values`, and must be broadcastable to `values` (i.e., all dimensions 107 must be either `1`, or the same as the corresponding `values` 108 dimension). 109 metrics_collections: An optional list of collections that the metric 110 value variable should be added to. 111 updates_collections: An optional list of collections that the metric update 112 ops should be added to. 113 114 Returns: 115 value_tensor: A `Tensor` representing the current value of the metric. 116 update_op: An operation that accumulates the error from a batch of data. 117 118 Raises: 119 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 120 or if either `metrics_collections` or `updates_collections` are not a list 121 or tuple. 122 """ 123 check_ops.assert_type(values, dtypes.bool) 124 count = _create_local('count', shape=[]) 125 126 values = math_ops.to_float(values) 127 if weights is not None: 128 weights = math_ops.to_float(weights) 129 with ops.control_dependencies((_assert_weights_rank(weights, values),)): 130 values = math_ops.multiply(values, weights) 131 132 value_tensor = array_ops.identity(count) 133 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 134 135 if metrics_collections: 136 ops.add_to_collections(metrics_collections, value_tensor) 137 138 if updates_collections: 139 ops.add_to_collections(updates_collections, update_op) 140 141 return value_tensor, update_op 142 143 144def streaming_true_positives(predictions, 145 labels, 146 weights=None, 147 metrics_collections=None, 148 updates_collections=None, 149 name=None): 150 """Sum the weights of true_positives. 151 152 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 153 154 Args: 155 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 156 be cast to `bool`. 157 labels: The ground truth values, a `Tensor` whose dimensions must match 158 `predictions`. Will be cast to `bool`. 159 weights: Optional `Tensor` whose rank is either 0, or the same rank as 160 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 161 must be either `1`, or the same as the corresponding `labels` 162 dimension). 163 metrics_collections: An optional list of collections that the metric 164 value variable should be added to. 165 updates_collections: An optional list of collections that the metric update 166 ops should be added to. 167 name: An optional variable_scope name. 168 169 Returns: 170 value_tensor: A `Tensor` representing the current value of the metric. 171 update_op: An operation that accumulates the error from a batch of data. 172 173 Raises: 174 ValueError: If `predictions` and `labels` have mismatched shapes, or if 175 `weights` is not `None` and its shape doesn't match `predictions`, or if 176 either `metrics_collections` or `updates_collections` are not a list or 177 tuple. 178 """ 179 return metrics.true_positives( 180 predictions=predictions, 181 labels=labels, 182 weights=weights, 183 metrics_collections=metrics_collections, 184 updates_collections=updates_collections, 185 name=name) 186 187 188def streaming_true_negatives(predictions, 189 labels, 190 weights=None, 191 metrics_collections=None, 192 updates_collections=None, 193 name=None): 194 """Sum the weights of true_negatives. 195 196 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 197 198 Args: 199 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 200 be cast to `bool`. 201 labels: The ground truth values, a `Tensor` whose dimensions must match 202 `predictions`. Will be cast to `bool`. 203 weights: Optional `Tensor` whose rank is either 0, or the same rank as 204 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 205 must be either `1`, or the same as the corresponding `labels` 206 dimension). 207 metrics_collections: An optional list of collections that the metric 208 value variable should be added to. 209 updates_collections: An optional list of collections that the metric update 210 ops should be added to. 211 name: An optional variable_scope name. 212 213 Returns: 214 value_tensor: A `Tensor` representing the current value of the metric. 215 update_op: An operation that accumulates the error from a batch of data. 216 217 Raises: 218 ValueError: If `predictions` and `labels` have mismatched shapes, or if 219 `weights` is not `None` and its shape doesn't match `predictions`, or if 220 either `metrics_collections` or `updates_collections` are not a list or 221 tuple. 222 """ 223 with variable_scope.variable_scope(name, 'true_negatives', 224 (predictions, labels, weights)): 225 226 predictions, labels, weights = _remove_squeezable_dimensions( 227 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 228 labels=math_ops.cast(labels, dtype=dtypes.bool), 229 weights=weights) 230 is_true_negative = math_ops.logical_and( 231 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 232 return _count_condition(is_true_negative, weights, metrics_collections, 233 updates_collections) 234 235 236def streaming_false_positives(predictions, 237 labels, 238 weights=None, 239 metrics_collections=None, 240 updates_collections=None, 241 name=None): 242 """Sum the weights of false positives. 243 244 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 245 246 Args: 247 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 248 be cast to `bool`. 249 labels: The ground truth values, a `Tensor` whose dimensions must match 250 `predictions`. Will be cast to `bool`. 251 weights: Optional `Tensor` whose rank is either 0, or the same rank as 252 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 253 must be either `1`, or the same as the corresponding `labels` 254 dimension). 255 metrics_collections: An optional list of collections that the metric 256 value variable should be added to. 257 updates_collections: An optional list of collections that the metric update 258 ops should be added to. 259 name: An optional variable_scope name. 260 261 Returns: 262 value_tensor: A `Tensor` representing the current value of the metric. 263 update_op: An operation that accumulates the error from a batch of data. 264 265 Raises: 266 ValueError: If `predictions` and `labels` have mismatched shapes, or if 267 `weights` is not `None` and its shape doesn't match `predictions`, or if 268 either `metrics_collections` or `updates_collections` are not a list or 269 tuple. 270 """ 271 return metrics.false_positives( 272 predictions=predictions, 273 labels=labels, 274 weights=weights, 275 metrics_collections=metrics_collections, 276 updates_collections=updates_collections, 277 name=name) 278 279 280def streaming_false_negatives(predictions, 281 labels, 282 weights=None, 283 metrics_collections=None, 284 updates_collections=None, 285 name=None): 286 """Computes the total number of false negatives. 287 288 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 289 290 Args: 291 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 292 be cast to `bool`. 293 labels: The ground truth values, a `Tensor` whose dimensions must match 294 `predictions`. Will be cast to `bool`. 295 weights: Optional `Tensor` whose rank is either 0, or the same rank as 296 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 297 must be either `1`, or the same as the corresponding `labels` 298 dimension). 299 metrics_collections: An optional list of collections that the metric 300 value variable should be added to. 301 updates_collections: An optional list of collections that the metric update 302 ops should be added to. 303 name: An optional variable_scope name. 304 305 Returns: 306 value_tensor: A `Tensor` representing the current value of the metric. 307 update_op: An operation that accumulates the error from a batch of data. 308 309 Raises: 310 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 311 or if either `metrics_collections` or `updates_collections` are not a list 312 or tuple. 313 """ 314 return metrics.false_negatives( 315 predictions=predictions, 316 labels=labels, 317 weights=weights, 318 metrics_collections=metrics_collections, 319 updates_collections=updates_collections, 320 name=name) 321 322 323# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. 324def _broadcast_weights(weights, values): 325 """Broadcast `weights` to the same shape as `values`. 326 327 This returns a version of `weights` following the same broadcast rules as 328 `mul(weights, values)`. When computing a weighted average, use this function 329 to broadcast `weights` before summing them; e.g., 330 `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 331 332 Args: 333 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 334 must be broadcastable to `values` (i.e., all dimensions must be either 335 `1`, or the same as the corresponding `values` dimension). 336 values: `Tensor` of any shape. 337 338 Returns: 339 `weights` broadcast to `values` shape. 340 """ 341 with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: 342 weights_shape = weights.get_shape() 343 values_shape = values.get_shape() 344 if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and 345 weights_shape.is_compatible_with(values_shape)): 346 return weights 347 with ops.control_dependencies((_assert_weights_rank(weights, values),)): 348 return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) 349 350 351def streaming_mean(values, 352 weights=None, 353 metrics_collections=None, 354 updates_collections=None, 355 name=None): 356 """Computes the (weighted) mean of the given values. 357 358 The `streaming_mean` function creates two local variables, `total` and `count` 359 that are used to compute the average of `values`. This average is ultimately 360 returned as `mean` which is an idempotent operation that simply divides 361 `total` by `count`. 362 363 For estimation of the metric over a stream of data, the function creates an 364 `update_op` operation that updates these variables and returns the `mean`. 365 `update_op` increments `total` with the reduced sum of the product of `values` 366 and `weights`, and it increments `count` with the reduced sum of `weights`. 367 368 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 369 370 Args: 371 values: A `Tensor` of arbitrary dimensions. 372 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 373 must be broadcastable to `values` (i.e., all dimensions must be either 374 `1`, or the same as the corresponding `values` dimension). 375 metrics_collections: An optional list of collections that `mean` 376 should be added to. 377 updates_collections: An optional list of collections that `update_op` 378 should be added to. 379 name: An optional variable_scope name. 380 381 Returns: 382 mean: A `Tensor` representing the current mean, the value of `total` divided 383 by `count`. 384 update_op: An operation that increments the `total` and `count` variables 385 appropriately and whose value matches `mean`. 386 387 Raises: 388 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 389 or if either `metrics_collections` or `updates_collections` are not a list 390 or tuple. 391 """ 392 return metrics.mean( 393 values=values, 394 weights=weights, 395 metrics_collections=metrics_collections, 396 updates_collections=updates_collections, 397 name=name) 398 399 400def streaming_mean_tensor(values, 401 weights=None, 402 metrics_collections=None, 403 updates_collections=None, 404 name=None): 405 """Computes the element-wise (weighted) mean of the given tensors. 406 407 In contrast to the `streaming_mean` function which returns a scalar with the 408 mean, this function returns an average tensor with the same shape as the 409 input tensors. 410 411 The `streaming_mean_tensor` function creates two local variables, 412 `total_tensor` and `count_tensor` that are used to compute the average of 413 `values`. This average is ultimately returned as `mean` which is an idempotent 414 operation that simply divides `total` by `count`. 415 416 For estimation of the metric over a stream of data, the function creates an 417 `update_op` operation that updates these variables and returns the `mean`. 418 `update_op` increments `total` with the reduced sum of the product of `values` 419 and `weights`, and it increments `count` with the reduced sum of `weights`. 420 421 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 422 423 Args: 424 values: A `Tensor` of arbitrary dimensions. 425 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 426 must be broadcastable to `values` (i.e., all dimensions must be either 427 `1`, or the same as the corresponding `values` dimension). 428 metrics_collections: An optional list of collections that `mean` 429 should be added to. 430 updates_collections: An optional list of collections that `update_op` 431 should be added to. 432 name: An optional variable_scope name. 433 434 Returns: 435 mean: A float `Tensor` representing the current mean, the value of `total` 436 divided by `count`. 437 update_op: An operation that increments the `total` and `count` variables 438 appropriately and whose value matches `mean`. 439 440 Raises: 441 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 442 or if either `metrics_collections` or `updates_collections` are not a list 443 or tuple. 444 """ 445 return metrics.mean_tensor( 446 values=values, 447 weights=weights, 448 metrics_collections=metrics_collections, 449 updates_collections=updates_collections, 450 name=name) 451 452 453def streaming_accuracy(predictions, 454 labels, 455 weights=None, 456 metrics_collections=None, 457 updates_collections=None, 458 name=None): 459 """Calculates how often `predictions` matches `labels`. 460 461 The `streaming_accuracy` function creates two local variables, `total` and 462 `count` that are used to compute the frequency with which `predictions` 463 matches `labels`. This frequency is ultimately returned as `accuracy`: an 464 idempotent operation that simply divides `total` by `count`. 465 466 For estimation of the metric over a stream of data, the function creates an 467 `update_op` operation that updates these variables and returns the `accuracy`. 468 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 469 where the corresponding elements of `predictions` and `labels` match and 0.0 470 otherwise. Then `update_op` increments `total` with the reduced sum of the 471 product of `weights` and `is_correct`, and it increments `count` with the 472 reduced sum of `weights`. 473 474 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 475 476 Args: 477 predictions: The predicted values, a `Tensor` of any shape. 478 labels: The ground truth values, a `Tensor` whose shape matches 479 `predictions`. 480 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 481 must be broadcastable to `labels` (i.e., all dimensions must be either 482 `1`, or the same as the corresponding `labels` dimension). 483 metrics_collections: An optional list of collections that `accuracy` should 484 be added to. 485 updates_collections: An optional list of collections that `update_op` should 486 be added to. 487 name: An optional variable_scope name. 488 489 Returns: 490 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 491 by `count`. 492 update_op: An operation that increments the `total` and `count` variables 493 appropriately and whose value matches `accuracy`. 494 495 Raises: 496 ValueError: If `predictions` and `labels` have mismatched shapes, or if 497 `weights` is not `None` and its shape doesn't match `predictions`, or if 498 either `metrics_collections` or `updates_collections` are not a list or 499 tuple. 500 """ 501 return metrics.accuracy( 502 predictions=predictions, 503 labels=labels, 504 weights=weights, 505 metrics_collections=metrics_collections, 506 updates_collections=updates_collections, 507 name=name) 508 509 510def streaming_precision(predictions, 511 labels, 512 weights=None, 513 metrics_collections=None, 514 updates_collections=None, 515 name=None): 516 """Computes the precision of the predictions with respect to the labels. 517 518 The `streaming_precision` function creates two local variables, 519 `true_positives` and `false_positives`, that are used to compute the 520 precision. This value is ultimately returned as `precision`, an idempotent 521 operation that simply divides `true_positives` by the sum of `true_positives` 522 and `false_positives`. 523 524 For estimation of the metric over a stream of data, the function creates an 525 `update_op` operation that updates these variables and returns the 526 `precision`. `update_op` weights each prediction by the corresponding value in 527 `weights`. 528 529 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 530 531 Args: 532 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 533 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 534 match `predictions`. 535 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 536 must be broadcastable to `labels` (i.e., all dimensions must be either 537 `1`, or the same as the corresponding `labels` dimension). 538 metrics_collections: An optional list of collections that `precision` should 539 be added to. 540 updates_collections: An optional list of collections that `update_op` should 541 be added to. 542 name: An optional variable_scope name. 543 544 Returns: 545 precision: Scalar float `Tensor` with the value of `true_positives` 546 divided by the sum of `true_positives` and `false_positives`. 547 update_op: `Operation` that increments `true_positives` and 548 `false_positives` variables appropriately and whose value matches 549 `precision`. 550 551 Raises: 552 ValueError: If `predictions` and `labels` have mismatched shapes, or if 553 `weights` is not `None` and its shape doesn't match `predictions`, or if 554 either `metrics_collections` or `updates_collections` are not a list or 555 tuple. 556 """ 557 return metrics.precision( 558 predictions=predictions, 559 labels=labels, 560 weights=weights, 561 metrics_collections=metrics_collections, 562 updates_collections=updates_collections, 563 name=name) 564 565 566def streaming_recall(predictions, 567 labels, 568 weights=None, 569 metrics_collections=None, 570 updates_collections=None, 571 name=None): 572 """Computes the recall of the predictions with respect to the labels. 573 574 The `streaming_recall` function creates two local variables, `true_positives` 575 and `false_negatives`, that are used to compute the recall. This value is 576 ultimately returned as `recall`, an idempotent operation that simply divides 577 `true_positives` by the sum of `true_positives` and `false_negatives`. 578 579 For estimation of the metric over a stream of data, the function creates an 580 `update_op` that updates these variables and returns the `recall`. `update_op` 581 weights each prediction by the corresponding value in `weights`. 582 583 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 584 585 Args: 586 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 587 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 588 match `predictions`. 589 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 590 must be broadcastable to `labels` (i.e., all dimensions must be either 591 `1`, or the same as the corresponding `labels` dimension). 592 metrics_collections: An optional list of collections that `recall` should 593 be added to. 594 updates_collections: An optional list of collections that `update_op` should 595 be added to. 596 name: An optional variable_scope name. 597 598 Returns: 599 recall: Scalar float `Tensor` with the value of `true_positives` divided 600 by the sum of `true_positives` and `false_negatives`. 601 update_op: `Operation` that increments `true_positives` and 602 `false_negatives` variables appropriately and whose value matches 603 `recall`. 604 605 Raises: 606 ValueError: If `predictions` and `labels` have mismatched shapes, or if 607 `weights` is not `None` and its shape doesn't match `predictions`, or if 608 either `metrics_collections` or `updates_collections` are not a list or 609 tuple. 610 """ 611 return metrics.recall( 612 predictions=predictions, 613 labels=labels, 614 weights=weights, 615 metrics_collections=metrics_collections, 616 updates_collections=updates_collections, 617 name=name) 618 619 620def _true_negatives(labels, 621 predictions, 622 weights=None, 623 metrics_collections=None, 624 updates_collections=None, 625 name=None): 626 """Sum the weights of true negatives. 627 628 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 629 630 Args: 631 labels: The ground truth values, a `Tensor` whose dimensions must match 632 `predictions`. Will be cast to `bool`. 633 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 634 be cast to `bool`. 635 weights: Optional `Tensor` whose rank is either 0, or the same rank as 636 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 637 be either `1`, or the same as the corresponding `labels` dimension). 638 metrics_collections: An optional list of collections that the metric 639 value variable should be added to. 640 updates_collections: An optional list of collections that the metric update 641 ops should be added to. 642 name: An optional variable_scope name. 643 644 Returns: 645 value_tensor: A `Tensor` representing the current value of the metric. 646 update_op: An operation that accumulates the error from a batch of data. 647 648 Raises: 649 ValueError: If `predictions` and `labels` have mismatched shapes, or if 650 `weights` is not `None` and its shape doesn't match `predictions`, or if 651 either `metrics_collections` or `updates_collections` are not a list or 652 tuple. 653 """ 654 with variable_scope.variable_scope(name, 'true_negatives', 655 (predictions, labels, weights)): 656 657 predictions, labels, weights = _remove_squeezable_dimensions( 658 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 659 labels=math_ops.cast(labels, dtype=dtypes.bool), 660 weights=weights) 661 is_true_negative = math_ops.logical_and( 662 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 663 return _count_condition(is_true_negative, weights, metrics_collections, 664 updates_collections) 665 666 667def streaming_false_positive_rate(predictions, 668 labels, 669 weights=None, 670 metrics_collections=None, 671 updates_collections=None, 672 name=None): 673 """Computes the false positive rate of predictions with respect to labels. 674 675 The `false_positive_rate` function creates two local variables, 676 `false_positives` and `true_negatives`, that are used to compute the 677 false positive rate. This value is ultimately returned as 678 `false_positive_rate`, an idempotent operation that simply divides 679 `false_positives` by the sum of `false_positives` and `true_negatives`. 680 681 For estimation of the metric over a stream of data, the function creates an 682 `update_op` operation that updates these variables and returns the 683 `false_positive_rate`. `update_op` weights each prediction by the 684 corresponding value in `weights`. 685 686 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 687 688 Args: 689 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 690 be cast to `bool`. 691 labels: The ground truth values, a `Tensor` whose dimensions must match 692 `predictions`. Will be cast to `bool`. 693 weights: Optional `Tensor` whose rank is either 0, or the same rank as 694 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 695 be either `1`, or the same as the corresponding `labels` dimension). 696 metrics_collections: An optional list of collections that 697 `false_positive_rate` should be added to. 698 updates_collections: An optional list of collections that `update_op` should 699 be added to. 700 name: An optional variable_scope name. 701 702 Returns: 703 false_positive_rate: Scalar float `Tensor` with the value of 704 `false_positives` divided by the sum of `false_positives` and 705 `true_negatives`. 706 update_op: `Operation` that increments `false_positives` and 707 `true_negatives` variables appropriately and whose value matches 708 `false_positive_rate`. 709 710 Raises: 711 ValueError: If `predictions` and `labels` have mismatched shapes, or if 712 `weights` is not `None` and its shape doesn't match `predictions`, or if 713 either `metrics_collections` or `updates_collections` are not a list or 714 tuple. 715 """ 716 with variable_scope.variable_scope(name, 'false_positive_rate', 717 (predictions, labels, weights)): 718 predictions, labels, weights = _remove_squeezable_dimensions( 719 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 720 labels=math_ops.cast(labels, dtype=dtypes.bool), 721 weights=weights) 722 723 false_p, false_positives_update_op = metrics.false_positives( 724 labels, 725 predictions, 726 weights, 727 metrics_collections=None, 728 updates_collections=None, 729 name=None) 730 true_n, true_negatives_update_op = _true_negatives( 731 labels, 732 predictions, 733 weights, 734 metrics_collections=None, 735 updates_collections=None, 736 name=None) 737 738 def compute_fpr(fp, tn, name): 739 return array_ops.where( 740 math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) 741 742 fpr = compute_fpr(false_p, true_n, 'value') 743 update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, 744 'update_op') 745 746 if metrics_collections: 747 ops.add_to_collections(metrics_collections, fpr) 748 749 if updates_collections: 750 ops.add_to_collections(updates_collections, update_op) 751 752 return fpr, update_op 753 754 755def streaming_false_negative_rate(predictions, 756 labels, 757 weights=None, 758 metrics_collections=None, 759 updates_collections=None, 760 name=None): 761 """Computes the false negative rate of predictions with respect to labels. 762 763 The `false_negative_rate` function creates two local variables, 764 `false_negatives` and `true_positives`, that are used to compute the 765 false positive rate. This value is ultimately returned as 766 `false_negative_rate`, an idempotent operation that simply divides 767 `false_negatives` by the sum of `false_negatives` and `true_positives`. 768 769 For estimation of the metric over a stream of data, the function creates an 770 `update_op` operation that updates these variables and returns the 771 `false_negative_rate`. `update_op` weights each prediction by the 772 corresponding value in `weights`. 773 774 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 775 776 Args: 777 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 778 be cast to `bool`. 779 labels: The ground truth values, a `Tensor` whose dimensions must match 780 `predictions`. Will be cast to `bool`. 781 weights: Optional `Tensor` whose rank is either 0, or the same rank as 782 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 783 be either `1`, or the same as the corresponding `labels` dimension). 784 metrics_collections: An optional list of collections that 785 `false_negative_rate` should be added to. 786 updates_collections: An optional list of collections that `update_op` should 787 be added to. 788 name: An optional variable_scope name. 789 790 Returns: 791 false_negative_rate: Scalar float `Tensor` with the value of 792 `false_negatives` divided by the sum of `false_negatives` and 793 `true_positives`. 794 update_op: `Operation` that increments `false_negatives` and 795 `true_positives` variables appropriately and whose value matches 796 `false_negative_rate`. 797 798 Raises: 799 ValueError: If `predictions` and `labels` have mismatched shapes, or if 800 `weights` is not `None` and its shape doesn't match `predictions`, or if 801 either `metrics_collections` or `updates_collections` are not a list or 802 tuple. 803 """ 804 with variable_scope.variable_scope(name, 'false_negative_rate', 805 (predictions, labels, weights)): 806 predictions, labels, weights = _remove_squeezable_dimensions( 807 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 808 labels=math_ops.cast(labels, dtype=dtypes.bool), 809 weights=weights) 810 811 false_n, false_negatives_update_op = metrics.false_negatives( 812 labels, 813 predictions, 814 weights, 815 metrics_collections=None, 816 updates_collections=None, 817 name=None) 818 true_p, true_positives_update_op = metrics.true_positives( 819 labels, 820 predictions, 821 weights, 822 metrics_collections=None, 823 updates_collections=None, 824 name=None) 825 826 def compute_fnr(fn, tp, name): 827 return array_ops.where( 828 math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) 829 830 fnr = compute_fnr(false_n, true_p, 'value') 831 update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, 832 'update_op') 833 834 if metrics_collections: 835 ops.add_to_collections(metrics_collections, fnr) 836 837 if updates_collections: 838 ops.add_to_collections(updates_collections, update_op) 839 840 return fnr, update_op 841 842 843def _streaming_confusion_matrix_at_thresholds(predictions, 844 labels, 845 thresholds, 846 weights=None, 847 includes=None): 848 """Computes true_positives, false_negatives, true_negatives, false_positives. 849 850 This function creates up to four local variables, `true_positives`, 851 `true_negatives`, `false_positives` and `false_negatives`. 852 `true_positive[i]` is defined as the total weight of values in `predictions` 853 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 854 `false_negatives[i]` is defined as the total weight of values in `predictions` 855 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 856 `true_negatives[i]` is defined as the total weight of values in `predictions` 857 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 858 `false_positives[i]` is defined as the total weight of values in `predictions` 859 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 860 861 For estimation of these metrics over a stream of data, for each metric the 862 function respectively creates an `update_op` operation that updates the 863 variable and returns its value. 864 865 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 866 867 Args: 868 predictions: A floating point `Tensor` of arbitrary shape and whose values 869 are in the range `[0, 1]`. 870 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 871 to `bool`. 872 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 873 weights: Optional `Tensor` whose rank is either 0, or the same rank as 874 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 875 must be either `1`, or the same as the corresponding `labels` 876 dimension). 877 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 878 default to all four. 879 880 Returns: 881 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 882 `includes`. 883 update_ops: Dict of operations that increments the `values`. Keys are from 884 `includes`. 885 886 Raises: 887 ValueError: If `predictions` and `labels` have mismatched shapes, or if 888 `weights` is not `None` and its shape doesn't match `predictions`, or if 889 `includes` contains invalid keys. 890 """ 891 all_includes = ('tp', 'fn', 'tn', 'fp') 892 if includes is None: 893 includes = all_includes 894 else: 895 for include in includes: 896 if include not in all_includes: 897 raise ValueError('Invaild key: %s.' % include) 898 899 predictions, labels, weights = _remove_squeezable_dimensions( 900 predictions, labels, weights) 901 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 902 903 num_thresholds = len(thresholds) 904 905 # Reshape predictions and labels. 906 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 907 labels_2d = array_ops.reshape( 908 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 909 910 # Use static shape if known. 911 num_predictions = predictions_2d.get_shape().as_list()[0] 912 913 # Otherwise use dynamic shape. 914 if num_predictions is None: 915 num_predictions = array_ops.shape(predictions_2d)[0] 916 thresh_tiled = array_ops.tile( 917 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 918 array_ops.stack([1, num_predictions])) 919 920 # Tile the predictions after thresholding them across different thresholds. 921 pred_is_pos = math_ops.greater( 922 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 923 thresh_tiled) 924 if ('fn' in includes) or ('tn' in includes): 925 pred_is_neg = math_ops.logical_not(pred_is_pos) 926 927 # Tile labels by number of thresholds 928 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 929 if ('fp' in includes) or ('tn' in includes): 930 label_is_neg = math_ops.logical_not(label_is_pos) 931 932 if weights is not None: 933 broadcast_weights = weights_broadcast_ops.broadcast_weights( 934 math_ops.to_float(weights), predictions) 935 weights_tiled = array_ops.tile( 936 array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) 937 thresh_tiled.get_shape().assert_is_compatible_with( 938 weights_tiled.get_shape()) 939 else: 940 weights_tiled = None 941 942 values = {} 943 update_ops = {} 944 945 if 'tp' in includes: 946 true_positives = _create_local('true_positives', shape=[num_thresholds]) 947 is_true_positive = math_ops.to_float( 948 math_ops.logical_and(label_is_pos, pred_is_pos)) 949 if weights_tiled is not None: 950 is_true_positive *= weights_tiled 951 update_ops['tp'] = state_ops.assign_add(true_positives, 952 math_ops.reduce_sum( 953 is_true_positive, 1)) 954 values['tp'] = true_positives 955 956 if 'fn' in includes: 957 false_negatives = _create_local('false_negatives', shape=[num_thresholds]) 958 is_false_negative = math_ops.to_float( 959 math_ops.logical_and(label_is_pos, pred_is_neg)) 960 if weights_tiled is not None: 961 is_false_negative *= weights_tiled 962 update_ops['fn'] = state_ops.assign_add(false_negatives, 963 math_ops.reduce_sum( 964 is_false_negative, 1)) 965 values['fn'] = false_negatives 966 967 if 'tn' in includes: 968 true_negatives = _create_local('true_negatives', shape=[num_thresholds]) 969 is_true_negative = math_ops.to_float( 970 math_ops.logical_and(label_is_neg, pred_is_neg)) 971 if weights_tiled is not None: 972 is_true_negative *= weights_tiled 973 update_ops['tn'] = state_ops.assign_add(true_negatives, 974 math_ops.reduce_sum( 975 is_true_negative, 1)) 976 values['tn'] = true_negatives 977 978 if 'fp' in includes: 979 false_positives = _create_local('false_positives', shape=[num_thresholds]) 980 is_false_positive = math_ops.to_float( 981 math_ops.logical_and(label_is_neg, pred_is_pos)) 982 if weights_tiled is not None: 983 is_false_positive *= weights_tiled 984 update_ops['fp'] = state_ops.assign_add(false_positives, 985 math_ops.reduce_sum( 986 is_false_positive, 1)) 987 values['fp'] = false_positives 988 989 return values, update_ops 990 991 992def streaming_true_positives_at_thresholds(predictions, 993 labels, 994 thresholds, 995 weights=None): 996 values, update_ops = _streaming_confusion_matrix_at_thresholds( 997 predictions, labels, thresholds, weights=weights, includes=('tp',)) 998 return values['tp'], update_ops['tp'] 999 1000 1001def streaming_false_negatives_at_thresholds(predictions, 1002 labels, 1003 thresholds, 1004 weights=None): 1005 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1006 predictions, labels, thresholds, weights=weights, includes=('fn',)) 1007 return values['fn'], update_ops['fn'] 1008 1009 1010def streaming_false_positives_at_thresholds(predictions, 1011 labels, 1012 thresholds, 1013 weights=None): 1014 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1015 predictions, labels, thresholds, weights=weights, includes=('fp',)) 1016 return values['fp'], update_ops['fp'] 1017 1018 1019def streaming_true_negatives_at_thresholds(predictions, 1020 labels, 1021 thresholds, 1022 weights=None): 1023 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1024 predictions, labels, thresholds, weights=weights, includes=('tn',)) 1025 return values['tn'], update_ops['tn'] 1026 1027 1028def streaming_curve_points(labels=None, 1029 predictions=None, 1030 weights=None, 1031 num_thresholds=200, 1032 metrics_collections=None, 1033 updates_collections=None, 1034 curve='ROC', 1035 name=None): 1036 """Computes curve (ROC or PR) values for a prespecified number of points. 1037 1038 The `streaming_curve_points` function creates four local variables, 1039 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1040 that are used to compute the curve values. To discretize the curve, a linearly 1041 spaced set of thresholds is used to compute pairs of recall and precision 1042 values. 1043 1044 For best results, `predictions` should be distributed approximately uniformly 1045 in the range [0, 1] and not peaked around 0 or 1. 1046 1047 For estimation of the metric over a stream of data, the function creates an 1048 `update_op` operation that updates these variables. 1049 1050 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1051 1052 Args: 1053 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1054 `bool`. 1055 predictions: A floating point `Tensor` of arbitrary shape and whose values 1056 are in the range `[0, 1]`. 1057 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1058 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1059 be either `1`, or the same as the corresponding `labels` dimension). 1060 num_thresholds: The number of thresholds to use when discretizing the roc 1061 curve. 1062 metrics_collections: An optional list of collections that `auc` should be 1063 added to. 1064 updates_collections: An optional list of collections that `update_op` should 1065 be added to. 1066 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 1067 'PR' for the Precision-Recall-curve. 1068 name: An optional variable_scope name. 1069 1070 Returns: 1071 points: A `Tensor` with shape [num_thresholds, 2] that contains points of 1072 the curve. 1073 update_op: An operation that increments the `true_positives`, 1074 `true_negatives`, `false_positives` and `false_negatives` variables. 1075 1076 Raises: 1077 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1078 `weights` is not `None` and its shape doesn't match `predictions`, or if 1079 either `metrics_collections` or `updates_collections` are not a list or 1080 tuple. 1081 1082 TODO(chizeng): Consider rewriting this method to make use of logic within the 1083 streaming_precision_recall_at_equal_thresholds method (to improve run time). 1084 """ 1085 with variable_scope.variable_scope(name, 'curve_points', 1086 (labels, predictions, weights)): 1087 if curve != 'ROC' and curve != 'PR': 1088 raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) 1089 kepsilon = 1e-7 # to account for floating point imprecisions 1090 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1091 for i in range(num_thresholds - 2)] 1092 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 1093 1094 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1095 labels=labels, 1096 predictions=predictions, 1097 thresholds=thresholds, 1098 weights=weights) 1099 1100 # Add epsilons to avoid dividing by 0. 1101 epsilon = 1.0e-6 1102 1103 def compute_points(tp, fn, tn, fp): 1104 """Computes the roc-auc or pr-auc based on confusion counts.""" 1105 rec = math_ops.div(tp + epsilon, tp + fn + epsilon) 1106 if curve == 'ROC': 1107 fp_rate = math_ops.div(fp, fp + tn + epsilon) 1108 return fp_rate, rec 1109 else: # curve == 'PR'. 1110 prec = math_ops.div(tp + epsilon, tp + fp + epsilon) 1111 return rec, prec 1112 1113 xs, ys = compute_points(values['tp'], values['fn'], values['tn'], 1114 values['fp']) 1115 points = array_ops.stack([xs, ys], axis=1) 1116 update_op = control_flow_ops.group(*update_ops.values()) 1117 1118 if metrics_collections: 1119 ops.add_to_collections(metrics_collections, points) 1120 1121 if updates_collections: 1122 ops.add_to_collections(updates_collections, update_op) 1123 1124 return points, update_op 1125 1126 1127def streaming_auc(predictions, 1128 labels, 1129 weights=None, 1130 num_thresholds=200, 1131 metrics_collections=None, 1132 updates_collections=None, 1133 curve='ROC', 1134 name=None): 1135 """Computes the approximate AUC via a Riemann sum. 1136 1137 The `streaming_auc` function creates four local variables, `true_positives`, 1138 `true_negatives`, `false_positives` and `false_negatives` that are used to 1139 compute the AUC. To discretize the AUC curve, a linearly spaced set of 1140 thresholds is used to compute pairs of recall and precision values. The area 1141 under the ROC-curve is therefore computed using the height of the recall 1142 values by the false positive rate, while the area under the PR-curve is the 1143 computed using the height of the precision values by the recall. 1144 1145 This value is ultimately returned as `auc`, an idempotent operation that 1146 computes the area under a discretized curve of precision versus recall values 1147 (computed using the aforementioned variables). The `num_thresholds` variable 1148 controls the degree of discretization with larger numbers of thresholds more 1149 closely approximating the true AUC. The quality of the approximation may vary 1150 dramatically depending on `num_thresholds`. 1151 1152 For best results, `predictions` should be distributed approximately uniformly 1153 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 1154 approximation may be poor if this is not the case. 1155 1156 For estimation of the metric over a stream of data, the function creates an 1157 `update_op` operation that updates these variables and returns the `auc`. 1158 1159 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1160 1161 Args: 1162 predictions: A floating point `Tensor` of arbitrary shape and whose values 1163 are in the range `[0, 1]`. 1164 labels: A `bool` `Tensor` whose shape matches `predictions`. 1165 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1166 must be broadcastable to `labels` (i.e., all dimensions must be either 1167 `1`, or the same as the corresponding `labels` dimension). 1168 num_thresholds: The number of thresholds to use when discretizing the roc 1169 curve. 1170 metrics_collections: An optional list of collections that `auc` should be 1171 added to. 1172 updates_collections: An optional list of collections that `update_op` should 1173 be added to. 1174 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 1175 'PR' for the Precision-Recall-curve. 1176 name: An optional variable_scope name. 1177 1178 Returns: 1179 auc: A scalar `Tensor` representing the current area-under-curve. 1180 update_op: An operation that increments the `true_positives`, 1181 `true_negatives`, `false_positives` and `false_negatives` variables 1182 appropriately and whose value matches `auc`. 1183 1184 Raises: 1185 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1186 `weights` is not `None` and its shape doesn't match `predictions`, or if 1187 either `metrics_collections` or `updates_collections` are not a list or 1188 tuple. 1189 """ 1190 return metrics.auc( 1191 predictions=predictions, 1192 labels=labels, 1193 weights=weights, 1194 metrics_collections=metrics_collections, 1195 num_thresholds=num_thresholds, 1196 curve=curve, 1197 updates_collections=updates_collections, 1198 name=name) 1199 1200 1201def streaming_precision_recall_at_equal_thresholds(predictions, 1202 labels, 1203 num_thresholds=None, 1204 weights=None, 1205 name=None, 1206 use_locking=None): 1207 """A helper method for creating metrics related to precision-recall curves. 1208 1209 These values are true positives, false negatives, true negatives, false 1210 positives, precision, and recall. This function returns a data structure that 1211 contains ops within it. 1212 1213 Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N) 1214 space and run time), this op exhibits O(T + N) space and run time, where T is 1215 the number of thresholds and N is the size of the predictions tensor. Hence, 1216 it may be advantageous to use this function when `predictions` is big. 1217 1218 For instance, prefer this method for per-pixel classification tasks, for which 1219 the predictions tensor may be very large. 1220 1221 Each number in `predictions`, a float in `[0, 1]`, is compared with its 1222 corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at 1223 each threshold. This is then multiplied with `weights` which can be used to 1224 reweight certain values, or more commonly used for masking values. 1225 1226 Args: 1227 predictions: A floating point `Tensor` of arbitrary shape and whose values 1228 are in the range `[0, 1]`. 1229 labels: A bool `Tensor` whose shape matches `predictions`. 1230 num_thresholds: Optional; Number of thresholds, evenly distributed in 1231 `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins 1232 is 1 less than `num_thresholds`. Using an even `num_thresholds` value 1233 instead of an odd one may yield unfriendly edges for bins. 1234 weights: Optional; If provided, a `Tensor` that has the same dtype as, 1235 and broadcastable to, `predictions`. This tensor is multplied by counts. 1236 name: Optional; variable_scope name. If not provided, the string 1237 'precision_recall_at_equal_threshold' is used. 1238 use_locking: Optional; If True, the op will be protected by a lock. 1239 Otherwise, the behavior is undefined, but may exhibit less contention. 1240 Defaults to True. 1241 1242 Returns: 1243 result: A named tuple (See PrecisionRecallData within the implementation of 1244 this function) with properties that are variables of shape 1245 `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, 1246 precision, recall, thresholds. 1247 update_op: An op that accumulates values. 1248 1249 Raises: 1250 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1251 `weights` is not `None` and its shape doesn't match `predictions`, or if 1252 `includes` contains invalid keys. 1253 """ 1254 # Disable the invalid-name checker so that we can capitalize the name. 1255 # pylint: disable=invalid-name 1256 PrecisionRecallData = collections_lib.namedtuple( 1257 'PrecisionRecallData', 1258 ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds']) 1259 # pylint: enable=invalid-name 1260 1261 if num_thresholds is None: 1262 num_thresholds = 201 1263 1264 if weights is None: 1265 weights = 1.0 1266 1267 if use_locking is None: 1268 use_locking = True 1269 1270 check_ops.assert_type(labels, dtypes.bool) 1271 1272 dtype = predictions.dtype 1273 with variable_scope.variable_scope(name, 1274 'precision_recall_at_equal_thresholds', 1275 (labels, predictions, weights)): 1276 # Make sure that predictions are within [0.0, 1.0]. 1277 with ops.control_dependencies([ 1278 check_ops.assert_greater_equal( 1279 predictions, 1280 math_ops.cast(0.0, dtype=predictions.dtype), 1281 message='predictions must be in [0, 1]'), 1282 check_ops.assert_less_equal( 1283 predictions, 1284 math_ops.cast(1.0, dtype=predictions.dtype), 1285 message='predictions must be in [0, 1]') 1286 ]): 1287 predictions, labels, weights = _remove_squeezable_dimensions( 1288 predictions=predictions, labels=labels, weights=weights) 1289 1290 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1291 1292 # We cast to float to ensure we have 0.0 or 1.0. 1293 f_labels = math_ops.cast(labels, dtype) 1294 1295 # Get weighted true/false labels. 1296 true_labels = f_labels * weights 1297 false_labels = (1.0 - f_labels) * weights 1298 1299 # Flatten predictions and labels. 1300 predictions = array_ops.reshape(predictions, [-1]) 1301 true_labels = array_ops.reshape(true_labels, [-1]) 1302 false_labels = array_ops.reshape(false_labels, [-1]) 1303 1304 # To compute TP/FP/TN/FN, we are measuring a binary classifier 1305 # C(t) = (predictions >= t) 1306 # at each threshold 't'. So we have 1307 # TP(t) = sum( C(t) * true_labels ) 1308 # FP(t) = sum( C(t) * false_labels ) 1309 # 1310 # But, computing C(t) requires computation for each t. To make it fast, 1311 # observe that C(t) is a cumulative integral, and so if we have 1312 # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 1313 # where n = num_thresholds, and if we can compute the bucket function 1314 # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 1315 # then we get 1316 # C(t_i) = sum( B(j), j >= i ) 1317 # which is the reversed cumulative sum in tf.cumsum(). 1318 # 1319 # We can compute B(i) efficiently by taking advantage of the fact that 1320 # our thresholds are evenly distributed, in that 1321 # width = 1.0 / (num_thresholds - 1) 1322 # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 1323 # Given a prediction value p, we can map it to its bucket by 1324 # bucket_index(p) = floor( p * (num_thresholds - 1) ) 1325 # so we can use tf.scatter_add() to update the buckets in one pass. 1326 # 1327 # This implementation exhibits a run time and space complexity of O(T + N), 1328 # where T is the number of thresholds and N is the size of predictions. 1329 # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead 1330 # exhibit a complexity of O(T * N). 1331 1332 # Compute the bucket indices for each prediction value. 1333 bucket_indices = math_ops.cast( 1334 math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) 1335 1336 with ops.name_scope('variables'): 1337 tp_buckets_v = _create_local( 1338 'tp_buckets', shape=[num_thresholds], dtype=dtype) 1339 fp_buckets_v = _create_local( 1340 'fp_buckets', shape=[num_thresholds], dtype=dtype) 1341 1342 with ops.name_scope('update_op'): 1343 update_tp = state_ops.scatter_add( 1344 tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking) 1345 update_fp = state_ops.scatter_add( 1346 fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking) 1347 1348 # Set up the cumulative sums to compute the actual metrics. 1349 tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp') 1350 fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp') 1351 # fn = sum(true_labels) - tp 1352 # = sum(tp_buckets) - tp 1353 # = tp[0] - tp 1354 # Similarly, 1355 # tn = fp[0] - fp 1356 tn = fp[0] - fp 1357 fn = tp[0] - tp 1358 1359 # We use a minimum to prevent division by 0. 1360 epsilon = 1e-7 1361 precision = tp / math_ops.maximum(epsilon, tp + fp) 1362 recall = tp / math_ops.maximum(epsilon, tp + fn) 1363 1364 result = PrecisionRecallData( 1365 tp=tp, 1366 fp=fp, 1367 tn=tn, 1368 fn=fn, 1369 precision=precision, 1370 recall=recall, 1371 thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) 1372 update_op = control_flow_ops.group(update_tp, update_fp) 1373 return result, update_op 1374 1375 1376def streaming_specificity_at_sensitivity(predictions, 1377 labels, 1378 sensitivity, 1379 weights=None, 1380 num_thresholds=200, 1381 metrics_collections=None, 1382 updates_collections=None, 1383 name=None): 1384 """Computes the specificity at a given sensitivity. 1385 1386 The `streaming_specificity_at_sensitivity` function creates four local 1387 variables, `true_positives`, `true_negatives`, `false_positives` and 1388 `false_negatives` that are used to compute the specificity at the given 1389 sensitivity value. The threshold for the given sensitivity value is computed 1390 and used to evaluate the corresponding specificity. 1391 1392 For estimation of the metric over a stream of data, the function creates an 1393 `update_op` operation that updates these variables and returns the 1394 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 1395 `false_positives` and `false_negatives` counts with the weight of each case 1396 found in the `predictions` and `labels`. 1397 1398 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1399 1400 For additional information about specificity and sensitivity, see the 1401 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1402 1403 Args: 1404 predictions: A floating point `Tensor` of arbitrary shape and whose values 1405 are in the range `[0, 1]`. 1406 labels: A `bool` `Tensor` whose shape matches `predictions`. 1407 sensitivity: A scalar value in range `[0, 1]`. 1408 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1409 must be broadcastable to `labels` (i.e., all dimensions must be either 1410 `1`, or the same as the corresponding `labels` dimension). 1411 num_thresholds: The number of thresholds to use for matching the given 1412 sensitivity. 1413 metrics_collections: An optional list of collections that `specificity` 1414 should be added to. 1415 updates_collections: An optional list of collections that `update_op` should 1416 be added to. 1417 name: An optional variable_scope name. 1418 1419 Returns: 1420 specificity: A scalar `Tensor` representing the specificity at the given 1421 `specificity` value. 1422 update_op: An operation that increments the `true_positives`, 1423 `true_negatives`, `false_positives` and `false_negatives` variables 1424 appropriately and whose value matches `specificity`. 1425 1426 Raises: 1427 ValueError: If `predictions` and `labels` have mismatched shapes, if 1428 `weights` is not `None` and its shape doesn't match `predictions`, or if 1429 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 1430 or `updates_collections` are not a list or tuple. 1431 """ 1432 return metrics.specificity_at_sensitivity( 1433 sensitivity=sensitivity, 1434 num_thresholds=num_thresholds, 1435 predictions=predictions, 1436 labels=labels, 1437 weights=weights, 1438 metrics_collections=metrics_collections, 1439 updates_collections=updates_collections, 1440 name=name) 1441 1442 1443def streaming_sensitivity_at_specificity(predictions, 1444 labels, 1445 specificity, 1446 weights=None, 1447 num_thresholds=200, 1448 metrics_collections=None, 1449 updates_collections=None, 1450 name=None): 1451 """Computes the sensitivity at a given specificity. 1452 1453 The `streaming_sensitivity_at_specificity` function creates four local 1454 variables, `true_positives`, `true_negatives`, `false_positives` and 1455 `false_negatives` that are used to compute the sensitivity at the given 1456 specificity value. The threshold for the given specificity value is computed 1457 and used to evaluate the corresponding sensitivity. 1458 1459 For estimation of the metric over a stream of data, the function creates an 1460 `update_op` operation that updates these variables and returns the 1461 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 1462 `false_positives` and `false_negatives` counts with the weight of each case 1463 found in the `predictions` and `labels`. 1464 1465 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1466 1467 For additional information about specificity and sensitivity, see the 1468 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1469 1470 Args: 1471 predictions: A floating point `Tensor` of arbitrary shape and whose values 1472 are in the range `[0, 1]`. 1473 labels: A `bool` `Tensor` whose shape matches `predictions`. 1474 specificity: A scalar value in range `[0, 1]`. 1475 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1476 must be broadcastable to `labels` (i.e., all dimensions must be either 1477 `1`, or the same as the corresponding `labels` dimension). 1478 num_thresholds: The number of thresholds to use for matching the given 1479 specificity. 1480 metrics_collections: An optional list of collections that `sensitivity` 1481 should be added to. 1482 updates_collections: An optional list of collections that `update_op` should 1483 be added to. 1484 name: An optional variable_scope name. 1485 1486 Returns: 1487 sensitivity: A scalar `Tensor` representing the sensitivity at the given 1488 `specificity` value. 1489 update_op: An operation that increments the `true_positives`, 1490 `true_negatives`, `false_positives` and `false_negatives` variables 1491 appropriately and whose value matches `sensitivity`. 1492 1493 Raises: 1494 ValueError: If `predictions` and `labels` have mismatched shapes, if 1495 `weights` is not `None` and its shape doesn't match `predictions`, or if 1496 `specificity` is not between 0 and 1, or if either `metrics_collections` 1497 or `updates_collections` are not a list or tuple. 1498 """ 1499 return metrics.sensitivity_at_specificity( 1500 specificity=specificity, 1501 num_thresholds=num_thresholds, 1502 predictions=predictions, 1503 labels=labels, 1504 weights=weights, 1505 metrics_collections=metrics_collections, 1506 updates_collections=updates_collections, 1507 name=name) 1508 1509 1510def streaming_precision_at_thresholds(predictions, 1511 labels, 1512 thresholds, 1513 weights=None, 1514 metrics_collections=None, 1515 updates_collections=None, 1516 name=None): 1517 """Computes precision values for different `thresholds` on `predictions`. 1518 1519 The `streaming_precision_at_thresholds` function creates four local variables, 1520 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1521 for various values of thresholds. `precision[i]` is defined as the total 1522 weight of values in `predictions` above `thresholds[i]` whose corresponding 1523 entry in `labels` is `True`, divided by the total weight of values in 1524 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 1525 false_positives[i])`). 1526 1527 For estimation of the metric over a stream of data, the function creates an 1528 `update_op` operation that updates these variables and returns the 1529 `precision`. 1530 1531 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1532 1533 Args: 1534 predictions: A floating point `Tensor` of arbitrary shape and whose values 1535 are in the range `[0, 1]`. 1536 labels: A `bool` `Tensor` whose shape matches `predictions`. 1537 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1538 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1539 must be broadcastable to `labels` (i.e., all dimensions must be either 1540 `1`, or the same as the corresponding `labels` dimension). 1541 metrics_collections: An optional list of collections that `precision` should 1542 be added to. 1543 updates_collections: An optional list of collections that `update_op` should 1544 be added to. 1545 name: An optional variable_scope name. 1546 1547 Returns: 1548 precision: A float `Tensor` of shape `[len(thresholds)]`. 1549 update_op: An operation that increments the `true_positives`, 1550 `true_negatives`, `false_positives` and `false_negatives` variables that 1551 are used in the computation of `precision`. 1552 1553 Raises: 1554 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1555 `weights` is not `None` and its shape doesn't match `predictions`, or if 1556 either `metrics_collections` or `updates_collections` are not a list or 1557 tuple. 1558 """ 1559 return metrics.precision_at_thresholds( 1560 thresholds=thresholds, 1561 predictions=predictions, 1562 labels=labels, 1563 weights=weights, 1564 metrics_collections=metrics_collections, 1565 updates_collections=updates_collections, 1566 name=name) 1567 1568 1569def streaming_recall_at_thresholds(predictions, 1570 labels, 1571 thresholds, 1572 weights=None, 1573 metrics_collections=None, 1574 updates_collections=None, 1575 name=None): 1576 """Computes various recall values for different `thresholds` on `predictions`. 1577 1578 The `streaming_recall_at_thresholds` function creates four local variables, 1579 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1580 for various values of thresholds. `recall[i]` is defined as the total weight 1581 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1582 `labels` is `True`, divided by the total weight of `True` values in `labels` 1583 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 1584 1585 For estimation of the metric over a stream of data, the function creates an 1586 `update_op` operation that updates these variables and returns the `recall`. 1587 1588 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1589 1590 Args: 1591 predictions: A floating point `Tensor` of arbitrary shape and whose values 1592 are in the range `[0, 1]`. 1593 labels: A `bool` `Tensor` whose shape matches `predictions`. 1594 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1595 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1596 must be broadcastable to `labels` (i.e., all dimensions must be either 1597 `1`, or the same as the corresponding `labels` dimension). 1598 metrics_collections: An optional list of collections that `recall` should be 1599 added to. 1600 updates_collections: An optional list of collections that `update_op` should 1601 be added to. 1602 name: An optional variable_scope name. 1603 1604 Returns: 1605 recall: A float `Tensor` of shape `[len(thresholds)]`. 1606 update_op: An operation that increments the `true_positives`, 1607 `true_negatives`, `false_positives` and `false_negatives` variables that 1608 are used in the computation of `recall`. 1609 1610 Raises: 1611 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1612 `weights` is not `None` and its shape doesn't match `predictions`, or if 1613 either `metrics_collections` or `updates_collections` are not a list or 1614 tuple. 1615 """ 1616 return metrics.recall_at_thresholds( 1617 thresholds=thresholds, 1618 predictions=predictions, 1619 labels=labels, 1620 weights=weights, 1621 metrics_collections=metrics_collections, 1622 updates_collections=updates_collections, 1623 name=name) 1624 1625 1626def streaming_false_positive_rate_at_thresholds(predictions, 1627 labels, 1628 thresholds, 1629 weights=None, 1630 metrics_collections=None, 1631 updates_collections=None, 1632 name=None): 1633 """Computes various fpr values for different `thresholds` on `predictions`. 1634 1635 The `streaming_false_positive_rate_at_thresholds` function creates two 1636 local variables, `false_positives`, `true_negatives`, for various values of 1637 thresholds. `false_positive_rate[i]` is defined as the total weight 1638 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1639 `labels` is `False`, divided by the total weight of `False` values in `labels` 1640 (`false_positives[i] / (false_positives[i] + true_negatives[i])`). 1641 1642 For estimation of the metric over a stream of data, the function creates an 1643 `update_op` operation that updates these variables and returns the 1644 `false_positive_rate`. 1645 1646 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1647 1648 Args: 1649 predictions: A floating point `Tensor` of arbitrary shape and whose values 1650 are in the range `[0, 1]`. 1651 labels: A `bool` `Tensor` whose shape matches `predictions`. 1652 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1653 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1654 must be broadcastable to `labels` (i.e., all dimensions must be either 1655 `1`, or the same as the corresponding `labels` dimension). 1656 metrics_collections: An optional list of collections that 1657 `false_positive_rate` should be added to. 1658 updates_collections: An optional list of collections that `update_op` should 1659 be added to. 1660 name: An optional variable_scope name. 1661 1662 Returns: 1663 false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`. 1664 update_op: An operation that increments the `false_positives` and 1665 `true_negatives` variables that are used in the computation of 1666 `false_positive_rate`. 1667 1668 Raises: 1669 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1670 `weights` is not `None` and its shape doesn't match `predictions`, or if 1671 either `metrics_collections` or `updates_collections` are not a list or 1672 tuple. 1673 """ 1674 with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', 1675 (predictions, labels, weights)): 1676 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1677 predictions, labels, thresholds, weights, includes=('fp', 'tn')) 1678 1679 # Avoid division by zero. 1680 epsilon = 1e-7 1681 1682 def compute_fpr(fp, tn, name): 1683 return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) 1684 1685 fpr = compute_fpr(values['fp'], values['tn'], 'value') 1686 update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') 1687 1688 if metrics_collections: 1689 ops.add_to_collections(metrics_collections, fpr) 1690 1691 if updates_collections: 1692 ops.add_to_collections(updates_collections, update_op) 1693 1694 return fpr, update_op 1695 1696 1697def streaming_false_negative_rate_at_thresholds(predictions, 1698 labels, 1699 thresholds, 1700 weights=None, 1701 metrics_collections=None, 1702 updates_collections=None, 1703 name=None): 1704 """Computes various fnr values for different `thresholds` on `predictions`. 1705 1706 The `streaming_false_negative_rate_at_thresholds` function creates two 1707 local variables, `false_negatives`, `true_positives`, for various values of 1708 thresholds. `false_negative_rate[i]` is defined as the total weight 1709 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1710 `labels` is `False`, divided by the total weight of `True` values in `labels` 1711 (`false_negatives[i] / (false_negatives[i] + true_positives[i])`). 1712 1713 For estimation of the metric over a stream of data, the function creates an 1714 `update_op` operation that updates these variables and returns the 1715 `false_positive_rate`. 1716 1717 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1718 1719 Args: 1720 predictions: A floating point `Tensor` of arbitrary shape and whose values 1721 are in the range `[0, 1]`. 1722 labels: A `bool` `Tensor` whose shape matches `predictions`. 1723 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1724 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1725 must be broadcastable to `labels` (i.e., all dimensions must be either 1726 `1`, or the same as the corresponding `labels` dimension). 1727 metrics_collections: An optional list of collections that 1728 `false_negative_rate` should be added to. 1729 updates_collections: An optional list of collections that `update_op` should 1730 be added to. 1731 name: An optional variable_scope name. 1732 1733 Returns: 1734 false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`. 1735 update_op: An operation that increments the `false_negatives` and 1736 `true_positives` variables that are used in the computation of 1737 `false_negative_rate`. 1738 1739 Raises: 1740 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1741 `weights` is not `None` and its shape doesn't match `predictions`, or if 1742 either `metrics_collections` or `updates_collections` are not a list or 1743 tuple. 1744 """ 1745 with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', 1746 (predictions, labels, weights)): 1747 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1748 predictions, labels, thresholds, weights, includes=('fn', 'tp')) 1749 1750 # Avoid division by zero. 1751 epsilon = 1e-7 1752 1753 def compute_fnr(fn, tp, name): 1754 return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) 1755 1756 fnr = compute_fnr(values['fn'], values['tp'], 'value') 1757 update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') 1758 1759 if metrics_collections: 1760 ops.add_to_collections(metrics_collections, fnr) 1761 1762 if updates_collections: 1763 ops.add_to_collections(updates_collections, update_op) 1764 1765 return fnr, update_op 1766 1767 1768def _at_k_name(name, k=None, class_id=None): 1769 if k is not None: 1770 name = '%s_at_%d' % (name, k) 1771 else: 1772 name = '%s_at_k' % (name) 1773 if class_id is not None: 1774 name = '%s_class%d' % (name, class_id) 1775 return name 1776 1777 1778@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 1779 'and reshape labels from [batch_size] to [batch_size, 1].') 1780def streaming_recall_at_k(predictions, 1781 labels, 1782 k, 1783 weights=None, 1784 metrics_collections=None, 1785 updates_collections=None, 1786 name=None): 1787 """Computes the recall@k of the predictions with respect to dense labels. 1788 1789 The `streaming_recall_at_k` function creates two local variables, `total` and 1790 `count`, that are used to compute the recall@k frequency. This frequency is 1791 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 1792 divides `total` by `count`. 1793 1794 For estimation of the metric over a stream of data, the function creates an 1795 `update_op` operation that updates these variables and returns the 1796 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 1797 shape [batch_size] whose elements indicate whether or not the corresponding 1798 label is in the top `k` `predictions`. Then `update_op` increments `total` 1799 with the reduced sum of `weights` where `in_top_k` is `True`, and it 1800 increments `count` with the reduced sum of `weights`. 1801 1802 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1803 1804 Args: 1805 predictions: A float `Tensor` of dimension [batch_size, num_classes]. 1806 labels: A `Tensor` of dimension [batch_size] whose type is in `int32`, 1807 `int64`. 1808 k: The number of top elements to look at for computing recall. 1809 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1810 must be broadcastable to `labels` (i.e., all dimensions must be either 1811 `1`, or the same as the corresponding `labels` dimension). 1812 metrics_collections: An optional list of collections that `recall_at_k` 1813 should be added to. 1814 updates_collections: An optional list of collections `update_op` should be 1815 added to. 1816 name: An optional variable_scope name. 1817 1818 Returns: 1819 recall_at_k: A `Tensor` representing the recall@k, the fraction of labels 1820 which fall into the top `k` predictions. 1821 update_op: An operation that increments the `total` and `count` variables 1822 appropriately and whose value matches `recall_at_k`. 1823 1824 Raises: 1825 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1826 `weights` is not `None` and its shape doesn't match `predictions`, or if 1827 either `metrics_collections` or `updates_collections` are not a list or 1828 tuple. 1829 """ 1830 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 1831 return streaming_mean(in_top_k, weights, metrics_collections, 1832 updates_collections, name or _at_k_name('recall', k)) 1833 1834 1835# TODO(ptucker): Validate range of values in labels? 1836def streaming_sparse_recall_at_k(predictions, 1837 labels, 1838 k, 1839 class_id=None, 1840 weights=None, 1841 metrics_collections=None, 1842 updates_collections=None, 1843 name=None): 1844 """Computes recall@k of the predictions with respect to sparse labels. 1845 1846 If `class_id` is not specified, we'll calculate recall as the ratio of true 1847 positives (i.e., correct predictions, items in the top `k` highest 1848 `predictions` that are found in the corresponding row in `labels`) to 1849 actual positives (the full `labels` row). 1850 If `class_id` is specified, we calculate recall by considering only the rows 1851 in the batch for which `class_id` is in `labels`, and computing the 1852 fraction of them for which `class_id` is in the corresponding row in 1853 `labels`. 1854 1855 `streaming_sparse_recall_at_k` creates two local variables, 1856 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 1857 the recall_at_k frequency. This frequency is ultimately returned as 1858 `recall_at_<k>`: an idempotent operation that simply divides 1859 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1860 `false_negative_at_<k>`). 1861 1862 For estimation of the metric over a stream of data, the function creates an 1863 `update_op` operation that updates these variables and returns the 1864 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1865 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1866 `labels` calculate the true positives and false negatives weighted by 1867 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1868 `false_negative_at_<k>` using these values. 1869 1870 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1871 1872 Args: 1873 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1874 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1875 The final dimension contains the logit values for each class. [D1, ... DN] 1876 must match `labels`. 1877 labels: `int64` `Tensor` or `SparseTensor` with shape 1878 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1879 target classes for the associated prediction. Commonly, N=1 and `labels` 1880 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 1881 Values should be in range [0, num_classes), where num_classes is the last 1882 dimension of `predictions`. Values outside this range always count 1883 towards `false_negative_at_<k>`. 1884 k: Integer, k for @k metric. 1885 class_id: Integer class ID for which we want binary metrics. This should be 1886 in range [0, num_classes), where num_classes is the last dimension of 1887 `predictions`. If class_id is outside this range, the method returns NAN. 1888 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1889 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1890 dimensions must be either `1`, or the same as the corresponding `labels` 1891 dimension). 1892 metrics_collections: An optional list of collections that values should 1893 be added to. 1894 updates_collections: An optional list of collections that updates should 1895 be added to. 1896 name: Name of new update operation, and namespace for other dependent ops. 1897 1898 Returns: 1899 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 1900 by the sum of `true_positives` and `false_negatives`. 1901 update_op: `Operation` that increments `true_positives` and 1902 `false_negatives` variables appropriately, and whose value matches 1903 `recall`. 1904 1905 Raises: 1906 ValueError: If `weights` is not `None` and its shape doesn't match 1907 `predictions`, or if either `metrics_collections` or `updates_collections` 1908 are not a list or tuple. 1909 """ 1910 return metrics.recall_at_k( 1911 k=k, 1912 class_id=class_id, 1913 predictions=predictions, 1914 labels=labels, 1915 weights=weights, 1916 metrics_collections=metrics_collections, 1917 updates_collections=updates_collections, 1918 name=name) 1919 1920 1921# TODO(ptucker): Validate range of values in labels? 1922def streaming_sparse_precision_at_k(predictions, 1923 labels, 1924 k, 1925 class_id=None, 1926 weights=None, 1927 metrics_collections=None, 1928 updates_collections=None, 1929 name=None): 1930 """Computes precision@k of the predictions with respect to sparse labels. 1931 1932 If `class_id` is not specified, we calculate precision as the ratio of true 1933 positives (i.e., correct predictions, items in the top `k` highest 1934 `predictions` that are found in the corresponding row in `labels`) to 1935 positives (all top `k` `predictions`). 1936 If `class_id` is specified, we calculate precision by considering only the 1937 rows in the batch for which `class_id` is in the top `k` highest 1938 `predictions`, and computing the fraction of them for which `class_id` is 1939 in the corresponding row in `labels`. 1940 1941 We expect precision to decrease as `k` increases. 1942 1943 `streaming_sparse_precision_at_k` creates two local variables, 1944 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 1945 the precision@k frequency. This frequency is ultimately returned as 1946 `precision_at_<k>`: an idempotent operation that simply divides 1947 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 1948 `false_positive_at_<k>`). 1949 1950 For estimation of the metric over a stream of data, the function creates an 1951 `update_op` operation that updates these variables and returns the 1952 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 1953 indicating the top `k` `predictions`. Set operations applied to `top_k` and 1954 `labels` calculate the true positives and false positives weighted by 1955 `weights`. Then `update_op` increments `true_positive_at_<k>` and 1956 `false_positive_at_<k>` using these values. 1957 1958 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1959 1960 Args: 1961 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 1962 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 1963 The final dimension contains the logit values for each class. [D1, ... DN] 1964 must match `labels`. 1965 labels: `int64` `Tensor` or `SparseTensor` with shape 1966 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 1967 target classes for the associated prediction. Commonly, N=1 and `labels` 1968 has shape [batch_size, num_labels]. [D1, ... DN] must match 1969 `predictions`. Values should be in range [0, num_classes), where 1970 num_classes is the last dimension of `predictions`. Values outside this 1971 range are ignored. 1972 k: Integer, k for @k metric. 1973 class_id: Integer class ID for which we want binary metrics. This should be 1974 in range [0, num_classes], where num_classes is the last dimension of 1975 `predictions`. If `class_id` is outside this range, the method returns 1976 NAN. 1977 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 1978 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 1979 dimensions must be either `1`, or the same as the corresponding `labels` 1980 dimension). 1981 metrics_collections: An optional list of collections that values should 1982 be added to. 1983 updates_collections: An optional list of collections that updates should 1984 be added to. 1985 name: Name of new update operation, and namespace for other dependent ops. 1986 1987 Returns: 1988 precision: Scalar `float64` `Tensor` with the value of `true_positives` 1989 divided by the sum of `true_positives` and `false_positives`. 1990 update_op: `Operation` that increments `true_positives` and 1991 `false_positives` variables appropriately, and whose value matches 1992 `precision`. 1993 1994 Raises: 1995 ValueError: If `weights` is not `None` and its shape doesn't match 1996 `predictions`, or if either `metrics_collections` or `updates_collections` 1997 are not a list or tuple. 1998 """ 1999 return metrics.sparse_precision_at_k( 2000 k=k, 2001 class_id=class_id, 2002 predictions=predictions, 2003 labels=labels, 2004 weights=weights, 2005 metrics_collections=metrics_collections, 2006 updates_collections=updates_collections, 2007 name=name) 2008 2009 2010# TODO(ptucker): Validate range of values in labels? 2011def streaming_sparse_precision_at_top_k(top_k_predictions, 2012 labels, 2013 class_id=None, 2014 weights=None, 2015 metrics_collections=None, 2016 updates_collections=None, 2017 name=None): 2018 """Computes precision@k of top-k predictions with respect to sparse labels. 2019 2020 If `class_id` is not specified, we calculate precision as the ratio of 2021 true positives (i.e., correct predictions, items in `top_k_predictions` 2022 that are found in the corresponding row in `labels`) to positives (all 2023 `top_k_predictions`). 2024 If `class_id` is specified, we calculate precision by considering only the 2025 rows in the batch for which `class_id` is in the top `k` highest 2026 `predictions`, and computing the fraction of them for which `class_id` is 2027 in the corresponding row in `labels`. 2028 2029 We expect precision to decrease as `k` increases. 2030 2031 `streaming_sparse_precision_at_top_k` creates two local variables, 2032 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 2033 the precision@k frequency. This frequency is ultimately returned as 2034 `precision_at_k`: an idempotent operation that simply divides 2035 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 2036 2037 For estimation of the metric over a stream of data, the function creates an 2038 `update_op` operation that updates these variables and returns the 2039 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 2040 and `labels` calculate the true positives and false positives weighted by 2041 `weights`. Then `update_op` increments `true_positive_at_k` and 2042 `false_positive_at_k` using these values. 2043 2044 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2045 2046 Args: 2047 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 2048 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 2049 The final dimension contains the indices of top-k labels. [D1, ... DN] 2050 must match `labels`. 2051 labels: `int64` `Tensor` or `SparseTensor` with shape 2052 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2053 target classes for the associated prediction. Commonly, N=1 and `labels` 2054 has shape [batch_size, num_labels]. [D1, ... DN] must match 2055 `top_k_predictions`. Values should be in range [0, num_classes), where 2056 num_classes is the last dimension of `predictions`. Values outside this 2057 range are ignored. 2058 class_id: Integer class ID for which we want binary metrics. This should be 2059 in range [0, num_classes), where num_classes is the last dimension of 2060 `predictions`. If `class_id` is outside this range, the method returns 2061 NAN. 2062 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2063 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2064 dimensions must be either `1`, or the same as the corresponding `labels` 2065 dimension). 2066 metrics_collections: An optional list of collections that values should 2067 be added to. 2068 updates_collections: An optional list of collections that updates should 2069 be added to. 2070 name: Name of new update operation, and namespace for other dependent ops. 2071 2072 Returns: 2073 precision: Scalar `float64` `Tensor` with the value of `true_positives` 2074 divided by the sum of `true_positives` and `false_positives`. 2075 update_op: `Operation` that increments `true_positives` and 2076 `false_positives` variables appropriately, and whose value matches 2077 `precision`. 2078 2079 Raises: 2080 ValueError: If `weights` is not `None` and its shape doesn't match 2081 `predictions`, or if either `metrics_collections` or `updates_collections` 2082 are not a list or tuple. 2083 ValueError: If `top_k_predictions` has rank < 2. 2084 """ 2085 default_name = _at_k_name('precision', class_id=class_id) 2086 with ops.name_scope(name, default_name, 2087 (top_k_predictions, labels, weights)) as name_scope: 2088 return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access 2089 labels=labels, 2090 predictions_idx=top_k_predictions, 2091 class_id=class_id, 2092 weights=weights, 2093 metrics_collections=metrics_collections, 2094 updates_collections=updates_collections, 2095 name=name_scope) 2096 2097 2098def sparse_recall_at_top_k(labels, 2099 top_k_predictions, 2100 class_id=None, 2101 weights=None, 2102 metrics_collections=None, 2103 updates_collections=None, 2104 name=None): 2105 """Computes recall@k of top-k predictions with respect to sparse labels. 2106 2107 If `class_id` is specified, we calculate recall by considering only the 2108 entries in the batch for which `class_id` is in the label, and computing 2109 the fraction of them for which `class_id` is in the top-k `predictions`. 2110 If `class_id` is not specified, we'll calculate recall as how often on 2111 average a class among the labels of a batch entry is in the top-k 2112 `predictions`. 2113 2114 `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>` 2115 and `false_negative_at_<k>`, that are used to compute the recall_at_k 2116 frequency. This frequency is ultimately returned as `recall_at_<k>`: an 2117 idempotent operation that simply divides `true_positive_at_<k>` by total 2118 (`true_positive_at_<k>` + `false_negative_at_<k>`). 2119 2120 For estimation of the metric over a stream of data, the function creates an 2121 `update_op` operation that updates these variables and returns the 2122 `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the 2123 true positives and false negatives weighted by `weights`. Then `update_op` 2124 increments `true_positive_at_<k>` and `false_negative_at_<k>` using these 2125 values. 2126 2127 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2128 2129 Args: 2130 labels: `int64` `Tensor` or `SparseTensor` with shape 2131 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2132 target classes for the associated prediction. Commonly, N=1 and `labels` 2133 has shape [batch_size, num_labels]. [D1, ... DN] must match 2134 `top_k_predictions`. Values should be in range [0, num_classes), where 2135 num_classes is the last dimension of `predictions`. Values outside this 2136 range always count towards `false_negative_at_<k>`. 2137 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 2138 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 2139 The final dimension contains the indices of top-k labels. [D1, ... DN] 2140 must match `labels`. 2141 class_id: Integer class ID for which we want binary metrics. This should be 2142 in range [0, num_classes), where num_classes is the last dimension of 2143 `predictions`. If class_id is outside this range, the method returns NAN. 2144 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2145 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2146 dimensions must be either `1`, or the same as the corresponding `labels` 2147 dimension). 2148 metrics_collections: An optional list of collections that values should 2149 be added to. 2150 updates_collections: An optional list of collections that updates should 2151 be added to. 2152 name: Name of new update operation, and namespace for other dependent ops. 2153 2154 Returns: 2155 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2156 by the sum of `true_positives` and `false_negatives`. 2157 update_op: `Operation` that increments `true_positives` and 2158 `false_negatives` variables appropriately, and whose value matches 2159 `recall`. 2160 2161 Raises: 2162 ValueError: If `weights` is not `None` and its shape doesn't match 2163 `predictions`, or if either `metrics_collections` or `updates_collections` 2164 are not a list or tuple. 2165 """ 2166 default_name = _at_k_name('recall', class_id=class_id) 2167 with ops.name_scope(name, default_name, 2168 (top_k_predictions, labels, weights)) as name_scope: 2169 return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access 2170 labels=labels, 2171 predictions_idx=top_k_predictions, 2172 class_id=class_id, 2173 weights=weights, 2174 metrics_collections=metrics_collections, 2175 updates_collections=updates_collections, 2176 name=name_scope) 2177 2178 2179def streaming_sparse_average_precision_at_k(predictions, 2180 labels, 2181 k, 2182 weights=None, 2183 metrics_collections=None, 2184 updates_collections=None, 2185 name=None): 2186 """Computes average precision@k of predictions with respect to sparse labels. 2187 2188 See `sparse_average_precision_at_k` for details on formula. `weights` are 2189 applied to the result of `sparse_average_precision_at_k` 2190 2191 `streaming_sparse_average_precision_at_k` creates two local variables, 2192 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2193 are used to compute the frequency. This frequency is ultimately returned as 2194 `average_precision_at_<k>`: an idempotent operation that simply divides 2195 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2196 2197 For estimation of the metric over a stream of data, the function creates an 2198 `update_op` operation that updates these variables and returns the 2199 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2200 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2201 `labels` calculate the true positives and false positives weighted by 2202 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2203 `false_positive_at_<k>` using these values. 2204 2205 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2206 2207 Args: 2208 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2209 N >= 1. Commonly, N=1 and `predictions` has shape 2210 [batch size, num_classes]. The final dimension contains the logit values 2211 for each class. [D1, ... DN] must match `labels`. 2212 labels: `int64` `Tensor` or `SparseTensor` with shape 2213 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2214 target classes for the associated prediction. Commonly, N=1 and `labels` 2215 has shape [batch_size, num_labels]. [D1, ... DN] must match 2216 `predictions_`. Values should be in range [0, num_classes), where 2217 num_classes is the last dimension of `predictions`. Values outside this 2218 range are ignored. 2219 k: Integer, k for @k metric. This will calculate an average precision for 2220 range `[1,k]`, as documented above. 2221 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2222 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2223 dimensions must be either `1`, or the same as the corresponding `labels` 2224 dimension). 2225 metrics_collections: An optional list of collections that values should 2226 be added to. 2227 updates_collections: An optional list of collections that updates should 2228 be added to. 2229 name: Name of new update operation, and namespace for other dependent ops. 2230 2231 Returns: 2232 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2233 precision values. 2234 update: `Operation` that increments variables appropriately, and whose 2235 value matches `metric`. 2236 """ 2237 return metrics.sparse_average_precision_at_k( 2238 k=k, 2239 predictions=predictions, 2240 labels=labels, 2241 weights=weights, 2242 metrics_collections=metrics_collections, 2243 updates_collections=updates_collections, 2244 name=name) 2245 2246 2247def streaming_sparse_average_precision_at_top_k(top_k_predictions, 2248 labels, 2249 weights=None, 2250 metrics_collections=None, 2251 updates_collections=None, 2252 name=None): 2253 """Computes average precision@k of predictions with respect to sparse labels. 2254 2255 `streaming_sparse_average_precision_at_top_k` creates two local variables, 2256 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2257 are used to compute the frequency. This frequency is ultimately returned as 2258 `average_precision_at_<k>`: an idempotent operation that simply divides 2259 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2260 2261 For estimation of the metric over a stream of data, the function creates an 2262 `update_op` operation that updates these variables and returns the 2263 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 2264 the true positives and false positives weighted by `weights`. Then `update_op` 2265 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 2266 values. 2267 2268 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2269 2270 Args: 2271 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2272 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 2273 dimension must be set and contains the top `k` predicted class indices. 2274 [D1, ... DN] must match `labels`. Values should be in range 2275 [0, num_classes). 2276 labels: `int64` `Tensor` or `SparseTensor` with shape 2277 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2278 num_labels=1. N >= 1 and num_labels is the number of target classes for 2279 the associated prediction. Commonly, N=1 and `labels` has shape 2280 [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. 2281 Values should be in range [0, num_classes). 2282 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2283 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2284 dimensions must be either `1`, or the same as the corresponding `labels` 2285 dimension). 2286 metrics_collections: An optional list of collections that values should 2287 be added to. 2288 updates_collections: An optional list of collections that updates should 2289 be added to. 2290 name: Name of new update operation, and namespace for other dependent ops. 2291 2292 Returns: 2293 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2294 precision values. 2295 update: `Operation` that increments variables appropriately, and whose 2296 value matches `metric`. 2297 2298 Raises: 2299 ValueError: if the last dimension of top_k_predictions is not set. 2300 """ 2301 return metrics_impl._streaming_sparse_average_precision_at_top_k( # pylint: disable=protected-access 2302 predictions_idx=top_k_predictions, 2303 labels=labels, 2304 weights=weights, 2305 metrics_collections=metrics_collections, 2306 updates_collections=updates_collections, 2307 name=name) 2308 2309 2310def streaming_mean_absolute_error(predictions, 2311 labels, 2312 weights=None, 2313 metrics_collections=None, 2314 updates_collections=None, 2315 name=None): 2316 """Computes the mean absolute error between the labels and predictions. 2317 2318 The `streaming_mean_absolute_error` function creates two local variables, 2319 `total` and `count` that are used to compute the mean absolute error. This 2320 average is weighted by `weights`, and it is ultimately returned as 2321 `mean_absolute_error`: an idempotent operation that simply divides `total` by 2322 `count`. 2323 2324 For estimation of the metric over a stream of data, the function creates an 2325 `update_op` operation that updates these variables and returns the 2326 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 2327 absolute value of the differences between `predictions` and `labels`. Then 2328 `update_op` increments `total` with the reduced sum of the product of 2329 `weights` and `absolute_errors`, and it increments `count` with the reduced 2330 sum of `weights` 2331 2332 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2333 2334 Args: 2335 predictions: A `Tensor` of arbitrary shape. 2336 labels: A `Tensor` of the same shape as `predictions`. 2337 weights: Optional `Tensor` indicating the frequency with which an example is 2338 sampled. Rank must be 0, or the same rank as `labels`, and must be 2339 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2340 the same as the corresponding `labels` dimension). 2341 metrics_collections: An optional list of collections that 2342 `mean_absolute_error` should be added to. 2343 updates_collections: An optional list of collections that `update_op` should 2344 be added to. 2345 name: An optional variable_scope name. 2346 2347 Returns: 2348 mean_absolute_error: A `Tensor` representing the current mean, the value of 2349 `total` divided by `count`. 2350 update_op: An operation that increments the `total` and `count` variables 2351 appropriately and whose value matches `mean_absolute_error`. 2352 2353 Raises: 2354 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2355 `weights` is not `None` and its shape doesn't match `predictions`, or if 2356 either `metrics_collections` or `updates_collections` are not a list or 2357 tuple. 2358 """ 2359 return metrics.mean_absolute_error( 2360 predictions=predictions, 2361 labels=labels, 2362 weights=weights, 2363 metrics_collections=metrics_collections, 2364 updates_collections=updates_collections, 2365 name=name) 2366 2367 2368def streaming_mean_relative_error(predictions, 2369 labels, 2370 normalizer, 2371 weights=None, 2372 metrics_collections=None, 2373 updates_collections=None, 2374 name=None): 2375 """Computes the mean relative error by normalizing with the given values. 2376 2377 The `streaming_mean_relative_error` function creates two local variables, 2378 `total` and `count` that are used to compute the mean relative absolute error. 2379 This average is weighted by `weights`, and it is ultimately returned as 2380 `mean_relative_error`: an idempotent operation that simply divides `total` by 2381 `count`. 2382 2383 For estimation of the metric over a stream of data, the function creates an 2384 `update_op` operation that updates these variables and returns the 2385 `mean_reative_error`. Internally, a `relative_errors` operation divides the 2386 absolute value of the differences between `predictions` and `labels` by the 2387 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 2388 product of `weights` and `relative_errors`, and it increments `count` with the 2389 reduced sum of `weights`. 2390 2391 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2392 2393 Args: 2394 predictions: A `Tensor` of arbitrary shape. 2395 labels: A `Tensor` of the same shape as `predictions`. 2396 normalizer: A `Tensor` of the same shape as `predictions`. 2397 weights: Optional `Tensor` indicating the frequency with which an example is 2398 sampled. Rank must be 0, or the same rank as `labels`, and must be 2399 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2400 the same as the corresponding `labels` dimension). 2401 metrics_collections: An optional list of collections that 2402 `mean_relative_error` should be added to. 2403 updates_collections: An optional list of collections that `update_op` should 2404 be added to. 2405 name: An optional variable_scope name. 2406 2407 Returns: 2408 mean_relative_error: A `Tensor` representing the current mean, the value of 2409 `total` divided by `count`. 2410 update_op: An operation that increments the `total` and `count` variables 2411 appropriately and whose value matches `mean_relative_error`. 2412 2413 Raises: 2414 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2415 `weights` is not `None` and its shape doesn't match `predictions`, or if 2416 either `metrics_collections` or `updates_collections` are not a list or 2417 tuple. 2418 """ 2419 return metrics.mean_relative_error( 2420 normalizer=normalizer, 2421 predictions=predictions, 2422 labels=labels, 2423 weights=weights, 2424 metrics_collections=metrics_collections, 2425 updates_collections=updates_collections, 2426 name=name) 2427 2428 2429def streaming_mean_squared_error(predictions, 2430 labels, 2431 weights=None, 2432 metrics_collections=None, 2433 updates_collections=None, 2434 name=None): 2435 """Computes the mean squared error between the labels and predictions. 2436 2437 The `streaming_mean_squared_error` function creates two local variables, 2438 `total` and `count` that are used to compute the mean squared error. 2439 This average is weighted by `weights`, and it is ultimately returned as 2440 `mean_squared_error`: an idempotent operation that simply divides `total` by 2441 `count`. 2442 2443 For estimation of the metric over a stream of data, the function creates an 2444 `update_op` operation that updates these variables and returns the 2445 `mean_squared_error`. Internally, a `squared_error` operation computes the 2446 element-wise square of the difference between `predictions` and `labels`. Then 2447 `update_op` increments `total` with the reduced sum of the product of 2448 `weights` and `squared_error`, and it increments `count` with the reduced sum 2449 of `weights`. 2450 2451 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2452 2453 Args: 2454 predictions: A `Tensor` of arbitrary shape. 2455 labels: A `Tensor` of the same shape as `predictions`. 2456 weights: Optional `Tensor` indicating the frequency with which an example is 2457 sampled. Rank must be 0, or the same rank as `labels`, and must be 2458 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2459 the same as the corresponding `labels` dimension). 2460 metrics_collections: An optional list of collections that 2461 `mean_squared_error` should be added to. 2462 updates_collections: An optional list of collections that `update_op` should 2463 be added to. 2464 name: An optional variable_scope name. 2465 2466 Returns: 2467 mean_squared_error: A `Tensor` representing the current mean, the value of 2468 `total` divided by `count`. 2469 update_op: An operation that increments the `total` and `count` variables 2470 appropriately and whose value matches `mean_squared_error`. 2471 2472 Raises: 2473 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2474 `weights` is not `None` and its shape doesn't match `predictions`, or if 2475 either `metrics_collections` or `updates_collections` are not a list or 2476 tuple. 2477 """ 2478 return metrics.mean_squared_error( 2479 predictions=predictions, 2480 labels=labels, 2481 weights=weights, 2482 metrics_collections=metrics_collections, 2483 updates_collections=updates_collections, 2484 name=name) 2485 2486 2487def streaming_root_mean_squared_error(predictions, 2488 labels, 2489 weights=None, 2490 metrics_collections=None, 2491 updates_collections=None, 2492 name=None): 2493 """Computes the root mean squared error between the labels and predictions. 2494 2495 The `streaming_root_mean_squared_error` function creates two local variables, 2496 `total` and `count` that are used to compute the root mean squared error. 2497 This average is weighted by `weights`, and it is ultimately returned as 2498 `root_mean_squared_error`: an idempotent operation that takes the square root 2499 of the division of `total` by `count`. 2500 2501 For estimation of the metric over a stream of data, the function creates an 2502 `update_op` operation that updates these variables and returns the 2503 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2504 the element-wise square of the difference between `predictions` and `labels`. 2505 Then `update_op` increments `total` with the reduced sum of the product of 2506 `weights` and `squared_error`, and it increments `count` with the reduced sum 2507 of `weights`. 2508 2509 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2510 2511 Args: 2512 predictions: A `Tensor` of arbitrary shape. 2513 labels: A `Tensor` of the same shape as `predictions`. 2514 weights: Optional `Tensor` indicating the frequency with which an example is 2515 sampled. Rank must be 0, or the same rank as `labels`, and must be 2516 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2517 the same as the corresponding `labels` dimension). 2518 metrics_collections: An optional list of collections that 2519 `root_mean_squared_error` should be added to. 2520 updates_collections: An optional list of collections that `update_op` should 2521 be added to. 2522 name: An optional variable_scope name. 2523 2524 Returns: 2525 root_mean_squared_error: A `Tensor` representing the current mean, the value 2526 of `total` divided by `count`. 2527 update_op: An operation that increments the `total` and `count` variables 2528 appropriately and whose value matches `root_mean_squared_error`. 2529 2530 Raises: 2531 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2532 `weights` is not `None` and its shape doesn't match `predictions`, or if 2533 either `metrics_collections` or `updates_collections` are not a list or 2534 tuple. 2535 """ 2536 return metrics.root_mean_squared_error( 2537 predictions=predictions, 2538 labels=labels, 2539 weights=weights, 2540 metrics_collections=metrics_collections, 2541 updates_collections=updates_collections, 2542 name=name) 2543 2544 2545def streaming_covariance(predictions, 2546 labels, 2547 weights=None, 2548 metrics_collections=None, 2549 updates_collections=None, 2550 name=None): 2551 """Computes the unbiased sample covariance between `predictions` and `labels`. 2552 2553 The `streaming_covariance` function creates four local variables, 2554 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 2555 compute the sample covariance between predictions and labels across multiple 2556 batches of data. The covariance is ultimately returned as an idempotent 2557 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 2558 in order to get an unbiased estimate. 2559 2560 The algorithm used for this online computation is described in 2561 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 2562 Specifically, the formula used to combine two sample comoments is 2563 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 2564 The comoment for a single batch of data is simply 2565 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 2566 2567 If `weights` is not None, then it is used to compute weighted comoments, 2568 means, and count. NOTE: these weights are treated as "frequency weights", as 2569 opposed to "reliability weights". See discussion of the difference on 2570 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2571 2572 To facilitate the computation of covariance across multiple batches of data, 2573 the function creates an `update_op` operation, which updates underlying 2574 variables and returns the updated covariance. 2575 2576 Args: 2577 predictions: A `Tensor` of arbitrary size. 2578 labels: A `Tensor` of the same size as `predictions`. 2579 weights: Optional `Tensor` indicating the frequency with which an example is 2580 sampled. Rank must be 0, or the same rank as `labels`, and must be 2581 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2582 the same as the corresponding `labels` dimension). 2583 metrics_collections: An optional list of collections that the metric 2584 value variable should be added to. 2585 updates_collections: An optional list of collections that the metric update 2586 ops should be added to. 2587 name: An optional variable_scope name. 2588 2589 Returns: 2590 covariance: A `Tensor` representing the current unbiased sample covariance, 2591 `comoment` / (`count` - 1). 2592 update_op: An operation that updates the local variables appropriately. 2593 2594 Raises: 2595 ValueError: If labels and predictions are of different sizes or if either 2596 `metrics_collections` or `updates_collections` are not a list or tuple. 2597 """ 2598 with variable_scope.variable_scope(name, 'covariance', 2599 (predictions, labels, weights)): 2600 predictions, labels, weights = _remove_squeezable_dimensions( 2601 predictions, labels, weights) 2602 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2603 count = _create_local('count', []) 2604 mean_prediction = _create_local('mean_prediction', []) 2605 mean_label = _create_local('mean_label', []) 2606 comoment = _create_local('comoment', []) # C_A in update equation 2607 2608 if weights is None: 2609 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 2610 weighted_predictions = predictions 2611 weighted_labels = labels 2612 else: 2613 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 2614 batch_count = math_ops.reduce_sum(weights) # n_B in eqn 2615 weighted_predictions = math_ops.multiply(predictions, weights) 2616 weighted_labels = math_ops.multiply(labels, weights) 2617 2618 update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn 2619 prev_count = update_count - batch_count # n_A in update equation 2620 2621 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 2622 # batch_mean_prediction is E[x_B] in the update equation 2623 batch_mean_prediction = _safe_div( 2624 math_ops.reduce_sum(weighted_predictions), batch_count, 2625 'batch_mean_prediction') 2626 delta_mean_prediction = _safe_div( 2627 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 2628 'delta_mean_prediction') 2629 update_mean_prediction = state_ops.assign_add(mean_prediction, 2630 delta_mean_prediction) 2631 # prev_mean_prediction is E[x_A] in the update equation 2632 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 2633 2634 # batch_mean_label is E[y_B] in the update equation 2635 batch_mean_label = _safe_div( 2636 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 2637 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 2638 update_count, 'delta_mean_label') 2639 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 2640 # prev_mean_label is E[y_A] in the update equation 2641 prev_mean_label = update_mean_label - delta_mean_label 2642 2643 unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * 2644 (labels - batch_mean_label)) 2645 # batch_comoment is C_B in the update equation 2646 if weights is None: 2647 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 2648 else: 2649 batch_comoment = math_ops.reduce_sum( 2650 unweighted_batch_coresiduals * weights) 2651 2652 # View delta_comoment as = C_AB - C_A in the update equation above. 2653 # Since C_A is stored in a var, by how much do we need to increment that var 2654 # to make the var = C_AB? 2655 delta_comoment = ( 2656 batch_comoment + (prev_mean_prediction - batch_mean_prediction) * 2657 (prev_mean_label - batch_mean_label) * 2658 (prev_count * batch_count / update_count)) 2659 update_comoment = state_ops.assign_add(comoment, delta_comoment) 2660 2661 covariance = array_ops.where( 2662 math_ops.less_equal(count, 1.), 2663 float('nan'), 2664 math_ops.truediv(comoment, count - 1), 2665 name='covariance') 2666 with ops.control_dependencies([update_comoment]): 2667 update_op = array_ops.where( 2668 math_ops.less_equal(count, 1.), 2669 float('nan'), 2670 math_ops.truediv(comoment, count - 1), 2671 name='update_op') 2672 2673 if metrics_collections: 2674 ops.add_to_collections(metrics_collections, covariance) 2675 2676 if updates_collections: 2677 ops.add_to_collections(updates_collections, update_op) 2678 2679 return covariance, update_op 2680 2681 2682def streaming_pearson_correlation(predictions, 2683 labels, 2684 weights=None, 2685 metrics_collections=None, 2686 updates_collections=None, 2687 name=None): 2688 """Computes Pearson correlation coefficient between `predictions`, `labels`. 2689 2690 The `streaming_pearson_correlation` function delegates to 2691 `streaming_covariance` the tracking of three [co]variances: 2692 2693 - `streaming_covariance(predictions, labels)`, i.e. covariance 2694 - `streaming_covariance(predictions, predictions)`, i.e. variance 2695 - `streaming_covariance(labels, labels)`, i.e. variance 2696 2697 The product-moment correlation ultimately returned is an idempotent operation 2698 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 2699 facilitate correlation computation across multiple batches, the function 2700 groups the `update_op`s of the underlying streaming_covariance and returns an 2701 `update_op`. 2702 2703 If `weights` is not None, then it is used to compute a weighted correlation. 2704 NOTE: these weights are treated as "frequency weights", as opposed to 2705 "reliability weights". See discussion of the difference on 2706 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2707 2708 Args: 2709 predictions: A `Tensor` of arbitrary size. 2710 labels: A `Tensor` of the same size as predictions. 2711 weights: Optional `Tensor` indicating the frequency with which an example is 2712 sampled. Rank must be 0, or the same rank as `labels`, and must be 2713 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2714 the same as the corresponding `labels` dimension). 2715 metrics_collections: An optional list of collections that the metric 2716 value variable should be added to. 2717 updates_collections: An optional list of collections that the metric update 2718 ops should be added to. 2719 name: An optional variable_scope name. 2720 2721 Returns: 2722 pearson_r: A `Tensor` representing the current Pearson product-moment 2723 correlation coefficient, the value of 2724 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 2725 update_op: An operation that updates the underlying variables appropriately. 2726 2727 Raises: 2728 ValueError: If `labels` and `predictions` are of different sizes, or if 2729 `weights` is the wrong size, or if either `metrics_collections` or 2730 `updates_collections` are not a `list` or `tuple`. 2731 """ 2732 with variable_scope.variable_scope(name, 'pearson_r', 2733 (predictions, labels, weights)): 2734 predictions, labels, weights = _remove_squeezable_dimensions( 2735 predictions, labels, weights) 2736 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2737 # Broadcast weights here to avoid duplicate broadcasting in each call to 2738 # `streaming_covariance`. 2739 if weights is not None: 2740 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 2741 cov, update_cov = streaming_covariance( 2742 predictions, labels, weights=weights, name='covariance') 2743 var_predictions, update_var_predictions = streaming_covariance( 2744 predictions, predictions, weights=weights, name='variance_predictions') 2745 var_labels, update_var_labels = streaming_covariance( 2746 labels, labels, weights=weights, name='variance_labels') 2747 2748 pearson_r = math_ops.truediv( 2749 cov, 2750 math_ops.multiply( 2751 math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 2752 name='pearson_r') 2753 update_op = math_ops.truediv( 2754 update_cov, 2755 math_ops.multiply( 2756 math_ops.sqrt(update_var_predictions), 2757 math_ops.sqrt(update_var_labels)), 2758 name='update_op') 2759 2760 if metrics_collections: 2761 ops.add_to_collections(metrics_collections, pearson_r) 2762 2763 if updates_collections: 2764 ops.add_to_collections(updates_collections, update_op) 2765 2766 return pearson_r, update_op 2767 2768 2769# TODO(nsilberman): add a 'normalized' flag so that the user can request 2770# normalization if the inputs are not normalized. 2771def streaming_mean_cosine_distance(predictions, 2772 labels, 2773 dim, 2774 weights=None, 2775 metrics_collections=None, 2776 updates_collections=None, 2777 name=None): 2778 """Computes the cosine distance between the labels and predictions. 2779 2780 The `streaming_mean_cosine_distance` function creates two local variables, 2781 `total` and `count` that are used to compute the average cosine distance 2782 between `predictions` and `labels`. This average is weighted by `weights`, 2783 and it is ultimately returned as `mean_distance`, which is an idempotent 2784 operation that simply divides `total` by `count`. 2785 2786 For estimation of the metric over a stream of data, the function creates an 2787 `update_op` operation that updates these variables and returns the 2788 `mean_distance`. 2789 2790 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2791 2792 Args: 2793 predictions: A `Tensor` of the same shape as `labels`. 2794 labels: A `Tensor` of arbitrary shape. 2795 dim: The dimension along which the cosine distance is computed. 2796 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 2797 and whose dimension `dim` is 1. 2798 metrics_collections: An optional list of collections that the metric 2799 value variable should be added to. 2800 updates_collections: An optional list of collections that the metric update 2801 ops should be added to. 2802 name: An optional variable_scope name. 2803 2804 Returns: 2805 mean_distance: A `Tensor` representing the current mean, the value of 2806 `total` divided by `count`. 2807 update_op: An operation that increments the `total` and `count` variables 2808 appropriately. 2809 2810 Raises: 2811 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2812 `weights` is not `None` and its shape doesn't match `predictions`, or if 2813 either `metrics_collections` or `updates_collections` are not a list or 2814 tuple. 2815 """ 2816 predictions, labels, weights = _remove_squeezable_dimensions( 2817 predictions, labels, weights) 2818 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 2819 radial_diffs = math_ops.multiply(predictions, labels) 2820 radial_diffs = math_ops.reduce_sum( 2821 radial_diffs, reduction_indices=[ 2822 dim, 2823 ], keep_dims=True) 2824 mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, 2825 name or 'mean_cosine_distance') 2826 mean_distance = math_ops.subtract(1.0, mean_distance) 2827 update_op = math_ops.subtract(1.0, update_op) 2828 2829 if metrics_collections: 2830 ops.add_to_collections(metrics_collections, mean_distance) 2831 2832 if updates_collections: 2833 ops.add_to_collections(updates_collections, update_op) 2834 2835 return mean_distance, update_op 2836 2837 2838def streaming_percentage_less(values, 2839 threshold, 2840 weights=None, 2841 metrics_collections=None, 2842 updates_collections=None, 2843 name=None): 2844 """Computes the percentage of values less than the given threshold. 2845 2846 The `streaming_percentage_less` function creates two local variables, 2847 `total` and `count` that are used to compute the percentage of `values` that 2848 fall below `threshold`. This rate is weighted by `weights`, and it is 2849 ultimately returned as `percentage` which is an idempotent operation that 2850 simply divides `total` by `count`. 2851 2852 For estimation of the metric over a stream of data, the function creates an 2853 `update_op` operation that updates these variables and returns the 2854 `percentage`. 2855 2856 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2857 2858 Args: 2859 values: A numeric `Tensor` of arbitrary size. 2860 threshold: A scalar threshold. 2861 weights: An optional `Tensor` whose shape is broadcastable to `values`. 2862 metrics_collections: An optional list of collections that the metric 2863 value variable should be added to. 2864 updates_collections: An optional list of collections that the metric update 2865 ops should be added to. 2866 name: An optional variable_scope name. 2867 2868 Returns: 2869 percentage: A `Tensor` representing the current mean, the value of `total` 2870 divided by `count`. 2871 update_op: An operation that increments the `total` and `count` variables 2872 appropriately. 2873 2874 Raises: 2875 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 2876 or if either `metrics_collections` or `updates_collections` are not a list 2877 or tuple. 2878 """ 2879 return metrics.percentage_below( 2880 values=values, 2881 threshold=threshold, 2882 weights=weights, 2883 metrics_collections=metrics_collections, 2884 updates_collections=updates_collections, 2885 name=name) 2886 2887 2888def streaming_mean_iou(predictions, 2889 labels, 2890 num_classes, 2891 weights=None, 2892 metrics_collections=None, 2893 updates_collections=None, 2894 name=None): 2895 """Calculate per-step mean Intersection-Over-Union (mIOU). 2896 2897 Mean Intersection-Over-Union is a common evaluation metric for 2898 semantic image segmentation, which first computes the IOU for each 2899 semantic class and then computes the average over classes. 2900 IOU is defined as follows: 2901 IOU = true_positive / (true_positive + false_positive + false_negative). 2902 The predictions are accumulated in a confusion matrix, weighted by `weights`, 2903 and mIOU is then calculated from it. 2904 2905 For estimation of the metric over a stream of data, the function creates an 2906 `update_op` operation that updates these variables and returns the `mean_iou`. 2907 2908 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2909 2910 Args: 2911 predictions: A `Tensor` of prediction results for semantic labels, whose 2912 shape is [batch size] and type `int32` or `int64`. The tensor will be 2913 flattened, if its rank > 1. 2914 labels: A `Tensor` of ground truth labels with shape [batch size] and of 2915 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 2916 num_classes: The possible number of labels the prediction task can 2917 have. This value must be provided, since a confusion matrix of 2918 dimension = [num_classes, num_classes] will be allocated. 2919 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 2920 metrics_collections: An optional list of collections that `mean_iou` 2921 should be added to. 2922 updates_collections: An optional list of collections `update_op` should be 2923 added to. 2924 name: An optional variable_scope name. 2925 2926 Returns: 2927 mean_iou: A `Tensor` representing the mean intersection-over-union. 2928 update_op: An operation that increments the confusion matrix. 2929 2930 Raises: 2931 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2932 `weights` is not `None` and its shape doesn't match `predictions`, or if 2933 either `metrics_collections` or `updates_collections` are not a list or 2934 tuple. 2935 """ 2936 return metrics.mean_iou( 2937 num_classes=num_classes, 2938 predictions=predictions, 2939 labels=labels, 2940 weights=weights, 2941 metrics_collections=metrics_collections, 2942 updates_collections=updates_collections, 2943 name=name) 2944 2945 2946def _next_array_size(required_size, growth_factor=1.5): 2947 """Calculate the next size for reallocating a dynamic array. 2948 2949 Args: 2950 required_size: number or tf.Tensor specifying required array capacity. 2951 growth_factor: optional number or tf.Tensor specifying the growth factor 2952 between subsequent allocations. 2953 2954 Returns: 2955 tf.Tensor with dtype=int32 giving the next array size. 2956 """ 2957 exponent = math_ops.ceil( 2958 math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( 2959 math_ops.cast(growth_factor, dtypes.float32))) 2960 return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) 2961 2962 2963def streaming_concat(values, 2964 axis=0, 2965 max_size=None, 2966 metrics_collections=None, 2967 updates_collections=None, 2968 name=None): 2969 """Concatenate values along an axis across batches. 2970 2971 The function `streaming_concat` creates two local variables, `array` and 2972 `size`, that are used to store concatenated values. Internally, `array` is 2973 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 2974 that updates can be run in amortized constant time. 2975 2976 For estimation of the metric over a stream of data, the function creates an 2977 `update_op` operation that appends the values of a tensor and returns the 2978 length of the concatenated axis. 2979 2980 This op allows for evaluating metrics that cannot be updated incrementally 2981 using the same framework as other streaming metrics. 2982 2983 Args: 2984 values: `Tensor` to concatenate. Rank and the shape along all axes other 2985 than the axis to concatenate along must be statically known. 2986 axis: optional integer axis to concatenate along. 2987 max_size: optional integer maximum size of `value` along the given axis. 2988 Once the maximum size is reached, further updates are no-ops. By default, 2989 there is no maximum size: the array is resized as necessary. 2990 metrics_collections: An optional list of collections that `value` 2991 should be added to. 2992 updates_collections: An optional list of collections `update_op` should be 2993 added to. 2994 name: An optional variable_scope name. 2995 2996 Returns: 2997 value: A `Tensor` representing the concatenated values. 2998 update_op: An operation that concatenates the next values. 2999 3000 Raises: 3001 ValueError: if `values` does not have a statically known rank, `axis` is 3002 not in the valid range or the size of `values` is not statically known 3003 along any axis other than `axis`. 3004 """ 3005 with variable_scope.variable_scope(name, 'streaming_concat', (values,)): 3006 # pylint: disable=invalid-slice-index 3007 values_shape = values.get_shape() 3008 if values_shape.dims is None: 3009 raise ValueError('`values` must have known statically known rank') 3010 3011 ndim = len(values_shape) 3012 if axis < 0: 3013 axis += ndim 3014 if not 0 <= axis < ndim: 3015 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 3016 3017 fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] 3018 if any(value is None for value in fixed_shape): 3019 raise ValueError('all dimensions of `values` other than the dimension to ' 3020 'concatenate along must have statically known size') 3021 3022 # We move `axis` to the front of the internal array so assign ops can be 3023 # applied to contiguous slices 3024 init_size = 0 if max_size is None else max_size 3025 init_shape = [init_size] + fixed_shape 3026 array = _create_local( 3027 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) 3028 size = _create_local('size', shape=[], dtype=dtypes.int32) 3029 3030 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 3031 valid_array = array[:size] 3032 valid_array.set_shape([None] + fixed_shape) 3033 value = array_ops.transpose(valid_array, perm, name='concat') 3034 3035 values_size = array_ops.shape(values)[axis] 3036 if max_size is None: 3037 batch_size = values_size 3038 else: 3039 batch_size = math_ops.minimum(values_size, max_size - size) 3040 3041 perm = [axis] + [n for n in range(ndim) if n != axis] 3042 batch_values = array_ops.transpose(values, perm)[:batch_size] 3043 3044 def reallocate(): 3045 next_size = _next_array_size(new_size) 3046 next_shape = array_ops.stack([next_size] + fixed_shape) 3047 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 3048 old_value = array.value() 3049 assign_op = state_ops.assign(array, new_value, validate_shape=False) 3050 with ops.control_dependencies([assign_op]): 3051 copy_op = array[:size].assign(old_value[:size]) 3052 # return value needs to be the same dtype as no_op() for cond 3053 with ops.control_dependencies([copy_op]): 3054 return control_flow_ops.no_op() 3055 3056 new_size = size + batch_size 3057 array_size = array_ops.shape_internal(array, optimize=False)[0] 3058 maybe_reallocate_op = control_flow_ops.cond( 3059 new_size > array_size, reallocate, control_flow_ops.no_op) 3060 with ops.control_dependencies([maybe_reallocate_op]): 3061 append_values_op = array[size:new_size].assign(batch_values) 3062 with ops.control_dependencies([append_values_op]): 3063 update_op = size.assign(new_size) 3064 3065 if metrics_collections: 3066 ops.add_to_collections(metrics_collections, value) 3067 3068 if updates_collections: 3069 ops.add_to_collections(updates_collections, update_op) 3070 3071 return value, update_op 3072 # pylint: enable=invalid-slice-index 3073 3074 3075def aggregate_metrics(*value_update_tuples): 3076 """Aggregates the metric value tensors and update ops into two lists. 3077 3078 Args: 3079 *value_update_tuples: a variable number of tuples, each of which contain the 3080 pair of (value_tensor, update_op) from a streaming metric. 3081 3082 Returns: 3083 A list of value `Tensor` objects and a list of update ops. 3084 3085 Raises: 3086 ValueError: if `value_update_tuples` is empty. 3087 """ 3088 if not value_update_tuples: 3089 raise ValueError('Expected at least one value_tensor/update_op pair') 3090 value_ops, update_ops = zip(*value_update_tuples) 3091 return list(value_ops), list(update_ops) 3092 3093 3094def aggregate_metric_map(names_to_tuples): 3095 """Aggregates the metric names to tuple dictionary. 3096 3097 This function is useful for pairing metric names with their associated value 3098 and update ops when the list of metrics is long. For example: 3099 3100 ```python 3101 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 3102 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 3103 predictions, labels, weights), 3104 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 3105 predictions, labels, labels, weights), 3106 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 3107 predictions, labels, weights), 3108 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 3109 predictions, labels, weights), 3110 }) 3111 ``` 3112 3113 Args: 3114 names_to_tuples: a map of metric names to tuples, each of which contain the 3115 pair of (value_tensor, update_op) from a streaming metric. 3116 3117 Returns: 3118 A dictionary from metric names to value ops and a dictionary from metric 3119 names to update ops. 3120 """ 3121 metric_names = names_to_tuples.keys() 3122 value_ops, update_ops = zip(*names_to_tuples.values()) 3123 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 3124 3125 3126def _remove_squeezable_dimensions(predictions, labels, weights): 3127 """Squeeze last dim if needed. 3128 3129 Squeezes `predictions` and `labels` if their rank differs by 1. 3130 Squeezes `weights` if its rank is 1 more than the new rank of `predictions` 3131 3132 This will use static shape if available. Otherwise, it will add graph 3133 operations, which could result in a performance hit. 3134 3135 Args: 3136 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 3137 labels: Label values, a `Tensor` whose dimensions match `predictions`. 3138 weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 3139 more than the new rank of `predictions` 3140 3141 Returns: 3142 Tuple of `predictions`, `labels` and `weights`, possibly with the last 3143 dimension squeezed. 3144 """ 3145 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 3146 labels, predictions) 3147 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3148 3149 if weights is not None: 3150 weights = ops.convert_to_tensor(weights) 3151 predictions_shape = predictions.get_shape() 3152 predictions_rank = predictions_shape.ndims 3153 weights_shape = weights.get_shape() 3154 weights_rank = weights_shape.ndims 3155 3156 if (predictions_rank is not None) and (weights_rank is not None): 3157 # Use static rank. 3158 if weights_rank - predictions_rank == 1: 3159 weights = array_ops.squeeze(weights, [-1]) 3160 elif (weights_rank is 3161 None) or (weights_shape.dims[-1].is_compatible_with(1)): 3162 # Use dynamic rank 3163 weights = control_flow_ops.cond( 3164 math_ops.equal( 3165 array_ops.rank(weights), 3166 math_ops.add(array_ops.rank(predictions), 1)), 3167 lambda: array_ops.squeeze(weights, [-1]), lambda: weights) 3168 return predictions, labels, weights 3169 3170 3171__all__ = [ 3172 'aggregate_metric_map', 3173 'aggregate_metrics', 3174 'sparse_recall_at_top_k', 3175 'streaming_accuracy', 3176 'streaming_auc', 3177 'streaming_curve_points', 3178 'streaming_false_negative_rate', 3179 'streaming_false_negative_rate_at_thresholds', 3180 'streaming_false_negatives', 3181 'streaming_false_negatives_at_thresholds', 3182 'streaming_false_positive_rate', 3183 'streaming_false_positive_rate_at_thresholds', 3184 'streaming_false_positives', 3185 'streaming_false_positives_at_thresholds', 3186 'streaming_mean', 3187 'streaming_mean_absolute_error', 3188 'streaming_mean_cosine_distance', 3189 'streaming_mean_iou', 3190 'streaming_mean_relative_error', 3191 'streaming_mean_squared_error', 3192 'streaming_mean_tensor', 3193 'streaming_percentage_less', 3194 'streaming_precision', 3195 'streaming_precision_at_thresholds', 3196 'streaming_recall', 3197 'streaming_recall_at_k', 3198 'streaming_recall_at_thresholds', 3199 'streaming_root_mean_squared_error', 3200 'streaming_sensitivity_at_specificity', 3201 'streaming_sparse_average_precision_at_k', 3202 'streaming_sparse_average_precision_at_top_k', 3203 'streaming_sparse_precision_at_k', 3204 'streaming_sparse_precision_at_top_k', 3205 'streaming_sparse_recall_at_k', 3206 'streaming_specificity_at_sensitivity', 3207 'streaming_true_negatives', 3208 'streaming_true_negatives_at_thresholds', 3209 'streaming_true_positives', 3210 'streaming_true_positives_at_thresholds', 3211] 3212