metric_ops.py revision 972fa89023f8f27948321c388fa3f1f7857833c3
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains metric-computing operations on streamed tensors. 16 17Module documentation, including "@@" callouts, should be put in 18third_party/tensorflow/contrib/metrics/__init__.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections as collections_lib 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import confusion_matrix 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import metrics 36from tensorflow.python.ops import metrics_impl 37from tensorflow.python.ops import nn 38from tensorflow.python.ops import state_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.ops import weights_broadcast_ops 41from tensorflow.python.ops.distributions.normal import Normal 42from tensorflow.python.util.deprecation import deprecated 43 44# Epsilon constant used to represent extremely small quantity. 45_EPSILON = 1e-7 46 47 48def _safe_div(numerator, denominator, name): 49 """Divides two values, returning 0 if the denominator is <= 0. 50 51 Args: 52 numerator: A real `Tensor`. 53 denominator: A real `Tensor`, with dtype matching `numerator`. 54 name: Name for the returned op. 55 56 Returns: 57 0 if `denominator` <= 0, else `numerator` / `denominator` 58 """ 59 return array_ops.where( 60 math_ops.greater(denominator, 0), 61 math_ops.truediv(numerator, denominator), 62 0, 63 name=name) 64 65 66def streaming_true_positives(predictions, 67 labels, 68 weights=None, 69 metrics_collections=None, 70 updates_collections=None, 71 name=None): 72 """Sum the weights of true_positives. 73 74 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 75 76 Args: 77 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 78 be cast to `bool`. 79 labels: The ground truth values, a `Tensor` whose dimensions must match 80 `predictions`. Will be cast to `bool`. 81 weights: Optional `Tensor` whose rank is either 0, or the same rank as 82 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 83 must be either `1`, or the same as the corresponding `labels` 84 dimension). 85 metrics_collections: An optional list of collections that the metric 86 value variable should be added to. 87 updates_collections: An optional list of collections that the metric update 88 ops should be added to. 89 name: An optional variable_scope name. 90 91 Returns: 92 value_tensor: A `Tensor` representing the current value of the metric. 93 update_op: An operation that accumulates the error from a batch of data. 94 95 Raises: 96 ValueError: If `predictions` and `labels` have mismatched shapes, or if 97 `weights` is not `None` and its shape doesn't match `predictions`, or if 98 either `metrics_collections` or `updates_collections` are not a list or 99 tuple. 100 """ 101 return metrics.true_positives( 102 predictions=predictions, 103 labels=labels, 104 weights=weights, 105 metrics_collections=metrics_collections, 106 updates_collections=updates_collections, 107 name=name) 108 109 110def streaming_true_negatives(predictions, 111 labels, 112 weights=None, 113 metrics_collections=None, 114 updates_collections=None, 115 name=None): 116 """Sum the weights of true_negatives. 117 118 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 119 120 Args: 121 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 122 be cast to `bool`. 123 labels: The ground truth values, a `Tensor` whose dimensions must match 124 `predictions`. Will be cast to `bool`. 125 weights: Optional `Tensor` whose rank is either 0, or the same rank as 126 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 127 must be either `1`, or the same as the corresponding `labels` 128 dimension). 129 metrics_collections: An optional list of collections that the metric 130 value variable should be added to. 131 updates_collections: An optional list of collections that the metric update 132 ops should be added to. 133 name: An optional variable_scope name. 134 135 Returns: 136 value_tensor: A `Tensor` representing the current value of the metric. 137 update_op: An operation that accumulates the error from a batch of data. 138 139 Raises: 140 ValueError: If `predictions` and `labels` have mismatched shapes, or if 141 `weights` is not `None` and its shape doesn't match `predictions`, or if 142 either `metrics_collections` or `updates_collections` are not a list or 143 tuple. 144 """ 145 return metrics.true_negatives( 146 predictions=predictions, 147 labels=labels, 148 weights=weights, 149 metrics_collections=metrics_collections, 150 updates_collections=updates_collections, 151 name=name) 152 153 154def streaming_false_positives(predictions, 155 labels, 156 weights=None, 157 metrics_collections=None, 158 updates_collections=None, 159 name=None): 160 """Sum the weights of false positives. 161 162 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 163 164 Args: 165 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 166 be cast to `bool`. 167 labels: The ground truth values, a `Tensor` whose dimensions must match 168 `predictions`. Will be cast to `bool`. 169 weights: Optional `Tensor` whose rank is either 0, or the same rank as 170 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 171 must be either `1`, or the same as the corresponding `labels` 172 dimension). 173 metrics_collections: An optional list of collections that the metric 174 value variable should be added to. 175 updates_collections: An optional list of collections that the metric update 176 ops should be added to. 177 name: An optional variable_scope name. 178 179 Returns: 180 value_tensor: A `Tensor` representing the current value of the metric. 181 update_op: An operation that accumulates the error from a batch of data. 182 183 Raises: 184 ValueError: If `predictions` and `labels` have mismatched shapes, or if 185 `weights` is not `None` and its shape doesn't match `predictions`, or if 186 either `metrics_collections` or `updates_collections` are not a list or 187 tuple. 188 """ 189 return metrics.false_positives( 190 predictions=predictions, 191 labels=labels, 192 weights=weights, 193 metrics_collections=metrics_collections, 194 updates_collections=updates_collections, 195 name=name) 196 197 198def streaming_false_negatives(predictions, 199 labels, 200 weights=None, 201 metrics_collections=None, 202 updates_collections=None, 203 name=None): 204 """Computes the total number of false negatives. 205 206 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 207 208 Args: 209 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 210 be cast to `bool`. 211 labels: The ground truth values, a `Tensor` whose dimensions must match 212 `predictions`. Will be cast to `bool`. 213 weights: Optional `Tensor` whose rank is either 0, or the same rank as 214 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 215 must be either `1`, or the same as the corresponding `labels` 216 dimension). 217 metrics_collections: An optional list of collections that the metric 218 value variable should be added to. 219 updates_collections: An optional list of collections that the metric update 220 ops should be added to. 221 name: An optional variable_scope name. 222 223 Returns: 224 value_tensor: A `Tensor` representing the current value of the metric. 225 update_op: An operation that accumulates the error from a batch of data. 226 227 Raises: 228 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 229 or if either `metrics_collections` or `updates_collections` are not a list 230 or tuple. 231 """ 232 return metrics.false_negatives( 233 predictions=predictions, 234 labels=labels, 235 weights=weights, 236 metrics_collections=metrics_collections, 237 updates_collections=updates_collections, 238 name=name) 239 240 241def streaming_mean(values, 242 weights=None, 243 metrics_collections=None, 244 updates_collections=None, 245 name=None): 246 """Computes the (weighted) mean of the given values. 247 248 The `streaming_mean` function creates two local variables, `total` and `count` 249 that are used to compute the average of `values`. This average is ultimately 250 returned as `mean` which is an idempotent operation that simply divides 251 `total` by `count`. 252 253 For estimation of the metric over a stream of data, the function creates an 254 `update_op` operation that updates these variables and returns the `mean`. 255 `update_op` increments `total` with the reduced sum of the product of `values` 256 and `weights`, and it increments `count` with the reduced sum of `weights`. 257 258 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 259 260 Args: 261 values: A `Tensor` of arbitrary dimensions. 262 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 263 must be broadcastable to `values` (i.e., all dimensions must be either 264 `1`, or the same as the corresponding `values` dimension). 265 metrics_collections: An optional list of collections that `mean` 266 should be added to. 267 updates_collections: An optional list of collections that `update_op` 268 should be added to. 269 name: An optional variable_scope name. 270 271 Returns: 272 mean: A `Tensor` representing the current mean, the value of `total` divided 273 by `count`. 274 update_op: An operation that increments the `total` and `count` variables 275 appropriately and whose value matches `mean`. 276 277 Raises: 278 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 279 or if either `metrics_collections` or `updates_collections` are not a list 280 or tuple. 281 """ 282 return metrics.mean( 283 values=values, 284 weights=weights, 285 metrics_collections=metrics_collections, 286 updates_collections=updates_collections, 287 name=name) 288 289 290def streaming_mean_tensor(values, 291 weights=None, 292 metrics_collections=None, 293 updates_collections=None, 294 name=None): 295 """Computes the element-wise (weighted) mean of the given tensors. 296 297 In contrast to the `streaming_mean` function which returns a scalar with the 298 mean, this function returns an average tensor with the same shape as the 299 input tensors. 300 301 The `streaming_mean_tensor` function creates two local variables, 302 `total_tensor` and `count_tensor` that are used to compute the average of 303 `values`. This average is ultimately returned as `mean` which is an idempotent 304 operation that simply divides `total` by `count`. 305 306 For estimation of the metric over a stream of data, the function creates an 307 `update_op` operation that updates these variables and returns the `mean`. 308 `update_op` increments `total` with the reduced sum of the product of `values` 309 and `weights`, and it increments `count` with the reduced sum of `weights`. 310 311 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 312 313 Args: 314 values: A `Tensor` of arbitrary dimensions. 315 weights: `Tensor` whose rank is either 0, or the same rank as `values`, and 316 must be broadcastable to `values` (i.e., all dimensions must be either 317 `1`, or the same as the corresponding `values` dimension). 318 metrics_collections: An optional list of collections that `mean` 319 should be added to. 320 updates_collections: An optional list of collections that `update_op` 321 should be added to. 322 name: An optional variable_scope name. 323 324 Returns: 325 mean: A float `Tensor` representing the current mean, the value of `total` 326 divided by `count`. 327 update_op: An operation that increments the `total` and `count` variables 328 appropriately and whose value matches `mean`. 329 330 Raises: 331 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 332 or if either `metrics_collections` or `updates_collections` are not a list 333 or tuple. 334 """ 335 return metrics.mean_tensor( 336 values=values, 337 weights=weights, 338 metrics_collections=metrics_collections, 339 updates_collections=updates_collections, 340 name=name) 341 342 343@deprecated(None, 344 'Please switch to tf.metrics.accuracy. Note that the order of the ' 345 'labels and predictions arguments has been switched.') 346def streaming_accuracy(predictions, 347 labels, 348 weights=None, 349 metrics_collections=None, 350 updates_collections=None, 351 name=None): 352 """Calculates how often `predictions` matches `labels`. 353 354 The `streaming_accuracy` function creates two local variables, `total` and 355 `count` that are used to compute the frequency with which `predictions` 356 matches `labels`. This frequency is ultimately returned as `accuracy`: an 357 idempotent operation that simply divides `total` by `count`. 358 359 For estimation of the metric over a stream of data, the function creates an 360 `update_op` operation that updates these variables and returns the `accuracy`. 361 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 362 where the corresponding elements of `predictions` and `labels` match and 0.0 363 otherwise. Then `update_op` increments `total` with the reduced sum of the 364 product of `weights` and `is_correct`, and it increments `count` with the 365 reduced sum of `weights`. 366 367 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 368 369 Args: 370 predictions: The predicted values, a `Tensor` of any shape. 371 labels: The ground truth values, a `Tensor` whose shape matches 372 `predictions`. 373 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 374 must be broadcastable to `labels` (i.e., all dimensions must be either 375 `1`, or the same as the corresponding `labels` dimension). 376 metrics_collections: An optional list of collections that `accuracy` should 377 be added to. 378 updates_collections: An optional list of collections that `update_op` should 379 be added to. 380 name: An optional variable_scope name. 381 382 Returns: 383 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 384 by `count`. 385 update_op: An operation that increments the `total` and `count` variables 386 appropriately and whose value matches `accuracy`. 387 388 Raises: 389 ValueError: If `predictions` and `labels` have mismatched shapes, or if 390 `weights` is not `None` and its shape doesn't match `predictions`, or if 391 either `metrics_collections` or `updates_collections` are not a list or 392 tuple. 393 """ 394 return metrics.accuracy( 395 predictions=predictions, 396 labels=labels, 397 weights=weights, 398 metrics_collections=metrics_collections, 399 updates_collections=updates_collections, 400 name=name) 401 402 403def streaming_precision(predictions, 404 labels, 405 weights=None, 406 metrics_collections=None, 407 updates_collections=None, 408 name=None): 409 """Computes the precision of the predictions with respect to the labels. 410 411 The `streaming_precision` function creates two local variables, 412 `true_positives` and `false_positives`, that are used to compute the 413 precision. This value is ultimately returned as `precision`, an idempotent 414 operation that simply divides `true_positives` by the sum of `true_positives` 415 and `false_positives`. 416 417 For estimation of the metric over a stream of data, the function creates an 418 `update_op` operation that updates these variables and returns the 419 `precision`. `update_op` weights each prediction by the corresponding value in 420 `weights`. 421 422 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 423 424 Args: 425 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 426 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 427 match `predictions`. 428 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 429 must be broadcastable to `labels` (i.e., all dimensions must be either 430 `1`, or the same as the corresponding `labels` dimension). 431 metrics_collections: An optional list of collections that `precision` should 432 be added to. 433 updates_collections: An optional list of collections that `update_op` should 434 be added to. 435 name: An optional variable_scope name. 436 437 Returns: 438 precision: Scalar float `Tensor` with the value of `true_positives` 439 divided by the sum of `true_positives` and `false_positives`. 440 update_op: `Operation` that increments `true_positives` and 441 `false_positives` variables appropriately and whose value matches 442 `precision`. 443 444 Raises: 445 ValueError: If `predictions` and `labels` have mismatched shapes, or if 446 `weights` is not `None` and its shape doesn't match `predictions`, or if 447 either `metrics_collections` or `updates_collections` are not a list or 448 tuple. 449 """ 450 return metrics.precision( 451 predictions=predictions, 452 labels=labels, 453 weights=weights, 454 metrics_collections=metrics_collections, 455 updates_collections=updates_collections, 456 name=name) 457 458 459def streaming_recall(predictions, 460 labels, 461 weights=None, 462 metrics_collections=None, 463 updates_collections=None, 464 name=None): 465 """Computes the recall of the predictions with respect to the labels. 466 467 The `streaming_recall` function creates two local variables, `true_positives` 468 and `false_negatives`, that are used to compute the recall. This value is 469 ultimately returned as `recall`, an idempotent operation that simply divides 470 `true_positives` by the sum of `true_positives` and `false_negatives`. 471 472 For estimation of the metric over a stream of data, the function creates an 473 `update_op` that updates these variables and returns the `recall`. `update_op` 474 weights each prediction by the corresponding value in `weights`. 475 476 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 477 478 Args: 479 predictions: The predicted values, a `bool` `Tensor` of arbitrary shape. 480 labels: The ground truth values, a `bool` `Tensor` whose dimensions must 481 match `predictions`. 482 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 483 must be broadcastable to `labels` (i.e., all dimensions must be either 484 `1`, or the same as the corresponding `labels` dimension). 485 metrics_collections: An optional list of collections that `recall` should 486 be added to. 487 updates_collections: An optional list of collections that `update_op` should 488 be added to. 489 name: An optional variable_scope name. 490 491 Returns: 492 recall: Scalar float `Tensor` with the value of `true_positives` divided 493 by the sum of `true_positives` and `false_negatives`. 494 update_op: `Operation` that increments `true_positives` and 495 `false_negatives` variables appropriately and whose value matches 496 `recall`. 497 498 Raises: 499 ValueError: If `predictions` and `labels` have mismatched shapes, or if 500 `weights` is not `None` and its shape doesn't match `predictions`, or if 501 either `metrics_collections` or `updates_collections` are not a list or 502 tuple. 503 """ 504 return metrics.recall( 505 predictions=predictions, 506 labels=labels, 507 weights=weights, 508 metrics_collections=metrics_collections, 509 updates_collections=updates_collections, 510 name=name) 511 512 513def streaming_false_positive_rate(predictions, 514 labels, 515 weights=None, 516 metrics_collections=None, 517 updates_collections=None, 518 name=None): 519 """Computes the false positive rate of predictions with respect to labels. 520 521 The `false_positive_rate` function creates two local variables, 522 `false_positives` and `true_negatives`, that are used to compute the 523 false positive rate. This value is ultimately returned as 524 `false_positive_rate`, an idempotent operation that simply divides 525 `false_positives` by the sum of `false_positives` and `true_negatives`. 526 527 For estimation of the metric over a stream of data, the function creates an 528 `update_op` operation that updates these variables and returns the 529 `false_positive_rate`. `update_op` weights each prediction by the 530 corresponding value in `weights`. 531 532 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 533 534 Args: 535 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 536 be cast to `bool`. 537 labels: The ground truth values, a `Tensor` whose dimensions must match 538 `predictions`. Will be cast to `bool`. 539 weights: Optional `Tensor` whose rank is either 0, or the same rank as 540 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 541 be either `1`, or the same as the corresponding `labels` dimension). 542 metrics_collections: An optional list of collections that 543 `false_positive_rate` should be added to. 544 updates_collections: An optional list of collections that `update_op` should 545 be added to. 546 name: An optional variable_scope name. 547 548 Returns: 549 false_positive_rate: Scalar float `Tensor` with the value of 550 `false_positives` divided by the sum of `false_positives` and 551 `true_negatives`. 552 update_op: `Operation` that increments `false_positives` and 553 `true_negatives` variables appropriately and whose value matches 554 `false_positive_rate`. 555 556 Raises: 557 ValueError: If `predictions` and `labels` have mismatched shapes, or if 558 `weights` is not `None` and its shape doesn't match `predictions`, or if 559 either `metrics_collections` or `updates_collections` are not a list or 560 tuple. 561 """ 562 with variable_scope.variable_scope(name, 'false_positive_rate', 563 (predictions, labels, weights)): 564 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 565 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 566 labels=math_ops.cast(labels, dtype=dtypes.bool), 567 weights=weights) 568 569 false_p, false_positives_update_op = metrics.false_positives( 570 labels=labels, 571 predictions=predictions, 572 weights=weights, 573 metrics_collections=None, 574 updates_collections=None, 575 name=None) 576 true_n, true_negatives_update_op = metrics.true_negatives( 577 labels=labels, 578 predictions=predictions, 579 weights=weights, 580 metrics_collections=None, 581 updates_collections=None, 582 name=None) 583 584 def compute_fpr(fp, tn, name): 585 return array_ops.where( 586 math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) 587 588 fpr = compute_fpr(false_p, true_n, 'value') 589 update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, 590 'update_op') 591 592 if metrics_collections: 593 ops.add_to_collections(metrics_collections, fpr) 594 595 if updates_collections: 596 ops.add_to_collections(updates_collections, update_op) 597 598 return fpr, update_op 599 600 601def streaming_false_negative_rate(predictions, 602 labels, 603 weights=None, 604 metrics_collections=None, 605 updates_collections=None, 606 name=None): 607 """Computes the false negative rate of predictions with respect to labels. 608 609 The `false_negative_rate` function creates two local variables, 610 `false_negatives` and `true_positives`, that are used to compute the 611 false positive rate. This value is ultimately returned as 612 `false_negative_rate`, an idempotent operation that simply divides 613 `false_negatives` by the sum of `false_negatives` and `true_positives`. 614 615 For estimation of the metric over a stream of data, the function creates an 616 `update_op` operation that updates these variables and returns the 617 `false_negative_rate`. `update_op` weights each prediction by the 618 corresponding value in `weights`. 619 620 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 621 622 Args: 623 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 624 be cast to `bool`. 625 labels: The ground truth values, a `Tensor` whose dimensions must match 626 `predictions`. Will be cast to `bool`. 627 weights: Optional `Tensor` whose rank is either 0, or the same rank as 628 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 629 be either `1`, or the same as the corresponding `labels` dimension). 630 metrics_collections: An optional list of collections that 631 `false_negative_rate` should be added to. 632 updates_collections: An optional list of collections that `update_op` should 633 be added to. 634 name: An optional variable_scope name. 635 636 Returns: 637 false_negative_rate: Scalar float `Tensor` with the value of 638 `false_negatives` divided by the sum of `false_negatives` and 639 `true_positives`. 640 update_op: `Operation` that increments `false_negatives` and 641 `true_positives` variables appropriately and whose value matches 642 `false_negative_rate`. 643 644 Raises: 645 ValueError: If `predictions` and `labels` have mismatched shapes, or if 646 `weights` is not `None` and its shape doesn't match `predictions`, or if 647 either `metrics_collections` or `updates_collections` are not a list or 648 tuple. 649 """ 650 with variable_scope.variable_scope(name, 'false_negative_rate', 651 (predictions, labels, weights)): 652 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 653 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 654 labels=math_ops.cast(labels, dtype=dtypes.bool), 655 weights=weights) 656 657 false_n, false_negatives_update_op = metrics.false_negatives( 658 labels, 659 predictions, 660 weights, 661 metrics_collections=None, 662 updates_collections=None, 663 name=None) 664 true_p, true_positives_update_op = metrics.true_positives( 665 labels, 666 predictions, 667 weights, 668 metrics_collections=None, 669 updates_collections=None, 670 name=None) 671 672 def compute_fnr(fn, tp, name): 673 return array_ops.where( 674 math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) 675 676 fnr = compute_fnr(false_n, true_p, 'value') 677 update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, 678 'update_op') 679 680 if metrics_collections: 681 ops.add_to_collections(metrics_collections, fnr) 682 683 if updates_collections: 684 ops.add_to_collections(updates_collections, update_op) 685 686 return fnr, update_op 687 688 689def _streaming_confusion_matrix_at_thresholds(predictions, 690 labels, 691 thresholds, 692 weights=None, 693 includes=None): 694 """Computes true_positives, false_negatives, true_negatives, false_positives. 695 696 This function creates up to four local variables, `true_positives`, 697 `true_negatives`, `false_positives` and `false_negatives`. 698 `true_positive[i]` is defined as the total weight of values in `predictions` 699 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 700 `false_negatives[i]` is defined as the total weight of values in `predictions` 701 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 702 `true_negatives[i]` is defined as the total weight of values in `predictions` 703 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 704 `false_positives[i]` is defined as the total weight of values in `predictions` 705 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 706 707 For estimation of these metrics over a stream of data, for each metric the 708 function respectively creates an `update_op` operation that updates the 709 variable and returns its value. 710 711 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 712 713 Args: 714 predictions: A floating point `Tensor` of arbitrary shape and whose values 715 are in the range `[0, 1]`. 716 labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast 717 to `bool`. 718 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 719 weights: Optional `Tensor` whose rank is either 0, or the same rank as 720 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 721 must be either `1`, or the same as the corresponding `labels` 722 dimension). 723 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 724 default to all four. 725 726 Returns: 727 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 728 `includes`. 729 update_ops: Dict of operations that increments the `values`. Keys are from 730 `includes`. 731 732 Raises: 733 ValueError: If `predictions` and `labels` have mismatched shapes, or if 734 `weights` is not `None` and its shape doesn't match `predictions`, or if 735 `includes` contains invalid keys. 736 """ 737 all_includes = ('tp', 'fn', 'tn', 'fp') 738 if includes is None: 739 includes = all_includes 740 else: 741 for include in includes: 742 if include not in all_includes: 743 raise ValueError('Invaild key: %s.' % include) 744 745 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 746 predictions, labels, weights) 747 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 748 749 num_thresholds = len(thresholds) 750 751 # Reshape predictions and labels. 752 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 753 labels_2d = array_ops.reshape( 754 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 755 756 # Use static shape if known. 757 num_predictions = predictions_2d.get_shape().as_list()[0] 758 759 # Otherwise use dynamic shape. 760 if num_predictions is None: 761 num_predictions = array_ops.shape(predictions_2d)[0] 762 thresh_tiled = array_ops.tile( 763 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 764 array_ops.stack([1, num_predictions])) 765 766 # Tile the predictions after thresholding them across different thresholds. 767 pred_is_pos = math_ops.greater( 768 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 769 thresh_tiled) 770 if ('fn' in includes) or ('tn' in includes): 771 pred_is_neg = math_ops.logical_not(pred_is_pos) 772 773 # Tile labels by number of thresholds 774 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 775 if ('fp' in includes) or ('tn' in includes): 776 label_is_neg = math_ops.logical_not(label_is_pos) 777 778 if weights is not None: 779 broadcast_weights = weights_broadcast_ops.broadcast_weights( 780 math_ops.to_float(weights), predictions) 781 weights_tiled = array_ops.tile( 782 array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) 783 thresh_tiled.get_shape().assert_is_compatible_with( 784 weights_tiled.get_shape()) 785 else: 786 weights_tiled = None 787 788 values = {} 789 update_ops = {} 790 791 if 'tp' in includes: 792 true_positives = metrics_impl.metric_variable( 793 [num_thresholds], dtypes.float32, name='true_positives') 794 is_true_positive = math_ops.to_float( 795 math_ops.logical_and(label_is_pos, pred_is_pos)) 796 if weights_tiled is not None: 797 is_true_positive *= weights_tiled 798 update_ops['tp'] = state_ops.assign_add(true_positives, 799 math_ops.reduce_sum( 800 is_true_positive, 1)) 801 values['tp'] = true_positives 802 803 if 'fn' in includes: 804 false_negatives = metrics_impl.metric_variable( 805 [num_thresholds], dtypes.float32, name='false_negatives') 806 is_false_negative = math_ops.to_float( 807 math_ops.logical_and(label_is_pos, pred_is_neg)) 808 if weights_tiled is not None: 809 is_false_negative *= weights_tiled 810 update_ops['fn'] = state_ops.assign_add(false_negatives, 811 math_ops.reduce_sum( 812 is_false_negative, 1)) 813 values['fn'] = false_negatives 814 815 if 'tn' in includes: 816 true_negatives = metrics_impl.metric_variable( 817 [num_thresholds], dtypes.float32, name='true_negatives') 818 is_true_negative = math_ops.to_float( 819 math_ops.logical_and(label_is_neg, pred_is_neg)) 820 if weights_tiled is not None: 821 is_true_negative *= weights_tiled 822 update_ops['tn'] = state_ops.assign_add(true_negatives, 823 math_ops.reduce_sum( 824 is_true_negative, 1)) 825 values['tn'] = true_negatives 826 827 if 'fp' in includes: 828 false_positives = metrics_impl.metric_variable( 829 [num_thresholds], dtypes.float32, name='false_positives') 830 is_false_positive = math_ops.to_float( 831 math_ops.logical_and(label_is_neg, pred_is_pos)) 832 if weights_tiled is not None: 833 is_false_positive *= weights_tiled 834 update_ops['fp'] = state_ops.assign_add(false_positives, 835 math_ops.reduce_sum( 836 is_false_positive, 1)) 837 values['fp'] = false_positives 838 839 return values, update_ops 840 841 842def streaming_true_positives_at_thresholds(predictions, 843 labels, 844 thresholds, 845 weights=None): 846 values, update_ops = _streaming_confusion_matrix_at_thresholds( 847 predictions, labels, thresholds, weights=weights, includes=('tp',)) 848 return values['tp'], update_ops['tp'] 849 850 851def streaming_false_negatives_at_thresholds(predictions, 852 labels, 853 thresholds, 854 weights=None): 855 values, update_ops = _streaming_confusion_matrix_at_thresholds( 856 predictions, labels, thresholds, weights=weights, includes=('fn',)) 857 return values['fn'], update_ops['fn'] 858 859 860def streaming_false_positives_at_thresholds(predictions, 861 labels, 862 thresholds, 863 weights=None): 864 values, update_ops = _streaming_confusion_matrix_at_thresholds( 865 predictions, labels, thresholds, weights=weights, includes=('fp',)) 866 return values['fp'], update_ops['fp'] 867 868 869def streaming_true_negatives_at_thresholds(predictions, 870 labels, 871 thresholds, 872 weights=None): 873 values, update_ops = _streaming_confusion_matrix_at_thresholds( 874 predictions, labels, thresholds, weights=weights, includes=('tn',)) 875 return values['tn'], update_ops['tn'] 876 877 878def streaming_curve_points(labels=None, 879 predictions=None, 880 weights=None, 881 num_thresholds=200, 882 metrics_collections=None, 883 updates_collections=None, 884 curve='ROC', 885 name=None): 886 """Computes curve (ROC or PR) values for a prespecified number of points. 887 888 The `streaming_curve_points` function creates four local variables, 889 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 890 that are used to compute the curve values. To discretize the curve, a linearly 891 spaced set of thresholds is used to compute pairs of recall and precision 892 values. 893 894 For best results, `predictions` should be distributed approximately uniformly 895 in the range [0, 1] and not peaked around 0 or 1. 896 897 For estimation of the metric over a stream of data, the function creates an 898 `update_op` operation that updates these variables. 899 900 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 901 902 Args: 903 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 904 `bool`. 905 predictions: A floating point `Tensor` of arbitrary shape and whose values 906 are in the range `[0, 1]`. 907 weights: Optional `Tensor` whose rank is either 0, or the same rank as 908 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 909 be either `1`, or the same as the corresponding `labels` dimension). 910 num_thresholds: The number of thresholds to use when discretizing the roc 911 curve. 912 metrics_collections: An optional list of collections that `auc` should be 913 added to. 914 updates_collections: An optional list of collections that `update_op` should 915 be added to. 916 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 917 'PR' for the Precision-Recall-curve. 918 name: An optional variable_scope name. 919 920 Returns: 921 points: A `Tensor` with shape [num_thresholds, 2] that contains points of 922 the curve. 923 update_op: An operation that increments the `true_positives`, 924 `true_negatives`, `false_positives` and `false_negatives` variables. 925 926 Raises: 927 ValueError: If `predictions` and `labels` have mismatched shapes, or if 928 `weights` is not `None` and its shape doesn't match `predictions`, or if 929 either `metrics_collections` or `updates_collections` are not a list or 930 tuple. 931 932 TODO(chizeng): Consider rewriting this method to make use of logic within the 933 precision_recall_at_equal_thresholds method (to improve run time). 934 """ 935 with variable_scope.variable_scope(name, 'curve_points', 936 (labels, predictions, weights)): 937 if curve != 'ROC' and curve != 'PR': 938 raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) 939 kepsilon = _EPSILON # to account for floating point imprecisions 940 thresholds = [ 941 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 942 ] 943 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 944 945 values, update_ops = _streaming_confusion_matrix_at_thresholds( 946 labels=labels, 947 predictions=predictions, 948 thresholds=thresholds, 949 weights=weights) 950 951 # Add epsilons to avoid dividing by 0. 952 epsilon = 1.0e-6 953 954 def compute_points(tp, fn, tn, fp): 955 """Computes the roc-auc or pr-auc based on confusion counts.""" 956 rec = math_ops.div(tp + epsilon, tp + fn + epsilon) 957 if curve == 'ROC': 958 fp_rate = math_ops.div(fp, fp + tn + epsilon) 959 return fp_rate, rec 960 else: # curve == 'PR'. 961 prec = math_ops.div(tp + epsilon, tp + fp + epsilon) 962 return rec, prec 963 964 xs, ys = compute_points(values['tp'], values['fn'], values['tn'], 965 values['fp']) 966 points = array_ops.stack([xs, ys], axis=1) 967 update_op = control_flow_ops.group(*update_ops.values()) 968 969 if metrics_collections: 970 ops.add_to_collections(metrics_collections, points) 971 972 if updates_collections: 973 ops.add_to_collections(updates_collections, update_op) 974 975 return points, update_op 976 977 978@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the ' 979 'labels and predictions arguments has been switched.') 980def streaming_auc(predictions, 981 labels, 982 weights=None, 983 num_thresholds=200, 984 metrics_collections=None, 985 updates_collections=None, 986 curve='ROC', 987 name=None): 988 """Computes the approximate AUC via a Riemann sum. 989 990 The `streaming_auc` function creates four local variables, `true_positives`, 991 `true_negatives`, `false_positives` and `false_negatives` that are used to 992 compute the AUC. To discretize the AUC curve, a linearly spaced set of 993 thresholds is used to compute pairs of recall and precision values. The area 994 under the ROC-curve is therefore computed using the height of the recall 995 values by the false positive rate, while the area under the PR-curve is the 996 computed using the height of the precision values by the recall. 997 998 This value is ultimately returned as `auc`, an idempotent operation that 999 computes the area under a discretized curve of precision versus recall values 1000 (computed using the aforementioned variables). The `num_thresholds` variable 1001 controls the degree of discretization with larger numbers of thresholds more 1002 closely approximating the true AUC. The quality of the approximation may vary 1003 dramatically depending on `num_thresholds`. 1004 1005 For best results, `predictions` should be distributed approximately uniformly 1006 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 1007 approximation may be poor if this is not the case. 1008 1009 For estimation of the metric over a stream of data, the function creates an 1010 `update_op` operation that updates these variables and returns the `auc`. 1011 1012 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1013 1014 Args: 1015 predictions: A floating point `Tensor` of arbitrary shape and whose values 1016 are in the range `[0, 1]`. 1017 labels: A `bool` `Tensor` whose shape matches `predictions`. 1018 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1019 must be broadcastable to `labels` (i.e., all dimensions must be either 1020 `1`, or the same as the corresponding `labels` dimension). 1021 num_thresholds: The number of thresholds to use when discretizing the roc 1022 curve. 1023 metrics_collections: An optional list of collections that `auc` should be 1024 added to. 1025 updates_collections: An optional list of collections that `update_op` should 1026 be added to. 1027 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 1028 'PR' for the Precision-Recall-curve. 1029 name: An optional variable_scope name. 1030 1031 Returns: 1032 auc: A scalar `Tensor` representing the current area-under-curve. 1033 update_op: An operation that increments the `true_positives`, 1034 `true_negatives`, `false_positives` and `false_negatives` variables 1035 appropriately and whose value matches `auc`. 1036 1037 Raises: 1038 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1039 `weights` is not `None` and its shape doesn't match `predictions`, or if 1040 either `metrics_collections` or `updates_collections` are not a list or 1041 tuple. 1042 """ 1043 return metrics.auc( 1044 predictions=predictions, 1045 labels=labels, 1046 weights=weights, 1047 metrics_collections=metrics_collections, 1048 num_thresholds=num_thresholds, 1049 curve=curve, 1050 updates_collections=updates_collections, 1051 name=name) 1052 1053 1054def _compute_dynamic_auc(labels, predictions, curve='ROC'): 1055 """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. 1056 1057 Computes the area under the ROC or PR curve using each prediction as a 1058 threshold. This could be slow for large batches, but has the advantage of not 1059 having its results degrade depending on the distribution of predictions. 1060 1061 Args: 1062 labels: A `Tensor` of ground truth labels with the same shape as 1063 `predictions` with values of 0 or 1 and type `int64`. 1064 predictions: A 1-D `Tensor` of predictions whose values are `float64`. 1065 curve: The name of the curve to be computed, 'ROC' for the Receiving 1066 Operating Characteristic or 'PR' for the Precision-Recall curve. 1067 1068 Returns: 1069 A scalar `Tensor` containing the area-under-curve value for the input. 1070 """ 1071 # Count the total number of positive and negative labels in the input. 1072 size = array_ops.size(predictions) 1073 total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32) 1074 1075 def continue_computing_dynamic_auc(): 1076 """Continues dynamic auc computation, entered if labels are not all equal. 1077 1078 Returns: 1079 A scalar `Tensor` containing the area-under-curve value. 1080 """ 1081 # Sort the predictions descending, and the corresponding labels as well. 1082 ordered_predictions, indices = nn.top_k(predictions, k=size) 1083 ordered_labels = array_ops.gather(labels, indices) 1084 1085 # Get the counts of the unique ordered predictions. 1086 _, _, counts = array_ops.unique_with_counts(ordered_predictions) 1087 1088 # Compute the indices of the split points between different predictions. 1089 splits = math_ops.cast( 1090 array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) 1091 1092 # Count the positives to the left of the split indices. 1093 positives = math_ops.cast( 1094 array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]), 1095 dtypes.int32) 1096 true_positives = array_ops.gather(positives, splits) 1097 if curve == 'ROC': 1098 # Count the negatives to the left of every split point and the total 1099 # number of negatives for computing the FPR. 1100 false_positives = math_ops.subtract(splits, true_positives) 1101 total_negative = size - total_positive 1102 x_axis_values = math_ops.truediv(false_positives, total_negative) 1103 y_axis_values = math_ops.truediv(true_positives, total_positive) 1104 elif curve == 'PR': 1105 x_axis_values = math_ops.truediv(true_positives, total_positive) 1106 # For conformance, set precision to 1 when the number of positive 1107 # classifications is 0. 1108 y_axis_values = array_ops.where( 1109 math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits), 1110 array_ops.ones_like(true_positives, dtype=dtypes.float64)) 1111 1112 # Calculate trapezoid areas. 1113 heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0 1114 widths = math_ops.abs( 1115 math_ops.subtract(x_axis_values[1:], x_axis_values[:-1])) 1116 return math_ops.reduce_sum(math_ops.multiply(heights, widths)) 1117 1118 # If all the labels are the same, AUC isn't well-defined (but raising an 1119 # exception seems excessive) so we return 0, otherwise we finish computing. 1120 return control_flow_ops.cond( 1121 math_ops.logical_or( 1122 math_ops.equal(total_positive, 0), math_ops.equal( 1123 total_positive, size)), 1124 true_fn=lambda: array_ops.constant(0, dtypes.float64), 1125 false_fn=continue_computing_dynamic_auc) 1126 1127 1128def streaming_dynamic_auc(labels, 1129 predictions, 1130 curve='ROC', 1131 metrics_collections=(), 1132 updates_collections=(), 1133 name=None): 1134 """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. 1135 1136 USAGE NOTE: this approach requires storing all of the predictions and labels 1137 for a single evaluation in memory, so it may not be usable when the evaluation 1138 batch size and/or the number of evaluation steps is very large. 1139 1140 Computes the area under the ROC or PR curve using each prediction as a 1141 threshold. This has the advantage of being resilient to the distribution of 1142 predictions by aggregating across batches, accumulating labels and predictions 1143 and performing the final calculation using all of the concatenated values. 1144 1145 Args: 1146 labels: A `Tensor` of ground truth labels with the same shape as `labels` 1147 and with values of 0 or 1 whose values are castable to `int64`. 1148 predictions: A `Tensor` of predictions whose values are castable to 1149 `float64`. Will be flattened into a 1-D `Tensor`. 1150 curve: The name of the curve for which to compute AUC, 'ROC' for the 1151 Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. 1152 metrics_collections: An optional iterable of collections that `auc` should 1153 be added to. 1154 updates_collections: An optional iterable of collections that `update_op` 1155 should be added to. 1156 name: An optional name for the variable_scope that contains the metric 1157 variables. 1158 1159 Returns: 1160 auc: A scalar `Tensor` containing the current area-under-curve value. 1161 update_op: An operation that concatenates the input labels and predictions 1162 to the accumulated values. 1163 1164 Raises: 1165 ValueError: If `labels` and `predictions` have mismatched shapes or if 1166 `curve` isn't a recognized curve type. 1167 """ 1168 1169 if curve not in ['PR', 'ROC']: 1170 raise ValueError('curve must be either ROC or PR, %s unknown' % curve) 1171 1172 with variable_scope.variable_scope(name, default_name='dynamic_auc'): 1173 labels.get_shape().assert_is_compatible_with(predictions.get_shape()) 1174 predictions = array_ops.reshape( 1175 math_ops.cast(predictions, dtypes.float64), [-1]) 1176 labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) 1177 with ops.control_dependencies([ 1178 check_ops.assert_greater_equal( 1179 labels, 1180 array_ops.zeros_like(labels, dtypes.int64), 1181 message='labels must be 0 or 1, at least one is <0'), 1182 check_ops.assert_less_equal( 1183 labels, 1184 array_ops.ones_like(labels, dtypes.int64), 1185 message='labels must be 0 or 1, at least one is >1') 1186 ]): 1187 preds_accum, update_preds = streaming_concat( 1188 predictions, name='concat_preds') 1189 labels_accum, update_labels = streaming_concat( 1190 labels, name='concat_labels') 1191 update_op = control_flow_ops.group(update_labels, update_preds) 1192 auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) 1193 if updates_collections: 1194 ops.add_to_collections(updates_collections, update_op) 1195 if metrics_collections: 1196 ops.add_to_collections(metrics_collections, auc) 1197 return auc, update_op 1198 1199 1200def _compute_placement_auc(labels, predictions, weights, alpha, 1201 logit_transformation, is_valid): 1202 """Computes the AUC and asymptotic normally distributed confidence interval. 1203 1204 The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the 1205 concept of placement values for each labeled group, as presented by Delong and 1206 Delong (1988). The actual algorithm used is a more computationally efficient 1207 approach presented by Sun and Xu (2014). This could be slow for large batches, 1208 but has the advantage of not having its results degrade depending on the 1209 distribution of predictions. 1210 1211 Args: 1212 labels: A `Tensor` of ground truth labels with the same shape as 1213 `predictions` with values of 0 or 1 and type `int64`. 1214 predictions: A 1-D `Tensor` of predictions whose values are `float64`. 1215 weights: `Tensor` whose rank is either 0, or the same rank as `labels`. 1216 alpha: Confidence interval level desired. 1217 logit_transformation: A boolean value indicating whether the estimate should 1218 be logit transformed prior to calculating the confidence interval. Doing 1219 so enforces the restriction that the AUC should never be outside the 1220 interval [0,1]. 1221 is_valid: A bool tensor describing whether the input is valid. 1222 1223 Returns: 1224 A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence 1225 interval values. 1226 """ 1227 # Disable the invalid-name checker so that we can capitalize the name. 1228 # pylint: disable=invalid-name 1229 AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper']) 1230 # pylint: enable=invalid-name 1231 1232 # If all the labels are the same or if number of observations are too few, 1233 # AUC isn't well-defined 1234 size = array_ops.size(predictions, out_type=dtypes.int32) 1235 1236 # Count the total number of positive and negative labels in the input. 1237 total_0 = math_ops.reduce_sum( 1238 math_ops.cast(1 - labels, weights.dtype) * weights) 1239 total_1 = math_ops.reduce_sum( 1240 math_ops.cast(labels, weights.dtype) * weights) 1241 1242 # Sort the predictions ascending, as well as 1243 # (i) the corresponding labels and 1244 # (ii) the corresponding weights. 1245 ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True) 1246 ordered_predictions = array_ops.reverse( 1247 ordered_predictions, axis=array_ops.zeros(1, dtypes.int32)) 1248 indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32)) 1249 ordered_labels = array_ops.gather(labels, indices) 1250 ordered_weights = array_ops.gather(weights, indices) 1251 1252 # We now compute values required for computing placement values. 1253 1254 # We generate a list of indices (segmented_indices) of increasing order. An 1255 # index is assigned for each unique prediction float value. Prediction 1256 # values that are the same share the same index. 1257 _, segmented_indices = array_ops.unique(ordered_predictions) 1258 1259 # We create 2 tensors of weights. weights_for_true is non-zero for true 1260 # labels. weights_for_false is non-zero for false labels. 1261 float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32) 1262 float_labels_for_false = 1.0 - float_labels_for_true 1263 weights_for_true = ordered_weights * float_labels_for_true 1264 weights_for_false = ordered_weights * float_labels_for_false 1265 1266 # For each set of weights with the same segmented indices, we add up the 1267 # weight values. Note that for each label, we deliberately rely on weights 1268 # for the opposite label. 1269 weight_totals_for_true = math_ops.segment_sum(weights_for_false, 1270 segmented_indices) 1271 weight_totals_for_false = math_ops.segment_sum(weights_for_true, 1272 segmented_indices) 1273 1274 # These cumulative sums of weights importantly exclude the current weight 1275 # sums. 1276 cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true, 1277 exclusive=True) 1278 cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false, 1279 exclusive=True) 1280 1281 # Compute placement values using the formula. Values with the same segmented 1282 # indices and labels share the same placement values. 1283 placements_for_true = ( 1284 (cum_weight_totals_for_true + weight_totals_for_true / 2.0) / 1285 (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON)) 1286 placements_for_false = ( 1287 (cum_weight_totals_for_false + weight_totals_for_false / 2.0) / 1288 (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON)) 1289 1290 # We expand the tensors of placement values (for each label) so that their 1291 # shapes match that of predictions. 1292 placements_for_true = array_ops.gather(placements_for_true, segmented_indices) 1293 placements_for_false = array_ops.gather(placements_for_false, 1294 segmented_indices) 1295 1296 # Select placement values based on the label for each index. 1297 placement_values = ( 1298 placements_for_true * float_labels_for_true + 1299 placements_for_false * float_labels_for_false) 1300 1301 # Split placement values by labeled groups. 1302 placement_values_0 = placement_values * math_ops.cast( 1303 1 - ordered_labels, weights.dtype) 1304 weights_0 = ordered_weights * math_ops.cast( 1305 1 - ordered_labels, weights.dtype) 1306 placement_values_1 = placement_values * math_ops.cast( 1307 ordered_labels, weights.dtype) 1308 weights_1 = ordered_weights * math_ops.cast( 1309 ordered_labels, weights.dtype) 1310 1311 # Calculate AUC using placement values 1312 auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) / 1313 (total_0 + _EPSILON)) 1314 auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) / 1315 (total_1 + _EPSILON)) 1316 auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0) 1317 1318 # Calculate variance and standard error using the placement values. 1319 var_0 = ( 1320 math_ops.reduce_sum( 1321 weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / 1322 (total_0 - 1. + _EPSILON)) 1323 var_1 = ( 1324 math_ops.reduce_sum( 1325 weights_1 * math_ops.square(placement_values_1 - auc_1)) / 1326 (total_1 - 1. + _EPSILON)) 1327 auc_std_err = math_ops.sqrt( 1328 (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) 1329 1330 # Calculate asymptotic normal confidence intervals 1331 std_norm_dist = Normal(loc=0., scale=1.) 1332 z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0) 1333 if logit_transformation: 1334 estimate = math_ops.log(auc / (1. - auc + _EPSILON)) 1335 std_err = auc_std_err / (auc * (1. - auc + _EPSILON)) 1336 transformed_auc_lower = estimate + (z_value * std_err) 1337 transformed_auc_upper = estimate - (z_value * std_err) 1338 def inverse_logit_transformation(x): 1339 exp_negative = math_ops.exp(math_ops.negative(x)) 1340 return 1. / (1. + exp_negative + _EPSILON) 1341 1342 auc_lower = inverse_logit_transformation(transformed_auc_lower) 1343 auc_upper = inverse_logit_transformation(transformed_auc_upper) 1344 else: 1345 estimate = auc 1346 std_err = auc_std_err 1347 auc_lower = estimate + (z_value * std_err) 1348 auc_upper = estimate - (z_value * std_err) 1349 1350 ## If estimate is 1 or 0, no variance is present so CI = 1 1351 ## n.b. This can be misleading, since number obs can just be too low. 1352 lower = array_ops.where( 1353 math_ops.logical_or( 1354 math_ops.equal(auc, array_ops.ones_like(auc)), 1355 math_ops.equal(auc, array_ops.zeros_like(auc))), 1356 auc, auc_lower) 1357 upper = array_ops.where( 1358 math_ops.logical_or( 1359 math_ops.equal(auc, array_ops.ones_like(auc)), 1360 math_ops.equal(auc, array_ops.zeros_like(auc))), 1361 auc, auc_upper) 1362 1363 # If all the labels are the same, AUC isn't well-defined (but raising an 1364 # exception seems excessive) so we return 0, otherwise we finish computing. 1365 trivial_value = array_ops.constant(0.0) 1366 1367 return AucData(*control_flow_ops.cond( 1368 is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3)) 1369 1370 1371def auc_with_confidence_intervals(labels, 1372 predictions, 1373 weights=None, 1374 alpha=0.95, 1375 logit_transformation=True, 1376 metrics_collections=(), 1377 updates_collections=(), 1378 name=None): 1379 """Computes the AUC and asymptotic normally distributed confidence interval. 1380 1381 USAGE NOTE: this approach requires storing all of the predictions and labels 1382 for a single evaluation in memory, so it may not be usable when the evaluation 1383 batch size and/or the number of evaluation steps is very large. 1384 1385 Computes the area under the ROC curve and its confidence interval using 1386 placement values. This has the advantage of being resilient to the 1387 distribution of predictions by aggregating across batches, accumulating labels 1388 and predictions and performing the final calculation using all of the 1389 concatenated values. 1390 1391 Args: 1392 labels: A `Tensor` of ground truth labels with the same shape as `labels` 1393 and with values of 0 or 1 whose values are castable to `int64`. 1394 predictions: A `Tensor` of predictions whose values are castable to 1395 `float64`. Will be flattened into a 1-D `Tensor`. 1396 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1397 `labels`. 1398 alpha: Confidence interval level desired. 1399 logit_transformation: A boolean value indicating whether the estimate should 1400 be logit transformed prior to calculating the confidence interval. Doing 1401 so enforces the restriction that the AUC should never be outside the 1402 interval [0,1]. 1403 metrics_collections: An optional iterable of collections that `auc` should 1404 be added to. 1405 updates_collections: An optional iterable of collections that `update_op` 1406 should be added to. 1407 name: An optional name for the variable_scope that contains the metric 1408 variables. 1409 1410 Returns: 1411 auc: A 1-D `Tensor` containing the current area-under-curve, lower, and 1412 upper confidence interval values. 1413 update_op: An operation that concatenates the input labels and predictions 1414 to the accumulated values. 1415 1416 Raises: 1417 ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes 1418 or if `alpha` isn't in the range (0,1). 1419 """ 1420 if not (alpha > 0 and alpha < 1): 1421 raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha) 1422 1423 if weights is None: 1424 weights = array_ops.ones_like(predictions) 1425 1426 with variable_scope.variable_scope( 1427 name, 1428 default_name='auc_with_confidence_intervals', 1429 values=[labels, predictions, weights]): 1430 1431 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 1432 predictions=predictions, 1433 labels=labels, 1434 weights=weights) 1435 1436 total_weight = math_ops.reduce_sum(weights) 1437 1438 weights = array_ops.reshape(weights, [-1]) 1439 predictions = array_ops.reshape( 1440 math_ops.cast(predictions, dtypes.float64), [-1]) 1441 labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) 1442 1443 with ops.control_dependencies([ 1444 check_ops.assert_greater_equal( 1445 labels, 1446 array_ops.zeros_like(labels, dtypes.int64), 1447 message='labels must be 0 or 1, at least one is <0'), 1448 check_ops.assert_less_equal( 1449 labels, 1450 array_ops.ones_like(labels, dtypes.int64), 1451 message='labels must be 0 or 1, at least one is >1'), 1452 ]): 1453 preds_accum, update_preds = streaming_concat( 1454 predictions, name='concat_preds') 1455 labels_accum, update_labels = streaming_concat(labels, 1456 name='concat_labels') 1457 weights_accum, update_weights = streaming_concat( 1458 weights, name='concat_weights') 1459 update_op_for_valid_case = control_flow_ops.group( 1460 update_labels, update_preds, update_weights) 1461 1462 # Only perform updates if this case is valid. 1463 all_labels_positive_or_0 = math_ops.logical_and( 1464 math_ops.equal(math_ops.reduce_min(labels), 0), 1465 math_ops.equal(math_ops.reduce_max(labels), 1)) 1466 sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0) 1467 is_valid = math_ops.logical_and(all_labels_positive_or_0, 1468 sums_of_weights_at_least_1) 1469 1470 update_op = control_flow_ops.cond( 1471 sums_of_weights_at_least_1, 1472 lambda: update_op_for_valid_case, control_flow_ops.no_op) 1473 1474 auc = _compute_placement_auc( 1475 labels_accum, 1476 preds_accum, 1477 weights_accum, 1478 alpha=alpha, 1479 logit_transformation=logit_transformation, 1480 is_valid=is_valid) 1481 1482 if updates_collections: 1483 ops.add_to_collections(updates_collections, update_op) 1484 if metrics_collections: 1485 ops.add_to_collections(metrics_collections, auc) 1486 return auc, update_op 1487 1488 1489def precision_recall_at_equal_thresholds(labels, 1490 predictions, 1491 weights=None, 1492 num_thresholds=None, 1493 use_locking=None, 1494 name=None): 1495 """A helper method for creating metrics related to precision-recall curves. 1496 1497 These values are true positives, false negatives, true negatives, false 1498 positives, precision, and recall. This function returns a data structure that 1499 contains ops within it. 1500 1501 Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N) 1502 space and run time), this op exhibits O(T + N) space and run time, where T is 1503 the number of thresholds and N is the size of the predictions tensor. Hence, 1504 it may be advantageous to use this function when `predictions` is big. 1505 1506 For instance, prefer this method for per-pixel classification tasks, for which 1507 the predictions tensor may be very large. 1508 1509 Each number in `predictions`, a float in `[0, 1]`, is compared with its 1510 corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at 1511 each threshold. This is then multiplied with `weights` which can be used to 1512 reweight certain values, or more commonly used for masking values. 1513 1514 Args: 1515 labels: A bool `Tensor` whose shape matches `predictions`. 1516 predictions: A floating point `Tensor` of arbitrary shape and whose values 1517 are in the range `[0, 1]`. 1518 weights: Optional; If provided, a `Tensor` that has the same dtype as, 1519 and broadcastable to, `predictions`. This tensor is multplied by counts. 1520 num_thresholds: Optional; Number of thresholds, evenly distributed in 1521 `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins 1522 is 1 less than `num_thresholds`. Using an even `num_thresholds` value 1523 instead of an odd one may yield unfriendly edges for bins. 1524 use_locking: Optional; If True, the op will be protected by a lock. 1525 Otherwise, the behavior is undefined, but may exhibit less contention. 1526 Defaults to True. 1527 name: Optional; variable_scope name. If not provided, the string 1528 'precision_recall_at_equal_threshold' is used. 1529 1530 Returns: 1531 result: A named tuple (See PrecisionRecallData within the implementation of 1532 this function) with properties that are variables of shape 1533 `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, 1534 precision, recall, thresholds. 1535 update_op: An op that accumulates values. 1536 1537 Raises: 1538 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1539 `weights` is not `None` and its shape doesn't match `predictions`, or if 1540 `includes` contains invalid keys. 1541 """ 1542 # Disable the invalid-name checker so that we can capitalize the name. 1543 # pylint: disable=invalid-name 1544 PrecisionRecallData = collections_lib.namedtuple( 1545 'PrecisionRecallData', 1546 ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds']) 1547 # pylint: enable=invalid-name 1548 1549 if num_thresholds is None: 1550 num_thresholds = 201 1551 1552 if weights is None: 1553 weights = 1.0 1554 1555 if use_locking is None: 1556 use_locking = True 1557 1558 check_ops.assert_type(labels, dtypes.bool) 1559 1560 dtype = predictions.dtype 1561 with variable_scope.variable_scope(name, 1562 'precision_recall_at_equal_thresholds', 1563 (labels, predictions, weights)): 1564 # Make sure that predictions are within [0.0, 1.0]. 1565 with ops.control_dependencies([ 1566 check_ops.assert_greater_equal( 1567 predictions, 1568 math_ops.cast(0.0, dtype=predictions.dtype), 1569 message='predictions must be in [0, 1]'), 1570 check_ops.assert_less_equal( 1571 predictions, 1572 math_ops.cast(1.0, dtype=predictions.dtype), 1573 message='predictions must be in [0, 1]') 1574 ]): 1575 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 1576 predictions=predictions, 1577 labels=labels, 1578 weights=weights) 1579 1580 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1581 1582 # We cast to float to ensure we have 0.0 or 1.0. 1583 f_labels = math_ops.cast(labels, dtype) 1584 1585 # Get weighted true/false labels. 1586 true_labels = f_labels * weights 1587 false_labels = (1.0 - f_labels) * weights 1588 1589 # Flatten predictions and labels. 1590 predictions = array_ops.reshape(predictions, [-1]) 1591 true_labels = array_ops.reshape(true_labels, [-1]) 1592 false_labels = array_ops.reshape(false_labels, [-1]) 1593 1594 # To compute TP/FP/TN/FN, we are measuring a binary classifier 1595 # C(t) = (predictions >= t) 1596 # at each threshold 't'. So we have 1597 # TP(t) = sum( C(t) * true_labels ) 1598 # FP(t) = sum( C(t) * false_labels ) 1599 # 1600 # But, computing C(t) requires computation for each t. To make it fast, 1601 # observe that C(t) is a cumulative integral, and so if we have 1602 # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 1603 # where n = num_thresholds, and if we can compute the bucket function 1604 # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 1605 # then we get 1606 # C(t_i) = sum( B(j), j >= i ) 1607 # which is the reversed cumulative sum in tf.cumsum(). 1608 # 1609 # We can compute B(i) efficiently by taking advantage of the fact that 1610 # our thresholds are evenly distributed, in that 1611 # width = 1.0 / (num_thresholds - 1) 1612 # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 1613 # Given a prediction value p, we can map it to its bucket by 1614 # bucket_index(p) = floor( p * (num_thresholds - 1) ) 1615 # so we can use tf.scatter_add() to update the buckets in one pass. 1616 # 1617 # This implementation exhibits a run time and space complexity of O(T + N), 1618 # where T is the number of thresholds and N is the size of predictions. 1619 # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead 1620 # exhibit a complexity of O(T * N). 1621 1622 # Compute the bucket indices for each prediction value. 1623 bucket_indices = math_ops.cast( 1624 math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) 1625 1626 with ops.name_scope('variables'): 1627 tp_buckets_v = metrics_impl.metric_variable( 1628 [num_thresholds], dtype, name='tp_buckets') 1629 fp_buckets_v = metrics_impl.metric_variable( 1630 [num_thresholds], dtype, name='fp_buckets') 1631 1632 with ops.name_scope('update_op'): 1633 update_tp = state_ops.scatter_add( 1634 tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking) 1635 update_fp = state_ops.scatter_add( 1636 fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking) 1637 1638 # Set up the cumulative sums to compute the actual metrics. 1639 tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp') 1640 fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp') 1641 # fn = sum(true_labels) - tp 1642 # = sum(tp_buckets) - tp 1643 # = tp[0] - tp 1644 # Similarly, 1645 # tn = fp[0] - fp 1646 tn = fp[0] - fp 1647 fn = tp[0] - tp 1648 1649 # We use a minimum to prevent division by 0. 1650 epsilon = 1e-7 1651 precision = tp / math_ops.maximum(epsilon, tp + fp) 1652 recall = tp / math_ops.maximum(epsilon, tp + fn) 1653 1654 result = PrecisionRecallData( 1655 tp=tp, 1656 fp=fp, 1657 tn=tn, 1658 fn=fn, 1659 precision=precision, 1660 recall=recall, 1661 thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) 1662 update_op = control_flow_ops.group(update_tp, update_fp) 1663 return result, update_op 1664 1665 1666def streaming_specificity_at_sensitivity(predictions, 1667 labels, 1668 sensitivity, 1669 weights=None, 1670 num_thresholds=200, 1671 metrics_collections=None, 1672 updates_collections=None, 1673 name=None): 1674 """Computes the specificity at a given sensitivity. 1675 1676 The `streaming_specificity_at_sensitivity` function creates four local 1677 variables, `true_positives`, `true_negatives`, `false_positives` and 1678 `false_negatives` that are used to compute the specificity at the given 1679 sensitivity value. The threshold for the given sensitivity value is computed 1680 and used to evaluate the corresponding specificity. 1681 1682 For estimation of the metric over a stream of data, the function creates an 1683 `update_op` operation that updates these variables and returns the 1684 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 1685 `false_positives` and `false_negatives` counts with the weight of each case 1686 found in the `predictions` and `labels`. 1687 1688 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1689 1690 For additional information about specificity and sensitivity, see the 1691 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1692 1693 Args: 1694 predictions: A floating point `Tensor` of arbitrary shape and whose values 1695 are in the range `[0, 1]`. 1696 labels: A `bool` `Tensor` whose shape matches `predictions`. 1697 sensitivity: A scalar value in range `[0, 1]`. 1698 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1699 must be broadcastable to `labels` (i.e., all dimensions must be either 1700 `1`, or the same as the corresponding `labels` dimension). 1701 num_thresholds: The number of thresholds to use for matching the given 1702 sensitivity. 1703 metrics_collections: An optional list of collections that `specificity` 1704 should be added to. 1705 updates_collections: An optional list of collections that `update_op` should 1706 be added to. 1707 name: An optional variable_scope name. 1708 1709 Returns: 1710 specificity: A scalar `Tensor` representing the specificity at the given 1711 `specificity` value. 1712 update_op: An operation that increments the `true_positives`, 1713 `true_negatives`, `false_positives` and `false_negatives` variables 1714 appropriately and whose value matches `specificity`. 1715 1716 Raises: 1717 ValueError: If `predictions` and `labels` have mismatched shapes, if 1718 `weights` is not `None` and its shape doesn't match `predictions`, or if 1719 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 1720 or `updates_collections` are not a list or tuple. 1721 """ 1722 return metrics.specificity_at_sensitivity( 1723 sensitivity=sensitivity, 1724 num_thresholds=num_thresholds, 1725 predictions=predictions, 1726 labels=labels, 1727 weights=weights, 1728 metrics_collections=metrics_collections, 1729 updates_collections=updates_collections, 1730 name=name) 1731 1732 1733def streaming_sensitivity_at_specificity(predictions, 1734 labels, 1735 specificity, 1736 weights=None, 1737 num_thresholds=200, 1738 metrics_collections=None, 1739 updates_collections=None, 1740 name=None): 1741 """Computes the sensitivity at a given specificity. 1742 1743 The `streaming_sensitivity_at_specificity` function creates four local 1744 variables, `true_positives`, `true_negatives`, `false_positives` and 1745 `false_negatives` that are used to compute the sensitivity at the given 1746 specificity value. The threshold for the given specificity value is computed 1747 and used to evaluate the corresponding sensitivity. 1748 1749 For estimation of the metric over a stream of data, the function creates an 1750 `update_op` operation that updates these variables and returns the 1751 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 1752 `false_positives` and `false_negatives` counts with the weight of each case 1753 found in the `predictions` and `labels`. 1754 1755 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1756 1757 For additional information about specificity and sensitivity, see the 1758 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 1759 1760 Args: 1761 predictions: A floating point `Tensor` of arbitrary shape and whose values 1762 are in the range `[0, 1]`. 1763 labels: A `bool` `Tensor` whose shape matches `predictions`. 1764 specificity: A scalar value in range `[0, 1]`. 1765 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1766 must be broadcastable to `labels` (i.e., all dimensions must be either 1767 `1`, or the same as the corresponding `labels` dimension). 1768 num_thresholds: The number of thresholds to use for matching the given 1769 specificity. 1770 metrics_collections: An optional list of collections that `sensitivity` 1771 should be added to. 1772 updates_collections: An optional list of collections that `update_op` should 1773 be added to. 1774 name: An optional variable_scope name. 1775 1776 Returns: 1777 sensitivity: A scalar `Tensor` representing the sensitivity at the given 1778 `specificity` value. 1779 update_op: An operation that increments the `true_positives`, 1780 `true_negatives`, `false_positives` and `false_negatives` variables 1781 appropriately and whose value matches `sensitivity`. 1782 1783 Raises: 1784 ValueError: If `predictions` and `labels` have mismatched shapes, if 1785 `weights` is not `None` and its shape doesn't match `predictions`, or if 1786 `specificity` is not between 0 and 1, or if either `metrics_collections` 1787 or `updates_collections` are not a list or tuple. 1788 """ 1789 return metrics.sensitivity_at_specificity( 1790 specificity=specificity, 1791 num_thresholds=num_thresholds, 1792 predictions=predictions, 1793 labels=labels, 1794 weights=weights, 1795 metrics_collections=metrics_collections, 1796 updates_collections=updates_collections, 1797 name=name) 1798 1799 1800@deprecated( 1801 None, 'Please switch to tf.metrics.precision_at_thresholds. Note that the ' 1802 'order of the labels and predictions arguments has been switched.') 1803def streaming_precision_at_thresholds(predictions, 1804 labels, 1805 thresholds, 1806 weights=None, 1807 metrics_collections=None, 1808 updates_collections=None, 1809 name=None): 1810 """Computes precision values for different `thresholds` on `predictions`. 1811 1812 The `streaming_precision_at_thresholds` function creates four local variables, 1813 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1814 for various values of thresholds. `precision[i]` is defined as the total 1815 weight of values in `predictions` above `thresholds[i]` whose corresponding 1816 entry in `labels` is `True`, divided by the total weight of values in 1817 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 1818 false_positives[i])`). 1819 1820 For estimation of the metric over a stream of data, the function creates an 1821 `update_op` operation that updates these variables and returns the 1822 `precision`. 1823 1824 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1825 1826 Args: 1827 predictions: A floating point `Tensor` of arbitrary shape and whose values 1828 are in the range `[0, 1]`. 1829 labels: A `bool` `Tensor` whose shape matches `predictions`. 1830 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1831 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1832 must be broadcastable to `labels` (i.e., all dimensions must be either 1833 `1`, or the same as the corresponding `labels` dimension). 1834 metrics_collections: An optional list of collections that `precision` should 1835 be added to. 1836 updates_collections: An optional list of collections that `update_op` should 1837 be added to. 1838 name: An optional variable_scope name. 1839 1840 Returns: 1841 precision: A float `Tensor` of shape `[len(thresholds)]`. 1842 update_op: An operation that increments the `true_positives`, 1843 `true_negatives`, `false_positives` and `false_negatives` variables that 1844 are used in the computation of `precision`. 1845 1846 Raises: 1847 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1848 `weights` is not `None` and its shape doesn't match `predictions`, or if 1849 either `metrics_collections` or `updates_collections` are not a list or 1850 tuple. 1851 """ 1852 return metrics.precision_at_thresholds( 1853 thresholds=thresholds, 1854 predictions=predictions, 1855 labels=labels, 1856 weights=weights, 1857 metrics_collections=metrics_collections, 1858 updates_collections=updates_collections, 1859 name=name) 1860 1861 1862@deprecated(None, 1863 'Please switch to tf.metrics.recall_at_thresholds. Note that the ' 1864 'order of the labels and predictions arguments has been switched.') 1865def streaming_recall_at_thresholds(predictions, 1866 labels, 1867 thresholds, 1868 weights=None, 1869 metrics_collections=None, 1870 updates_collections=None, 1871 name=None): 1872 """Computes various recall values for different `thresholds` on `predictions`. 1873 1874 The `streaming_recall_at_thresholds` function creates four local variables, 1875 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1876 for various values of thresholds. `recall[i]` is defined as the total weight 1877 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1878 `labels` is `True`, divided by the total weight of `True` values in `labels` 1879 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 1880 1881 For estimation of the metric over a stream of data, the function creates an 1882 `update_op` operation that updates these variables and returns the `recall`. 1883 1884 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1885 1886 Args: 1887 predictions: A floating point `Tensor` of arbitrary shape and whose values 1888 are in the range `[0, 1]`. 1889 labels: A `bool` `Tensor` whose shape matches `predictions`. 1890 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1891 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1892 must be broadcastable to `labels` (i.e., all dimensions must be either 1893 `1`, or the same as the corresponding `labels` dimension). 1894 metrics_collections: An optional list of collections that `recall` should be 1895 added to. 1896 updates_collections: An optional list of collections that `update_op` should 1897 be added to. 1898 name: An optional variable_scope name. 1899 1900 Returns: 1901 recall: A float `Tensor` of shape `[len(thresholds)]`. 1902 update_op: An operation that increments the `true_positives`, 1903 `true_negatives`, `false_positives` and `false_negatives` variables that 1904 are used in the computation of `recall`. 1905 1906 Raises: 1907 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1908 `weights` is not `None` and its shape doesn't match `predictions`, or if 1909 either `metrics_collections` or `updates_collections` are not a list or 1910 tuple. 1911 """ 1912 return metrics.recall_at_thresholds( 1913 thresholds=thresholds, 1914 predictions=predictions, 1915 labels=labels, 1916 weights=weights, 1917 metrics_collections=metrics_collections, 1918 updates_collections=updates_collections, 1919 name=name) 1920 1921 1922def streaming_false_positive_rate_at_thresholds(predictions, 1923 labels, 1924 thresholds, 1925 weights=None, 1926 metrics_collections=None, 1927 updates_collections=None, 1928 name=None): 1929 """Computes various fpr values for different `thresholds` on `predictions`. 1930 1931 The `streaming_false_positive_rate_at_thresholds` function creates two 1932 local variables, `false_positives`, `true_negatives`, for various values of 1933 thresholds. `false_positive_rate[i]` is defined as the total weight 1934 of values in `predictions` above `thresholds[i]` whose corresponding entry in 1935 `labels` is `False`, divided by the total weight of `False` values in `labels` 1936 (`false_positives[i] / (false_positives[i] + true_negatives[i])`). 1937 1938 For estimation of the metric over a stream of data, the function creates an 1939 `update_op` operation that updates these variables and returns the 1940 `false_positive_rate`. 1941 1942 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1943 1944 Args: 1945 predictions: A floating point `Tensor` of arbitrary shape and whose values 1946 are in the range `[0, 1]`. 1947 labels: A `bool` `Tensor` whose shape matches `predictions`. 1948 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1949 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 1950 must be broadcastable to `labels` (i.e., all dimensions must be either 1951 `1`, or the same as the corresponding `labels` dimension). 1952 metrics_collections: An optional list of collections that 1953 `false_positive_rate` should be added to. 1954 updates_collections: An optional list of collections that `update_op` should 1955 be added to. 1956 name: An optional variable_scope name. 1957 1958 Returns: 1959 false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`. 1960 update_op: An operation that increments the `false_positives` and 1961 `true_negatives` variables that are used in the computation of 1962 `false_positive_rate`. 1963 1964 Raises: 1965 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1966 `weights` is not `None` and its shape doesn't match `predictions`, or if 1967 either `metrics_collections` or `updates_collections` are not a list or 1968 tuple. 1969 """ 1970 with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', 1971 (predictions, labels, weights)): 1972 values, update_ops = _streaming_confusion_matrix_at_thresholds( 1973 predictions, labels, thresholds, weights, includes=('fp', 'tn')) 1974 1975 # Avoid division by zero. 1976 epsilon = _EPSILON 1977 1978 def compute_fpr(fp, tn, name): 1979 return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) 1980 1981 fpr = compute_fpr(values['fp'], values['tn'], 'value') 1982 update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') 1983 1984 if metrics_collections: 1985 ops.add_to_collections(metrics_collections, fpr) 1986 1987 if updates_collections: 1988 ops.add_to_collections(updates_collections, update_op) 1989 1990 return fpr, update_op 1991 1992 1993def streaming_false_negative_rate_at_thresholds(predictions, 1994 labels, 1995 thresholds, 1996 weights=None, 1997 metrics_collections=None, 1998 updates_collections=None, 1999 name=None): 2000 """Computes various fnr values for different `thresholds` on `predictions`. 2001 2002 The `streaming_false_negative_rate_at_thresholds` function creates two 2003 local variables, `false_negatives`, `true_positives`, for various values of 2004 thresholds. `false_negative_rate[i]` is defined as the total weight 2005 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2006 `labels` is `False`, divided by the total weight of `True` values in `labels` 2007 (`false_negatives[i] / (false_negatives[i] + true_positives[i])`). 2008 2009 For estimation of the metric over a stream of data, the function creates an 2010 `update_op` operation that updates these variables and returns the 2011 `false_positive_rate`. 2012 2013 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2014 2015 Args: 2016 predictions: A floating point `Tensor` of arbitrary shape and whose values 2017 are in the range `[0, 1]`. 2018 labels: A `bool` `Tensor` whose shape matches `predictions`. 2019 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2020 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 2021 must be broadcastable to `labels` (i.e., all dimensions must be either 2022 `1`, or the same as the corresponding `labels` dimension). 2023 metrics_collections: An optional list of collections that 2024 `false_negative_rate` should be added to. 2025 updates_collections: An optional list of collections that `update_op` should 2026 be added to. 2027 name: An optional variable_scope name. 2028 2029 Returns: 2030 false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`. 2031 update_op: An operation that increments the `false_negatives` and 2032 `true_positives` variables that are used in the computation of 2033 `false_negative_rate`. 2034 2035 Raises: 2036 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2037 `weights` is not `None` and its shape doesn't match `predictions`, or if 2038 either `metrics_collections` or `updates_collections` are not a list or 2039 tuple. 2040 """ 2041 with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', 2042 (predictions, labels, weights)): 2043 values, update_ops = _streaming_confusion_matrix_at_thresholds( 2044 predictions, labels, thresholds, weights, includes=('fn', 'tp')) 2045 2046 # Avoid division by zero. 2047 epsilon = _EPSILON 2048 2049 def compute_fnr(fn, tp, name): 2050 return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) 2051 2052 fnr = compute_fnr(values['fn'], values['tp'], 'value') 2053 update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') 2054 2055 if metrics_collections: 2056 ops.add_to_collections(metrics_collections, fnr) 2057 2058 if updates_collections: 2059 ops.add_to_collections(updates_collections, update_op) 2060 2061 return fnr, update_op 2062 2063 2064def _at_k_name(name, k=None, class_id=None): 2065 if k is not None: 2066 name = '%s_at_%d' % (name, k) 2067 else: 2068 name = '%s_at_k' % (name) 2069 if class_id is not None: 2070 name = '%s_class%d' % (name, class_id) 2071 return name 2072 2073 2074@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 2075 'and reshape labels from [batch_size] to [batch_size, 1].') 2076def streaming_recall_at_k(predictions, 2077 labels, 2078 k, 2079 weights=None, 2080 metrics_collections=None, 2081 updates_collections=None, 2082 name=None): 2083 """Computes the recall@k of the predictions with respect to dense labels. 2084 2085 The `streaming_recall_at_k` function creates two local variables, `total` and 2086 `count`, that are used to compute the recall@k frequency. This frequency is 2087 ultimately returned as `recall_at_<k>`: an idempotent operation that simply 2088 divides `total` by `count`. 2089 2090 For estimation of the metric over a stream of data, the function creates an 2091 `update_op` operation that updates these variables and returns the 2092 `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with 2093 shape [batch_size] whose elements indicate whether or not the corresponding 2094 label is in the top `k` `predictions`. Then `update_op` increments `total` 2095 with the reduced sum of `weights` where `in_top_k` is `True`, and it 2096 increments `count` with the reduced sum of `weights`. 2097 2098 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2099 2100 Args: 2101 predictions: A float `Tensor` of dimension [batch_size, num_classes]. 2102 labels: A `Tensor` of dimension [batch_size] whose type is in `int32`, 2103 `int64`. 2104 k: The number of top elements to look at for computing recall. 2105 weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and 2106 must be broadcastable to `labels` (i.e., all dimensions must be either 2107 `1`, or the same as the corresponding `labels` dimension). 2108 metrics_collections: An optional list of collections that `recall_at_k` 2109 should be added to. 2110 updates_collections: An optional list of collections `update_op` should be 2111 added to. 2112 name: An optional variable_scope name. 2113 2114 Returns: 2115 recall_at_k: A `Tensor` representing the recall@k, the fraction of labels 2116 which fall into the top `k` predictions. 2117 update_op: An operation that increments the `total` and `count` variables 2118 appropriately and whose value matches `recall_at_k`. 2119 2120 Raises: 2121 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2122 `weights` is not `None` and its shape doesn't match `predictions`, or if 2123 either `metrics_collections` or `updates_collections` are not a list or 2124 tuple. 2125 """ 2126 in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) 2127 return streaming_mean(in_top_k, weights, metrics_collections, 2128 updates_collections, name or _at_k_name('recall', k)) 2129 2130 2131# TODO(ptucker): Validate range of values in labels? 2132def streaming_sparse_recall_at_k(predictions, 2133 labels, 2134 k, 2135 class_id=None, 2136 weights=None, 2137 metrics_collections=None, 2138 updates_collections=None, 2139 name=None): 2140 """Computes recall@k of the predictions with respect to sparse labels. 2141 2142 If `class_id` is not specified, we'll calculate recall as the ratio of true 2143 positives (i.e., correct predictions, items in the top `k` highest 2144 `predictions` that are found in the corresponding row in `labels`) to 2145 actual positives (the full `labels` row). 2146 If `class_id` is specified, we calculate recall by considering only the rows 2147 in the batch for which `class_id` is in `labels`, and computing the 2148 fraction of them for which `class_id` is in the corresponding row in 2149 `labels`. 2150 2151 `streaming_sparse_recall_at_k` creates two local variables, 2152 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2153 the recall_at_k frequency. This frequency is ultimately returned as 2154 `recall_at_<k>`: an idempotent operation that simply divides 2155 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2156 `false_negative_at_<k>`). 2157 2158 For estimation of the metric over a stream of data, the function creates an 2159 `update_op` operation that updates these variables and returns the 2160 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2161 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2162 `labels` calculate the true positives and false negatives weighted by 2163 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2164 `false_negative_at_<k>` using these values. 2165 2166 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2167 2168 Args: 2169 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2170 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2171 The final dimension contains the logit values for each class. [D1, ... DN] 2172 must match `labels`. 2173 labels: `int64` `Tensor` or `SparseTensor` with shape 2174 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2175 target classes for the associated prediction. Commonly, N=1 and `labels` 2176 has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. 2177 Values should be in range [0, num_classes), where num_classes is the last 2178 dimension of `predictions`. Values outside this range always count 2179 towards `false_negative_at_<k>`. 2180 k: Integer, k for @k metric. 2181 class_id: Integer class ID for which we want binary metrics. This should be 2182 in range [0, num_classes), where num_classes is the last dimension of 2183 `predictions`. If class_id is outside this range, the method returns NAN. 2184 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2185 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2186 dimensions must be either `1`, or the same as the corresponding `labels` 2187 dimension). 2188 metrics_collections: An optional list of collections that values should 2189 be added to. 2190 updates_collections: An optional list of collections that updates should 2191 be added to. 2192 name: Name of new update operation, and namespace for other dependent ops. 2193 2194 Returns: 2195 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2196 by the sum of `true_positives` and `false_negatives`. 2197 update_op: `Operation` that increments `true_positives` and 2198 `false_negatives` variables appropriately, and whose value matches 2199 `recall`. 2200 2201 Raises: 2202 ValueError: If `weights` is not `None` and its shape doesn't match 2203 `predictions`, or if either `metrics_collections` or `updates_collections` 2204 are not a list or tuple. 2205 """ 2206 return metrics.recall_at_k( 2207 k=k, 2208 class_id=class_id, 2209 predictions=predictions, 2210 labels=labels, 2211 weights=weights, 2212 metrics_collections=metrics_collections, 2213 updates_collections=updates_collections, 2214 name=name) 2215 2216 2217# TODO(ptucker): Validate range of values in labels? 2218def streaming_sparse_precision_at_k(predictions, 2219 labels, 2220 k, 2221 class_id=None, 2222 weights=None, 2223 metrics_collections=None, 2224 updates_collections=None, 2225 name=None): 2226 """Computes precision@k of the predictions with respect to sparse labels. 2227 2228 If `class_id` is not specified, we calculate precision as the ratio of true 2229 positives (i.e., correct predictions, items in the top `k` highest 2230 `predictions` that are found in the corresponding row in `labels`) to 2231 positives (all top `k` `predictions`). 2232 If `class_id` is specified, we calculate precision by considering only the 2233 rows in the batch for which `class_id` is in the top `k` highest 2234 `predictions`, and computing the fraction of them for which `class_id` is 2235 in the corresponding row in `labels`. 2236 2237 We expect precision to decrease as `k` increases. 2238 2239 `streaming_sparse_precision_at_k` creates two local variables, 2240 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 2241 the precision@k frequency. This frequency is ultimately returned as 2242 `precision_at_<k>`: an idempotent operation that simply divides 2243 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2244 `false_positive_at_<k>`). 2245 2246 For estimation of the metric over a stream of data, the function creates an 2247 `update_op` operation that updates these variables and returns the 2248 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2249 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2250 `labels` calculate the true positives and false positives weighted by 2251 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2252 `false_positive_at_<k>` using these values. 2253 2254 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2255 2256 Args: 2257 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2258 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2259 The final dimension contains the logit values for each class. [D1, ... DN] 2260 must match `labels`. 2261 labels: `int64` `Tensor` or `SparseTensor` with shape 2262 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2263 target classes for the associated prediction. Commonly, N=1 and `labels` 2264 has shape [batch_size, num_labels]. [D1, ... DN] must match 2265 `predictions`. Values should be in range [0, num_classes), where 2266 num_classes is the last dimension of `predictions`. Values outside this 2267 range are ignored. 2268 k: Integer, k for @k metric. 2269 class_id: Integer class ID for which we want binary metrics. This should be 2270 in range [0, num_classes], where num_classes is the last dimension of 2271 `predictions`. If `class_id` is outside this range, the method returns 2272 NAN. 2273 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2274 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2275 dimensions must be either `1`, or the same as the corresponding `labels` 2276 dimension). 2277 metrics_collections: An optional list of collections that values should 2278 be added to. 2279 updates_collections: An optional list of collections that updates should 2280 be added to. 2281 name: Name of new update operation, and namespace for other dependent ops. 2282 2283 Returns: 2284 precision: Scalar `float64` `Tensor` with the value of `true_positives` 2285 divided by the sum of `true_positives` and `false_positives`. 2286 update_op: `Operation` that increments `true_positives` and 2287 `false_positives` variables appropriately, and whose value matches 2288 `precision`. 2289 2290 Raises: 2291 ValueError: If `weights` is not `None` and its shape doesn't match 2292 `predictions`, or if either `metrics_collections` or `updates_collections` 2293 are not a list or tuple. 2294 """ 2295 return metrics.precision_at_k( 2296 k=k, 2297 class_id=class_id, 2298 predictions=predictions, 2299 labels=labels, 2300 weights=weights, 2301 metrics_collections=metrics_collections, 2302 updates_collections=updates_collections, 2303 name=name) 2304 2305 2306# TODO(ptucker): Validate range of values in labels? 2307def streaming_sparse_precision_at_top_k(top_k_predictions, 2308 labels, 2309 class_id=None, 2310 weights=None, 2311 metrics_collections=None, 2312 updates_collections=None, 2313 name=None): 2314 """Computes precision@k of top-k predictions with respect to sparse labels. 2315 2316 If `class_id` is not specified, we calculate precision as the ratio of 2317 true positives (i.e., correct predictions, items in `top_k_predictions` 2318 that are found in the corresponding row in `labels`) to positives (all 2319 `top_k_predictions`). 2320 If `class_id` is specified, we calculate precision by considering only the 2321 rows in the batch for which `class_id` is in the top `k` highest 2322 `predictions`, and computing the fraction of them for which `class_id` is 2323 in the corresponding row in `labels`. 2324 2325 We expect precision to decrease as `k` increases. 2326 2327 `streaming_sparse_precision_at_top_k` creates two local variables, 2328 `true_positive_at_k` and `false_positive_at_k`, that are used to compute 2329 the precision@k frequency. This frequency is ultimately returned as 2330 `precision_at_k`: an idempotent operation that simply divides 2331 `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`). 2332 2333 For estimation of the metric over a stream of data, the function creates an 2334 `update_op` operation that updates these variables and returns the 2335 `precision_at_k`. Internally, set operations applied to `top_k_predictions` 2336 and `labels` calculate the true positives and false positives weighted by 2337 `weights`. Then `update_op` increments `true_positive_at_k` and 2338 `false_positive_at_k` using these values. 2339 2340 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2341 2342 Args: 2343 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 2344 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 2345 The final dimension contains the indices of top-k labels. [D1, ... DN] 2346 must match `labels`. 2347 labels: `int64` `Tensor` or `SparseTensor` with shape 2348 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2349 target classes for the associated prediction. Commonly, N=1 and `labels` 2350 has shape [batch_size, num_labels]. [D1, ... DN] must match 2351 `top_k_predictions`. Values should be in range [0, num_classes), where 2352 num_classes is the last dimension of `predictions`. Values outside this 2353 range are ignored. 2354 class_id: Integer class ID for which we want binary metrics. This should be 2355 in range [0, num_classes), where num_classes is the last dimension of 2356 `predictions`. If `class_id` is outside this range, the method returns 2357 NAN. 2358 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2359 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2360 dimensions must be either `1`, or the same as the corresponding `labels` 2361 dimension). 2362 metrics_collections: An optional list of collections that values should 2363 be added to. 2364 updates_collections: An optional list of collections that updates should 2365 be added to. 2366 name: Name of new update operation, and namespace for other dependent ops. 2367 2368 Returns: 2369 precision: Scalar `float64` `Tensor` with the value of `true_positives` 2370 divided by the sum of `true_positives` and `false_positives`. 2371 update_op: `Operation` that increments `true_positives` and 2372 `false_positives` variables appropriately, and whose value matches 2373 `precision`. 2374 2375 Raises: 2376 ValueError: If `weights` is not `None` and its shape doesn't match 2377 `predictions`, or if either `metrics_collections` or `updates_collections` 2378 are not a list or tuple. 2379 ValueError: If `top_k_predictions` has rank < 2. 2380 """ 2381 default_name = _at_k_name('precision', class_id=class_id) 2382 with ops.name_scope(name, default_name, 2383 (top_k_predictions, labels, weights)) as name_scope: 2384 return metrics_impl.precision_at_top_k( 2385 labels=labels, 2386 predictions_idx=top_k_predictions, 2387 class_id=class_id, 2388 weights=weights, 2389 metrics_collections=metrics_collections, 2390 updates_collections=updates_collections, 2391 name=name_scope) 2392 2393 2394def sparse_recall_at_top_k(labels, 2395 top_k_predictions, 2396 class_id=None, 2397 weights=None, 2398 metrics_collections=None, 2399 updates_collections=None, 2400 name=None): 2401 """Computes recall@k of top-k predictions with respect to sparse labels. 2402 2403 If `class_id` is specified, we calculate recall by considering only the 2404 entries in the batch for which `class_id` is in the label, and computing 2405 the fraction of them for which `class_id` is in the top-k `predictions`. 2406 If `class_id` is not specified, we'll calculate recall as how often on 2407 average a class among the labels of a batch entry is in the top-k 2408 `predictions`. 2409 2410 `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>` 2411 and `false_negative_at_<k>`, that are used to compute the recall_at_k 2412 frequency. This frequency is ultimately returned as `recall_at_<k>`: an 2413 idempotent operation that simply divides `true_positive_at_<k>` by total 2414 (`true_positive_at_<k>` + `false_negative_at_<k>`). 2415 2416 For estimation of the metric over a stream of data, the function creates an 2417 `update_op` operation that updates these variables and returns the 2418 `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the 2419 true positives and false negatives weighted by `weights`. Then `update_op` 2420 increments `true_positive_at_<k>` and `false_negative_at_<k>` using these 2421 values. 2422 2423 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2424 2425 Args: 2426 labels: `int64` `Tensor` or `SparseTensor` with shape 2427 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2428 target classes for the associated prediction. Commonly, N=1 and `labels` 2429 has shape [batch_size, num_labels]. [D1, ... DN] must match 2430 `top_k_predictions`. Values should be in range [0, num_classes), where 2431 num_classes is the last dimension of `predictions`. Values outside this 2432 range always count towards `false_negative_at_<k>`. 2433 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where 2434 N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. 2435 The final dimension contains the indices of top-k labels. [D1, ... DN] 2436 must match `labels`. 2437 class_id: Integer class ID for which we want binary metrics. This should be 2438 in range [0, num_classes), where num_classes is the last dimension of 2439 `predictions`. If class_id is outside this range, the method returns NAN. 2440 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2441 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2442 dimensions must be either `1`, or the same as the corresponding `labels` 2443 dimension). 2444 metrics_collections: An optional list of collections that values should 2445 be added to. 2446 updates_collections: An optional list of collections that updates should 2447 be added to. 2448 name: Name of new update operation, and namespace for other dependent ops. 2449 2450 Returns: 2451 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2452 by the sum of `true_positives` and `false_negatives`. 2453 update_op: `Operation` that increments `true_positives` and 2454 `false_negatives` variables appropriately, and whose value matches 2455 `recall`. 2456 2457 Raises: 2458 ValueError: If `weights` is not `None` and its shape doesn't match 2459 `predictions`, or if either `metrics_collections` or `updates_collections` 2460 are not a list or tuple. 2461 """ 2462 default_name = _at_k_name('recall', class_id=class_id) 2463 with ops.name_scope(name, default_name, 2464 (top_k_predictions, labels, weights)) as name_scope: 2465 return metrics_impl.recall_at_top_k( 2466 labels=labels, 2467 predictions_idx=top_k_predictions, 2468 class_id=class_id, 2469 weights=weights, 2470 metrics_collections=metrics_collections, 2471 updates_collections=updates_collections, 2472 name=name_scope) 2473 2474 2475def _compute_recall_at_precision(tp, fp, fn, precision, name): 2476 """Helper function to compute recall at a given `precision`. 2477 2478 Args: 2479 tp: The number of true positives. 2480 fp: The number of false positives. 2481 fn: The number of false negatives. 2482 precision: The precision for which the recall will be calculated. 2483 name: An optional variable_scope name. 2484 2485 Returns: 2486 The recall at a the given `precision`. 2487 """ 2488 precisions = math_ops.div(tp, tp + fp + _EPSILON) 2489 tf_index = math_ops.argmin( 2490 math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) 2491 2492 # Now, we have the implicit threshold, so compute the recall: 2493 return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, 2494 name) 2495 2496 2497def recall_at_precision(labels, 2498 predictions, 2499 precision, 2500 weights=None, 2501 num_thresholds=200, 2502 metrics_collections=None, 2503 updates_collections=None, 2504 name=None): 2505 """Computes `recall` at `precision`. 2506 2507 The `recall_at_precision` function creates four local variables, 2508 `tp` (true positives), `fp` (false positives) and `fn` (false negatives) 2509 that are used to compute the `recall` at the given `precision` value. The 2510 threshold for the given `precision` value is computed and used to evaluate the 2511 corresponding `recall`. 2512 2513 For estimation of the metric over a stream of data, the function creates an 2514 `update_op` operation that updates these variables and returns the 2515 `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the 2516 weight of each case found in the `predictions` and `labels`. 2517 2518 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2519 2520 Args: 2521 labels: The ground truth values, a `Tensor` whose dimensions must match 2522 `predictions`. Will be cast to `bool`. 2523 predictions: A floating point `Tensor` of arbitrary shape and whose values 2524 are in the range `[0, 1]`. 2525 precision: A scalar value in range `[0, 1]`. 2526 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2527 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2528 be either `1`, or the same as the corresponding `labels` dimension). 2529 num_thresholds: The number of thresholds to use for matching the given 2530 `precision`. 2531 metrics_collections: An optional list of collections that `recall` 2532 should be added to. 2533 updates_collections: An optional list of collections that `update_op` should 2534 be added to. 2535 name: An optional variable_scope name. 2536 2537 Returns: 2538 recall: A scalar `Tensor` representing the recall at the given 2539 `precision` value. 2540 update_op: An operation that increments the `tp`, `fp` and `fn` 2541 variables appropriately and whose value matches `recall`. 2542 2543 Raises: 2544 ValueError: If `predictions` and `labels` have mismatched shapes, if 2545 `weights` is not `None` and its shape doesn't match `predictions`, or if 2546 `precision` is not between 0 and 1, or if either `metrics_collections` 2547 or `updates_collections` are not a list or tuple. 2548 2549 """ 2550 if not 0 <= precision <= 1: 2551 raise ValueError('`precision` must be in the range [0, 1].') 2552 2553 with variable_scope.variable_scope(name, 'recall_at_precision', 2554 (predictions, labels, weights)): 2555 thresholds = [ 2556 i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1) 2557 ] 2558 thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON] 2559 2560 values, update_ops = _streaming_confusion_matrix_at_thresholds( 2561 predictions, labels, thresholds, weights) 2562 2563 recall = _compute_recall_at_precision(values['tp'], values['fp'], 2564 values['fn'], precision, 'value') 2565 update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], 2566 update_ops['fn'], precision, 2567 'update_op') 2568 2569 if metrics_collections: 2570 ops.add_to_collections(metrics_collections, recall) 2571 2572 if updates_collections: 2573 ops.add_to_collections(updates_collections, update_op) 2574 2575 return recall, update_op 2576 2577 2578def streaming_sparse_average_precision_at_k(predictions, 2579 labels, 2580 k, 2581 weights=None, 2582 metrics_collections=None, 2583 updates_collections=None, 2584 name=None): 2585 """Computes average precision@k of predictions with respect to sparse labels. 2586 2587 See `sparse_average_precision_at_k` for details on formula. `weights` are 2588 applied to the result of `sparse_average_precision_at_k` 2589 2590 `streaming_sparse_average_precision_at_k` creates two local variables, 2591 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2592 are used to compute the frequency. This frequency is ultimately returned as 2593 `average_precision_at_<k>`: an idempotent operation that simply divides 2594 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2595 2596 For estimation of the metric over a stream of data, the function creates an 2597 `update_op` operation that updates these variables and returns the 2598 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2599 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2600 `labels` calculate the true positives and false positives weighted by 2601 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2602 `false_positive_at_<k>` using these values. 2603 2604 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2605 2606 Args: 2607 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2608 N >= 1. Commonly, N=1 and `predictions` has shape 2609 [batch size, num_classes]. The final dimension contains the logit values 2610 for each class. [D1, ... DN] must match `labels`. 2611 labels: `int64` `Tensor` or `SparseTensor` with shape 2612 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2613 target classes for the associated prediction. Commonly, N=1 and `labels` 2614 has shape [batch_size, num_labels]. [D1, ... DN] must match 2615 `predictions_`. Values should be in range [0, num_classes), where 2616 num_classes is the last dimension of `predictions`. Values outside this 2617 range are ignored. 2618 k: Integer, k for @k metric. This will calculate an average precision for 2619 range `[1,k]`, as documented above. 2620 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2621 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2622 dimensions must be either `1`, or the same as the corresponding `labels` 2623 dimension). 2624 metrics_collections: An optional list of collections that values should 2625 be added to. 2626 updates_collections: An optional list of collections that updates should 2627 be added to. 2628 name: Name of new update operation, and namespace for other dependent ops. 2629 2630 Returns: 2631 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2632 precision values. 2633 update: `Operation` that increments variables appropriately, and whose 2634 value matches `metric`. 2635 """ 2636 return metrics.average_precision_at_k( 2637 k=k, 2638 predictions=predictions, 2639 labels=labels, 2640 weights=weights, 2641 metrics_collections=metrics_collections, 2642 updates_collections=updates_collections, 2643 name=name) 2644 2645 2646def streaming_sparse_average_precision_at_top_k(top_k_predictions, 2647 labels, 2648 weights=None, 2649 metrics_collections=None, 2650 updates_collections=None, 2651 name=None): 2652 """Computes average precision@k of predictions with respect to sparse labels. 2653 2654 `streaming_sparse_average_precision_at_top_k` creates two local variables, 2655 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2656 are used to compute the frequency. This frequency is ultimately returned as 2657 `average_precision_at_<k>`: an idempotent operation that simply divides 2658 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2659 2660 For estimation of the metric over a stream of data, the function creates an 2661 `update_op` operation that updates these variables and returns the 2662 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 2663 the true positives and false positives weighted by `weights`. Then `update_op` 2664 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 2665 values. 2666 2667 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2668 2669 Args: 2670 top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2671 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 2672 dimension must be set and contains the top `k` predicted class indices. 2673 [D1, ... DN] must match `labels`. Values should be in range 2674 [0, num_classes). 2675 labels: `int64` `Tensor` or `SparseTensor` with shape 2676 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2677 num_labels=1. N >= 1 and num_labels is the number of target classes for 2678 the associated prediction. Commonly, N=1 and `labels` has shape 2679 [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. 2680 Values should be in range [0, num_classes). 2681 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2682 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2683 dimensions must be either `1`, or the same as the corresponding `labels` 2684 dimension). 2685 metrics_collections: An optional list of collections that values should 2686 be added to. 2687 updates_collections: An optional list of collections that updates should 2688 be added to. 2689 name: Name of new update operation, and namespace for other dependent ops. 2690 2691 Returns: 2692 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2693 precision values. 2694 update: `Operation` that increments variables appropriately, and whose 2695 value matches `metric`. 2696 2697 Raises: 2698 ValueError: if the last dimension of top_k_predictions is not set. 2699 """ 2700 return metrics_impl._streaming_sparse_average_precision_at_top_k( # pylint: disable=protected-access 2701 predictions_idx=top_k_predictions, 2702 labels=labels, 2703 weights=weights, 2704 metrics_collections=metrics_collections, 2705 updates_collections=updates_collections, 2706 name=name) 2707 2708 2709@deprecated(None, 'Please switch to tf.metrics.mean.') 2710def streaming_mean_absolute_error(predictions, 2711 labels, 2712 weights=None, 2713 metrics_collections=None, 2714 updates_collections=None, 2715 name=None): 2716 """Computes the mean absolute error between the labels and predictions. 2717 2718 The `streaming_mean_absolute_error` function creates two local variables, 2719 `total` and `count` that are used to compute the mean absolute error. This 2720 average is weighted by `weights`, and it is ultimately returned as 2721 `mean_absolute_error`: an idempotent operation that simply divides `total` by 2722 `count`. 2723 2724 For estimation of the metric over a stream of data, the function creates an 2725 `update_op` operation that updates these variables and returns the 2726 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 2727 absolute value of the differences between `predictions` and `labels`. Then 2728 `update_op` increments `total` with the reduced sum of the product of 2729 `weights` and `absolute_errors`, and it increments `count` with the reduced 2730 sum of `weights` 2731 2732 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2733 2734 Args: 2735 predictions: A `Tensor` of arbitrary shape. 2736 labels: A `Tensor` of the same shape as `predictions`. 2737 weights: Optional `Tensor` indicating the frequency with which an example is 2738 sampled. Rank must be 0, or the same rank as `labels`, and must be 2739 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2740 the same as the corresponding `labels` dimension). 2741 metrics_collections: An optional list of collections that 2742 `mean_absolute_error` should be added to. 2743 updates_collections: An optional list of collections that `update_op` should 2744 be added to. 2745 name: An optional variable_scope name. 2746 2747 Returns: 2748 mean_absolute_error: A `Tensor` representing the current mean, the value of 2749 `total` divided by `count`. 2750 update_op: An operation that increments the `total` and `count` variables 2751 appropriately and whose value matches `mean_absolute_error`. 2752 2753 Raises: 2754 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2755 `weights` is not `None` and its shape doesn't match `predictions`, or if 2756 either `metrics_collections` or `updates_collections` are not a list or 2757 tuple. 2758 """ 2759 return metrics.mean_absolute_error( 2760 predictions=predictions, 2761 labels=labels, 2762 weights=weights, 2763 metrics_collections=metrics_collections, 2764 updates_collections=updates_collections, 2765 name=name) 2766 2767 2768def streaming_mean_relative_error(predictions, 2769 labels, 2770 normalizer, 2771 weights=None, 2772 metrics_collections=None, 2773 updates_collections=None, 2774 name=None): 2775 """Computes the mean relative error by normalizing with the given values. 2776 2777 The `streaming_mean_relative_error` function creates two local variables, 2778 `total` and `count` that are used to compute the mean relative absolute error. 2779 This average is weighted by `weights`, and it is ultimately returned as 2780 `mean_relative_error`: an idempotent operation that simply divides `total` by 2781 `count`. 2782 2783 For estimation of the metric over a stream of data, the function creates an 2784 `update_op` operation that updates these variables and returns the 2785 `mean_reative_error`. Internally, a `relative_errors` operation divides the 2786 absolute value of the differences between `predictions` and `labels` by the 2787 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 2788 product of `weights` and `relative_errors`, and it increments `count` with the 2789 reduced sum of `weights`. 2790 2791 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2792 2793 Args: 2794 predictions: A `Tensor` of arbitrary shape. 2795 labels: A `Tensor` of the same shape as `predictions`. 2796 normalizer: A `Tensor` of the same shape as `predictions`. 2797 weights: Optional `Tensor` indicating the frequency with which an example is 2798 sampled. Rank must be 0, or the same rank as `labels`, and must be 2799 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2800 the same as the corresponding `labels` dimension). 2801 metrics_collections: An optional list of collections that 2802 `mean_relative_error` should be added to. 2803 updates_collections: An optional list of collections that `update_op` should 2804 be added to. 2805 name: An optional variable_scope name. 2806 2807 Returns: 2808 mean_relative_error: A `Tensor` representing the current mean, the value of 2809 `total` divided by `count`. 2810 update_op: An operation that increments the `total` and `count` variables 2811 appropriately and whose value matches `mean_relative_error`. 2812 2813 Raises: 2814 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2815 `weights` is not `None` and its shape doesn't match `predictions`, or if 2816 either `metrics_collections` or `updates_collections` are not a list or 2817 tuple. 2818 """ 2819 return metrics.mean_relative_error( 2820 normalizer=normalizer, 2821 predictions=predictions, 2822 labels=labels, 2823 weights=weights, 2824 metrics_collections=metrics_collections, 2825 updates_collections=updates_collections, 2826 name=name) 2827 2828 2829def streaming_mean_squared_error(predictions, 2830 labels, 2831 weights=None, 2832 metrics_collections=None, 2833 updates_collections=None, 2834 name=None): 2835 """Computes the mean squared error between the labels and predictions. 2836 2837 The `streaming_mean_squared_error` function creates two local variables, 2838 `total` and `count` that are used to compute the mean squared error. 2839 This average is weighted by `weights`, and it is ultimately returned as 2840 `mean_squared_error`: an idempotent operation that simply divides `total` by 2841 `count`. 2842 2843 For estimation of the metric over a stream of data, the function creates an 2844 `update_op` operation that updates these variables and returns the 2845 `mean_squared_error`. Internally, a `squared_error` operation computes the 2846 element-wise square of the difference between `predictions` and `labels`. Then 2847 `update_op` increments `total` with the reduced sum of the product of 2848 `weights` and `squared_error`, and it increments `count` with the reduced sum 2849 of `weights`. 2850 2851 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2852 2853 Args: 2854 predictions: A `Tensor` of arbitrary shape. 2855 labels: A `Tensor` of the same shape as `predictions`. 2856 weights: Optional `Tensor` indicating the frequency with which an example is 2857 sampled. Rank must be 0, or the same rank as `labels`, and must be 2858 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2859 the same as the corresponding `labels` dimension). 2860 metrics_collections: An optional list of collections that 2861 `mean_squared_error` should be added to. 2862 updates_collections: An optional list of collections that `update_op` should 2863 be added to. 2864 name: An optional variable_scope name. 2865 2866 Returns: 2867 mean_squared_error: A `Tensor` representing the current mean, the value of 2868 `total` divided by `count`. 2869 update_op: An operation that increments the `total` and `count` variables 2870 appropriately and whose value matches `mean_squared_error`. 2871 2872 Raises: 2873 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2874 `weights` is not `None` and its shape doesn't match `predictions`, or if 2875 either `metrics_collections` or `updates_collections` are not a list or 2876 tuple. 2877 """ 2878 return metrics.mean_squared_error( 2879 predictions=predictions, 2880 labels=labels, 2881 weights=weights, 2882 metrics_collections=metrics_collections, 2883 updates_collections=updates_collections, 2884 name=name) 2885 2886 2887def streaming_root_mean_squared_error(predictions, 2888 labels, 2889 weights=None, 2890 metrics_collections=None, 2891 updates_collections=None, 2892 name=None): 2893 """Computes the root mean squared error between the labels and predictions. 2894 2895 The `streaming_root_mean_squared_error` function creates two local variables, 2896 `total` and `count` that are used to compute the root mean squared error. 2897 This average is weighted by `weights`, and it is ultimately returned as 2898 `root_mean_squared_error`: an idempotent operation that takes the square root 2899 of the division of `total` by `count`. 2900 2901 For estimation of the metric over a stream of data, the function creates an 2902 `update_op` operation that updates these variables and returns the 2903 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2904 the element-wise square of the difference between `predictions` and `labels`. 2905 Then `update_op` increments `total` with the reduced sum of the product of 2906 `weights` and `squared_error`, and it increments `count` with the reduced sum 2907 of `weights`. 2908 2909 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2910 2911 Args: 2912 predictions: A `Tensor` of arbitrary shape. 2913 labels: A `Tensor` of the same shape as `predictions`. 2914 weights: Optional `Tensor` indicating the frequency with which an example is 2915 sampled. Rank must be 0, or the same rank as `labels`, and must be 2916 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2917 the same as the corresponding `labels` dimension). 2918 metrics_collections: An optional list of collections that 2919 `root_mean_squared_error` should be added to. 2920 updates_collections: An optional list of collections that `update_op` should 2921 be added to. 2922 name: An optional variable_scope name. 2923 2924 Returns: 2925 root_mean_squared_error: A `Tensor` representing the current mean, the value 2926 of `total` divided by `count`. 2927 update_op: An operation that increments the `total` and `count` variables 2928 appropriately and whose value matches `root_mean_squared_error`. 2929 2930 Raises: 2931 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2932 `weights` is not `None` and its shape doesn't match `predictions`, or if 2933 either `metrics_collections` or `updates_collections` are not a list or 2934 tuple. 2935 """ 2936 return metrics.root_mean_squared_error( 2937 predictions=predictions, 2938 labels=labels, 2939 weights=weights, 2940 metrics_collections=metrics_collections, 2941 updates_collections=updates_collections, 2942 name=name) 2943 2944 2945def streaming_covariance(predictions, 2946 labels, 2947 weights=None, 2948 metrics_collections=None, 2949 updates_collections=None, 2950 name=None): 2951 """Computes the unbiased sample covariance between `predictions` and `labels`. 2952 2953 The `streaming_covariance` function creates four local variables, 2954 `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to 2955 compute the sample covariance between predictions and labels across multiple 2956 batches of data. The covariance is ultimately returned as an idempotent 2957 operation that simply divides `comoment` by `count` - 1. We use `count` - 1 2958 in order to get an unbiased estimate. 2959 2960 The algorithm used for this online computation is described in 2961 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 2962 Specifically, the formula used to combine two sample comoments is 2963 `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB` 2964 The comoment for a single batch of data is simply 2965 `sum((x - E[x]) * (y - E[y]))`, optionally weighted. 2966 2967 If `weights` is not None, then it is used to compute weighted comoments, 2968 means, and count. NOTE: these weights are treated as "frequency weights", as 2969 opposed to "reliability weights". See discussion of the difference on 2970 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 2971 2972 To facilitate the computation of covariance across multiple batches of data, 2973 the function creates an `update_op` operation, which updates underlying 2974 variables and returns the updated covariance. 2975 2976 Args: 2977 predictions: A `Tensor` of arbitrary size. 2978 labels: A `Tensor` of the same size as `predictions`. 2979 weights: Optional `Tensor` indicating the frequency with which an example is 2980 sampled. Rank must be 0, or the same rank as `labels`, and must be 2981 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 2982 the same as the corresponding `labels` dimension). 2983 metrics_collections: An optional list of collections that the metric 2984 value variable should be added to. 2985 updates_collections: An optional list of collections that the metric update 2986 ops should be added to. 2987 name: An optional variable_scope name. 2988 2989 Returns: 2990 covariance: A `Tensor` representing the current unbiased sample covariance, 2991 `comoment` / (`count` - 1). 2992 update_op: An operation that updates the local variables appropriately. 2993 2994 Raises: 2995 ValueError: If labels and predictions are of different sizes or if either 2996 `metrics_collections` or `updates_collections` are not a list or tuple. 2997 """ 2998 with variable_scope.variable_scope(name, 'covariance', 2999 (predictions, labels, weights)): 3000 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3001 predictions, labels, weights) 3002 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3003 count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') 3004 mean_prediction = metrics_impl.metric_variable( 3005 [], dtypes.float32, name='mean_prediction') 3006 mean_label = metrics_impl.metric_variable( 3007 [], dtypes.float32, name='mean_label') 3008 comoment = metrics_impl.metric_variable( # C_A in update equation 3009 [], dtypes.float32, name='comoment') 3010 3011 if weights is None: 3012 batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn 3013 weighted_predictions = predictions 3014 weighted_labels = labels 3015 else: 3016 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 3017 batch_count = math_ops.reduce_sum(weights) # n_B in eqn 3018 weighted_predictions = math_ops.multiply(predictions, weights) 3019 weighted_labels = math_ops.multiply(labels, weights) 3020 3021 update_count = state_ops.assign_add(count_, batch_count) # n_AB in eqn 3022 prev_count = update_count - batch_count # n_A in update equation 3023 3024 # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) 3025 # batch_mean_prediction is E[x_B] in the update equation 3026 batch_mean_prediction = _safe_div( 3027 math_ops.reduce_sum(weighted_predictions), batch_count, 3028 'batch_mean_prediction') 3029 delta_mean_prediction = _safe_div( 3030 (batch_mean_prediction - mean_prediction) * batch_count, update_count, 3031 'delta_mean_prediction') 3032 update_mean_prediction = state_ops.assign_add(mean_prediction, 3033 delta_mean_prediction) 3034 # prev_mean_prediction is E[x_A] in the update equation 3035 prev_mean_prediction = update_mean_prediction - delta_mean_prediction 3036 3037 # batch_mean_label is E[y_B] in the update equation 3038 batch_mean_label = _safe_div( 3039 math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label') 3040 delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count, 3041 update_count, 'delta_mean_label') 3042 update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) 3043 # prev_mean_label is E[y_A] in the update equation 3044 prev_mean_label = update_mean_label - delta_mean_label 3045 3046 unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * 3047 (labels - batch_mean_label)) 3048 # batch_comoment is C_B in the update equation 3049 if weights is None: 3050 batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) 3051 else: 3052 batch_comoment = math_ops.reduce_sum( 3053 unweighted_batch_coresiduals * weights) 3054 3055 # View delta_comoment as = C_AB - C_A in the update equation above. 3056 # Since C_A is stored in a var, by how much do we need to increment that var 3057 # to make the var = C_AB? 3058 delta_comoment = ( 3059 batch_comoment + (prev_mean_prediction - batch_mean_prediction) * 3060 (prev_mean_label - batch_mean_label) * 3061 (prev_count * batch_count / update_count)) 3062 update_comoment = state_ops.assign_add(comoment, delta_comoment) 3063 3064 covariance = array_ops.where( 3065 math_ops.less_equal(count_, 1.), 3066 float('nan'), 3067 math_ops.truediv(comoment, count_ - 1), 3068 name='covariance') 3069 with ops.control_dependencies([update_comoment]): 3070 update_op = array_ops.where( 3071 math_ops.less_equal(count_, 1.), 3072 float('nan'), 3073 math_ops.truediv(comoment, count_ - 1), 3074 name='update_op') 3075 3076 if metrics_collections: 3077 ops.add_to_collections(metrics_collections, covariance) 3078 3079 if updates_collections: 3080 ops.add_to_collections(updates_collections, update_op) 3081 3082 return covariance, update_op 3083 3084 3085def streaming_pearson_correlation(predictions, 3086 labels, 3087 weights=None, 3088 metrics_collections=None, 3089 updates_collections=None, 3090 name=None): 3091 """Computes Pearson correlation coefficient between `predictions`, `labels`. 3092 3093 The `streaming_pearson_correlation` function delegates to 3094 `streaming_covariance` the tracking of three [co]variances: 3095 3096 - `streaming_covariance(predictions, labels)`, i.e. covariance 3097 - `streaming_covariance(predictions, predictions)`, i.e. variance 3098 - `streaming_covariance(labels, labels)`, i.e. variance 3099 3100 The product-moment correlation ultimately returned is an idempotent operation 3101 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To 3102 facilitate correlation computation across multiple batches, the function 3103 groups the `update_op`s of the underlying streaming_covariance and returns an 3104 `update_op`. 3105 3106 If `weights` is not None, then it is used to compute a weighted correlation. 3107 NOTE: these weights are treated as "frequency weights", as opposed to 3108 "reliability weights". See discussion of the difference on 3109 https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance 3110 3111 Args: 3112 predictions: A `Tensor` of arbitrary size. 3113 labels: A `Tensor` of the same size as predictions. 3114 weights: Optional `Tensor` indicating the frequency with which an example is 3115 sampled. Rank must be 0, or the same rank as `labels`, and must be 3116 broadcastable to `labels` (i.e., all dimensions must be either `1`, or 3117 the same as the corresponding `labels` dimension). 3118 metrics_collections: An optional list of collections that the metric 3119 value variable should be added to. 3120 updates_collections: An optional list of collections that the metric update 3121 ops should be added to. 3122 name: An optional variable_scope name. 3123 3124 Returns: 3125 pearson_r: A `Tensor` representing the current Pearson product-moment 3126 correlation coefficient, the value of 3127 `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. 3128 update_op: An operation that updates the underlying variables appropriately. 3129 3130 Raises: 3131 ValueError: If `labels` and `predictions` are of different sizes, or if 3132 `weights` is the wrong size, or if either `metrics_collections` or 3133 `updates_collections` are not a `list` or `tuple`. 3134 """ 3135 with variable_scope.variable_scope(name, 'pearson_r', 3136 (predictions, labels, weights)): 3137 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3138 predictions, labels, weights) 3139 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3140 # Broadcast weights here to avoid duplicate broadcasting in each call to 3141 # `streaming_covariance`. 3142 if weights is not None: 3143 weights = weights_broadcast_ops.broadcast_weights(weights, labels) 3144 cov, update_cov = streaming_covariance( 3145 predictions, labels, weights=weights, name='covariance') 3146 var_predictions, update_var_predictions = streaming_covariance( 3147 predictions, predictions, weights=weights, name='variance_predictions') 3148 var_labels, update_var_labels = streaming_covariance( 3149 labels, labels, weights=weights, name='variance_labels') 3150 3151 pearson_r = math_ops.truediv( 3152 cov, 3153 math_ops.multiply( 3154 math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), 3155 name='pearson_r') 3156 update_op = math_ops.truediv( 3157 update_cov, 3158 math_ops.multiply( 3159 math_ops.sqrt(update_var_predictions), 3160 math_ops.sqrt(update_var_labels)), 3161 name='update_op') 3162 3163 if metrics_collections: 3164 ops.add_to_collections(metrics_collections, pearson_r) 3165 3166 if updates_collections: 3167 ops.add_to_collections(updates_collections, update_op) 3168 3169 return pearson_r, update_op 3170 3171 3172# TODO(nsilberman): add a 'normalized' flag so that the user can request 3173# normalization if the inputs are not normalized. 3174def streaming_mean_cosine_distance(predictions, 3175 labels, 3176 dim, 3177 weights=None, 3178 metrics_collections=None, 3179 updates_collections=None, 3180 name=None): 3181 """Computes the cosine distance between the labels and predictions. 3182 3183 The `streaming_mean_cosine_distance` function creates two local variables, 3184 `total` and `count` that are used to compute the average cosine distance 3185 between `predictions` and `labels`. This average is weighted by `weights`, 3186 and it is ultimately returned as `mean_distance`, which is an idempotent 3187 operation that simply divides `total` by `count`. 3188 3189 For estimation of the metric over a stream of data, the function creates an 3190 `update_op` operation that updates these variables and returns the 3191 `mean_distance`. 3192 3193 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3194 3195 Args: 3196 predictions: A `Tensor` of the same shape as `labels`. 3197 labels: A `Tensor` of arbitrary shape. 3198 dim: The dimension along which the cosine distance is computed. 3199 weights: An optional `Tensor` whose shape is broadcastable to `predictions`, 3200 and whose dimension `dim` is 1. 3201 metrics_collections: An optional list of collections that the metric 3202 value variable should be added to. 3203 updates_collections: An optional list of collections that the metric update 3204 ops should be added to. 3205 name: An optional variable_scope name. 3206 3207 Returns: 3208 mean_distance: A `Tensor` representing the current mean, the value of 3209 `total` divided by `count`. 3210 update_op: An operation that increments the `total` and `count` variables 3211 appropriately. 3212 3213 Raises: 3214 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3215 `weights` is not `None` and its shape doesn't match `predictions`, or if 3216 either `metrics_collections` or `updates_collections` are not a list or 3217 tuple. 3218 """ 3219 predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3220 predictions, labels, weights) 3221 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 3222 radial_diffs = math_ops.multiply(predictions, labels) 3223 radial_diffs = math_ops.reduce_sum( 3224 radial_diffs, reduction_indices=[ 3225 dim, 3226 ], keep_dims=True) 3227 mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, 3228 name or 'mean_cosine_distance') 3229 mean_distance = math_ops.subtract(1.0, mean_distance) 3230 update_op = math_ops.subtract(1.0, update_op) 3231 3232 if metrics_collections: 3233 ops.add_to_collections(metrics_collections, mean_distance) 3234 3235 if updates_collections: 3236 ops.add_to_collections(updates_collections, update_op) 3237 3238 return mean_distance, update_op 3239 3240 3241def streaming_percentage_less(values, 3242 threshold, 3243 weights=None, 3244 metrics_collections=None, 3245 updates_collections=None, 3246 name=None): 3247 """Computes the percentage of values less than the given threshold. 3248 3249 The `streaming_percentage_less` function creates two local variables, 3250 `total` and `count` that are used to compute the percentage of `values` that 3251 fall below `threshold`. This rate is weighted by `weights`, and it is 3252 ultimately returned as `percentage` which is an idempotent operation that 3253 simply divides `total` by `count`. 3254 3255 For estimation of the metric over a stream of data, the function creates an 3256 `update_op` operation that updates these variables and returns the 3257 `percentage`. 3258 3259 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3260 3261 Args: 3262 values: A numeric `Tensor` of arbitrary size. 3263 threshold: A scalar threshold. 3264 weights: An optional `Tensor` whose shape is broadcastable to `values`. 3265 metrics_collections: An optional list of collections that the metric 3266 value variable should be added to. 3267 updates_collections: An optional list of collections that the metric update 3268 ops should be added to. 3269 name: An optional variable_scope name. 3270 3271 Returns: 3272 percentage: A `Tensor` representing the current mean, the value of `total` 3273 divided by `count`. 3274 update_op: An operation that increments the `total` and `count` variables 3275 appropriately. 3276 3277 Raises: 3278 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 3279 or if either `metrics_collections` or `updates_collections` are not a list 3280 or tuple. 3281 """ 3282 return metrics.percentage_below( 3283 values=values, 3284 threshold=threshold, 3285 weights=weights, 3286 metrics_collections=metrics_collections, 3287 updates_collections=updates_collections, 3288 name=name) 3289 3290 3291def streaming_mean_iou(predictions, 3292 labels, 3293 num_classes, 3294 weights=None, 3295 metrics_collections=None, 3296 updates_collections=None, 3297 name=None): 3298 """Calculate per-step mean Intersection-Over-Union (mIOU). 3299 3300 Mean Intersection-Over-Union is a common evaluation metric for 3301 semantic image segmentation, which first computes the IOU for each 3302 semantic class and then computes the average over classes. 3303 IOU is defined as follows: 3304 IOU = true_positive / (true_positive + false_positive + false_negative). 3305 The predictions are accumulated in a confusion matrix, weighted by `weights`, 3306 and mIOU is then calculated from it. 3307 3308 For estimation of the metric over a stream of data, the function creates an 3309 `update_op` operation that updates these variables and returns the `mean_iou`. 3310 3311 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3312 3313 Args: 3314 predictions: A `Tensor` of prediction results for semantic labels, whose 3315 shape is [batch size] and type `int32` or `int64`. The tensor will be 3316 flattened, if its rank > 1. 3317 labels: A `Tensor` of ground truth labels with shape [batch size] and of 3318 type `int32` or `int64`. The tensor will be flattened, if its rank > 1. 3319 num_classes: The possible number of labels the prediction task can 3320 have. This value must be provided, since a confusion matrix of 3321 dimension = [num_classes, num_classes] will be allocated. 3322 weights: An optional `Tensor` whose shape is broadcastable to `predictions`. 3323 metrics_collections: An optional list of collections that `mean_iou` 3324 should be added to. 3325 updates_collections: An optional list of collections `update_op` should be 3326 added to. 3327 name: An optional variable_scope name. 3328 3329 Returns: 3330 mean_iou: A `Tensor` representing the mean intersection-over-union. 3331 update_op: An operation that increments the confusion matrix. 3332 3333 Raises: 3334 ValueError: If `predictions` and `labels` have mismatched shapes, or if 3335 `weights` is not `None` and its shape doesn't match `predictions`, or if 3336 either `metrics_collections` or `updates_collections` are not a list or 3337 tuple. 3338 """ 3339 return metrics.mean_iou( 3340 num_classes=num_classes, 3341 predictions=predictions, 3342 labels=labels, 3343 weights=weights, 3344 metrics_collections=metrics_collections, 3345 updates_collections=updates_collections, 3346 name=name) 3347 3348 3349def _next_array_size(required_size, growth_factor=1.5): 3350 """Calculate the next size for reallocating a dynamic array. 3351 3352 Args: 3353 required_size: number or tf.Tensor specifying required array capacity. 3354 growth_factor: optional number or tf.Tensor specifying the growth factor 3355 between subsequent allocations. 3356 3357 Returns: 3358 tf.Tensor with dtype=int32 giving the next array size. 3359 """ 3360 exponent = math_ops.ceil( 3361 math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( 3362 math_ops.cast(growth_factor, dtypes.float32))) 3363 return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) 3364 3365 3366def streaming_concat(values, 3367 axis=0, 3368 max_size=None, 3369 metrics_collections=None, 3370 updates_collections=None, 3371 name=None): 3372 """Concatenate values along an axis across batches. 3373 3374 The function `streaming_concat` creates two local variables, `array` and 3375 `size`, that are used to store concatenated values. Internally, `array` is 3376 used as storage for a dynamic array (if `maxsize` is `None`), which ensures 3377 that updates can be run in amortized constant time. 3378 3379 For estimation of the metric over a stream of data, the function creates an 3380 `update_op` operation that appends the values of a tensor and returns the 3381 length of the concatenated axis. 3382 3383 This op allows for evaluating metrics that cannot be updated incrementally 3384 using the same framework as other streaming metrics. 3385 3386 Args: 3387 values: `Tensor` to concatenate. Rank and the shape along all axes other 3388 than the axis to concatenate along must be statically known. 3389 axis: optional integer axis to concatenate along. 3390 max_size: optional integer maximum size of `value` along the given axis. 3391 Once the maximum size is reached, further updates are no-ops. By default, 3392 there is no maximum size: the array is resized as necessary. 3393 metrics_collections: An optional list of collections that `value` 3394 should be added to. 3395 updates_collections: An optional list of collections `update_op` should be 3396 added to. 3397 name: An optional variable_scope name. 3398 3399 Returns: 3400 value: A `Tensor` representing the concatenated values. 3401 update_op: An operation that concatenates the next values. 3402 3403 Raises: 3404 ValueError: if `values` does not have a statically known rank, `axis` is 3405 not in the valid range or the size of `values` is not statically known 3406 along any axis other than `axis`. 3407 """ 3408 with variable_scope.variable_scope(name, 'streaming_concat', (values,)): 3409 # pylint: disable=invalid-slice-index 3410 values_shape = values.get_shape() 3411 if values_shape.dims is None: 3412 raise ValueError('`values` must have known statically known rank') 3413 3414 ndim = len(values_shape) 3415 if axis < 0: 3416 axis += ndim 3417 if not 0 <= axis < ndim: 3418 raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) 3419 3420 fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] 3421 if any(value is None for value in fixed_shape): 3422 raise ValueError('all dimensions of `values` other than the dimension to ' 3423 'concatenate along must have statically known size') 3424 3425 # We move `axis` to the front of the internal array so assign ops can be 3426 # applied to contiguous slices 3427 init_size = 0 if max_size is None else max_size 3428 init_shape = [init_size] + fixed_shape 3429 array = metrics_impl.metric_variable( 3430 init_shape, values.dtype, validate_shape=False, name='array') 3431 size = metrics_impl.metric_variable([], dtypes.int32, name='size') 3432 3433 perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] 3434 valid_array = array[:size] 3435 valid_array.set_shape([None] + fixed_shape) 3436 value = array_ops.transpose(valid_array, perm, name='concat') 3437 3438 values_size = array_ops.shape(values)[axis] 3439 if max_size is None: 3440 batch_size = values_size 3441 else: 3442 batch_size = math_ops.minimum(values_size, max_size - size) 3443 3444 perm = [axis] + [n for n in range(ndim) if n != axis] 3445 batch_values = array_ops.transpose(values, perm)[:batch_size] 3446 3447 def reallocate(): 3448 next_size = _next_array_size(new_size) 3449 next_shape = array_ops.stack([next_size] + fixed_shape) 3450 new_value = array_ops.zeros(next_shape, dtype=values.dtype) 3451 old_value = array.value() 3452 assign_op = state_ops.assign(array, new_value, validate_shape=False) 3453 with ops.control_dependencies([assign_op]): 3454 copy_op = array[:size].assign(old_value[:size]) 3455 # return value needs to be the same dtype as no_op() for cond 3456 with ops.control_dependencies([copy_op]): 3457 return control_flow_ops.no_op() 3458 3459 new_size = size + batch_size 3460 array_size = array_ops.shape_internal(array, optimize=False)[0] 3461 maybe_reallocate_op = control_flow_ops.cond( 3462 new_size > array_size, reallocate, control_flow_ops.no_op) 3463 with ops.control_dependencies([maybe_reallocate_op]): 3464 append_values_op = array[size:new_size].assign(batch_values) 3465 with ops.control_dependencies([append_values_op]): 3466 update_op = size.assign(new_size) 3467 3468 if metrics_collections: 3469 ops.add_to_collections(metrics_collections, value) 3470 3471 if updates_collections: 3472 ops.add_to_collections(updates_collections, update_op) 3473 3474 return value, update_op 3475 # pylint: enable=invalid-slice-index 3476 3477 3478def aggregate_metrics(*value_update_tuples): 3479 """Aggregates the metric value tensors and update ops into two lists. 3480 3481 Args: 3482 *value_update_tuples: a variable number of tuples, each of which contain the 3483 pair of (value_tensor, update_op) from a streaming metric. 3484 3485 Returns: 3486 A list of value `Tensor` objects and a list of update ops. 3487 3488 Raises: 3489 ValueError: if `value_update_tuples` is empty. 3490 """ 3491 if not value_update_tuples: 3492 raise ValueError('Expected at least one value_tensor/update_op pair') 3493 value_ops, update_ops = zip(*value_update_tuples) 3494 return list(value_ops), list(update_ops) 3495 3496 3497def aggregate_metric_map(names_to_tuples): 3498 """Aggregates the metric names to tuple dictionary. 3499 3500 This function is useful for pairing metric names with their associated value 3501 and update ops when the list of metrics is long. For example: 3502 3503 ```python 3504 metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({ 3505 'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error( 3506 predictions, labels, weights), 3507 'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error( 3508 predictions, labels, labels, weights), 3509 'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error( 3510 predictions, labels, weights), 3511 'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error( 3512 predictions, labels, weights), 3513 }) 3514 ``` 3515 3516 Args: 3517 names_to_tuples: a map of metric names to tuples, each of which contain the 3518 pair of (value_tensor, update_op) from a streaming metric. 3519 3520 Returns: 3521 A dictionary from metric names to value ops and a dictionary from metric 3522 names to update ops. 3523 """ 3524 metric_names = names_to_tuples.keys() 3525 value_ops, update_ops = zip(*names_to_tuples.values()) 3526 return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) 3527 3528 3529def count(values, 3530 weights=None, 3531 metrics_collections=None, 3532 updates_collections=None, 3533 name=None): 3534 """Computes the number of examples, or sum of `weights`. 3535 3536 When evaluating some metric (e.g. mean) on one or more subsets of the data, 3537 this auxiliary metric is useful for keeping track of how many examples there 3538 are in each subset. 3539 3540 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3541 3542 Args: 3543 values: A `Tensor` of arbitrary dimensions. Only it's shape is used. 3544 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3545 `labels`, and must be broadcastable to `labels` (i.e., all dimensions 3546 must be either `1`, or the same as the corresponding `labels` 3547 dimension). 3548 metrics_collections: An optional list of collections that the metric 3549 value variable should be added to. 3550 updates_collections: An optional list of collections that the metric update 3551 ops should be added to. 3552 name: An optional variable_scope name. 3553 3554 Returns: 3555 count: A `Tensor` representing the current value of the metric. 3556 update_op: An operation that accumulates the metric from a batch of data. 3557 3558 Raises: 3559 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 3560 or if either `metrics_collections` or `updates_collections` are not a list 3561 or tuple. 3562 """ 3563 3564 with variable_scope.variable_scope(name, 'count', (values, weights)): 3565 count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') 3566 3567 if weights is None: 3568 num_values = math_ops.to_float(array_ops.size(values)) 3569 else: 3570 _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3571 predictions=values, 3572 labels=None, 3573 weights=weights) 3574 weights = weights_broadcast_ops.broadcast_weights( 3575 math_ops.to_float(weights), values) 3576 num_values = math_ops.reduce_sum(weights) 3577 3578 with ops.control_dependencies([values]): 3579 update_op = state_ops.assign_add(count_, num_values) 3580 3581 if metrics_collections: 3582 ops.add_to_collections(metrics_collections, count_) 3583 3584 if updates_collections: 3585 ops.add_to_collections(updates_collections, update_op) 3586 3587 return count_, update_op 3588 3589 3590def cohen_kappa(labels, 3591 predictions_idx, 3592 num_classes, 3593 weights=None, 3594 metrics_collections=None, 3595 updates_collections=None, 3596 name=None): 3597 """Calculates Cohen's kappa. 3598 3599 [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic 3600 that measures inter-annotator agreement. 3601 3602 The `cohen_kappa` function calculates the confusion matrix, and creates three 3603 local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`, 3604 which refer to the diagonal part, rows and columns totals of the confusion 3605 matrix, respectively. This value is ultimately returned as `kappa`, an 3606 idempotent operation that is calculated by 3607 3608 pe = (pe_row * pe_col) / N 3609 k = (sum(po) - sum(pe)) / (N - sum(pe)) 3610 3611 For estimation of the metric over a stream of data, the function creates an 3612 `update_op` operation that updates these variables and returns the 3613 `kappa`. `update_op` weights each prediction by the corresponding value in 3614 `weights`. 3615 3616 Class labels are expected to start at 0. E.g., if `num_classes` 3617 was three, then the possible labels would be [0, 1, 2]. 3618 3619 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3620 3621 NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method 3622 doesn't support weighted matrix yet. 3623 3624 Args: 3625 labels: 1-D `Tensor` of real labels for the classification task. Must be 3626 one of the following types: int16, int32, int64. 3627 predictions_idx: 1-D `Tensor` of predicted class indices for a given 3628 classification. Must have the same type as `labels`. 3629 num_classes: The possible number of labels. 3630 weights: Optional `Tensor` whose shape matches `predictions`. 3631 metrics_collections: An optional list of collections that `kappa` should 3632 be added to. 3633 updates_collections: An optional list of collections that `update_op` should 3634 be added to. 3635 name: An optional variable_scope name. 3636 3637 Returns: 3638 kappa: Scalar float `Tensor` representing the current Cohen's kappa. 3639 update_op: `Operation` that increments `po`, `pe_row` and `pe_col` 3640 variables appropriately and whose value matches `kappa`. 3641 3642 Raises: 3643 ValueError: If `num_classes` is less than 2, or `predictions` and `labels` 3644 have mismatched shapes, or if `weights` is not `None` and its shape 3645 doesn't match `predictions`, or if either `metrics_collections` or 3646 `updates_collections` are not a list or tuple. 3647 RuntimeError: If eager execution is enabled. 3648 """ 3649 if context.in_eager_mode(): 3650 raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported' 3651 'when eager execution is enabled.') 3652 if num_classes < 2: 3653 raise ValueError('`num_classes` must be >= 2.' 3654 'Found: {}'.format(num_classes)) 3655 with variable_scope.variable_scope(name, 'cohen_kappa', 3656 (labels, predictions_idx, weights)): 3657 # Convert 2-dim (num, 1) to 1-dim (num,) 3658 labels.get_shape().with_rank_at_most(2) 3659 if labels.get_shape().ndims == 2: 3660 labels = array_ops.squeeze(labels, axis=[-1]) 3661 predictions_idx, labels, weights = ( 3662 metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access 3663 predictions=predictions_idx, 3664 labels=labels, 3665 weights=weights)) 3666 predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape()) 3667 3668 stat_dtype = ( 3669 dtypes.int64 3670 if weights is None or weights.dtype.is_integer else dtypes.float32) 3671 po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po') 3672 pe_row = metrics_impl.metric_variable( 3673 (num_classes,), stat_dtype, name='pe_row') 3674 pe_col = metrics_impl.metric_variable( 3675 (num_classes,), stat_dtype, name='pe_col') 3676 3677 # Table of the counts of agreement: 3678 counts_in_table = confusion_matrix.confusion_matrix( 3679 labels, 3680 predictions_idx, 3681 num_classes=num_classes, 3682 weights=weights, 3683 dtype=stat_dtype, 3684 name='counts_in_table') 3685 3686 po_t = array_ops.diag_part(counts_in_table) 3687 pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0) 3688 pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1) 3689 update_po = state_ops.assign_add(po, po_t) 3690 update_pe_row = state_ops.assign_add(pe_row, pe_row_t) 3691 update_pe_col = state_ops.assign_add(pe_col, pe_col_t) 3692 3693 def _calculate_k(po, pe_row, pe_col, name): 3694 po_sum = math_ops.reduce_sum(po) 3695 total = math_ops.reduce_sum(pe_row) 3696 pe_sum = math_ops.reduce_sum( 3697 metrics_impl._safe_div( # pylint: disable=protected-access 3698 pe_row * pe_col, total, None)) 3699 po_sum, pe_sum, total = (math_ops.to_double(po_sum), 3700 math_ops.to_double(pe_sum), 3701 math_ops.to_double(total)) 3702 # kappa = (po - pe) / (N - pe) 3703 k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access 3704 po_sum - pe_sum, 3705 total - pe_sum, 3706 name=name) 3707 return k 3708 3709 kappa = _calculate_k(po, pe_row, pe_col, name='value') 3710 update_op = _calculate_k( 3711 update_po, update_pe_row, update_pe_col, name='update_op') 3712 3713 if metrics_collections: 3714 ops.add_to_collections(metrics_collections, kappa) 3715 3716 if updates_collections: 3717 ops.add_to_collections(updates_collections, update_op) 3718 3719 return kappa, update_op 3720 3721 3722__all__ = [ 3723 'auc_with_confidence_intervals', 3724 'aggregate_metric_map', 3725 'aggregate_metrics', 3726 'cohen_kappa', 3727 'count', 3728 'precision_recall_at_equal_thresholds', 3729 'recall_at_precision', 3730 'sparse_recall_at_top_k', 3731 'streaming_accuracy', 3732 'streaming_auc', 3733 'streaming_curve_points', 3734 'streaming_dynamic_auc', 3735 'streaming_false_negative_rate', 3736 'streaming_false_negative_rate_at_thresholds', 3737 'streaming_false_negatives', 3738 'streaming_false_negatives_at_thresholds', 3739 'streaming_false_positive_rate', 3740 'streaming_false_positive_rate_at_thresholds', 3741 'streaming_false_positives', 3742 'streaming_false_positives_at_thresholds', 3743 'streaming_mean', 3744 'streaming_mean_absolute_error', 3745 'streaming_mean_cosine_distance', 3746 'streaming_mean_iou', 3747 'streaming_mean_relative_error', 3748 'streaming_mean_squared_error', 3749 'streaming_mean_tensor', 3750 'streaming_percentage_less', 3751 'streaming_precision', 3752 'streaming_precision_at_thresholds', 3753 'streaming_recall', 3754 'streaming_recall_at_k', 3755 'streaming_recall_at_thresholds', 3756 'streaming_root_mean_squared_error', 3757 'streaming_sensitivity_at_specificity', 3758 'streaming_sparse_average_precision_at_k', 3759 'streaming_sparse_average_precision_at_top_k', 3760 'streaming_sparse_precision_at_k', 3761 'streaming_sparse_precision_at_top_k', 3762 'streaming_sparse_recall_at_k', 3763 'streaming_specificity_at_sensitivity', 3764 'streaming_true_negatives', 3765 'streaming_true_negatives_at_thresholds', 3766 'streaming_true_positives', 3767 'streaming_true_positives_at_thresholds', 3768] 3769