metric_ops.py revision 972fa89023f8f27948321c388fa3f1f7857833c3
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections as collections_lib
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import confusion_matrix
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import metrics
36from tensorflow.python.ops import metrics_impl
37from tensorflow.python.ops import nn
38from tensorflow.python.ops import state_ops
39from tensorflow.python.ops import variable_scope
40from tensorflow.python.ops import weights_broadcast_ops
41from tensorflow.python.ops.distributions.normal import Normal
42from tensorflow.python.util.deprecation import deprecated
43
44# Epsilon constant used to represent extremely small quantity.
45_EPSILON = 1e-7
46
47
48def _safe_div(numerator, denominator, name):
49  """Divides two values, returning 0 if the denominator is <= 0.
50
51  Args:
52    numerator: A real `Tensor`.
53    denominator: A real `Tensor`, with dtype matching `numerator`.
54    name: Name for the returned op.
55
56  Returns:
57    0 if `denominator` <= 0, else `numerator` / `denominator`
58  """
59  return array_ops.where(
60      math_ops.greater(denominator, 0),
61      math_ops.truediv(numerator, denominator),
62      0,
63      name=name)
64
65
66def streaming_true_positives(predictions,
67                             labels,
68                             weights=None,
69                             metrics_collections=None,
70                             updates_collections=None,
71                             name=None):
72  """Sum the weights of true_positives.
73
74  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
75
76  Args:
77    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
78      be cast to `bool`.
79    labels: The ground truth values, a `Tensor` whose dimensions must match
80      `predictions`. Will be cast to `bool`.
81    weights: Optional `Tensor` whose rank is either 0, or the same rank as
82      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
83      must be either `1`, or the same as the corresponding `labels`
84      dimension).
85    metrics_collections: An optional list of collections that the metric
86      value variable should be added to.
87    updates_collections: An optional list of collections that the metric update
88      ops should be added to.
89    name: An optional variable_scope name.
90
91  Returns:
92    value_tensor: A `Tensor` representing the current value of the metric.
93    update_op: An operation that accumulates the error from a batch of data.
94
95  Raises:
96    ValueError: If `predictions` and `labels` have mismatched shapes, or if
97      `weights` is not `None` and its shape doesn't match `predictions`, or if
98      either `metrics_collections` or `updates_collections` are not a list or
99      tuple.
100  """
101  return metrics.true_positives(
102      predictions=predictions,
103      labels=labels,
104      weights=weights,
105      metrics_collections=metrics_collections,
106      updates_collections=updates_collections,
107      name=name)
108
109
110def streaming_true_negatives(predictions,
111                             labels,
112                             weights=None,
113                             metrics_collections=None,
114                             updates_collections=None,
115                             name=None):
116  """Sum the weights of true_negatives.
117
118  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
119
120  Args:
121    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
122      be cast to `bool`.
123    labels: The ground truth values, a `Tensor` whose dimensions must match
124      `predictions`. Will be cast to `bool`.
125    weights: Optional `Tensor` whose rank is either 0, or the same rank as
126      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
127      must be either `1`, or the same as the corresponding `labels`
128      dimension).
129    metrics_collections: An optional list of collections that the metric
130      value variable should be added to.
131    updates_collections: An optional list of collections that the metric update
132      ops should be added to.
133    name: An optional variable_scope name.
134
135  Returns:
136    value_tensor: A `Tensor` representing the current value of the metric.
137    update_op: An operation that accumulates the error from a batch of data.
138
139  Raises:
140    ValueError: If `predictions` and `labels` have mismatched shapes, or if
141      `weights` is not `None` and its shape doesn't match `predictions`, or if
142      either `metrics_collections` or `updates_collections` are not a list or
143      tuple.
144  """
145  return metrics.true_negatives(
146      predictions=predictions,
147      labels=labels,
148      weights=weights,
149      metrics_collections=metrics_collections,
150      updates_collections=updates_collections,
151      name=name)
152
153
154def streaming_false_positives(predictions,
155                              labels,
156                              weights=None,
157                              metrics_collections=None,
158                              updates_collections=None,
159                              name=None):
160  """Sum the weights of false positives.
161
162  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
163
164  Args:
165    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
166      be cast to `bool`.
167    labels: The ground truth values, a `Tensor` whose dimensions must match
168      `predictions`. Will be cast to `bool`.
169    weights: Optional `Tensor` whose rank is either 0, or the same rank as
170      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
171      must be either `1`, or the same as the corresponding `labels`
172      dimension).
173    metrics_collections: An optional list of collections that the metric
174      value variable should be added to.
175    updates_collections: An optional list of collections that the metric update
176      ops should be added to.
177    name: An optional variable_scope name.
178
179  Returns:
180    value_tensor: A `Tensor` representing the current value of the metric.
181    update_op: An operation that accumulates the error from a batch of data.
182
183  Raises:
184    ValueError: If `predictions` and `labels` have mismatched shapes, or if
185      `weights` is not `None` and its shape doesn't match `predictions`, or if
186      either `metrics_collections` or `updates_collections` are not a list or
187      tuple.
188  """
189  return metrics.false_positives(
190      predictions=predictions,
191      labels=labels,
192      weights=weights,
193      metrics_collections=metrics_collections,
194      updates_collections=updates_collections,
195      name=name)
196
197
198def streaming_false_negatives(predictions,
199                              labels,
200                              weights=None,
201                              metrics_collections=None,
202                              updates_collections=None,
203                              name=None):
204  """Computes the total number of false negatives.
205
206  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
207
208  Args:
209    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
210      be cast to `bool`.
211    labels: The ground truth values, a `Tensor` whose dimensions must match
212      `predictions`. Will be cast to `bool`.
213    weights: Optional `Tensor` whose rank is either 0, or the same rank as
214      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
215      must be either `1`, or the same as the corresponding `labels`
216      dimension).
217    metrics_collections: An optional list of collections that the metric
218      value variable should be added to.
219    updates_collections: An optional list of collections that the metric update
220      ops should be added to.
221    name: An optional variable_scope name.
222
223  Returns:
224    value_tensor: A `Tensor` representing the current value of the metric.
225    update_op: An operation that accumulates the error from a batch of data.
226
227  Raises:
228    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
229      or if either `metrics_collections` or `updates_collections` are not a list
230      or tuple.
231  """
232  return metrics.false_negatives(
233      predictions=predictions,
234      labels=labels,
235      weights=weights,
236      metrics_collections=metrics_collections,
237      updates_collections=updates_collections,
238      name=name)
239
240
241def streaming_mean(values,
242                   weights=None,
243                   metrics_collections=None,
244                   updates_collections=None,
245                   name=None):
246  """Computes the (weighted) mean of the given values.
247
248  The `streaming_mean` function creates two local variables, `total` and `count`
249  that are used to compute the average of `values`. This average is ultimately
250  returned as `mean` which is an idempotent operation that simply divides
251  `total` by `count`.
252
253  For estimation of the metric over a stream of data, the function creates an
254  `update_op` operation that updates these variables and returns the `mean`.
255  `update_op` increments `total` with the reduced sum of the product of `values`
256  and `weights`, and it increments `count` with the reduced sum of `weights`.
257
258  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
259
260  Args:
261    values: A `Tensor` of arbitrary dimensions.
262    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
263      must be broadcastable to `values` (i.e., all dimensions must be either
264      `1`, or the same as the corresponding `values` dimension).
265    metrics_collections: An optional list of collections that `mean`
266      should be added to.
267    updates_collections: An optional list of collections that `update_op`
268      should be added to.
269    name: An optional variable_scope name.
270
271  Returns:
272    mean: A `Tensor` representing the current mean, the value of `total` divided
273      by `count`.
274    update_op: An operation that increments the `total` and `count` variables
275      appropriately and whose value matches `mean`.
276
277  Raises:
278    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
279      or if either `metrics_collections` or `updates_collections` are not a list
280      or tuple.
281  """
282  return metrics.mean(
283      values=values,
284      weights=weights,
285      metrics_collections=metrics_collections,
286      updates_collections=updates_collections,
287      name=name)
288
289
290def streaming_mean_tensor(values,
291                          weights=None,
292                          metrics_collections=None,
293                          updates_collections=None,
294                          name=None):
295  """Computes the element-wise (weighted) mean of the given tensors.
296
297  In contrast to the `streaming_mean` function which returns a scalar with the
298  mean,  this function returns an average tensor with the same shape as the
299  input tensors.
300
301  The `streaming_mean_tensor` function creates two local variables,
302  `total_tensor` and `count_tensor` that are used to compute the average of
303  `values`. This average is ultimately returned as `mean` which is an idempotent
304  operation that simply divides `total` by `count`.
305
306  For estimation of the metric over a stream of data, the function creates an
307  `update_op` operation that updates these variables and returns the `mean`.
308  `update_op` increments `total` with the reduced sum of the product of `values`
309  and `weights`, and it increments `count` with the reduced sum of `weights`.
310
311  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
312
313  Args:
314    values: A `Tensor` of arbitrary dimensions.
315    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
316      must be broadcastable to `values` (i.e., all dimensions must be either
317      `1`, or the same as the corresponding `values` dimension).
318    metrics_collections: An optional list of collections that `mean`
319      should be added to.
320    updates_collections: An optional list of collections that `update_op`
321      should be added to.
322    name: An optional variable_scope name.
323
324  Returns:
325    mean: A float `Tensor` representing the current mean, the value of `total`
326      divided by `count`.
327    update_op: An operation that increments the `total` and `count` variables
328      appropriately and whose value matches `mean`.
329
330  Raises:
331    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
332      or if either `metrics_collections` or `updates_collections` are not a list
333      or tuple.
334  """
335  return metrics.mean_tensor(
336      values=values,
337      weights=weights,
338      metrics_collections=metrics_collections,
339      updates_collections=updates_collections,
340      name=name)
341
342
343@deprecated(None,
344            'Please switch to tf.metrics.accuracy. Note that the order of the '
345            'labels and predictions arguments has been switched.')
346def streaming_accuracy(predictions,
347                       labels,
348                       weights=None,
349                       metrics_collections=None,
350                       updates_collections=None,
351                       name=None):
352  """Calculates how often `predictions` matches `labels`.
353
354  The `streaming_accuracy` function creates two local variables, `total` and
355  `count` that are used to compute the frequency with which `predictions`
356  matches `labels`. This frequency is ultimately returned as `accuracy`: an
357  idempotent operation that simply divides `total` by `count`.
358
359  For estimation of the metric over a stream of data, the function creates an
360  `update_op` operation that updates these variables and returns the `accuracy`.
361  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
362  where the corresponding elements of `predictions` and `labels` match and 0.0
363  otherwise. Then `update_op` increments `total` with the reduced sum of the
364  product of `weights` and `is_correct`, and it increments `count` with the
365  reduced sum of `weights`.
366
367  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
368
369  Args:
370    predictions: The predicted values, a `Tensor` of any shape.
371    labels: The ground truth values, a `Tensor` whose shape matches
372      `predictions`.
373    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
374      must be broadcastable to `labels` (i.e., all dimensions must be either
375      `1`, or the same as the corresponding `labels` dimension).
376    metrics_collections: An optional list of collections that `accuracy` should
377      be added to.
378    updates_collections: An optional list of collections that `update_op` should
379      be added to.
380    name: An optional variable_scope name.
381
382  Returns:
383    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
384      by `count`.
385    update_op: An operation that increments the `total` and `count` variables
386      appropriately and whose value matches `accuracy`.
387
388  Raises:
389    ValueError: If `predictions` and `labels` have mismatched shapes, or if
390      `weights` is not `None` and its shape doesn't match `predictions`, or if
391      either `metrics_collections` or `updates_collections` are not a list or
392      tuple.
393  """
394  return metrics.accuracy(
395      predictions=predictions,
396      labels=labels,
397      weights=weights,
398      metrics_collections=metrics_collections,
399      updates_collections=updates_collections,
400      name=name)
401
402
403def streaming_precision(predictions,
404                        labels,
405                        weights=None,
406                        metrics_collections=None,
407                        updates_collections=None,
408                        name=None):
409  """Computes the precision of the predictions with respect to the labels.
410
411  The `streaming_precision` function creates two local variables,
412  `true_positives` and `false_positives`, that are used to compute the
413  precision. This value is ultimately returned as `precision`, an idempotent
414  operation that simply divides `true_positives` by the sum of `true_positives`
415  and `false_positives`.
416
417  For estimation of the metric over a stream of data, the function creates an
418  `update_op` operation that updates these variables and returns the
419  `precision`. `update_op` weights each prediction by the corresponding value in
420  `weights`.
421
422  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
423
424  Args:
425    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
426    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
427      match `predictions`.
428    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
429      must be broadcastable to `labels` (i.e., all dimensions must be either
430      `1`, or the same as the corresponding `labels` dimension).
431    metrics_collections: An optional list of collections that `precision` should
432      be added to.
433    updates_collections: An optional list of collections that `update_op` should
434      be added to.
435    name: An optional variable_scope name.
436
437  Returns:
438    precision: Scalar float `Tensor` with the value of `true_positives`
439      divided by the sum of `true_positives` and `false_positives`.
440    update_op: `Operation` that increments `true_positives` and
441      `false_positives` variables appropriately and whose value matches
442      `precision`.
443
444  Raises:
445    ValueError: If `predictions` and `labels` have mismatched shapes, or if
446      `weights` is not `None` and its shape doesn't match `predictions`, or if
447      either `metrics_collections` or `updates_collections` are not a list or
448      tuple.
449  """
450  return metrics.precision(
451      predictions=predictions,
452      labels=labels,
453      weights=weights,
454      metrics_collections=metrics_collections,
455      updates_collections=updates_collections,
456      name=name)
457
458
459def streaming_recall(predictions,
460                     labels,
461                     weights=None,
462                     metrics_collections=None,
463                     updates_collections=None,
464                     name=None):
465  """Computes the recall of the predictions with respect to the labels.
466
467  The `streaming_recall` function creates two local variables, `true_positives`
468  and `false_negatives`, that are used to compute the recall. This value is
469  ultimately returned as `recall`, an idempotent operation that simply divides
470  `true_positives` by the sum of `true_positives`  and `false_negatives`.
471
472  For estimation of the metric over a stream of data, the function creates an
473  `update_op` that updates these variables and returns the `recall`. `update_op`
474  weights each prediction by the corresponding value in `weights`.
475
476  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
477
478  Args:
479    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
480    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
481      match `predictions`.
482    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
483      must be broadcastable to `labels` (i.e., all dimensions must be either
484      `1`, or the same as the corresponding `labels` dimension).
485    metrics_collections: An optional list of collections that `recall` should
486      be added to.
487    updates_collections: An optional list of collections that `update_op` should
488      be added to.
489    name: An optional variable_scope name.
490
491  Returns:
492    recall: Scalar float `Tensor` with the value of `true_positives` divided
493      by the sum of `true_positives` and `false_negatives`.
494    update_op: `Operation` that increments `true_positives` and
495      `false_negatives` variables appropriately and whose value matches
496      `recall`.
497
498  Raises:
499    ValueError: If `predictions` and `labels` have mismatched shapes, or if
500      `weights` is not `None` and its shape doesn't match `predictions`, or if
501      either `metrics_collections` or `updates_collections` are not a list or
502      tuple.
503  """
504  return metrics.recall(
505      predictions=predictions,
506      labels=labels,
507      weights=weights,
508      metrics_collections=metrics_collections,
509      updates_collections=updates_collections,
510      name=name)
511
512
513def streaming_false_positive_rate(predictions,
514                                  labels,
515                                  weights=None,
516                                  metrics_collections=None,
517                                  updates_collections=None,
518                                  name=None):
519  """Computes the false positive rate of predictions with respect to labels.
520
521  The `false_positive_rate` function creates two local variables,
522  `false_positives` and `true_negatives`, that are used to compute the
523  false positive rate. This value is ultimately returned as
524  `false_positive_rate`, an idempotent operation that simply divides
525  `false_positives` by the sum of `false_positives` and `true_negatives`.
526
527  For estimation of the metric over a stream of data, the function creates an
528  `update_op` operation that updates these variables and returns the
529  `false_positive_rate`. `update_op` weights each prediction by the
530  corresponding value in `weights`.
531
532  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
533
534  Args:
535    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
536      be cast to `bool`.
537    labels: The ground truth values, a `Tensor` whose dimensions must match
538      `predictions`. Will be cast to `bool`.
539    weights: Optional `Tensor` whose rank is either 0, or the same rank as
540      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
541      be either `1`, or the same as the corresponding `labels` dimension).
542    metrics_collections: An optional list of collections that
543     `false_positive_rate` should be added to.
544    updates_collections: An optional list of collections that `update_op` should
545      be added to.
546    name: An optional variable_scope name.
547
548  Returns:
549    false_positive_rate: Scalar float `Tensor` with the value of
550      `false_positives` divided by the sum of `false_positives` and
551      `true_negatives`.
552    update_op: `Operation` that increments `false_positives` and
553      `true_negatives` variables appropriately and whose value matches
554      `false_positive_rate`.
555
556  Raises:
557    ValueError: If `predictions` and `labels` have mismatched shapes, or if
558      `weights` is not `None` and its shape doesn't match `predictions`, or if
559      either `metrics_collections` or `updates_collections` are not a list or
560      tuple.
561  """
562  with variable_scope.variable_scope(name, 'false_positive_rate',
563                                     (predictions, labels, weights)):
564    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
565        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
566        labels=math_ops.cast(labels, dtype=dtypes.bool),
567        weights=weights)
568
569    false_p, false_positives_update_op = metrics.false_positives(
570        labels=labels,
571        predictions=predictions,
572        weights=weights,
573        metrics_collections=None,
574        updates_collections=None,
575        name=None)
576    true_n, true_negatives_update_op = metrics.true_negatives(
577        labels=labels,
578        predictions=predictions,
579        weights=weights,
580        metrics_collections=None,
581        updates_collections=None,
582        name=None)
583
584    def compute_fpr(fp, tn, name):
585      return array_ops.where(
586          math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name)
587
588    fpr = compute_fpr(false_p, true_n, 'value')
589    update_op = compute_fpr(false_positives_update_op, true_negatives_update_op,
590                            'update_op')
591
592    if metrics_collections:
593      ops.add_to_collections(metrics_collections, fpr)
594
595    if updates_collections:
596      ops.add_to_collections(updates_collections, update_op)
597
598    return fpr, update_op
599
600
601def streaming_false_negative_rate(predictions,
602                                  labels,
603                                  weights=None,
604                                  metrics_collections=None,
605                                  updates_collections=None,
606                                  name=None):
607  """Computes the false negative rate of predictions with respect to labels.
608
609  The `false_negative_rate` function creates two local variables,
610  `false_negatives` and `true_positives`, that are used to compute the
611  false positive rate. This value is ultimately returned as
612  `false_negative_rate`, an idempotent operation that simply divides
613  `false_negatives` by the sum of `false_negatives` and `true_positives`.
614
615  For estimation of the metric over a stream of data, the function creates an
616  `update_op` operation that updates these variables and returns the
617  `false_negative_rate`. `update_op` weights each prediction by the
618  corresponding value in `weights`.
619
620  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
621
622  Args:
623    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
624      be cast to `bool`.
625    labels: The ground truth values, a `Tensor` whose dimensions must match
626      `predictions`. Will be cast to `bool`.
627    weights: Optional `Tensor` whose rank is either 0, or the same rank as
628      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
629      be either `1`, or the same as the corresponding `labels` dimension).
630    metrics_collections: An optional list of collections that
631      `false_negative_rate` should be added to.
632    updates_collections: An optional list of collections that `update_op` should
633      be added to.
634    name: An optional variable_scope name.
635
636  Returns:
637    false_negative_rate: Scalar float `Tensor` with the value of
638      `false_negatives` divided by the sum of `false_negatives` and
639      `true_positives`.
640    update_op: `Operation` that increments `false_negatives` and
641      `true_positives` variables appropriately and whose value matches
642      `false_negative_rate`.
643
644  Raises:
645    ValueError: If `predictions` and `labels` have mismatched shapes, or if
646      `weights` is not `None` and its shape doesn't match `predictions`, or if
647      either `metrics_collections` or `updates_collections` are not a list or
648      tuple.
649  """
650  with variable_scope.variable_scope(name, 'false_negative_rate',
651                                     (predictions, labels, weights)):
652    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
653        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
654        labels=math_ops.cast(labels, dtype=dtypes.bool),
655        weights=weights)
656
657    false_n, false_negatives_update_op = metrics.false_negatives(
658        labels,
659        predictions,
660        weights,
661        metrics_collections=None,
662        updates_collections=None,
663        name=None)
664    true_p, true_positives_update_op = metrics.true_positives(
665        labels,
666        predictions,
667        weights,
668        metrics_collections=None,
669        updates_collections=None,
670        name=None)
671
672    def compute_fnr(fn, tp, name):
673      return array_ops.where(
674          math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name)
675
676    fnr = compute_fnr(false_n, true_p, 'value')
677    update_op = compute_fnr(false_negatives_update_op, true_positives_update_op,
678                            'update_op')
679
680    if metrics_collections:
681      ops.add_to_collections(metrics_collections, fnr)
682
683    if updates_collections:
684      ops.add_to_collections(updates_collections, update_op)
685
686    return fnr, update_op
687
688
689def _streaming_confusion_matrix_at_thresholds(predictions,
690                                              labels,
691                                              thresholds,
692                                              weights=None,
693                                              includes=None):
694  """Computes true_positives, false_negatives, true_negatives, false_positives.
695
696  This function creates up to four local variables, `true_positives`,
697  `true_negatives`, `false_positives` and `false_negatives`.
698  `true_positive[i]` is defined as the total weight of values in `predictions`
699  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
700  `false_negatives[i]` is defined as the total weight of values in `predictions`
701  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
702  `true_negatives[i]` is defined as the total weight of values in `predictions`
703  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
704  `false_positives[i]` is defined as the total weight of values in `predictions`
705  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
706
707  For estimation of these metrics over a stream of data, for each metric the
708  function respectively creates an `update_op` operation that updates the
709  variable and returns its value.
710
711  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
712
713  Args:
714    predictions: A floating point `Tensor` of arbitrary shape and whose values
715      are in the range `[0, 1]`.
716    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
717      to `bool`.
718    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
719    weights: Optional `Tensor` whose rank is either 0, or the same rank as
720      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
721      must be either `1`, or the same as the corresponding `labels`
722      dimension).
723    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
724      default to all four.
725
726  Returns:
727    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
728        `includes`.
729    update_ops: Dict of operations that increments the `values`. Keys are from
730        `includes`.
731
732  Raises:
733    ValueError: If `predictions` and `labels` have mismatched shapes, or if
734      `weights` is not `None` and its shape doesn't match `predictions`, or if
735      `includes` contains invalid keys.
736  """
737  all_includes = ('tp', 'fn', 'tn', 'fp')
738  if includes is None:
739    includes = all_includes
740  else:
741    for include in includes:
742      if include not in all_includes:
743        raise ValueError('Invaild key: %s.' % include)
744
745  predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
746      predictions, labels, weights)
747  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
748
749  num_thresholds = len(thresholds)
750
751  # Reshape predictions and labels.
752  predictions_2d = array_ops.reshape(predictions, [-1, 1])
753  labels_2d = array_ops.reshape(
754      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
755
756  # Use static shape if known.
757  num_predictions = predictions_2d.get_shape().as_list()[0]
758
759  # Otherwise use dynamic shape.
760  if num_predictions is None:
761    num_predictions = array_ops.shape(predictions_2d)[0]
762  thresh_tiled = array_ops.tile(
763      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
764      array_ops.stack([1, num_predictions]))
765
766  # Tile the predictions after thresholding them across different thresholds.
767  pred_is_pos = math_ops.greater(
768      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
769      thresh_tiled)
770  if ('fn' in includes) or ('tn' in includes):
771    pred_is_neg = math_ops.logical_not(pred_is_pos)
772
773  # Tile labels by number of thresholds
774  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
775  if ('fp' in includes) or ('tn' in includes):
776    label_is_neg = math_ops.logical_not(label_is_pos)
777
778  if weights is not None:
779    broadcast_weights = weights_broadcast_ops.broadcast_weights(
780        math_ops.to_float(weights), predictions)
781    weights_tiled = array_ops.tile(
782        array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1])
783    thresh_tiled.get_shape().assert_is_compatible_with(
784        weights_tiled.get_shape())
785  else:
786    weights_tiled = None
787
788  values = {}
789  update_ops = {}
790
791  if 'tp' in includes:
792    true_positives = metrics_impl.metric_variable(
793        [num_thresholds], dtypes.float32, name='true_positives')
794    is_true_positive = math_ops.to_float(
795        math_ops.logical_and(label_is_pos, pred_is_pos))
796    if weights_tiled is not None:
797      is_true_positive *= weights_tiled
798    update_ops['tp'] = state_ops.assign_add(true_positives,
799                                            math_ops.reduce_sum(
800                                                is_true_positive, 1))
801    values['tp'] = true_positives
802
803  if 'fn' in includes:
804    false_negatives = metrics_impl.metric_variable(
805        [num_thresholds], dtypes.float32, name='false_negatives')
806    is_false_negative = math_ops.to_float(
807        math_ops.logical_and(label_is_pos, pred_is_neg))
808    if weights_tiled is not None:
809      is_false_negative *= weights_tiled
810    update_ops['fn'] = state_ops.assign_add(false_negatives,
811                                            math_ops.reduce_sum(
812                                                is_false_negative, 1))
813    values['fn'] = false_negatives
814
815  if 'tn' in includes:
816    true_negatives = metrics_impl.metric_variable(
817        [num_thresholds], dtypes.float32, name='true_negatives')
818    is_true_negative = math_ops.to_float(
819        math_ops.logical_and(label_is_neg, pred_is_neg))
820    if weights_tiled is not None:
821      is_true_negative *= weights_tiled
822    update_ops['tn'] = state_ops.assign_add(true_negatives,
823                                            math_ops.reduce_sum(
824                                                is_true_negative, 1))
825    values['tn'] = true_negatives
826
827  if 'fp' in includes:
828    false_positives = metrics_impl.metric_variable(
829        [num_thresholds], dtypes.float32, name='false_positives')
830    is_false_positive = math_ops.to_float(
831        math_ops.logical_and(label_is_neg, pred_is_pos))
832    if weights_tiled is not None:
833      is_false_positive *= weights_tiled
834    update_ops['fp'] = state_ops.assign_add(false_positives,
835                                            math_ops.reduce_sum(
836                                                is_false_positive, 1))
837    values['fp'] = false_positives
838
839  return values, update_ops
840
841
842def streaming_true_positives_at_thresholds(predictions,
843                                           labels,
844                                           thresholds,
845                                           weights=None):
846  values, update_ops = _streaming_confusion_matrix_at_thresholds(
847      predictions, labels, thresholds, weights=weights, includes=('tp',))
848  return values['tp'], update_ops['tp']
849
850
851def streaming_false_negatives_at_thresholds(predictions,
852                                            labels,
853                                            thresholds,
854                                            weights=None):
855  values, update_ops = _streaming_confusion_matrix_at_thresholds(
856      predictions, labels, thresholds, weights=weights, includes=('fn',))
857  return values['fn'], update_ops['fn']
858
859
860def streaming_false_positives_at_thresholds(predictions,
861                                            labels,
862                                            thresholds,
863                                            weights=None):
864  values, update_ops = _streaming_confusion_matrix_at_thresholds(
865      predictions, labels, thresholds, weights=weights, includes=('fp',))
866  return values['fp'], update_ops['fp']
867
868
869def streaming_true_negatives_at_thresholds(predictions,
870                                           labels,
871                                           thresholds,
872                                           weights=None):
873  values, update_ops = _streaming_confusion_matrix_at_thresholds(
874      predictions, labels, thresholds, weights=weights, includes=('tn',))
875  return values['tn'], update_ops['tn']
876
877
878def streaming_curve_points(labels=None,
879                           predictions=None,
880                           weights=None,
881                           num_thresholds=200,
882                           metrics_collections=None,
883                           updates_collections=None,
884                           curve='ROC',
885                           name=None):
886  """Computes curve (ROC or PR) values for a prespecified number of points.
887
888  The `streaming_curve_points` function creates four local variables,
889  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
890  that are used to compute the curve values. To discretize the curve, a linearly
891  spaced set of thresholds is used to compute pairs of recall and precision
892  values.
893
894  For best results, `predictions` should be distributed approximately uniformly
895  in the range [0, 1] and not peaked around 0 or 1.
896
897  For estimation of the metric over a stream of data, the function creates an
898  `update_op` operation that updates these variables.
899
900  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
901
902  Args:
903    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
904      `bool`.
905    predictions: A floating point `Tensor` of arbitrary shape and whose values
906      are in the range `[0, 1]`.
907    weights: Optional `Tensor` whose rank is either 0, or the same rank as
908      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
909      be either `1`, or the same as the corresponding `labels` dimension).
910    num_thresholds: The number of thresholds to use when discretizing the roc
911      curve.
912    metrics_collections: An optional list of collections that `auc` should be
913      added to.
914    updates_collections: An optional list of collections that `update_op` should
915      be added to.
916    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
917      'PR' for the Precision-Recall-curve.
918    name: An optional variable_scope name.
919
920  Returns:
921    points: A `Tensor` with shape [num_thresholds, 2] that contains points of
922      the curve.
923    update_op: An operation that increments the `true_positives`,
924      `true_negatives`, `false_positives` and `false_negatives` variables.
925
926  Raises:
927    ValueError: If `predictions` and `labels` have mismatched shapes, or if
928      `weights` is not `None` and its shape doesn't match `predictions`, or if
929      either `metrics_collections` or `updates_collections` are not a list or
930      tuple.
931
932  TODO(chizeng): Consider rewriting this method to make use of logic within the
933  precision_recall_at_equal_thresholds method (to improve run time).
934  """
935  with variable_scope.variable_scope(name, 'curve_points',
936                                     (labels, predictions, weights)):
937    if curve != 'ROC' and curve != 'PR':
938      raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
939    kepsilon = _EPSILON  # to account for floating point imprecisions
940    thresholds = [
941        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
942    ]
943    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
944
945    values, update_ops = _streaming_confusion_matrix_at_thresholds(
946        labels=labels,
947        predictions=predictions,
948        thresholds=thresholds,
949        weights=weights)
950
951    # Add epsilons to avoid dividing by 0.
952    epsilon = 1.0e-6
953
954    def compute_points(tp, fn, tn, fp):
955      """Computes the roc-auc or pr-auc based on confusion counts."""
956      rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
957      if curve == 'ROC':
958        fp_rate = math_ops.div(fp, fp + tn + epsilon)
959        return fp_rate, rec
960      else:  # curve == 'PR'.
961        prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
962        return rec, prec
963
964    xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
965                            values['fp'])
966    points = array_ops.stack([xs, ys], axis=1)
967    update_op = control_flow_ops.group(*update_ops.values())
968
969    if metrics_collections:
970      ops.add_to_collections(metrics_collections, points)
971
972    if updates_collections:
973      ops.add_to_collections(updates_collections, update_op)
974
975    return points, update_op
976
977
978@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the '
979            'labels and predictions arguments has been switched.')
980def streaming_auc(predictions,
981                  labels,
982                  weights=None,
983                  num_thresholds=200,
984                  metrics_collections=None,
985                  updates_collections=None,
986                  curve='ROC',
987                  name=None):
988  """Computes the approximate AUC via a Riemann sum.
989
990  The `streaming_auc` function creates four local variables, `true_positives`,
991  `true_negatives`, `false_positives` and `false_negatives` that are used to
992  compute the AUC. To discretize the AUC curve, a linearly spaced set of
993  thresholds is used to compute pairs of recall and precision values. The area
994  under the ROC-curve is therefore computed using the height of the recall
995  values by the false positive rate, while the area under the PR-curve is the
996  computed using the height of the precision values by the recall.
997
998  This value is ultimately returned as `auc`, an idempotent operation that
999  computes the area under a discretized curve of precision versus recall values
1000  (computed using the aforementioned variables). The `num_thresholds` variable
1001  controls the degree of discretization with larger numbers of thresholds more
1002  closely approximating the true AUC. The quality of the approximation may vary
1003  dramatically depending on `num_thresholds`.
1004
1005  For best results, `predictions` should be distributed approximately uniformly
1006  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
1007  approximation may be poor if this is not the case.
1008
1009  For estimation of the metric over a stream of data, the function creates an
1010  `update_op` operation that updates these variables and returns the `auc`.
1011
1012  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1013
1014  Args:
1015    predictions: A floating point `Tensor` of arbitrary shape and whose values
1016      are in the range `[0, 1]`.
1017    labels: A `bool` `Tensor` whose shape matches `predictions`.
1018    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1019      must be broadcastable to `labels` (i.e., all dimensions must be either
1020      `1`, or the same as the corresponding `labels` dimension).
1021    num_thresholds: The number of thresholds to use when discretizing the roc
1022      curve.
1023    metrics_collections: An optional list of collections that `auc` should be
1024      added to.
1025    updates_collections: An optional list of collections that `update_op` should
1026      be added to.
1027    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
1028    'PR' for the Precision-Recall-curve.
1029    name: An optional variable_scope name.
1030
1031  Returns:
1032    auc: A scalar `Tensor` representing the current area-under-curve.
1033    update_op: An operation that increments the `true_positives`,
1034      `true_negatives`, `false_positives` and `false_negatives` variables
1035      appropriately and whose value matches `auc`.
1036
1037  Raises:
1038    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1039      `weights` is not `None` and its shape doesn't match `predictions`, or if
1040      either `metrics_collections` or `updates_collections` are not a list or
1041      tuple.
1042  """
1043  return metrics.auc(
1044      predictions=predictions,
1045      labels=labels,
1046      weights=weights,
1047      metrics_collections=metrics_collections,
1048      num_thresholds=num_thresholds,
1049      curve=curve,
1050      updates_collections=updates_collections,
1051      name=name)
1052
1053
1054def _compute_dynamic_auc(labels, predictions, curve='ROC'):
1055  """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
1056
1057  Computes the area under the ROC or PR curve using each prediction as a
1058  threshold. This could be slow for large batches, but has the advantage of not
1059  having its results degrade depending on the distribution of predictions.
1060
1061  Args:
1062    labels: A `Tensor` of ground truth labels with the same shape as
1063      `predictions` with values of 0 or 1 and type `int64`.
1064    predictions: A 1-D `Tensor` of predictions whose values are `float64`.
1065    curve: The name of the curve to be computed, 'ROC' for the Receiving
1066      Operating Characteristic or 'PR' for the Precision-Recall curve.
1067
1068  Returns:
1069    A scalar `Tensor` containing the area-under-curve value for the input.
1070  """
1071  # Count the total number of positive and negative labels in the input.
1072  size = array_ops.size(predictions)
1073  total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
1074
1075  def continue_computing_dynamic_auc():
1076    """Continues dynamic auc computation, entered if labels are not all equal.
1077
1078    Returns:
1079      A scalar `Tensor` containing the area-under-curve value.
1080    """
1081    # Sort the predictions descending, and the corresponding labels as well.
1082    ordered_predictions, indices = nn.top_k(predictions, k=size)
1083    ordered_labels = array_ops.gather(labels, indices)
1084
1085    # Get the counts of the unique ordered predictions.
1086    _, _, counts = array_ops.unique_with_counts(ordered_predictions)
1087
1088    # Compute the indices of the split points between different predictions.
1089    splits = math_ops.cast(
1090        array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
1091
1092    # Count the positives to the left of the split indices.
1093    positives = math_ops.cast(
1094        array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
1095        dtypes.int32)
1096    true_positives = array_ops.gather(positives, splits)
1097    if curve == 'ROC':
1098      # Count the negatives to the left of every split point and the total
1099      # number of negatives for computing the FPR.
1100      false_positives = math_ops.subtract(splits, true_positives)
1101      total_negative = size - total_positive
1102      x_axis_values = math_ops.truediv(false_positives, total_negative)
1103      y_axis_values = math_ops.truediv(true_positives, total_positive)
1104    elif curve == 'PR':
1105      x_axis_values = math_ops.truediv(true_positives, total_positive)
1106      # For conformance, set precision to 1 when the number of positive
1107      # classifications is 0.
1108      y_axis_values = array_ops.where(
1109          math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits),
1110          array_ops.ones_like(true_positives, dtype=dtypes.float64))
1111
1112    # Calculate trapezoid areas.
1113    heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
1114    widths = math_ops.abs(
1115        math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
1116    return math_ops.reduce_sum(math_ops.multiply(heights, widths))
1117
1118  # If all the labels are the same, AUC isn't well-defined (but raising an
1119  # exception seems excessive) so we return 0, otherwise we finish computing.
1120  return control_flow_ops.cond(
1121      math_ops.logical_or(
1122          math_ops.equal(total_positive, 0), math_ops.equal(
1123              total_positive, size)),
1124      true_fn=lambda: array_ops.constant(0, dtypes.float64),
1125      false_fn=continue_computing_dynamic_auc)
1126
1127
1128def streaming_dynamic_auc(labels,
1129                          predictions,
1130                          curve='ROC',
1131                          metrics_collections=(),
1132                          updates_collections=(),
1133                          name=None):
1134  """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
1135
1136  USAGE NOTE: this approach requires storing all of the predictions and labels
1137  for a single evaluation in memory, so it may not be usable when the evaluation
1138  batch size and/or the number of evaluation steps is very large.
1139
1140  Computes the area under the ROC or PR curve using each prediction as a
1141  threshold. This has the advantage of being resilient to the distribution of
1142  predictions by aggregating across batches, accumulating labels and predictions
1143  and performing the final calculation using all of the concatenated values.
1144
1145  Args:
1146    labels: A `Tensor` of ground truth labels with the same shape as `labels`
1147      and with values of 0 or 1 whose values are castable to `int64`.
1148    predictions: A `Tensor` of predictions whose values are castable to
1149      `float64`. Will be flattened into a 1-D `Tensor`.
1150    curve: The name of the curve for which to compute AUC, 'ROC' for the
1151      Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
1152    metrics_collections: An optional iterable of collections that `auc` should
1153      be added to.
1154    updates_collections: An optional iterable of collections that `update_op`
1155      should be added to.
1156    name: An optional name for the variable_scope that contains the metric
1157      variables.
1158
1159  Returns:
1160    auc: A scalar `Tensor` containing the current area-under-curve value.
1161    update_op: An operation that concatenates the input labels and predictions
1162      to the accumulated values.
1163
1164  Raises:
1165    ValueError: If `labels` and `predictions` have mismatched shapes or if
1166      `curve` isn't a recognized curve type.
1167  """
1168
1169  if curve not in ['PR', 'ROC']:
1170    raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
1171
1172  with variable_scope.variable_scope(name, default_name='dynamic_auc'):
1173    labels.get_shape().assert_is_compatible_with(predictions.get_shape())
1174    predictions = array_ops.reshape(
1175        math_ops.cast(predictions, dtypes.float64), [-1])
1176    labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
1177    with ops.control_dependencies([
1178        check_ops.assert_greater_equal(
1179            labels,
1180            array_ops.zeros_like(labels, dtypes.int64),
1181            message='labels must be 0 or 1, at least one is <0'),
1182        check_ops.assert_less_equal(
1183            labels,
1184            array_ops.ones_like(labels, dtypes.int64),
1185            message='labels must be 0 or 1, at least one is >1')
1186    ]):
1187      preds_accum, update_preds = streaming_concat(
1188          predictions, name='concat_preds')
1189      labels_accum, update_labels = streaming_concat(
1190          labels, name='concat_labels')
1191      update_op = control_flow_ops.group(update_labels, update_preds)
1192      auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
1193      if updates_collections:
1194        ops.add_to_collections(updates_collections, update_op)
1195      if metrics_collections:
1196        ops.add_to_collections(metrics_collections, auc)
1197      return auc, update_op
1198
1199
1200def _compute_placement_auc(labels, predictions, weights, alpha,
1201                           logit_transformation, is_valid):
1202  """Computes the AUC and asymptotic normally distributed confidence interval.
1203
1204  The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the
1205  concept of placement values for each labeled group, as presented by Delong and
1206  Delong (1988). The actual algorithm used is a more computationally efficient
1207  approach presented by Sun and Xu (2014). This could be slow for large batches,
1208  but has the advantage of not having its results degrade depending on the
1209  distribution of predictions.
1210
1211  Args:
1212    labels: A `Tensor` of ground truth labels with the same shape as
1213      `predictions` with values of 0 or 1 and type `int64`.
1214    predictions: A 1-D `Tensor` of predictions whose values are `float64`.
1215    weights: `Tensor` whose rank is either 0, or the same rank as `labels`.
1216    alpha: Confidence interval level desired.
1217    logit_transformation: A boolean value indicating whether the estimate should
1218      be logit transformed prior to calculating the confidence interval. Doing
1219      so enforces the restriction that the AUC should never be outside the
1220      interval [0,1].
1221    is_valid: A bool tensor describing whether the input is valid.
1222
1223  Returns:
1224    A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence
1225    interval values.
1226  """
1227  # Disable the invalid-name checker so that we can capitalize the name.
1228  # pylint: disable=invalid-name
1229  AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper'])
1230  # pylint: enable=invalid-name
1231
1232  # If all the labels are the same or if number of observations are too few,
1233  # AUC isn't well-defined
1234  size = array_ops.size(predictions, out_type=dtypes.int32)
1235
1236  # Count the total number of positive and negative labels in the input.
1237  total_0 = math_ops.reduce_sum(
1238      math_ops.cast(1 - labels, weights.dtype) * weights)
1239  total_1 = math_ops.reduce_sum(
1240      math_ops.cast(labels, weights.dtype) * weights)
1241
1242  # Sort the predictions ascending, as well as
1243  # (i) the corresponding labels and
1244  # (ii) the corresponding weights.
1245  ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True)
1246  ordered_predictions = array_ops.reverse(
1247      ordered_predictions, axis=array_ops.zeros(1, dtypes.int32))
1248  indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32))
1249  ordered_labels = array_ops.gather(labels, indices)
1250  ordered_weights = array_ops.gather(weights, indices)
1251
1252  # We now compute values required for computing placement values.
1253
1254  # We generate a list of indices (segmented_indices) of increasing order. An
1255  # index is assigned for each unique prediction float value. Prediction
1256  # values that are the same share the same index.
1257  _, segmented_indices = array_ops.unique(ordered_predictions)
1258
1259  # We create 2 tensors of weights. weights_for_true is non-zero for true
1260  # labels. weights_for_false is non-zero for false labels.
1261  float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32)
1262  float_labels_for_false = 1.0 - float_labels_for_true
1263  weights_for_true = ordered_weights * float_labels_for_true
1264  weights_for_false = ordered_weights * float_labels_for_false
1265
1266 # For each set of weights with the same segmented indices, we add up the
1267  # weight values. Note that for each label, we deliberately rely on weights
1268  # for the opposite label.
1269  weight_totals_for_true = math_ops.segment_sum(weights_for_false,
1270                                                segmented_indices)
1271  weight_totals_for_false = math_ops.segment_sum(weights_for_true,
1272                                                 segmented_indices)
1273
1274  # These cumulative sums of weights importantly exclude the current weight
1275  # sums.
1276  cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true,
1277                                               exclusive=True)
1278  cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false,
1279                                                exclusive=True)
1280
1281  # Compute placement values using the formula. Values with the same segmented
1282  # indices and labels share the same placement values.
1283  placements_for_true = (
1284      (cum_weight_totals_for_true + weight_totals_for_true / 2.0) /
1285      (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON))
1286  placements_for_false = (
1287      (cum_weight_totals_for_false + weight_totals_for_false / 2.0) /
1288      (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON))
1289
1290  # We expand the tensors of placement values (for each label) so that their
1291  # shapes match that of predictions.
1292  placements_for_true = array_ops.gather(placements_for_true, segmented_indices)
1293  placements_for_false = array_ops.gather(placements_for_false,
1294                                          segmented_indices)
1295
1296  # Select placement values based on the label for each index.
1297  placement_values = (
1298      placements_for_true * float_labels_for_true +
1299      placements_for_false * float_labels_for_false)
1300
1301  # Split placement values by labeled groups.
1302  placement_values_0 = placement_values * math_ops.cast(
1303      1 - ordered_labels, weights.dtype)
1304  weights_0 = ordered_weights * math_ops.cast(
1305      1 - ordered_labels, weights.dtype)
1306  placement_values_1 = placement_values * math_ops.cast(
1307      ordered_labels, weights.dtype)
1308  weights_1 = ordered_weights * math_ops.cast(
1309      ordered_labels, weights.dtype)
1310
1311  # Calculate AUC using placement values
1312  auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) /
1313           (total_0 + _EPSILON))
1314  auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) /
1315           (total_1 + _EPSILON))
1316  auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0)
1317
1318  # Calculate variance and standard error using the placement values.
1319  var_0 = (
1320      math_ops.reduce_sum(
1321          weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) /
1322      (total_0 - 1. + _EPSILON))
1323  var_1 = (
1324      math_ops.reduce_sum(
1325          weights_1 * math_ops.square(placement_values_1 - auc_1)) /
1326      (total_1 - 1. + _EPSILON))
1327  auc_std_err = math_ops.sqrt(
1328      (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON)))
1329
1330  # Calculate asymptotic normal confidence intervals
1331  std_norm_dist = Normal(loc=0., scale=1.)
1332  z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0)
1333  if logit_transformation:
1334    estimate = math_ops.log(auc / (1. - auc + _EPSILON))
1335    std_err = auc_std_err / (auc * (1. - auc + _EPSILON))
1336    transformed_auc_lower = estimate + (z_value * std_err)
1337    transformed_auc_upper = estimate - (z_value * std_err)
1338    def inverse_logit_transformation(x):
1339      exp_negative = math_ops.exp(math_ops.negative(x))
1340      return 1. / (1. + exp_negative + _EPSILON)
1341
1342    auc_lower = inverse_logit_transformation(transformed_auc_lower)
1343    auc_upper = inverse_logit_transformation(transformed_auc_upper)
1344  else:
1345    estimate = auc
1346    std_err = auc_std_err
1347    auc_lower = estimate + (z_value * std_err)
1348    auc_upper = estimate - (z_value * std_err)
1349
1350  ## If estimate is 1 or 0, no variance is present so CI = 1
1351  ## n.b. This can be misleading, since number obs can just be too low.
1352  lower = array_ops.where(
1353      math_ops.logical_or(
1354          math_ops.equal(auc, array_ops.ones_like(auc)),
1355          math_ops.equal(auc, array_ops.zeros_like(auc))),
1356      auc, auc_lower)
1357  upper = array_ops.where(
1358      math_ops.logical_or(
1359          math_ops.equal(auc, array_ops.ones_like(auc)),
1360          math_ops.equal(auc, array_ops.zeros_like(auc))),
1361      auc, auc_upper)
1362
1363  # If all the labels are the same, AUC isn't well-defined (but raising an
1364  # exception seems excessive) so we return 0, otherwise we finish computing.
1365  trivial_value = array_ops.constant(0.0)
1366
1367  return AucData(*control_flow_ops.cond(
1368      is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3))
1369
1370
1371def auc_with_confidence_intervals(labels,
1372                                  predictions,
1373                                  weights=None,
1374                                  alpha=0.95,
1375                                  logit_transformation=True,
1376                                  metrics_collections=(),
1377                                  updates_collections=(),
1378                                  name=None):
1379  """Computes the AUC and asymptotic normally distributed confidence interval.
1380
1381  USAGE NOTE: this approach requires storing all of the predictions and labels
1382  for a single evaluation in memory, so it may not be usable when the evaluation
1383  batch size and/or the number of evaluation steps is very large.
1384
1385  Computes the area under the ROC curve and its confidence interval using
1386  placement values. This has the advantage of being resilient to the
1387  distribution of predictions by aggregating across batches, accumulating labels
1388  and predictions and performing the final calculation using all of the
1389  concatenated values.
1390
1391  Args:
1392    labels: A `Tensor` of ground truth labels with the same shape as `labels`
1393      and with values of 0 or 1 whose values are castable to `int64`.
1394    predictions: A `Tensor` of predictions whose values are castable to
1395      `float64`. Will be flattened into a 1-D `Tensor`.
1396    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1397      `labels`.
1398    alpha: Confidence interval level desired.
1399    logit_transformation: A boolean value indicating whether the estimate should
1400      be logit transformed prior to calculating the confidence interval. Doing
1401      so enforces the restriction that the AUC should never be outside the
1402      interval [0,1].
1403    metrics_collections: An optional iterable of collections that `auc` should
1404      be added to.
1405    updates_collections: An optional iterable of collections that `update_op`
1406      should be added to.
1407    name: An optional name for the variable_scope that contains the metric
1408      variables.
1409
1410  Returns:
1411    auc: A 1-D `Tensor` containing the current area-under-curve, lower, and
1412      upper confidence interval values.
1413    update_op: An operation that concatenates the input labels and predictions
1414      to the accumulated values.
1415
1416  Raises:
1417    ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes
1418    or if `alpha` isn't in the range (0,1).
1419  """
1420  if not (alpha > 0 and alpha < 1):
1421    raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha)
1422
1423  if weights is None:
1424    weights = array_ops.ones_like(predictions)
1425
1426  with variable_scope.variable_scope(
1427      name,
1428      default_name='auc_with_confidence_intervals',
1429      values=[labels, predictions, weights]):
1430
1431    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
1432        predictions=predictions,
1433        labels=labels,
1434        weights=weights)
1435
1436    total_weight = math_ops.reduce_sum(weights)
1437
1438    weights = array_ops.reshape(weights, [-1])
1439    predictions = array_ops.reshape(
1440        math_ops.cast(predictions, dtypes.float64), [-1])
1441    labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
1442
1443    with ops.control_dependencies([
1444        check_ops.assert_greater_equal(
1445            labels,
1446            array_ops.zeros_like(labels, dtypes.int64),
1447            message='labels must be 0 or 1, at least one is <0'),
1448        check_ops.assert_less_equal(
1449            labels,
1450            array_ops.ones_like(labels, dtypes.int64),
1451            message='labels must be 0 or 1, at least one is >1'),
1452    ]):
1453      preds_accum, update_preds = streaming_concat(
1454          predictions, name='concat_preds')
1455      labels_accum, update_labels = streaming_concat(labels,
1456                                                     name='concat_labels')
1457      weights_accum, update_weights = streaming_concat(
1458          weights, name='concat_weights')
1459      update_op_for_valid_case = control_flow_ops.group(
1460          update_labels, update_preds, update_weights)
1461
1462      # Only perform updates if this case is valid.
1463      all_labels_positive_or_0 = math_ops.logical_and(
1464          math_ops.equal(math_ops.reduce_min(labels), 0),
1465          math_ops.equal(math_ops.reduce_max(labels), 1))
1466      sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0)
1467      is_valid = math_ops.logical_and(all_labels_positive_or_0,
1468                                      sums_of_weights_at_least_1)
1469
1470      update_op = control_flow_ops.cond(
1471          sums_of_weights_at_least_1,
1472          lambda: update_op_for_valid_case, control_flow_ops.no_op)
1473
1474      auc = _compute_placement_auc(
1475          labels_accum,
1476          preds_accum,
1477          weights_accum,
1478          alpha=alpha,
1479          logit_transformation=logit_transformation,
1480          is_valid=is_valid)
1481
1482      if updates_collections:
1483        ops.add_to_collections(updates_collections, update_op)
1484      if metrics_collections:
1485        ops.add_to_collections(metrics_collections, auc)
1486      return auc, update_op
1487
1488
1489def precision_recall_at_equal_thresholds(labels,
1490                                         predictions,
1491                                         weights=None,
1492                                         num_thresholds=None,
1493                                         use_locking=None,
1494                                         name=None):
1495  """A helper method for creating metrics related to precision-recall curves.
1496
1497  These values are true positives, false negatives, true negatives, false
1498  positives, precision, and recall. This function returns a data structure that
1499  contains ops within it.
1500
1501  Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N)
1502  space and run time), this op exhibits O(T + N) space and run time, where T is
1503  the number of thresholds and N is the size of the predictions tensor. Hence,
1504  it may be advantageous to use this function when `predictions` is big.
1505
1506  For instance, prefer this method for per-pixel classification tasks, for which
1507  the predictions tensor may be very large.
1508
1509  Each number in `predictions`, a float in `[0, 1]`, is compared with its
1510  corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at
1511  each threshold. This is then multiplied with `weights` which can be used to
1512  reweight certain values, or more commonly used for masking values.
1513
1514  Args:
1515    labels: A bool `Tensor` whose shape matches `predictions`.
1516    predictions: A floating point `Tensor` of arbitrary shape and whose values
1517      are in the range `[0, 1]`.
1518    weights: Optional; If provided, a `Tensor` that has the same dtype as,
1519      and broadcastable to, `predictions`. This tensor is multplied by counts.
1520    num_thresholds: Optional; Number of thresholds, evenly distributed in
1521      `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins
1522      is 1 less than `num_thresholds`. Using an even `num_thresholds` value
1523      instead of an odd one may yield unfriendly edges for bins.
1524    use_locking: Optional; If True, the op will be protected by a lock.
1525      Otherwise, the behavior is undefined, but may exhibit less contention.
1526      Defaults to True.
1527    name: Optional; variable_scope name. If not provided, the string
1528      'precision_recall_at_equal_threshold' is used.
1529
1530  Returns:
1531    result: A named tuple (See PrecisionRecallData within the implementation of
1532      this function) with properties that are variables of shape
1533      `[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
1534      precision, recall, thresholds.
1535    update_op: An op that accumulates values.
1536
1537  Raises:
1538    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1539      `weights` is not `None` and its shape doesn't match `predictions`, or if
1540      `includes` contains invalid keys.
1541  """
1542  # Disable the invalid-name checker so that we can capitalize the name.
1543  # pylint: disable=invalid-name
1544  PrecisionRecallData = collections_lib.namedtuple(
1545      'PrecisionRecallData',
1546      ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds'])
1547  # pylint: enable=invalid-name
1548
1549  if num_thresholds is None:
1550    num_thresholds = 201
1551
1552  if weights is None:
1553    weights = 1.0
1554
1555  if use_locking is None:
1556    use_locking = True
1557
1558  check_ops.assert_type(labels, dtypes.bool)
1559
1560  dtype = predictions.dtype
1561  with variable_scope.variable_scope(name,
1562                                     'precision_recall_at_equal_thresholds',
1563                                     (labels, predictions, weights)):
1564    # Make sure that predictions are within [0.0, 1.0].
1565    with ops.control_dependencies([
1566        check_ops.assert_greater_equal(
1567            predictions,
1568            math_ops.cast(0.0, dtype=predictions.dtype),
1569            message='predictions must be in [0, 1]'),
1570        check_ops.assert_less_equal(
1571            predictions,
1572            math_ops.cast(1.0, dtype=predictions.dtype),
1573            message='predictions must be in [0, 1]')
1574    ]):
1575      predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
1576          predictions=predictions,
1577          labels=labels,
1578          weights=weights)
1579
1580    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1581
1582    # We cast to float to ensure we have 0.0 or 1.0.
1583    f_labels = math_ops.cast(labels, dtype)
1584
1585    # Get weighted true/false labels.
1586    true_labels = f_labels * weights
1587    false_labels = (1.0 - f_labels) * weights
1588
1589    # Flatten predictions and labels.
1590    predictions = array_ops.reshape(predictions, [-1])
1591    true_labels = array_ops.reshape(true_labels, [-1])
1592    false_labels = array_ops.reshape(false_labels, [-1])
1593
1594    # To compute TP/FP/TN/FN, we are measuring a binary classifier
1595    #   C(t) = (predictions >= t)
1596    # at each threshold 't'. So we have
1597    #   TP(t) = sum( C(t) * true_labels )
1598    #   FP(t) = sum( C(t) * false_labels )
1599    #
1600    # But, computing C(t) requires computation for each t. To make it fast,
1601    # observe that C(t) is a cumulative integral, and so if we have
1602    #   thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
1603    # where n = num_thresholds, and if we can compute the bucket function
1604    #   B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
1605    # then we get
1606    #   C(t_i) = sum( B(j), j >= i )
1607    # which is the reversed cumulative sum in tf.cumsum().
1608    #
1609    # We can compute B(i) efficiently by taking advantage of the fact that
1610    # our thresholds are evenly distributed, in that
1611    #   width = 1.0 / (num_thresholds - 1)
1612    #   thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
1613    # Given a prediction value p, we can map it to its bucket by
1614    #   bucket_index(p) = floor( p * (num_thresholds - 1) )
1615    # so we can use tf.scatter_add() to update the buckets in one pass.
1616    #
1617    # This implementation exhibits a run time and space complexity of O(T + N),
1618    # where T is the number of thresholds and N is the size of predictions.
1619    # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead
1620    # exhibit a complexity of O(T * N).
1621
1622    # Compute the bucket indices for each prediction value.
1623    bucket_indices = math_ops.cast(
1624        math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32)
1625
1626    with ops.name_scope('variables'):
1627      tp_buckets_v = metrics_impl.metric_variable(
1628          [num_thresholds], dtype, name='tp_buckets')
1629      fp_buckets_v = metrics_impl.metric_variable(
1630          [num_thresholds], dtype, name='fp_buckets')
1631
1632    with ops.name_scope('update_op'):
1633      update_tp = state_ops.scatter_add(
1634          tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking)
1635      update_fp = state_ops.scatter_add(
1636          fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking)
1637
1638    # Set up the cumulative sums to compute the actual metrics.
1639    tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp')
1640    fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp')
1641    # fn = sum(true_labels) - tp
1642    #    = sum(tp_buckets) - tp
1643    #    = tp[0] - tp
1644    # Similarly,
1645    # tn = fp[0] - fp
1646    tn = fp[0] - fp
1647    fn = tp[0] - tp
1648
1649    # We use a minimum to prevent division by 0.
1650    epsilon = 1e-7
1651    precision = tp / math_ops.maximum(epsilon, tp + fp)
1652    recall = tp / math_ops.maximum(epsilon, tp + fn)
1653
1654    result = PrecisionRecallData(
1655        tp=tp,
1656        fp=fp,
1657        tn=tn,
1658        fn=fn,
1659        precision=precision,
1660        recall=recall,
1661        thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds))
1662    update_op = control_flow_ops.group(update_tp, update_fp)
1663    return result, update_op
1664
1665
1666def streaming_specificity_at_sensitivity(predictions,
1667                                         labels,
1668                                         sensitivity,
1669                                         weights=None,
1670                                         num_thresholds=200,
1671                                         metrics_collections=None,
1672                                         updates_collections=None,
1673                                         name=None):
1674  """Computes the specificity at a given sensitivity.
1675
1676  The `streaming_specificity_at_sensitivity` function creates four local
1677  variables, `true_positives`, `true_negatives`, `false_positives` and
1678  `false_negatives` that are used to compute the specificity at the given
1679  sensitivity value. The threshold for the given sensitivity value is computed
1680  and used to evaluate the corresponding specificity.
1681
1682  For estimation of the metric over a stream of data, the function creates an
1683  `update_op` operation that updates these variables and returns the
1684  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
1685  `false_positives` and `false_negatives` counts with the weight of each case
1686  found in the `predictions` and `labels`.
1687
1688  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1689
1690  For additional information about specificity and sensitivity, see the
1691  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1692
1693  Args:
1694    predictions: A floating point `Tensor` of arbitrary shape and whose values
1695      are in the range `[0, 1]`.
1696    labels: A `bool` `Tensor` whose shape matches `predictions`.
1697    sensitivity: A scalar value in range `[0, 1]`.
1698    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1699      must be broadcastable to `labels` (i.e., all dimensions must be either
1700      `1`, or the same as the corresponding `labels` dimension).
1701    num_thresholds: The number of thresholds to use for matching the given
1702      sensitivity.
1703    metrics_collections: An optional list of collections that `specificity`
1704      should be added to.
1705    updates_collections: An optional list of collections that `update_op` should
1706      be added to.
1707    name: An optional variable_scope name.
1708
1709  Returns:
1710    specificity: A scalar `Tensor` representing the specificity at the given
1711      `specificity` value.
1712    update_op: An operation that increments the `true_positives`,
1713      `true_negatives`, `false_positives` and `false_negatives` variables
1714      appropriately and whose value matches `specificity`.
1715
1716  Raises:
1717    ValueError: If `predictions` and `labels` have mismatched shapes, if
1718      `weights` is not `None` and its shape doesn't match `predictions`, or if
1719      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
1720      or `updates_collections` are not a list or tuple.
1721  """
1722  return metrics.specificity_at_sensitivity(
1723      sensitivity=sensitivity,
1724      num_thresholds=num_thresholds,
1725      predictions=predictions,
1726      labels=labels,
1727      weights=weights,
1728      metrics_collections=metrics_collections,
1729      updates_collections=updates_collections,
1730      name=name)
1731
1732
1733def streaming_sensitivity_at_specificity(predictions,
1734                                         labels,
1735                                         specificity,
1736                                         weights=None,
1737                                         num_thresholds=200,
1738                                         metrics_collections=None,
1739                                         updates_collections=None,
1740                                         name=None):
1741  """Computes the sensitivity at a given specificity.
1742
1743  The `streaming_sensitivity_at_specificity` function creates four local
1744  variables, `true_positives`, `true_negatives`, `false_positives` and
1745  `false_negatives` that are used to compute the sensitivity at the given
1746  specificity value. The threshold for the given specificity value is computed
1747  and used to evaluate the corresponding sensitivity.
1748
1749  For estimation of the metric over a stream of data, the function creates an
1750  `update_op` operation that updates these variables and returns the
1751  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
1752  `false_positives` and `false_negatives` counts with the weight of each case
1753  found in the `predictions` and `labels`.
1754
1755  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1756
1757  For additional information about specificity and sensitivity, see the
1758  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1759
1760  Args:
1761    predictions: A floating point `Tensor` of arbitrary shape and whose values
1762      are in the range `[0, 1]`.
1763    labels: A `bool` `Tensor` whose shape matches `predictions`.
1764    specificity: A scalar value in range `[0, 1]`.
1765    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1766      must be broadcastable to `labels` (i.e., all dimensions must be either
1767      `1`, or the same as the corresponding `labels` dimension).
1768    num_thresholds: The number of thresholds to use for matching the given
1769      specificity.
1770    metrics_collections: An optional list of collections that `sensitivity`
1771      should be added to.
1772    updates_collections: An optional list of collections that `update_op` should
1773      be added to.
1774    name: An optional variable_scope name.
1775
1776  Returns:
1777    sensitivity: A scalar `Tensor` representing the sensitivity at the given
1778      `specificity` value.
1779    update_op: An operation that increments the `true_positives`,
1780      `true_negatives`, `false_positives` and `false_negatives` variables
1781      appropriately and whose value matches `sensitivity`.
1782
1783  Raises:
1784    ValueError: If `predictions` and `labels` have mismatched shapes, if
1785      `weights` is not `None` and its shape doesn't match `predictions`, or if
1786      `specificity` is not between 0 and 1, or if either `metrics_collections`
1787      or `updates_collections` are not a list or tuple.
1788  """
1789  return metrics.sensitivity_at_specificity(
1790      specificity=specificity,
1791      num_thresholds=num_thresholds,
1792      predictions=predictions,
1793      labels=labels,
1794      weights=weights,
1795      metrics_collections=metrics_collections,
1796      updates_collections=updates_collections,
1797      name=name)
1798
1799
1800@deprecated(
1801    None, 'Please switch to tf.metrics.precision_at_thresholds. Note that the '
1802    'order of the labels and predictions arguments has been switched.')
1803def streaming_precision_at_thresholds(predictions,
1804                                      labels,
1805                                      thresholds,
1806                                      weights=None,
1807                                      metrics_collections=None,
1808                                      updates_collections=None,
1809                                      name=None):
1810  """Computes precision values for different `thresholds` on `predictions`.
1811
1812  The `streaming_precision_at_thresholds` function creates four local variables,
1813  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1814  for various values of thresholds. `precision[i]` is defined as the total
1815  weight of values in `predictions` above `thresholds[i]` whose corresponding
1816  entry in `labels` is `True`, divided by the total weight of values in
1817  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
1818  false_positives[i])`).
1819
1820  For estimation of the metric over a stream of data, the function creates an
1821  `update_op` operation that updates these variables and returns the
1822  `precision`.
1823
1824  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1825
1826  Args:
1827    predictions: A floating point `Tensor` of arbitrary shape and whose values
1828      are in the range `[0, 1]`.
1829    labels: A `bool` `Tensor` whose shape matches `predictions`.
1830    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1831    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1832      must be broadcastable to `labels` (i.e., all dimensions must be either
1833      `1`, or the same as the corresponding `labels` dimension).
1834    metrics_collections: An optional list of collections that `precision` should
1835      be added to.
1836    updates_collections: An optional list of collections that `update_op` should
1837      be added to.
1838    name: An optional variable_scope name.
1839
1840  Returns:
1841    precision: A float `Tensor` of shape `[len(thresholds)]`.
1842    update_op: An operation that increments the `true_positives`,
1843      `true_negatives`, `false_positives` and `false_negatives` variables that
1844      are used in the computation of `precision`.
1845
1846  Raises:
1847    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1848      `weights` is not `None` and its shape doesn't match `predictions`, or if
1849      either `metrics_collections` or `updates_collections` are not a list or
1850      tuple.
1851  """
1852  return metrics.precision_at_thresholds(
1853      thresholds=thresholds,
1854      predictions=predictions,
1855      labels=labels,
1856      weights=weights,
1857      metrics_collections=metrics_collections,
1858      updates_collections=updates_collections,
1859      name=name)
1860
1861
1862@deprecated(None,
1863            'Please switch to tf.metrics.recall_at_thresholds. Note that the '
1864            'order of the labels and predictions arguments has been switched.')
1865def streaming_recall_at_thresholds(predictions,
1866                                   labels,
1867                                   thresholds,
1868                                   weights=None,
1869                                   metrics_collections=None,
1870                                   updates_collections=None,
1871                                   name=None):
1872  """Computes various recall values for different `thresholds` on `predictions`.
1873
1874  The `streaming_recall_at_thresholds` function creates four local variables,
1875  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1876  for various values of thresholds. `recall[i]` is defined as the total weight
1877  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1878  `labels` is `True`, divided by the total weight of `True` values in `labels`
1879  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
1880
1881  For estimation of the metric over a stream of data, the function creates an
1882  `update_op` operation that updates these variables and returns the `recall`.
1883
1884  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1885
1886  Args:
1887    predictions: A floating point `Tensor` of arbitrary shape and whose values
1888      are in the range `[0, 1]`.
1889    labels: A `bool` `Tensor` whose shape matches `predictions`.
1890    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1891    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1892      must be broadcastable to `labels` (i.e., all dimensions must be either
1893      `1`, or the same as the corresponding `labels` dimension).
1894    metrics_collections: An optional list of collections that `recall` should be
1895      added to.
1896    updates_collections: An optional list of collections that `update_op` should
1897      be added to.
1898    name: An optional variable_scope name.
1899
1900  Returns:
1901    recall: A float `Tensor` of shape `[len(thresholds)]`.
1902    update_op: An operation that increments the `true_positives`,
1903      `true_negatives`, `false_positives` and `false_negatives` variables that
1904      are used in the computation of `recall`.
1905
1906  Raises:
1907    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1908      `weights` is not `None` and its shape doesn't match `predictions`, or if
1909      either `metrics_collections` or `updates_collections` are not a list or
1910      tuple.
1911  """
1912  return metrics.recall_at_thresholds(
1913      thresholds=thresholds,
1914      predictions=predictions,
1915      labels=labels,
1916      weights=weights,
1917      metrics_collections=metrics_collections,
1918      updates_collections=updates_collections,
1919      name=name)
1920
1921
1922def streaming_false_positive_rate_at_thresholds(predictions,
1923                                                labels,
1924                                                thresholds,
1925                                                weights=None,
1926                                                metrics_collections=None,
1927                                                updates_collections=None,
1928                                                name=None):
1929  """Computes various fpr values for different `thresholds` on `predictions`.
1930
1931  The `streaming_false_positive_rate_at_thresholds` function creates two
1932  local variables, `false_positives`, `true_negatives`, for various values of
1933  thresholds. `false_positive_rate[i]` is defined as the total weight
1934  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1935  `labels` is `False`, divided by the total weight of `False` values in `labels`
1936  (`false_positives[i] / (false_positives[i] + true_negatives[i])`).
1937
1938  For estimation of the metric over a stream of data, the function creates an
1939  `update_op` operation that updates these variables and returns the
1940  `false_positive_rate`.
1941
1942  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1943
1944  Args:
1945    predictions: A floating point `Tensor` of arbitrary shape and whose values
1946      are in the range `[0, 1]`.
1947    labels: A `bool` `Tensor` whose shape matches `predictions`.
1948    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1949    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1950      must be broadcastable to `labels` (i.e., all dimensions must be either
1951      `1`, or the same as the corresponding `labels` dimension).
1952    metrics_collections: An optional list of collections that
1953      `false_positive_rate` should be added to.
1954    updates_collections: An optional list of collections that `update_op` should
1955      be added to.
1956    name: An optional variable_scope name.
1957
1958  Returns:
1959    false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`.
1960    update_op: An operation that increments the `false_positives` and
1961      `true_negatives` variables that are used in the computation of
1962      `false_positive_rate`.
1963
1964  Raises:
1965    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1966      `weights` is not `None` and its shape doesn't match `predictions`, or if
1967      either `metrics_collections` or `updates_collections` are not a list or
1968      tuple.
1969  """
1970  with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds',
1971                                     (predictions, labels, weights)):
1972    values, update_ops = _streaming_confusion_matrix_at_thresholds(
1973        predictions, labels, thresholds, weights, includes=('fp', 'tn'))
1974
1975    # Avoid division by zero.
1976    epsilon = _EPSILON
1977
1978    def compute_fpr(fp, tn, name):
1979      return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name)
1980
1981    fpr = compute_fpr(values['fp'], values['tn'], 'value')
1982    update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op')
1983
1984    if metrics_collections:
1985      ops.add_to_collections(metrics_collections, fpr)
1986
1987    if updates_collections:
1988      ops.add_to_collections(updates_collections, update_op)
1989
1990    return fpr, update_op
1991
1992
1993def streaming_false_negative_rate_at_thresholds(predictions,
1994                                                labels,
1995                                                thresholds,
1996                                                weights=None,
1997                                                metrics_collections=None,
1998                                                updates_collections=None,
1999                                                name=None):
2000  """Computes various fnr values for different `thresholds` on `predictions`.
2001
2002  The `streaming_false_negative_rate_at_thresholds` function creates two
2003  local variables, `false_negatives`, `true_positives`, for various values of
2004  thresholds. `false_negative_rate[i]` is defined as the total weight
2005  of values in `predictions` above `thresholds[i]` whose corresponding entry in
2006  `labels` is `False`, divided by the total weight of `True` values in `labels`
2007  (`false_negatives[i] / (false_negatives[i] + true_positives[i])`).
2008
2009  For estimation of the metric over a stream of data, the function creates an
2010  `update_op` operation that updates these variables and returns the
2011  `false_positive_rate`.
2012
2013  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2014
2015  Args:
2016    predictions: A floating point `Tensor` of arbitrary shape and whose values
2017      are in the range `[0, 1]`.
2018    labels: A `bool` `Tensor` whose shape matches `predictions`.
2019    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2020    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
2021      must be broadcastable to `labels` (i.e., all dimensions must be either
2022      `1`, or the same as the corresponding `labels` dimension).
2023    metrics_collections: An optional list of collections that
2024      `false_negative_rate` should be added to.
2025    updates_collections: An optional list of collections that `update_op` should
2026      be added to.
2027    name: An optional variable_scope name.
2028
2029  Returns:
2030    false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`.
2031    update_op: An operation that increments the `false_negatives` and
2032      `true_positives` variables that are used in the computation of
2033      `false_negative_rate`.
2034
2035  Raises:
2036    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2037      `weights` is not `None` and its shape doesn't match `predictions`, or if
2038      either `metrics_collections` or `updates_collections` are not a list or
2039      tuple.
2040  """
2041  with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds',
2042                                     (predictions, labels, weights)):
2043    values, update_ops = _streaming_confusion_matrix_at_thresholds(
2044        predictions, labels, thresholds, weights, includes=('fn', 'tp'))
2045
2046    # Avoid division by zero.
2047    epsilon = _EPSILON
2048
2049    def compute_fnr(fn, tp, name):
2050      return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name)
2051
2052    fnr = compute_fnr(values['fn'], values['tp'], 'value')
2053    update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op')
2054
2055    if metrics_collections:
2056      ops.add_to_collections(metrics_collections, fnr)
2057
2058    if updates_collections:
2059      ops.add_to_collections(updates_collections, update_op)
2060
2061    return fnr, update_op
2062
2063
2064def _at_k_name(name, k=None, class_id=None):
2065  if k is not None:
2066    name = '%s_at_%d' % (name, k)
2067  else:
2068    name = '%s_at_k' % (name)
2069  if class_id is not None:
2070    name = '%s_class%d' % (name, class_id)
2071  return name
2072
2073
2074@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
2075            'and reshape labels from [batch_size] to [batch_size, 1].')
2076def streaming_recall_at_k(predictions,
2077                          labels,
2078                          k,
2079                          weights=None,
2080                          metrics_collections=None,
2081                          updates_collections=None,
2082                          name=None):
2083  """Computes the recall@k of the predictions with respect to dense labels.
2084
2085  The `streaming_recall_at_k` function creates two local variables, `total` and
2086  `count`, that are used to compute the recall@k frequency. This frequency is
2087  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
2088  divides `total` by `count`.
2089
2090  For estimation of the metric over a stream of data, the function creates an
2091  `update_op` operation that updates these variables and returns the
2092  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
2093  shape [batch_size] whose elements indicate whether or not the corresponding
2094  label is in the top `k` `predictions`. Then `update_op` increments `total`
2095  with the reduced sum of `weights` where `in_top_k` is `True`, and it
2096  increments `count` with the reduced sum of `weights`.
2097
2098  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2099
2100  Args:
2101    predictions: A float `Tensor` of dimension [batch_size, num_classes].
2102    labels: A `Tensor` of dimension [batch_size] whose type is in `int32`,
2103      `int64`.
2104    k: The number of top elements to look at for computing recall.
2105    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
2106      must be broadcastable to `labels` (i.e., all dimensions must be either
2107      `1`, or the same as the corresponding `labels` dimension).
2108    metrics_collections: An optional list of collections that `recall_at_k`
2109      should be added to.
2110    updates_collections: An optional list of collections `update_op` should be
2111      added to.
2112    name: An optional variable_scope name.
2113
2114  Returns:
2115    recall_at_k: A `Tensor` representing the recall@k, the fraction of labels
2116      which fall into the top `k` predictions.
2117    update_op: An operation that increments the `total` and `count` variables
2118      appropriately and whose value matches `recall_at_k`.
2119
2120  Raises:
2121    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2122      `weights` is not `None` and its shape doesn't match `predictions`, or if
2123      either `metrics_collections` or `updates_collections` are not a list or
2124      tuple.
2125  """
2126  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
2127  return streaming_mean(in_top_k, weights, metrics_collections,
2128                        updates_collections, name or _at_k_name('recall', k))
2129
2130
2131# TODO(ptucker): Validate range of values in labels?
2132def streaming_sparse_recall_at_k(predictions,
2133                                 labels,
2134                                 k,
2135                                 class_id=None,
2136                                 weights=None,
2137                                 metrics_collections=None,
2138                                 updates_collections=None,
2139                                 name=None):
2140  """Computes recall@k of the predictions with respect to sparse labels.
2141
2142  If `class_id` is not specified, we'll calculate recall as the ratio of true
2143      positives (i.e., correct predictions, items in the top `k` highest
2144      `predictions` that are found in the corresponding row in `labels`) to
2145      actual positives (the full `labels` row).
2146  If `class_id` is specified, we calculate recall by considering only the rows
2147      in the batch for which `class_id` is in `labels`, and computing the
2148      fraction of them for which `class_id` is in the corresponding row in
2149      `labels`.
2150
2151  `streaming_sparse_recall_at_k` creates two local variables,
2152  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2153  the recall_at_k frequency. This frequency is ultimately returned as
2154  `recall_at_<k>`: an idempotent operation that simply divides
2155  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2156  `false_negative_at_<k>`).
2157
2158  For estimation of the metric over a stream of data, the function creates an
2159  `update_op` operation that updates these variables and returns the
2160  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2161  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2162  `labels` calculate the true positives and false negatives weighted by
2163  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2164  `false_negative_at_<k>` using these values.
2165
2166  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2167
2168  Args:
2169    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2170      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2171      The final dimension contains the logit values for each class. [D1, ... DN]
2172      must match `labels`.
2173    labels: `int64` `Tensor` or `SparseTensor` with shape
2174      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2175      target classes for the associated prediction. Commonly, N=1 and `labels`
2176      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
2177      Values should be in range [0, num_classes), where num_classes is the last
2178      dimension of `predictions`. Values outside this range always count
2179      towards `false_negative_at_<k>`.
2180    k: Integer, k for @k metric.
2181    class_id: Integer class ID for which we want binary metrics. This should be
2182      in range [0, num_classes), where num_classes is the last dimension of
2183      `predictions`. If class_id is outside this range, the method returns NAN.
2184    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2185      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2186      dimensions must be either `1`, or the same as the corresponding `labels`
2187      dimension).
2188    metrics_collections: An optional list of collections that values should
2189      be added to.
2190    updates_collections: An optional list of collections that updates should
2191      be added to.
2192    name: Name of new update operation, and namespace for other dependent ops.
2193
2194  Returns:
2195    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2196      by the sum of `true_positives` and `false_negatives`.
2197    update_op: `Operation` that increments `true_positives` and
2198      `false_negatives` variables appropriately, and whose value matches
2199      `recall`.
2200
2201  Raises:
2202    ValueError: If `weights` is not `None` and its shape doesn't match
2203    `predictions`, or if either `metrics_collections` or `updates_collections`
2204    are not a list or tuple.
2205  """
2206  return metrics.recall_at_k(
2207      k=k,
2208      class_id=class_id,
2209      predictions=predictions,
2210      labels=labels,
2211      weights=weights,
2212      metrics_collections=metrics_collections,
2213      updates_collections=updates_collections,
2214      name=name)
2215
2216
2217# TODO(ptucker): Validate range of values in labels?
2218def streaming_sparse_precision_at_k(predictions,
2219                                    labels,
2220                                    k,
2221                                    class_id=None,
2222                                    weights=None,
2223                                    metrics_collections=None,
2224                                    updates_collections=None,
2225                                    name=None):
2226  """Computes precision@k of the predictions with respect to sparse labels.
2227
2228  If `class_id` is not specified, we calculate precision as the ratio of true
2229      positives (i.e., correct predictions, items in the top `k` highest
2230      `predictions` that are found in the corresponding row in `labels`) to
2231      positives (all top `k` `predictions`).
2232  If `class_id` is specified, we calculate precision by considering only the
2233      rows in the batch for which `class_id` is in the top `k` highest
2234      `predictions`, and computing the fraction of them for which `class_id` is
2235      in the corresponding row in `labels`.
2236
2237  We expect precision to decrease as `k` increases.
2238
2239  `streaming_sparse_precision_at_k` creates two local variables,
2240  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
2241  the precision@k frequency. This frequency is ultimately returned as
2242  `precision_at_<k>`: an idempotent operation that simply divides
2243  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2244  `false_positive_at_<k>`).
2245
2246  For estimation of the metric over a stream of data, the function creates an
2247  `update_op` operation that updates these variables and returns the
2248  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2249  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2250  `labels` calculate the true positives and false positives weighted by
2251  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2252  `false_positive_at_<k>` using these values.
2253
2254  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2255
2256  Args:
2257    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2258      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2259      The final dimension contains the logit values for each class. [D1, ... DN]
2260      must match `labels`.
2261    labels: `int64` `Tensor` or `SparseTensor` with shape
2262      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2263      target classes for the associated prediction. Commonly, N=1 and `labels`
2264      has shape [batch_size, num_labels]. [D1, ... DN] must match
2265      `predictions`. Values should be in range [0, num_classes), where
2266      num_classes is the last dimension of `predictions`. Values outside this
2267      range are ignored.
2268    k: Integer, k for @k metric.
2269    class_id: Integer class ID for which we want binary metrics. This should be
2270      in range [0, num_classes], where num_classes is the last dimension of
2271      `predictions`. If `class_id` is outside this range, the method returns
2272      NAN.
2273    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2274      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2275      dimensions must be either `1`, or the same as the corresponding `labels`
2276      dimension).
2277    metrics_collections: An optional list of collections that values should
2278      be added to.
2279    updates_collections: An optional list of collections that updates should
2280      be added to.
2281    name: Name of new update operation, and namespace for other dependent ops.
2282
2283  Returns:
2284    precision: Scalar `float64` `Tensor` with the value of `true_positives`
2285      divided by the sum of `true_positives` and `false_positives`.
2286    update_op: `Operation` that increments `true_positives` and
2287      `false_positives` variables appropriately, and whose value matches
2288      `precision`.
2289
2290  Raises:
2291    ValueError: If `weights` is not `None` and its shape doesn't match
2292      `predictions`, or if either `metrics_collections` or `updates_collections`
2293      are not a list or tuple.
2294  """
2295  return metrics.precision_at_k(
2296      k=k,
2297      class_id=class_id,
2298      predictions=predictions,
2299      labels=labels,
2300      weights=weights,
2301      metrics_collections=metrics_collections,
2302      updates_collections=updates_collections,
2303      name=name)
2304
2305
2306# TODO(ptucker): Validate range of values in labels?
2307def streaming_sparse_precision_at_top_k(top_k_predictions,
2308                                        labels,
2309                                        class_id=None,
2310                                        weights=None,
2311                                        metrics_collections=None,
2312                                        updates_collections=None,
2313                                        name=None):
2314  """Computes precision@k of top-k predictions with respect to sparse labels.
2315
2316  If `class_id` is not specified, we calculate precision as the ratio of
2317      true positives (i.e., correct predictions, items in `top_k_predictions`
2318      that are found in the corresponding row in `labels`) to positives (all
2319      `top_k_predictions`).
2320  If `class_id` is specified, we calculate precision by considering only the
2321      rows in the batch for which `class_id` is in the top `k` highest
2322      `predictions`, and computing the fraction of them for which `class_id` is
2323      in the corresponding row in `labels`.
2324
2325  We expect precision to decrease as `k` increases.
2326
2327  `streaming_sparse_precision_at_top_k` creates two local variables,
2328  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
2329  the precision@k frequency. This frequency is ultimately returned as
2330  `precision_at_k`: an idempotent operation that simply divides
2331  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
2332
2333  For estimation of the metric over a stream of data, the function creates an
2334  `update_op` operation that updates these variables and returns the
2335  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
2336  and `labels` calculate the true positives and false positives weighted by
2337  `weights`. Then `update_op` increments `true_positive_at_k` and
2338  `false_positive_at_k` using these values.
2339
2340  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2341
2342  Args:
2343    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
2344      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
2345      The final dimension contains the indices of top-k labels. [D1, ... DN]
2346      must match `labels`.
2347    labels: `int64` `Tensor` or `SparseTensor` with shape
2348      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2349      target classes for the associated prediction. Commonly, N=1 and `labels`
2350      has shape [batch_size, num_labels]. [D1, ... DN] must match
2351      `top_k_predictions`. Values should be in range [0, num_classes), where
2352      num_classes is the last dimension of `predictions`. Values outside this
2353      range are ignored.
2354    class_id: Integer class ID for which we want binary metrics. This should be
2355      in range [0, num_classes), where num_classes is the last dimension of
2356      `predictions`. If `class_id` is outside this range, the method returns
2357      NAN.
2358    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2359      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2360      dimensions must be either `1`, or the same as the corresponding `labels`
2361      dimension).
2362    metrics_collections: An optional list of collections that values should
2363      be added to.
2364    updates_collections: An optional list of collections that updates should
2365      be added to.
2366    name: Name of new update operation, and namespace for other dependent ops.
2367
2368  Returns:
2369    precision: Scalar `float64` `Tensor` with the value of `true_positives`
2370      divided by the sum of `true_positives` and `false_positives`.
2371    update_op: `Operation` that increments `true_positives` and
2372      `false_positives` variables appropriately, and whose value matches
2373      `precision`.
2374
2375  Raises:
2376    ValueError: If `weights` is not `None` and its shape doesn't match
2377      `predictions`, or if either `metrics_collections` or `updates_collections`
2378      are not a list or tuple.
2379    ValueError: If `top_k_predictions` has rank < 2.
2380  """
2381  default_name = _at_k_name('precision', class_id=class_id)
2382  with ops.name_scope(name, default_name,
2383                      (top_k_predictions, labels, weights)) as name_scope:
2384    return metrics_impl.precision_at_top_k(
2385        labels=labels,
2386        predictions_idx=top_k_predictions,
2387        class_id=class_id,
2388        weights=weights,
2389        metrics_collections=metrics_collections,
2390        updates_collections=updates_collections,
2391        name=name_scope)
2392
2393
2394def sparse_recall_at_top_k(labels,
2395                           top_k_predictions,
2396                           class_id=None,
2397                           weights=None,
2398                           metrics_collections=None,
2399                           updates_collections=None,
2400                           name=None):
2401  """Computes recall@k of top-k predictions with respect to sparse labels.
2402
2403  If `class_id` is specified, we calculate recall by considering only the
2404      entries in the batch for which `class_id` is in the label, and computing
2405      the fraction of them for which `class_id` is in the top-k `predictions`.
2406  If `class_id` is not specified, we'll calculate recall as how often on
2407      average a class among the labels of a batch entry is in the top-k
2408      `predictions`.
2409
2410  `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
2411  and `false_negative_at_<k>`, that are used to compute the recall_at_k
2412  frequency. This frequency is ultimately returned as `recall_at_<k>`: an
2413  idempotent operation that simply divides `true_positive_at_<k>` by total
2414  (`true_positive_at_<k>` + `false_negative_at_<k>`).
2415
2416  For estimation of the metric over a stream of data, the function creates an
2417  `update_op` operation that updates these variables and returns the
2418  `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
2419  true positives and false negatives weighted by `weights`. Then `update_op`
2420  increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
2421  values.
2422
2423  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2424
2425  Args:
2426    labels: `int64` `Tensor` or `SparseTensor` with shape
2427      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2428      target classes for the associated prediction. Commonly, N=1 and `labels`
2429      has shape [batch_size, num_labels]. [D1, ... DN] must match
2430      `top_k_predictions`. Values should be in range [0, num_classes), where
2431      num_classes is the last dimension of `predictions`. Values outside this
2432      range always count towards `false_negative_at_<k>`.
2433    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
2434      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
2435      The final dimension contains the indices of top-k labels. [D1, ... DN]
2436      must match `labels`.
2437    class_id: Integer class ID for which we want binary metrics. This should be
2438      in range [0, num_classes), where num_classes is the last dimension of
2439      `predictions`. If class_id is outside this range, the method returns NAN.
2440    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2441      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2442      dimensions must be either `1`, or the same as the corresponding `labels`
2443      dimension).
2444    metrics_collections: An optional list of collections that values should
2445      be added to.
2446    updates_collections: An optional list of collections that updates should
2447      be added to.
2448    name: Name of new update operation, and namespace for other dependent ops.
2449
2450  Returns:
2451    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2452      by the sum of `true_positives` and `false_negatives`.
2453    update_op: `Operation` that increments `true_positives` and
2454      `false_negatives` variables appropriately, and whose value matches
2455      `recall`.
2456
2457  Raises:
2458    ValueError: If `weights` is not `None` and its shape doesn't match
2459    `predictions`, or if either `metrics_collections` or `updates_collections`
2460    are not a list or tuple.
2461  """
2462  default_name = _at_k_name('recall', class_id=class_id)
2463  with ops.name_scope(name, default_name,
2464                      (top_k_predictions, labels, weights)) as name_scope:
2465    return metrics_impl.recall_at_top_k(
2466        labels=labels,
2467        predictions_idx=top_k_predictions,
2468        class_id=class_id,
2469        weights=weights,
2470        metrics_collections=metrics_collections,
2471        updates_collections=updates_collections,
2472        name=name_scope)
2473
2474
2475def _compute_recall_at_precision(tp, fp, fn, precision, name):
2476  """Helper function to compute recall at a given `precision`.
2477
2478  Args:
2479    tp: The number of true positives.
2480    fp: The number of false positives.
2481    fn: The number of false negatives.
2482    precision: The precision for which the recall will be calculated.
2483    name: An optional variable_scope name.
2484
2485  Returns:
2486    The recall at a the given `precision`.
2487  """
2488  precisions = math_ops.div(tp, tp + fp + _EPSILON)
2489  tf_index = math_ops.argmin(
2490      math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
2491
2492  # Now, we have the implicit threshold, so compute the recall:
2493  return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
2494                      name)
2495
2496
2497def recall_at_precision(labels,
2498                        predictions,
2499                        precision,
2500                        weights=None,
2501                        num_thresholds=200,
2502                        metrics_collections=None,
2503                        updates_collections=None,
2504                        name=None):
2505  """Computes `recall` at `precision`.
2506
2507  The `recall_at_precision` function creates four local variables,
2508  `tp` (true positives), `fp` (false positives) and `fn` (false negatives)
2509  that are used to compute the `recall` at the given `precision` value. The
2510  threshold for the given `precision` value is computed and used to evaluate the
2511  corresponding `recall`.
2512
2513  For estimation of the metric over a stream of data, the function creates an
2514  `update_op` operation that updates these variables and returns the
2515  `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the
2516  weight of each case found in the `predictions` and `labels`.
2517
2518  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2519
2520  Args:
2521    labels: The ground truth values, a `Tensor` whose dimensions must match
2522      `predictions`. Will be cast to `bool`.
2523    predictions: A floating point `Tensor` of arbitrary shape and whose values
2524      are in the range `[0, 1]`.
2525    precision: A scalar value in range `[0, 1]`.
2526    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2527      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2528      be either `1`, or the same as the corresponding `labels` dimension).
2529    num_thresholds: The number of thresholds to use for matching the given
2530      `precision`.
2531    metrics_collections: An optional list of collections that `recall`
2532      should be added to.
2533    updates_collections: An optional list of collections that `update_op` should
2534      be added to.
2535    name: An optional variable_scope name.
2536
2537  Returns:
2538    recall: A scalar `Tensor` representing the recall at the given
2539      `precision` value.
2540    update_op: An operation that increments the `tp`, `fp` and `fn`
2541      variables appropriately and whose value matches `recall`.
2542
2543  Raises:
2544    ValueError: If `predictions` and `labels` have mismatched shapes, if
2545      `weights` is not `None` and its shape doesn't match `predictions`, or if
2546      `precision` is not between 0 and 1, or if either `metrics_collections`
2547      or `updates_collections` are not a list or tuple.
2548
2549  """
2550  if not 0 <= precision <= 1:
2551    raise ValueError('`precision` must be in the range [0, 1].')
2552
2553  with variable_scope.variable_scope(name, 'recall_at_precision',
2554                                     (predictions, labels, weights)):
2555    thresholds = [
2556        i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1)
2557    ]
2558    thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON]
2559
2560    values, update_ops = _streaming_confusion_matrix_at_thresholds(
2561        predictions, labels, thresholds, weights)
2562
2563    recall = _compute_recall_at_precision(values['tp'], values['fp'],
2564                                          values['fn'], precision, 'value')
2565    update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
2566                                             update_ops['fn'], precision,
2567                                             'update_op')
2568
2569    if metrics_collections:
2570      ops.add_to_collections(metrics_collections, recall)
2571
2572    if updates_collections:
2573      ops.add_to_collections(updates_collections, update_op)
2574
2575    return recall, update_op
2576
2577
2578def streaming_sparse_average_precision_at_k(predictions,
2579                                            labels,
2580                                            k,
2581                                            weights=None,
2582                                            metrics_collections=None,
2583                                            updates_collections=None,
2584                                            name=None):
2585  """Computes average precision@k of predictions with respect to sparse labels.
2586
2587  See `sparse_average_precision_at_k` for details on formula. `weights` are
2588  applied to the result of `sparse_average_precision_at_k`
2589
2590  `streaming_sparse_average_precision_at_k` creates two local variables,
2591  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
2592  are used to compute the frequency. This frequency is ultimately returned as
2593  `average_precision_at_<k>`: an idempotent operation that simply divides
2594  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
2595
2596  For estimation of the metric over a stream of data, the function creates an
2597  `update_op` operation that updates these variables and returns the
2598  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2599  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2600  `labels` calculate the true positives and false positives weighted by
2601  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2602  `false_positive_at_<k>` using these values.
2603
2604  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2605
2606  Args:
2607    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2608      N >= 1. Commonly, N=1 and `predictions` has shape
2609      [batch size, num_classes]. The final dimension contains the logit values
2610      for each class. [D1, ... DN] must match `labels`.
2611    labels: `int64` `Tensor` or `SparseTensor` with shape
2612      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2613      target classes for the associated prediction. Commonly, N=1 and `labels`
2614      has shape [batch_size, num_labels]. [D1, ... DN] must match
2615      `predictions_`. Values should be in range [0, num_classes), where
2616      num_classes is the last dimension of `predictions`. Values outside this
2617      range are ignored.
2618    k: Integer, k for @k metric. This will calculate an average precision for
2619      range `[1,k]`, as documented above.
2620    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2621      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2622      dimensions must be either `1`, or the same as the corresponding `labels`
2623      dimension).
2624    metrics_collections: An optional list of collections that values should
2625      be added to.
2626    updates_collections: An optional list of collections that updates should
2627      be added to.
2628    name: Name of new update operation, and namespace for other dependent ops.
2629
2630  Returns:
2631    mean_average_precision: Scalar `float64` `Tensor` with the mean average
2632      precision values.
2633    update: `Operation` that increments variables appropriately, and whose
2634      value matches `metric`.
2635  """
2636  return metrics.average_precision_at_k(
2637      k=k,
2638      predictions=predictions,
2639      labels=labels,
2640      weights=weights,
2641      metrics_collections=metrics_collections,
2642      updates_collections=updates_collections,
2643      name=name)
2644
2645
2646def streaming_sparse_average_precision_at_top_k(top_k_predictions,
2647                                                labels,
2648                                                weights=None,
2649                                                metrics_collections=None,
2650                                                updates_collections=None,
2651                                                name=None):
2652  """Computes average precision@k of predictions with respect to sparse labels.
2653
2654  `streaming_sparse_average_precision_at_top_k` creates two local variables,
2655  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
2656  are used to compute the frequency. This frequency is ultimately returned as
2657  `average_precision_at_<k>`: an idempotent operation that simply divides
2658  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
2659
2660  For estimation of the metric over a stream of data, the function creates an
2661  `update_op` operation that updates these variables and returns the
2662  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
2663  the true positives and false positives weighted by `weights`. Then `update_op`
2664  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
2665  values.
2666
2667  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2668
2669  Args:
2670    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2671      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
2672      dimension must be set and contains the top `k` predicted class indices.
2673      [D1, ... DN] must match `labels`. Values should be in range
2674      [0, num_classes).
2675    labels: `int64` `Tensor` or `SparseTensor` with shape
2676      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2677      num_labels=1. N >= 1 and num_labels is the number of target classes for
2678      the associated prediction. Commonly, N=1 and `labels` has shape
2679      [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`.
2680      Values should be in range [0, num_classes).
2681    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2682      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2683      dimensions must be either `1`, or the same as the corresponding `labels`
2684      dimension).
2685    metrics_collections: An optional list of collections that values should
2686      be added to.
2687    updates_collections: An optional list of collections that updates should
2688      be added to.
2689    name: Name of new update operation, and namespace for other dependent ops.
2690
2691  Returns:
2692    mean_average_precision: Scalar `float64` `Tensor` with the mean average
2693      precision values.
2694    update: `Operation` that increments variables appropriately, and whose
2695      value matches `metric`.
2696
2697  Raises:
2698    ValueError: if the last dimension of top_k_predictions is not set.
2699  """
2700  return metrics_impl._streaming_sparse_average_precision_at_top_k(  # pylint: disable=protected-access
2701      predictions_idx=top_k_predictions,
2702      labels=labels,
2703      weights=weights,
2704      metrics_collections=metrics_collections,
2705      updates_collections=updates_collections,
2706      name=name)
2707
2708
2709@deprecated(None, 'Please switch to tf.metrics.mean.')
2710def streaming_mean_absolute_error(predictions,
2711                                  labels,
2712                                  weights=None,
2713                                  metrics_collections=None,
2714                                  updates_collections=None,
2715                                  name=None):
2716  """Computes the mean absolute error between the labels and predictions.
2717
2718  The `streaming_mean_absolute_error` function creates two local variables,
2719  `total` and `count` that are used to compute the mean absolute error. This
2720  average is weighted by `weights`, and it is ultimately returned as
2721  `mean_absolute_error`: an idempotent operation that simply divides `total` by
2722  `count`.
2723
2724  For estimation of the metric over a stream of data, the function creates an
2725  `update_op` operation that updates these variables and returns the
2726  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
2727  absolute value of the differences between `predictions` and `labels`. Then
2728  `update_op` increments `total` with the reduced sum of the product of
2729  `weights` and `absolute_errors`, and it increments `count` with the reduced
2730  sum of `weights`
2731
2732  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2733
2734  Args:
2735    predictions: A `Tensor` of arbitrary shape.
2736    labels: A `Tensor` of the same shape as `predictions`.
2737    weights: Optional `Tensor` indicating the frequency with which an example is
2738      sampled. Rank must be 0, or the same rank as `labels`, and must be
2739      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2740      the same as the corresponding `labels` dimension).
2741    metrics_collections: An optional list of collections that
2742      `mean_absolute_error` should be added to.
2743    updates_collections: An optional list of collections that `update_op` should
2744      be added to.
2745    name: An optional variable_scope name.
2746
2747  Returns:
2748    mean_absolute_error: A `Tensor` representing the current mean, the value of
2749      `total` divided by `count`.
2750    update_op: An operation that increments the `total` and `count` variables
2751      appropriately and whose value matches `mean_absolute_error`.
2752
2753  Raises:
2754    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2755      `weights` is not `None` and its shape doesn't match `predictions`, or if
2756      either `metrics_collections` or `updates_collections` are not a list or
2757      tuple.
2758  """
2759  return metrics.mean_absolute_error(
2760      predictions=predictions,
2761      labels=labels,
2762      weights=weights,
2763      metrics_collections=metrics_collections,
2764      updates_collections=updates_collections,
2765      name=name)
2766
2767
2768def streaming_mean_relative_error(predictions,
2769                                  labels,
2770                                  normalizer,
2771                                  weights=None,
2772                                  metrics_collections=None,
2773                                  updates_collections=None,
2774                                  name=None):
2775  """Computes the mean relative error by normalizing with the given values.
2776
2777  The `streaming_mean_relative_error` function creates two local variables,
2778  `total` and `count` that are used to compute the mean relative absolute error.
2779  This average is weighted by `weights`, and it is ultimately returned as
2780  `mean_relative_error`: an idempotent operation that simply divides `total` by
2781  `count`.
2782
2783  For estimation of the metric over a stream of data, the function creates an
2784  `update_op` operation that updates these variables and returns the
2785  `mean_reative_error`. Internally, a `relative_errors` operation divides the
2786  absolute value of the differences between `predictions` and `labels` by the
2787  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
2788  product of `weights` and `relative_errors`, and it increments `count` with the
2789  reduced sum of `weights`.
2790
2791  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2792
2793  Args:
2794    predictions: A `Tensor` of arbitrary shape.
2795    labels: A `Tensor` of the same shape as `predictions`.
2796    normalizer: A `Tensor` of the same shape as `predictions`.
2797    weights: Optional `Tensor` indicating the frequency with which an example is
2798      sampled. Rank must be 0, or the same rank as `labels`, and must be
2799      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2800      the same as the corresponding `labels` dimension).
2801    metrics_collections: An optional list of collections that
2802      `mean_relative_error` should be added to.
2803    updates_collections: An optional list of collections that `update_op` should
2804      be added to.
2805    name: An optional variable_scope name.
2806
2807  Returns:
2808    mean_relative_error: A `Tensor` representing the current mean, the value of
2809      `total` divided by `count`.
2810    update_op: An operation that increments the `total` and `count` variables
2811      appropriately and whose value matches `mean_relative_error`.
2812
2813  Raises:
2814    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2815      `weights` is not `None` and its shape doesn't match `predictions`, or if
2816      either `metrics_collections` or `updates_collections` are not a list or
2817      tuple.
2818  """
2819  return metrics.mean_relative_error(
2820      normalizer=normalizer,
2821      predictions=predictions,
2822      labels=labels,
2823      weights=weights,
2824      metrics_collections=metrics_collections,
2825      updates_collections=updates_collections,
2826      name=name)
2827
2828
2829def streaming_mean_squared_error(predictions,
2830                                 labels,
2831                                 weights=None,
2832                                 metrics_collections=None,
2833                                 updates_collections=None,
2834                                 name=None):
2835  """Computes the mean squared error between the labels and predictions.
2836
2837  The `streaming_mean_squared_error` function creates two local variables,
2838  `total` and `count` that are used to compute the mean squared error.
2839  This average is weighted by `weights`, and it is ultimately returned as
2840  `mean_squared_error`: an idempotent operation that simply divides `total` by
2841  `count`.
2842
2843  For estimation of the metric over a stream of data, the function creates an
2844  `update_op` operation that updates these variables and returns the
2845  `mean_squared_error`. Internally, a `squared_error` operation computes the
2846  element-wise square of the difference between `predictions` and `labels`. Then
2847  `update_op` increments `total` with the reduced sum of the product of
2848  `weights` and `squared_error`, and it increments `count` with the reduced sum
2849  of `weights`.
2850
2851  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2852
2853  Args:
2854    predictions: A `Tensor` of arbitrary shape.
2855    labels: A `Tensor` of the same shape as `predictions`.
2856    weights: Optional `Tensor` indicating the frequency with which an example is
2857      sampled. Rank must be 0, or the same rank as `labels`, and must be
2858      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2859      the same as the corresponding `labels` dimension).
2860    metrics_collections: An optional list of collections that
2861      `mean_squared_error` should be added to.
2862    updates_collections: An optional list of collections that `update_op` should
2863      be added to.
2864    name: An optional variable_scope name.
2865
2866  Returns:
2867    mean_squared_error: A `Tensor` representing the current mean, the value of
2868      `total` divided by `count`.
2869    update_op: An operation that increments the `total` and `count` variables
2870      appropriately and whose value matches `mean_squared_error`.
2871
2872  Raises:
2873    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2874      `weights` is not `None` and its shape doesn't match `predictions`, or if
2875      either `metrics_collections` or `updates_collections` are not a list or
2876      tuple.
2877  """
2878  return metrics.mean_squared_error(
2879      predictions=predictions,
2880      labels=labels,
2881      weights=weights,
2882      metrics_collections=metrics_collections,
2883      updates_collections=updates_collections,
2884      name=name)
2885
2886
2887def streaming_root_mean_squared_error(predictions,
2888                                      labels,
2889                                      weights=None,
2890                                      metrics_collections=None,
2891                                      updates_collections=None,
2892                                      name=None):
2893  """Computes the root mean squared error between the labels and predictions.
2894
2895  The `streaming_root_mean_squared_error` function creates two local variables,
2896  `total` and `count` that are used to compute the root mean squared error.
2897  This average is weighted by `weights`, and it is ultimately returned as
2898  `root_mean_squared_error`: an idempotent operation that takes the square root
2899  of the division of `total` by `count`.
2900
2901  For estimation of the metric over a stream of data, the function creates an
2902  `update_op` operation that updates these variables and returns the
2903  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2904  the element-wise square of the difference between `predictions` and `labels`.
2905  Then `update_op` increments `total` with the reduced sum of the product of
2906  `weights` and `squared_error`, and it increments `count` with the reduced sum
2907  of `weights`.
2908
2909  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2910
2911  Args:
2912    predictions: A `Tensor` of arbitrary shape.
2913    labels: A `Tensor` of the same shape as `predictions`.
2914    weights: Optional `Tensor` indicating the frequency with which an example is
2915      sampled. Rank must be 0, or the same rank as `labels`, and must be
2916      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2917      the same as the corresponding `labels` dimension).
2918    metrics_collections: An optional list of collections that
2919      `root_mean_squared_error` should be added to.
2920    updates_collections: An optional list of collections that `update_op` should
2921      be added to.
2922    name: An optional variable_scope name.
2923
2924  Returns:
2925    root_mean_squared_error: A `Tensor` representing the current mean, the value
2926      of `total` divided by `count`.
2927    update_op: An operation that increments the `total` and `count` variables
2928      appropriately and whose value matches `root_mean_squared_error`.
2929
2930  Raises:
2931    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2932      `weights` is not `None` and its shape doesn't match `predictions`, or if
2933      either `metrics_collections` or `updates_collections` are not a list or
2934      tuple.
2935  """
2936  return metrics.root_mean_squared_error(
2937      predictions=predictions,
2938      labels=labels,
2939      weights=weights,
2940      metrics_collections=metrics_collections,
2941      updates_collections=updates_collections,
2942      name=name)
2943
2944
2945def streaming_covariance(predictions,
2946                         labels,
2947                         weights=None,
2948                         metrics_collections=None,
2949                         updates_collections=None,
2950                         name=None):
2951  """Computes the unbiased sample covariance between `predictions` and `labels`.
2952
2953  The `streaming_covariance` function creates four local variables,
2954  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
2955  compute the sample covariance between predictions and labels across multiple
2956  batches of data. The covariance is ultimately returned as an idempotent
2957  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
2958  in order to get an unbiased estimate.
2959
2960  The algorithm used for this online computation is described in
2961  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
2962  Specifically, the formula used to combine two sample comoments is
2963  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
2964  The comoment for a single batch of data is simply
2965  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
2966
2967  If `weights` is not None, then it is used to compute weighted comoments,
2968  means, and count. NOTE: these weights are treated as "frequency weights", as
2969  opposed to "reliability weights". See discussion of the difference on
2970  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2971
2972  To facilitate the computation of covariance across multiple batches of data,
2973  the function creates an `update_op` operation, which updates underlying
2974  variables and returns the updated covariance.
2975
2976  Args:
2977    predictions: A `Tensor` of arbitrary size.
2978    labels: A `Tensor` of the same size as `predictions`.
2979    weights: Optional `Tensor` indicating the frequency with which an example is
2980      sampled. Rank must be 0, or the same rank as `labels`, and must be
2981      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2982      the same as the corresponding `labels` dimension).
2983    metrics_collections: An optional list of collections that the metric
2984      value variable should be added to.
2985    updates_collections: An optional list of collections that the metric update
2986      ops should be added to.
2987    name: An optional variable_scope name.
2988
2989  Returns:
2990    covariance: A `Tensor` representing the current unbiased sample covariance,
2991      `comoment` / (`count` - 1).
2992    update_op: An operation that updates the local variables appropriately.
2993
2994  Raises:
2995    ValueError: If labels and predictions are of different sizes or if either
2996      `metrics_collections` or `updates_collections` are not a list or tuple.
2997  """
2998  with variable_scope.variable_scope(name, 'covariance',
2999                                     (predictions, labels, weights)):
3000    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3001        predictions, labels, weights)
3002    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3003    count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
3004    mean_prediction = metrics_impl.metric_variable(
3005        [], dtypes.float32, name='mean_prediction')
3006    mean_label = metrics_impl.metric_variable(
3007        [], dtypes.float32, name='mean_label')
3008    comoment = metrics_impl.metric_variable(  # C_A in update equation
3009        [], dtypes.float32, name='comoment')
3010
3011    if weights is None:
3012      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
3013      weighted_predictions = predictions
3014      weighted_labels = labels
3015    else:
3016      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
3017      batch_count = math_ops.reduce_sum(weights)  # n_B in eqn
3018      weighted_predictions = math_ops.multiply(predictions, weights)
3019      weighted_labels = math_ops.multiply(labels, weights)
3020
3021    update_count = state_ops.assign_add(count_, batch_count)  # n_AB in eqn
3022    prev_count = update_count - batch_count  # n_A in update equation
3023
3024    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
3025    # batch_mean_prediction is E[x_B] in the update equation
3026    batch_mean_prediction = _safe_div(
3027        math_ops.reduce_sum(weighted_predictions), batch_count,
3028        'batch_mean_prediction')
3029    delta_mean_prediction = _safe_div(
3030        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
3031        'delta_mean_prediction')
3032    update_mean_prediction = state_ops.assign_add(mean_prediction,
3033                                                  delta_mean_prediction)
3034    # prev_mean_prediction is E[x_A] in the update equation
3035    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
3036
3037    # batch_mean_label is E[y_B] in the update equation
3038    batch_mean_label = _safe_div(
3039        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
3040    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
3041                                 update_count, 'delta_mean_label')
3042    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
3043    # prev_mean_label is E[y_A] in the update equation
3044    prev_mean_label = update_mean_label - delta_mean_label
3045
3046    unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) *
3047                                    (labels - batch_mean_label))
3048    # batch_comoment is C_B in the update equation
3049    if weights is None:
3050      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
3051    else:
3052      batch_comoment = math_ops.reduce_sum(
3053          unweighted_batch_coresiduals * weights)
3054
3055    # View delta_comoment as = C_AB - C_A in the update equation above.
3056    # Since C_A is stored in a var, by how much do we need to increment that var
3057    # to make the var = C_AB?
3058    delta_comoment = (
3059        batch_comoment + (prev_mean_prediction - batch_mean_prediction) *
3060        (prev_mean_label - batch_mean_label) *
3061        (prev_count * batch_count / update_count))
3062    update_comoment = state_ops.assign_add(comoment, delta_comoment)
3063
3064    covariance = array_ops.where(
3065        math_ops.less_equal(count_, 1.),
3066        float('nan'),
3067        math_ops.truediv(comoment, count_ - 1),
3068        name='covariance')
3069    with ops.control_dependencies([update_comoment]):
3070      update_op = array_ops.where(
3071          math_ops.less_equal(count_, 1.),
3072          float('nan'),
3073          math_ops.truediv(comoment, count_ - 1),
3074          name='update_op')
3075
3076  if metrics_collections:
3077    ops.add_to_collections(metrics_collections, covariance)
3078
3079  if updates_collections:
3080    ops.add_to_collections(updates_collections, update_op)
3081
3082  return covariance, update_op
3083
3084
3085def streaming_pearson_correlation(predictions,
3086                                  labels,
3087                                  weights=None,
3088                                  metrics_collections=None,
3089                                  updates_collections=None,
3090                                  name=None):
3091  """Computes Pearson correlation coefficient between `predictions`, `labels`.
3092
3093  The `streaming_pearson_correlation` function delegates to
3094  `streaming_covariance` the tracking of three [co]variances:
3095
3096  - `streaming_covariance(predictions, labels)`, i.e. covariance
3097  - `streaming_covariance(predictions, predictions)`, i.e. variance
3098  - `streaming_covariance(labels, labels)`, i.e. variance
3099
3100  The product-moment correlation ultimately returned is an idempotent operation
3101  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
3102  facilitate correlation computation across multiple batches, the function
3103  groups the `update_op`s of the underlying streaming_covariance and returns an
3104  `update_op`.
3105
3106  If `weights` is not None, then it is used to compute a weighted correlation.
3107  NOTE: these weights are treated as "frequency weights", as opposed to
3108  "reliability weights". See discussion of the difference on
3109  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
3110
3111  Args:
3112    predictions: A `Tensor` of arbitrary size.
3113    labels: A `Tensor` of the same size as predictions.
3114    weights: Optional `Tensor` indicating the frequency with which an example is
3115      sampled. Rank must be 0, or the same rank as `labels`, and must be
3116      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
3117      the same as the corresponding `labels` dimension).
3118    metrics_collections: An optional list of collections that the metric
3119      value variable should be added to.
3120    updates_collections: An optional list of collections that the metric update
3121      ops should be added to.
3122    name: An optional variable_scope name.
3123
3124  Returns:
3125    pearson_r: A `Tensor` representing the current Pearson product-moment
3126      correlation coefficient, the value of
3127      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
3128    update_op: An operation that updates the underlying variables appropriately.
3129
3130  Raises:
3131    ValueError: If `labels` and `predictions` are of different sizes, or if
3132      `weights` is the wrong size, or if either `metrics_collections` or
3133      `updates_collections` are not a `list` or `tuple`.
3134  """
3135  with variable_scope.variable_scope(name, 'pearson_r',
3136                                     (predictions, labels, weights)):
3137    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3138        predictions, labels, weights)
3139    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3140    # Broadcast weights here to avoid duplicate broadcasting in each call to
3141    # `streaming_covariance`.
3142    if weights is not None:
3143      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
3144    cov, update_cov = streaming_covariance(
3145        predictions, labels, weights=weights, name='covariance')
3146    var_predictions, update_var_predictions = streaming_covariance(
3147        predictions, predictions, weights=weights, name='variance_predictions')
3148    var_labels, update_var_labels = streaming_covariance(
3149        labels, labels, weights=weights, name='variance_labels')
3150
3151    pearson_r = math_ops.truediv(
3152        cov,
3153        math_ops.multiply(
3154            math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)),
3155        name='pearson_r')
3156    update_op = math_ops.truediv(
3157        update_cov,
3158        math_ops.multiply(
3159            math_ops.sqrt(update_var_predictions),
3160            math_ops.sqrt(update_var_labels)),
3161        name='update_op')
3162
3163  if metrics_collections:
3164    ops.add_to_collections(metrics_collections, pearson_r)
3165
3166  if updates_collections:
3167    ops.add_to_collections(updates_collections, update_op)
3168
3169  return pearson_r, update_op
3170
3171
3172# TODO(nsilberman): add a 'normalized' flag so that the user can request
3173# normalization if the inputs are not normalized.
3174def streaming_mean_cosine_distance(predictions,
3175                                   labels,
3176                                   dim,
3177                                   weights=None,
3178                                   metrics_collections=None,
3179                                   updates_collections=None,
3180                                   name=None):
3181  """Computes the cosine distance between the labels and predictions.
3182
3183  The `streaming_mean_cosine_distance` function creates two local variables,
3184  `total` and `count` that are used to compute the average cosine distance
3185  between `predictions` and `labels`. This average is weighted by `weights`,
3186  and it is ultimately returned as `mean_distance`, which is an idempotent
3187  operation that simply divides `total` by `count`.
3188
3189  For estimation of the metric over a stream of data, the function creates an
3190  `update_op` operation that updates these variables and returns the
3191  `mean_distance`.
3192
3193  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3194
3195  Args:
3196    predictions: A `Tensor` of the same shape as `labels`.
3197    labels: A `Tensor` of arbitrary shape.
3198    dim: The dimension along which the cosine distance is computed.
3199    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
3200      and whose dimension `dim` is 1.
3201    metrics_collections: An optional list of collections that the metric
3202      value variable should be added to.
3203    updates_collections: An optional list of collections that the metric update
3204      ops should be added to.
3205    name: An optional variable_scope name.
3206
3207  Returns:
3208    mean_distance: A `Tensor` representing the current mean, the value of
3209      `total` divided by `count`.
3210    update_op: An operation that increments the `total` and `count` variables
3211      appropriately.
3212
3213  Raises:
3214    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3215      `weights` is not `None` and its shape doesn't match `predictions`, or if
3216      either `metrics_collections` or `updates_collections` are not a list or
3217      tuple.
3218  """
3219  predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3220      predictions, labels, weights)
3221  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3222  radial_diffs = math_ops.multiply(predictions, labels)
3223  radial_diffs = math_ops.reduce_sum(
3224      radial_diffs, reduction_indices=[
3225          dim,
3226      ], keep_dims=True)
3227  mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None,
3228                                            name or 'mean_cosine_distance')
3229  mean_distance = math_ops.subtract(1.0, mean_distance)
3230  update_op = math_ops.subtract(1.0, update_op)
3231
3232  if metrics_collections:
3233    ops.add_to_collections(metrics_collections, mean_distance)
3234
3235  if updates_collections:
3236    ops.add_to_collections(updates_collections, update_op)
3237
3238  return mean_distance, update_op
3239
3240
3241def streaming_percentage_less(values,
3242                              threshold,
3243                              weights=None,
3244                              metrics_collections=None,
3245                              updates_collections=None,
3246                              name=None):
3247  """Computes the percentage of values less than the given threshold.
3248
3249  The `streaming_percentage_less` function creates two local variables,
3250  `total` and `count` that are used to compute the percentage of `values` that
3251  fall below `threshold`. This rate is weighted by `weights`, and it is
3252  ultimately returned as `percentage` which is an idempotent operation that
3253  simply divides `total` by `count`.
3254
3255  For estimation of the metric over a stream of data, the function creates an
3256  `update_op` operation that updates these variables and returns the
3257  `percentage`.
3258
3259  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3260
3261  Args:
3262    values: A numeric `Tensor` of arbitrary size.
3263    threshold: A scalar threshold.
3264    weights: An optional `Tensor` whose shape is broadcastable to `values`.
3265    metrics_collections: An optional list of collections that the metric
3266      value variable should be added to.
3267    updates_collections: An optional list of collections that the metric update
3268      ops should be added to.
3269    name: An optional variable_scope name.
3270
3271  Returns:
3272    percentage: A `Tensor` representing the current mean, the value of `total`
3273      divided by `count`.
3274    update_op: An operation that increments the `total` and `count` variables
3275      appropriately.
3276
3277  Raises:
3278    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
3279      or if either `metrics_collections` or `updates_collections` are not a list
3280      or tuple.
3281  """
3282  return metrics.percentage_below(
3283      values=values,
3284      threshold=threshold,
3285      weights=weights,
3286      metrics_collections=metrics_collections,
3287      updates_collections=updates_collections,
3288      name=name)
3289
3290
3291def streaming_mean_iou(predictions,
3292                       labels,
3293                       num_classes,
3294                       weights=None,
3295                       metrics_collections=None,
3296                       updates_collections=None,
3297                       name=None):
3298  """Calculate per-step mean Intersection-Over-Union (mIOU).
3299
3300  Mean Intersection-Over-Union is a common evaluation metric for
3301  semantic image segmentation, which first computes the IOU for each
3302  semantic class and then computes the average over classes.
3303  IOU is defined as follows:
3304    IOU = true_positive / (true_positive + false_positive + false_negative).
3305  The predictions are accumulated in a confusion matrix, weighted by `weights`,
3306  and mIOU is then calculated from it.
3307
3308  For estimation of the metric over a stream of data, the function creates an
3309  `update_op` operation that updates these variables and returns the `mean_iou`.
3310
3311  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3312
3313  Args:
3314    predictions: A `Tensor` of prediction results for semantic labels, whose
3315      shape is [batch size] and type `int32` or `int64`. The tensor will be
3316      flattened, if its rank > 1.
3317    labels: A `Tensor` of ground truth labels with shape [batch size] and of
3318      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
3319    num_classes: The possible number of labels the prediction task can
3320      have. This value must be provided, since a confusion matrix of
3321      dimension = [num_classes, num_classes] will be allocated.
3322    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
3323    metrics_collections: An optional list of collections that `mean_iou`
3324      should be added to.
3325    updates_collections: An optional list of collections `update_op` should be
3326      added to.
3327    name: An optional variable_scope name.
3328
3329  Returns:
3330    mean_iou: A `Tensor` representing the mean intersection-over-union.
3331    update_op: An operation that increments the confusion matrix.
3332
3333  Raises:
3334    ValueError: If `predictions` and `labels` have mismatched shapes, or if
3335      `weights` is not `None` and its shape doesn't match `predictions`, or if
3336      either `metrics_collections` or `updates_collections` are not a list or
3337      tuple.
3338  """
3339  return metrics.mean_iou(
3340      num_classes=num_classes,
3341      predictions=predictions,
3342      labels=labels,
3343      weights=weights,
3344      metrics_collections=metrics_collections,
3345      updates_collections=updates_collections,
3346      name=name)
3347
3348
3349def _next_array_size(required_size, growth_factor=1.5):
3350  """Calculate the next size for reallocating a dynamic array.
3351
3352  Args:
3353    required_size: number or tf.Tensor specifying required array capacity.
3354    growth_factor: optional number or tf.Tensor specifying the growth factor
3355      between subsequent allocations.
3356
3357  Returns:
3358    tf.Tensor with dtype=int32 giving the next array size.
3359  """
3360  exponent = math_ops.ceil(
3361      math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log(
3362          math_ops.cast(growth_factor, dtypes.float32)))
3363  return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32)
3364
3365
3366def streaming_concat(values,
3367                     axis=0,
3368                     max_size=None,
3369                     metrics_collections=None,
3370                     updates_collections=None,
3371                     name=None):
3372  """Concatenate values along an axis across batches.
3373
3374  The function `streaming_concat` creates two local variables, `array` and
3375  `size`, that are used to store concatenated values. Internally, `array` is
3376  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
3377  that updates can be run in amortized constant time.
3378
3379  For estimation of the metric over a stream of data, the function creates an
3380  `update_op` operation that appends the values of a tensor and returns the
3381  length of the concatenated axis.
3382
3383  This op allows for evaluating metrics that cannot be updated incrementally
3384  using the same framework as other streaming metrics.
3385
3386  Args:
3387    values: `Tensor` to concatenate. Rank and the shape along all axes other
3388      than the axis to concatenate along must be statically known.
3389    axis: optional integer axis to concatenate along.
3390    max_size: optional integer maximum size of `value` along the given axis.
3391      Once the maximum size is reached, further updates are no-ops. By default,
3392      there is no maximum size: the array is resized as necessary.
3393    metrics_collections: An optional list of collections that `value`
3394      should be added to.
3395    updates_collections: An optional list of collections `update_op` should be
3396      added to.
3397    name: An optional variable_scope name.
3398
3399  Returns:
3400    value: A `Tensor` representing the concatenated values.
3401    update_op: An operation that concatenates the next values.
3402
3403  Raises:
3404    ValueError: if `values` does not have a statically known rank, `axis` is
3405      not in the valid range or the size of `values` is not statically known
3406      along any axis other than `axis`.
3407  """
3408  with variable_scope.variable_scope(name, 'streaming_concat', (values,)):
3409    # pylint: disable=invalid-slice-index
3410    values_shape = values.get_shape()
3411    if values_shape.dims is None:
3412      raise ValueError('`values` must have known statically known rank')
3413
3414    ndim = len(values_shape)
3415    if axis < 0:
3416      axis += ndim
3417    if not 0 <= axis < ndim:
3418      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
3419
3420    fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis]
3421    if any(value is None for value in fixed_shape):
3422      raise ValueError('all dimensions of `values` other than the dimension to '
3423                       'concatenate along must have statically known size')
3424
3425    # We move `axis` to the front of the internal array so assign ops can be
3426    # applied to contiguous slices
3427    init_size = 0 if max_size is None else max_size
3428    init_shape = [init_size] + fixed_shape
3429    array = metrics_impl.metric_variable(
3430        init_shape, values.dtype, validate_shape=False, name='array')
3431    size = metrics_impl.metric_variable([], dtypes.int32, name='size')
3432
3433    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
3434    valid_array = array[:size]
3435    valid_array.set_shape([None] + fixed_shape)
3436    value = array_ops.transpose(valid_array, perm, name='concat')
3437
3438    values_size = array_ops.shape(values)[axis]
3439    if max_size is None:
3440      batch_size = values_size
3441    else:
3442      batch_size = math_ops.minimum(values_size, max_size - size)
3443
3444    perm = [axis] + [n for n in range(ndim) if n != axis]
3445    batch_values = array_ops.transpose(values, perm)[:batch_size]
3446
3447    def reallocate():
3448      next_size = _next_array_size(new_size)
3449      next_shape = array_ops.stack([next_size] + fixed_shape)
3450      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
3451      old_value = array.value()
3452      assign_op = state_ops.assign(array, new_value, validate_shape=False)
3453      with ops.control_dependencies([assign_op]):
3454        copy_op = array[:size].assign(old_value[:size])
3455      # return value needs to be the same dtype as no_op() for cond
3456      with ops.control_dependencies([copy_op]):
3457        return control_flow_ops.no_op()
3458
3459    new_size = size + batch_size
3460    array_size = array_ops.shape_internal(array, optimize=False)[0]
3461    maybe_reallocate_op = control_flow_ops.cond(
3462        new_size > array_size, reallocate, control_flow_ops.no_op)
3463    with ops.control_dependencies([maybe_reallocate_op]):
3464      append_values_op = array[size:new_size].assign(batch_values)
3465    with ops.control_dependencies([append_values_op]):
3466      update_op = size.assign(new_size)
3467
3468    if metrics_collections:
3469      ops.add_to_collections(metrics_collections, value)
3470
3471    if updates_collections:
3472      ops.add_to_collections(updates_collections, update_op)
3473
3474    return value, update_op
3475    # pylint: enable=invalid-slice-index
3476
3477
3478def aggregate_metrics(*value_update_tuples):
3479  """Aggregates the metric value tensors and update ops into two lists.
3480
3481  Args:
3482    *value_update_tuples: a variable number of tuples, each of which contain the
3483      pair of (value_tensor, update_op) from a streaming metric.
3484
3485  Returns:
3486    A list of value `Tensor` objects and a list of update ops.
3487
3488  Raises:
3489    ValueError: if `value_update_tuples` is empty.
3490  """
3491  if not value_update_tuples:
3492    raise ValueError('Expected at least one value_tensor/update_op pair')
3493  value_ops, update_ops = zip(*value_update_tuples)
3494  return list(value_ops), list(update_ops)
3495
3496
3497def aggregate_metric_map(names_to_tuples):
3498  """Aggregates the metric names to tuple dictionary.
3499
3500  This function is useful for pairing metric names with their associated value
3501  and update ops when the list of metrics is long. For example:
3502
3503  ```python
3504    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
3505        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
3506            predictions, labels, weights),
3507        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
3508            predictions, labels, labels, weights),
3509        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
3510            predictions, labels, weights),
3511        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
3512            predictions, labels, weights),
3513    })
3514  ```
3515
3516  Args:
3517    names_to_tuples: a map of metric names to tuples, each of which contain the
3518      pair of (value_tensor, update_op) from a streaming metric.
3519
3520  Returns:
3521    A dictionary from metric names to value ops and a dictionary from metric
3522    names to update ops.
3523  """
3524  metric_names = names_to_tuples.keys()
3525  value_ops, update_ops = zip(*names_to_tuples.values())
3526  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
3527
3528
3529def count(values,
3530          weights=None,
3531          metrics_collections=None,
3532          updates_collections=None,
3533          name=None):
3534  """Computes the number of examples, or sum of `weights`.
3535
3536  When evaluating some metric (e.g. mean) on one or more subsets of the data,
3537  this auxiliary metric is useful for keeping track of how many examples there
3538  are in each subset.
3539
3540  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3541
3542  Args:
3543    values: A `Tensor` of arbitrary dimensions. Only it's shape is used.
3544    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3545      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
3546      must be either `1`, or the same as the corresponding `labels`
3547      dimension).
3548    metrics_collections: An optional list of collections that the metric
3549      value variable should be added to.
3550    updates_collections: An optional list of collections that the metric update
3551      ops should be added to.
3552    name: An optional variable_scope name.
3553
3554  Returns:
3555    count: A `Tensor` representing the current value of the metric.
3556    update_op: An operation that accumulates the metric from a batch of data.
3557
3558  Raises:
3559    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
3560      or if either `metrics_collections` or `updates_collections` are not a list
3561      or tuple.
3562  """
3563
3564  with variable_scope.variable_scope(name, 'count', (values, weights)):
3565    count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
3566
3567    if weights is None:
3568      num_values = math_ops.to_float(array_ops.size(values))
3569    else:
3570      _, _, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3571          predictions=values,
3572          labels=None,
3573          weights=weights)
3574      weights = weights_broadcast_ops.broadcast_weights(
3575          math_ops.to_float(weights), values)
3576      num_values = math_ops.reduce_sum(weights)
3577
3578    with ops.control_dependencies([values]):
3579      update_op = state_ops.assign_add(count_, num_values)
3580
3581    if metrics_collections:
3582      ops.add_to_collections(metrics_collections, count_)
3583
3584    if updates_collections:
3585      ops.add_to_collections(updates_collections, update_op)
3586
3587    return count_, update_op
3588
3589
3590def cohen_kappa(labels,
3591                predictions_idx,
3592                num_classes,
3593                weights=None,
3594                metrics_collections=None,
3595                updates_collections=None,
3596                name=None):
3597  """Calculates Cohen's kappa.
3598
3599  [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic
3600  that measures inter-annotator agreement.
3601
3602  The `cohen_kappa` function calculates the confusion matrix, and creates three
3603  local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`,
3604  which refer to the diagonal part, rows and columns totals of the confusion
3605  matrix, respectively. This value is ultimately returned as `kappa`, an
3606  idempotent operation that is calculated by
3607
3608      pe = (pe_row * pe_col) / N
3609      k = (sum(po) - sum(pe)) / (N - sum(pe))
3610
3611  For estimation of the metric over a stream of data, the function creates an
3612  `update_op` operation that updates these variables and returns the
3613  `kappa`. `update_op` weights each prediction by the corresponding value in
3614  `weights`.
3615
3616  Class labels are expected to start at 0. E.g., if `num_classes`
3617  was three, then the possible labels would be [0, 1, 2].
3618
3619  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3620
3621  NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method
3622  doesn't support weighted matrix yet.
3623
3624  Args:
3625    labels: 1-D `Tensor` of real labels for the classification task. Must be
3626      one of the following types: int16, int32, int64.
3627    predictions_idx: 1-D `Tensor` of predicted class indices for a given
3628      classification. Must have the same type as `labels`.
3629    num_classes: The possible number of labels.
3630    weights: Optional `Tensor` whose shape matches `predictions`.
3631    metrics_collections: An optional list of collections that `kappa` should
3632      be added to.
3633    updates_collections: An optional list of collections that `update_op` should
3634      be added to.
3635    name: An optional variable_scope name.
3636
3637  Returns:
3638    kappa: Scalar float `Tensor` representing the current Cohen's kappa.
3639    update_op: `Operation` that increments `po`, `pe_row` and `pe_col`
3640      variables appropriately and whose value matches `kappa`.
3641
3642  Raises:
3643    ValueError: If `num_classes` is less than 2, or `predictions` and `labels`
3644      have mismatched shapes, or if `weights` is not `None` and its shape
3645      doesn't match `predictions`, or if either `metrics_collections` or
3646      `updates_collections` are not a list or tuple.
3647    RuntimeError: If eager execution is enabled.
3648  """
3649  if context.in_eager_mode():
3650    raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
3651                       'when eager execution is enabled.')
3652  if num_classes < 2:
3653    raise ValueError('`num_classes` must be >= 2.'
3654                     'Found: {}'.format(num_classes))
3655  with variable_scope.variable_scope(name, 'cohen_kappa',
3656                                     (labels, predictions_idx, weights)):
3657    # Convert 2-dim (num, 1) to 1-dim (num,)
3658    labels.get_shape().with_rank_at_most(2)
3659    if labels.get_shape().ndims == 2:
3660      labels = array_ops.squeeze(labels, axis=[-1])
3661    predictions_idx, labels, weights = (
3662        metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
3663            predictions=predictions_idx,
3664            labels=labels,
3665            weights=weights))
3666    predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape())
3667
3668    stat_dtype = (
3669        dtypes.int64
3670        if weights is None or weights.dtype.is_integer else dtypes.float32)
3671    po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po')
3672    pe_row = metrics_impl.metric_variable(
3673        (num_classes,), stat_dtype, name='pe_row')
3674    pe_col = metrics_impl.metric_variable(
3675        (num_classes,), stat_dtype, name='pe_col')
3676
3677    # Table of the counts of agreement:
3678    counts_in_table = confusion_matrix.confusion_matrix(
3679        labels,
3680        predictions_idx,
3681        num_classes=num_classes,
3682        weights=weights,
3683        dtype=stat_dtype,
3684        name='counts_in_table')
3685
3686    po_t = array_ops.diag_part(counts_in_table)
3687    pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0)
3688    pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1)
3689    update_po = state_ops.assign_add(po, po_t)
3690    update_pe_row = state_ops.assign_add(pe_row, pe_row_t)
3691    update_pe_col = state_ops.assign_add(pe_col, pe_col_t)
3692
3693    def _calculate_k(po, pe_row, pe_col, name):
3694      po_sum = math_ops.reduce_sum(po)
3695      total = math_ops.reduce_sum(pe_row)
3696      pe_sum = math_ops.reduce_sum(
3697          metrics_impl._safe_div(  # pylint: disable=protected-access
3698              pe_row * pe_col, total, None))
3699      po_sum, pe_sum, total = (math_ops.to_double(po_sum),
3700                               math_ops.to_double(pe_sum),
3701                               math_ops.to_double(total))
3702      # kappa = (po - pe) / (N - pe)
3703      k = metrics_impl._safe_scalar_div(  # pylint: disable=protected-access
3704          po_sum - pe_sum,
3705          total - pe_sum,
3706          name=name)
3707      return k
3708
3709    kappa = _calculate_k(po, pe_row, pe_col, name='value')
3710    update_op = _calculate_k(
3711        update_po, update_pe_row, update_pe_col, name='update_op')
3712
3713    if metrics_collections:
3714      ops.add_to_collections(metrics_collections, kappa)
3715
3716    if updates_collections:
3717      ops.add_to_collections(updates_collections, update_op)
3718
3719    return kappa, update_op
3720
3721
3722__all__ = [
3723    'auc_with_confidence_intervals',
3724    'aggregate_metric_map',
3725    'aggregate_metrics',
3726    'cohen_kappa',
3727    'count',
3728    'precision_recall_at_equal_thresholds',
3729    'recall_at_precision',
3730    'sparse_recall_at_top_k',
3731    'streaming_accuracy',
3732    'streaming_auc',
3733    'streaming_curve_points',
3734    'streaming_dynamic_auc',
3735    'streaming_false_negative_rate',
3736    'streaming_false_negative_rate_at_thresholds',
3737    'streaming_false_negatives',
3738    'streaming_false_negatives_at_thresholds',
3739    'streaming_false_positive_rate',
3740    'streaming_false_positive_rate_at_thresholds',
3741    'streaming_false_positives',
3742    'streaming_false_positives_at_thresholds',
3743    'streaming_mean',
3744    'streaming_mean_absolute_error',
3745    'streaming_mean_cosine_distance',
3746    'streaming_mean_iou',
3747    'streaming_mean_relative_error',
3748    'streaming_mean_squared_error',
3749    'streaming_mean_tensor',
3750    'streaming_percentage_less',
3751    'streaming_precision',
3752    'streaming_precision_at_thresholds',
3753    'streaming_recall',
3754    'streaming_recall_at_k',
3755    'streaming_recall_at_thresholds',
3756    'streaming_root_mean_squared_error',
3757    'streaming_sensitivity_at_specificity',
3758    'streaming_sparse_average_precision_at_k',
3759    'streaming_sparse_average_precision_at_top_k',
3760    'streaming_sparse_precision_at_k',
3761    'streaming_sparse_precision_at_top_k',
3762    'streaming_sparse_recall_at_k',
3763    'streaming_specificity_at_sensitivity',
3764    'streaming_true_negatives',
3765    'streaming_true_negatives_at_thresholds',
3766    'streaming_true_positives',
3767    'streaming_true_positives_at_thresholds',
3768]
3769