metric_ops.py revision 4fb94d38b25c73d3bd8bd6a0ea632eb34c2787d6
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.framework import deprecated_args
26from tensorflow.contrib.framework.python.ops import variables as contrib_variables
27from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
28from tensorflow.contrib.metrics.python.ops import metric_ops_util
29from tensorflow.contrib.metrics.python.ops import set_ops
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn
36from tensorflow.python.ops import sparse_ops
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.ops import variables
40from tensorflow.python.util.all_util import make_all
41
42
43IGNORE_MASK_DATE = '2016-10-19'
44IGNORE_MASK_INSTRUCTIONS = (
45    '`ignore_mask` is being deprecated. Instead use `weights` with values 0.0 '
46    'and 1.0 to mask values. For example, `weights=tf.logical_not(mask)`.')
47
48
49def _mask_weights(mask=None, weights=None):
50  """Mask a given set of weights.
51
52  Elements are included when the corresponding `mask` element is `False`, and
53  excluded otherwise.
54
55  Args:
56    mask: An optional, `bool` `Tensor`.
57    weights: An optional `Tensor` whose shape matches `mask` if `mask` is not
58      `None`.
59
60  Returns:
61    Masked weights if `mask` and `weights` are not `None`, weights equivalent to
62    `mask` if `weights` is `None`, and otherwise `weights`.
63
64  Raises:
65    ValueError: If `weights` and `mask` are not `None` and have mismatched
66      shapes.
67  """
68  if mask is not None:
69    check_ops.assert_type(mask, dtypes.bool)
70    if weights is None:
71      weights = array_ops.ones_like(mask, dtype=dtypes.float32)
72    weights = math_ops.cast(math_ops.logical_not(mask), weights.dtype) * weights
73
74  return weights
75
76
77def _safe_div(numerator, denominator, name):
78  """Divides two values, returning 0 if the denominator is <= 0.
79
80  Args:
81    numerator: A real `Tensor`.
82    denominator: A real `Tensor`, with dtype matching `numerator`.
83    name: Name for the returned op.
84
85  Returns:
86    0 if `denominator` <= 0, else `numerator` / `denominator`
87  """
88  return math_ops.select(
89      math_ops.greater(denominator, 0),
90      math_ops.truediv(numerator, denominator),
91      0,
92      name=name)
93
94
95def _create_local(name, shape=None, collections=None, dtype=dtypes.float32):
96  """Creates a new local variable.
97
98  Args:
99    name: The name of the new or existing variable.
100    shape: Shape of the new or existing variable.
101    collections: A list of collection names to which the Variable will be added.
102    dtype: Data type of the variables.
103
104  Returns:
105    The created variable.
106  """
107  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
108  collections = list(collections or [])
109  collections += [ops.GraphKeys.LOCAL_VARIABLES]
110  return variables.Variable(
111      initial_value=array_ops.zeros(shape, dtype=dtype),
112      name=name,
113      trainable=False,
114      collections=collections)
115
116
117def _count_condition(values, weights=None, metrics_collections=None,
118                     updates_collections=None):
119  """Sums the weights of cases where the given values are True.
120
121  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
122
123  Args:
124    values: A `bool` `Tensor` of arbitrary size.
125    weights: An optional `Tensor` whose shape matches `values`.
126    metrics_collections: An optional list of collections that the metric
127      value variable should be added to.
128    updates_collections: An optional list of collections that the metric update
129      ops should be added to.
130
131  Returns:
132    value_tensor: A tensor representing the current value of the metric.
133    update_op: An operation that accumulates the error from a batch of data.
134
135  Raises:
136    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
137      or if either `metrics_collections` or `updates_collections` are not a list
138      or tuple.
139  """
140  check_ops.assert_type(values, dtypes.bool)
141  count = _create_local('count', shape=[])
142
143  values = math_ops.to_float(values)
144  if weights is not None:
145    values.get_shape().assert_is_compatible_with(weights.get_shape())
146    weights = math_ops.to_float(weights)
147    values = math_ops.mul(values, weights)
148
149  value_tensor = array_ops.identity(count)
150  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
151
152  if metrics_collections:
153    ops.add_to_collections(metrics_collections, value_tensor)
154
155  if updates_collections:
156    ops.add_to_collections(updates_collections, update_op)
157
158  return value_tensor, update_op
159
160
161def _streaming_true_negatives(predictions, labels, weights=None,
162                              metrics_collections=None,
163                              updates_collections=None,
164                              name=None):
165  """Sum the weights of true_negatives.
166
167  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
168
169  Args:
170    predictions: The predicted values, a `bool` `Tensor` of arbitrary
171      dimensions.
172    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
173      match `predictions`.
174    weights: An optional `Tensor` whose shape matches `predictions`.
175    metrics_collections: An optional list of collections that the metric
176      value variable should be added to.
177    updates_collections: An optional list of collections that the metric update
178      ops should be added to.
179    name: An optional variable_scope name.
180
181  Returns:
182    value_tensor: A tensor representing the current value of the metric.
183    update_op: An operation that accumulates the error from a batch of data.
184
185  Raises:
186    ValueError: If `predictions` and `labels` have mismatched shapes, or if
187      `weights` is not `None` and its shape doesn't match `predictions`, or if
188      either `metrics_collections` or `updates_collections` are not a list or
189      tuple.
190  """
191  with variable_scope.variable_scope(
192      [predictions, labels], name, 'true_negatives'):
193
194    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
195    is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0),
196                                            math_ops.equal(predictions, 0))
197    return _count_condition(is_true_negative, weights, metrics_collections,
198                            updates_collections)
199
200
201def _streaming_true_positives(predictions, labels, weights=None,
202                              metrics_collections=None,
203                              updates_collections=None,
204                              name=None):
205  """Sum the weights of true_positives.
206
207  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
208
209  Args:
210    predictions: The predicted values, a `bool` `Tensor` of arbitrary
211      dimensions.
212    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
213      match `predictions`.
214    weights: An optional `Tensor` whose shape matches `predictions`.
215    metrics_collections: An optional list of collections that the metric
216      value variable should be added to.
217    updates_collections: An optional list of collections that the metric update
218      ops should be added to.
219    name: An optional variable_scope name.
220
221  Returns:
222    value_tensor: A tensor representing the current value of the metric.
223    update_op: An operation that accumulates the error from a batch of data.
224
225  Raises:
226    ValueError: If `predictions` and `labels` have mismatched shapes, or if
227      `weights` is not `None` and its shape doesn't match `predictions`, or if
228      either `metrics_collections` or `updates_collections` are not a list or
229      tuple.
230  """
231  with variable_scope.variable_scope(
232      name, 'true_positives', [predictions, labels]):
233
234    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
235    is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1),
236                                            math_ops.equal(predictions, 1))
237    return _count_condition(is_true_positive, weights, metrics_collections,
238                            updates_collections)
239
240
241def _streaming_false_positives(predictions, labels, weights=None,
242                               metrics_collections=None,
243                               updates_collections=None,
244                               name=None):
245  """Sum the weights of false positives.
246
247  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
248
249  Args:
250    predictions: The predicted values, a `bool` `Tensor` of arbitrary
251      dimensions.
252    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
253      match `predictions`.
254    weights: An optional `Tensor` whose shape matches `predictions`.
255    metrics_collections: An optional list of collections that the metric
256      value variable should be added to.
257    updates_collections: An optional list of collections that the metric update
258      ops should be added to.
259    name: An optional variable_scope name.
260
261  Returns:
262    value_tensor: A tensor representing the current value of the metric.
263    update_op: An operation that accumulates the error from a batch of data.
264
265  Raises:
266    ValueError: If `predictions` and `labels` have mismatched shapes, or if
267      `weights` is not `None` and its shape doesn't match `predictions`, or if
268      either `metrics_collections` or `updates_collections` are not a list or
269      tuple.
270  """
271  with variable_scope.variable_scope(
272      name, 'false_positives', [predictions, labels]):
273
274    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
275    is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0),
276                                             math_ops.equal(predictions, 1))
277    return _count_condition(is_false_positive, weights, metrics_collections,
278                            updates_collections)
279
280
281def _streaming_false_negatives(predictions, labels, weights=None,
282                               metrics_collections=None,
283                               updates_collections=None,
284                               name=None):
285  """Computes the total number of false positives.
286
287  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
288
289  Args:
290    predictions: The predicted values, a `bool` `Tensor` of arbitrary
291      dimensions.
292    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
293      match `predictions`.
294    weights: An optional `Tensor` whose shape matches `predictions`.
295    metrics_collections: An optional list of collections that the metric
296      value variable should be added to.
297    updates_collections: An optional list of collections that the metric update
298      ops should be added to.
299    name: An optional variable_scope name.
300
301  Returns:
302    value_tensor: A tensor representing the current value of the metric.
303    update_op: An operation that accumulates the error from a batch of data.
304
305  Raises:
306    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
307      or if either `metrics_collections` or `updates_collections` are not a list
308      or tuple.
309  """
310  with variable_scope.variable_scope(
311      name, 'false_negatives', [predictions, labels]):
312
313    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
314    is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1),
315                                             math_ops.equal(predictions, 0))
316    return _count_condition(is_false_negative, weights, metrics_collections,
317                            updates_collections)
318
319
320def streaming_mean(values, weights=None, metrics_collections=None,
321                   updates_collections=None, name=None):
322  """Computes the (weighted) mean of the given values.
323
324  The `streaming_mean` function creates two local variables, `total` and `count`
325  that are used to compute the average of `values`. This average is ultimately
326  returned as `mean` which is an idempotent operation that simply divides
327  `total` by `count`.
328
329  For estimation of the metric  over a stream of data, the function creates an
330  `update_op` operation that updates these variables and returns the `mean`.
331  `update_op` increments `total` with the reduced sum of the product of `values`
332  and `weights`, and it increments `count` with the reduced sum of `weights`.
333
334  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
335
336  Args:
337    values: A `Tensor` of arbitrary dimensions.
338    weights: An optional `Tensor` whose shape matches `values`.
339    metrics_collections: An optional list of collections that `mean`
340      should be added to.
341    updates_collections: An optional list of collections that `update_op`
342      should be added to.
343    name: An optional variable_scope name.
344
345  Returns:
346    mean: A tensor representing the current mean, the value of `total` divided
347      by `count`.
348    update_op: An operation that increments the `total` and `count` variables
349      appropriately and whose value matches `mean_value`.
350
351  Raises:
352    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
353      or if either `metrics_collections` or `updates_collections` are not a list
354      or tuple.
355  """
356  with variable_scope.variable_scope(name, 'mean', [values, weights]):
357    values = math_ops.to_float(values)
358
359    total = _create_local('total', shape=[])
360    count = _create_local('count', shape=[])
361
362    if weights is not None:
363      values.get_shape().assert_is_compatible_with(weights.get_shape())
364      weights = math_ops.to_float(weights)
365      values = math_ops.mul(values, weights)
366      num_values = math_ops.reduce_sum(weights)
367    else:
368      num_values = math_ops.to_float(array_ops.size(values))
369
370    total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
371    count_compute_op = state_ops.assign_add(count, num_values)
372
373    mean = _safe_div(total, count, 'value')
374    with ops.control_dependencies([total_compute_op, count_compute_op]):
375      update_op = _safe_div(total, count, 'update_op')
376
377    if metrics_collections:
378      ops.add_to_collections(metrics_collections, mean)
379
380    if updates_collections:
381      ops.add_to_collections(updates_collections, update_op)
382
383    return mean, update_op
384
385
386def streaming_mean_tensor(values, weights=None, metrics_collections=None,
387                          updates_collections=None, name=None):
388  """Computes the element-wise (weighted) mean of the given tensors.
389
390  In contrast to the `streaming_mean` function which returns a scalar with the
391  mean,  this function returns an average tensor with the same shape as the
392  input tensors.
393
394  The `streaming_mean_tensor` function creates two local variables,
395  `total_tensor` and `count_tensor` that are used to compute the average of
396  `values`. This average is ultimately returned as `mean` which is an idempotent
397  operation that simply divides `total` by `count`.
398
399  For estimation of the metric  over a stream of data, the function creates an
400  `update_op` operation that updates these variables and returns the `mean`.
401  `update_op` increments `total` with the reduced sum of the product of `values`
402  and `weights`, and it increments `count` with the reduced sum of `weights`.
403
404  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
405
406  Args:
407    values: A `Tensor` of arbitrary dimensions.
408    weights: An optional `Tensor` whose shape matches `values`.
409    metrics_collections: An optional list of collections that `mean`
410      should be added to.
411    updates_collections: An optional list of collections that `update_op`
412      should be added to.
413    name: An optional variable_scope name.
414
415  Returns:
416    mean: A float tensor representing the current mean, the value of `total`
417      divided by `count`.
418    update_op: An operation that increments the `total` and `count` variables
419      appropriately and whose value matches `mean_value`.
420
421  Raises:
422    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
423      or if either `metrics_collections` or `updates_collections` are not a list
424      or tuple.
425  """
426  with variable_scope.variable_scope(name, 'mean', [values, weights]):
427    total = _create_local('total_tensor', shape=values.get_shape())
428    count = _create_local('count_tensor', shape=values.get_shape())
429
430    if weights is not None:
431      values.get_shape().assert_is_compatible_with(weights.get_shape())
432      weights = math_ops.to_float(weights)
433      values = math_ops.mul(values, weights)
434      num_values = weights
435    else:
436      num_values = array_ops.ones_like(values)
437
438    total_compute_op = state_ops.assign_add(total, values)
439    count_compute_op = state_ops.assign_add(count, num_values)
440
441    def compute_mean(total, count, name):
442      non_zero_count = math_ops.maximum(count,
443                                        array_ops.ones_like(count),
444                                        name=name)
445      return math_ops.truediv(total, non_zero_count, name=name)
446
447    mean = compute_mean(total, count, 'value')
448    with ops.control_dependencies([total_compute_op, count_compute_op]):
449      update_op = compute_mean(total, count, 'update_op')
450
451    if metrics_collections:
452      ops.add_to_collections(metrics_collections, mean)
453
454    if updates_collections:
455      ops.add_to_collections(updates_collections, update_op)
456
457    return mean, update_op
458
459
460def streaming_accuracy(predictions, labels, weights=None,
461                       metrics_collections=None, updates_collections=None,
462                       name=None):
463  """Calculates how often `predictions` matches `labels`.
464
465  The `streaming_accuracy` function creates two local variables, `total` and
466  `count` that are used to compute the frequency with which `predictions`
467  matches `labels`. This frequency is ultimately returned as `accuracy`: an
468  idempotent operation that simply divides `total` by `count`.
469
470  For estimation of the metric  over a stream of data, the function creates an
471  `update_op` operation that updates these variables and returns the `accuracy`.
472  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
473  where the corresponding elements of `predictions` and `labels` match and 0.0
474  otherwise. Then `update_op` increments `total` with the reduced sum of the
475  product of `weights` and `is_correct`, and it increments `count` with the
476  reduced sum of `weights`.
477
478  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
479
480  Args:
481    predictions: The predicted values, a `Tensor` of any shape.
482    labels: The ground truth values, a `Tensor` whose shape matches
483      `predictions`.
484    weights: An optional `Tensor` whose shape matches `predictions`.
485    metrics_collections: An optional list of collections that `accuracy` should
486      be added to.
487    updates_collections: An optional list of collections that `update_op` should
488      be added to.
489    name: An optional variable_scope name.
490
491  Returns:
492    accuracy: A tensor representing the accuracy, the value of `total` divided
493      by `count`.
494    update_op: An operation that increments the `total` and `count` variables
495      appropriately and whose value matches `accuracy`.
496
497  Raises:
498    ValueError: If `predictions` and `labels` have mismatched shapes, or if
499      `weights` is not `None` and its shape doesn't match `predictions`, or if
500      either `metrics_collections` or `updates_collections` are not a list or
501      tuple.
502  """
503  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
504      predictions, labels)
505  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
506  if labels.dtype != predictions.dtype:
507    predictions = math_ops.cast(predictions, labels.dtype)
508  is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
509  return streaming_mean(is_correct, weights, metrics_collections,
510                        updates_collections, name or 'accuracy')
511
512
513@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
514def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
515                        metrics_collections=None, updates_collections=None,
516                        name=None):
517  """Computes the precision of the predictions with respect to the labels.
518
519  The `streaming_precision` function creates two local variables,
520  `true_positives` and `false_positives`, that are used to compute the
521  precision. This value is ultimately returned as `precision`, an idempotent
522  operation that simply divides `true_positives` by the sum of `true_positives`
523  and `false_positives`.
524
525  For estimation of the metric  over a stream of data, the function creates an
526  `update_op` operation that updates these variables and returns the
527  `precision`. `update_op` weights each prediction by the corresponding value in
528  `weights`.
529
530  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
531  Alternatively, if `ignore_mask` is not `None`, then mask values where
532  `ignore_mask` is `True`.
533
534  Args:
535    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
536    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
537      match `predictions`.
538    ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
539    weights: An optional `Tensor` whose shape matches `predictions`.
540    metrics_collections: An optional list of collections that `precision` should
541      be added to.
542    updates_collections: An optional list of collections that `update_op` should
543      be added to.
544    name: An optional variable_scope name.
545
546  Returns:
547    precision: Scalar float `Tensor` with the value of `true_positives`
548      divided by the sum of `true_positives` and `false_positives`.
549    update_op: `Operation` that increments `true_positives` and
550      `false_positives` variables appropriately and whose value matches
551      `precision`.
552
553  Raises:
554    ValueError: If `predictions` and `labels` have mismatched shapes, or if
555      `ignore_mask` is not `None` and its shape doesn't match `predictions`, or
556      if `weights` is not `None` and its shape doesn't match `predictions`, or
557      if either `metrics_collections` or `updates_collections` are not a list or
558      tuple.
559  """
560  with variable_scope.variable_scope(
561      name, 'precision', [predictions, labels]):
562
563    predictions, labels = metric_ops_util.remove_squeezable_dimensions(
564        predictions, labels)
565    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
566
567    weights = _mask_weights(ignore_mask, weights)
568    true_positives, true_positives_update_op = _streaming_true_positives(
569        predictions, labels, weights, metrics_collections=None,
570        updates_collections=None, name=None)
571    false_positives, false_positives_update_op = _streaming_false_positives(
572        predictions, labels, weights, metrics_collections=None,
573        updates_collections=None, name=None)
574
575    def compute_precision(name):
576      return math_ops.select(
577          math_ops.greater(true_positives + false_positives, 0),
578          math_ops.div(true_positives, true_positives + false_positives),
579          0,
580          name)
581
582    precision = compute_precision('value')
583    with ops.control_dependencies([true_positives_update_op,
584                                   false_positives_update_op]):
585      update_op = compute_precision('update_op')
586
587    if metrics_collections:
588      ops.add_to_collections(metrics_collections, precision)
589
590    if updates_collections:
591      ops.add_to_collections(updates_collections, update_op)
592
593    return precision, update_op
594
595
596@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
597def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
598                     metrics_collections=None, updates_collections=None,
599                     name=None):
600  """Computes the recall of the predictions with respect to the labels.
601
602  The `streaming_recall` function creates two local variables, `true_positives`
603  and `false_negatives`, that are used to compute the recall. This value is
604  ultimately returned as `recall`, an idempotent operation that simply divides
605  `true_positives` by the sum of `true_positives`  and `false_negatives`.
606
607  For estimation of the metric  over a stream of data, the function creates an
608  `update_op` that updates these variables and returns the `recall`. `update_op`
609  weights each prediction by the corresponding value in `weights`.
610
611  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
612  Alternatively, if `ignore_mask` is not `None`, then mask values where
613  `ignore_mask` is `True`.
614
615  Args:
616    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
617    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
618      match `predictions`.
619    ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
620    weights: An optional `Tensor` whose shape matches `predictions`.
621    metrics_collections: An optional list of collections that `recall` should
622      be added to.
623    updates_collections: An optional list of collections that `update_op` should
624      be added to.
625    name: An optional variable_scope name.
626
627  Returns:
628    recall: Scalar float `Tensor` with the value of `true_positives` divided
629      by the sum of `true_positives` and `false_negatives`.
630    update_op: `Operation` that increments `true_positives` and
631      `false_negatives` variables appropriately and whose value matches
632      `recall`.
633
634  Raises:
635    ValueError: If `predictions` and `labels` have mismatched shapes, or if
636      `ignore_mask` is not `None` and its shape doesn't match `predictions`, or
637      if `weights` is not `None` and its shape doesn't match `predictions`, or
638      if either `metrics_collections` or `updates_collections` are not a list or
639      tuple.
640  """
641  with variable_scope.variable_scope(name, 'recall', [predictions, labels]):
642    predictions, labels = metric_ops_util.remove_squeezable_dimensions(
643        predictions, labels)
644    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
645
646    weights = _mask_weights(ignore_mask, weights)
647    true_positives, true_positives_update_op = _streaming_true_positives(
648        predictions, labels, weights, metrics_collections=None,
649        updates_collections=None, name=None)
650    false_negatives, false_negatives_update_op = _streaming_false_negatives(
651        predictions, labels, weights, metrics_collections=None,
652        updates_collections=None, name=None)
653
654    def compute_recall(true_positives, false_negatives, name):
655      return math_ops.select(
656          math_ops.greater(true_positives + false_negatives, 0),
657          math_ops.div(true_positives, true_positives + false_negatives),
658          0,
659          name)
660
661    recall = compute_recall(true_positives, false_negatives, 'value')
662    with ops.control_dependencies([true_positives_update_op,
663                                   false_negatives_update_op]):
664      update_op = compute_recall(true_positives, false_negatives, 'update_op')
665
666    if metrics_collections:
667      ops.add_to_collections(metrics_collections, recall)
668
669    if updates_collections:
670      ops.add_to_collections(updates_collections, update_op)
671
672    return recall, update_op
673
674
675def _tp_fn_tn_fp(predictions, labels, thresholds, weights=None):
676  """Computes true_positives, false_negatives, true_negatives, false_positives.
677
678  The `_tp_fn_tn_fp` function creates four local variables, `true_positives`,
679  `true_negatives`, `false_positives` and `false_negatives`.
680  `true_positive[i]` is defined as the total weight of values in `predictions`
681  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
682  `false_negatives[i]` is defined as the total weight of values in `predictions`
683  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
684  `true_negatives[i]` is defined as the total weight of values in `predictions`
685  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
686  `false_positives[i]` is defined as the total weight of values in `predictions`
687  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
688
689  For estimation of these metrics over a stream of data, for each metric the
690  function respectively creates an `update_op` operation that updates the
691  variable and returns its value.
692
693  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
694
695  Args:
696    predictions: A floating point `Tensor` of arbitrary shape and whose values
697      are in the range `[0, 1]`.
698    labels: A `bool` `Tensor` whose shape matches `predictions`.
699    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
700    weights: An optional `Tensor` whose shape matches `predictions`.
701
702  Returns:
703    true_positive: A variable of shape [len(thresholds)].
704    false_negative: A variable of shape [len(thresholds)].
705    true_negatives: A variable of shape [len(thresholds)].
706    false_positives: A variable of shape [len(thresholds)].
707    true_positives_update_op: An operation that increments the `true_positives`.
708    false_negative_update_op: An operation that increments the `false_negative`.
709    true_negatives_update_op: An operation that increments the `true_negatives`.
710    false_positives_update_op: An operation that increments the
711      `false_positives`.
712
713  Raises:
714    ValueError: If `predictions` and `labels` have mismatched shapes, or if
715      `weights` is not `None` and its shape doesn't match `predictions`.
716  """
717  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
718      predictions, labels)
719  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
720
721  num_thresholds = len(thresholds)
722
723  # Reshape predictions and labels
724  predictions = array_ops.reshape(predictions, [-1, 1])
725  labels = array_ops.reshape(math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
726
727  # Use static shape if known.
728  num_predictions = predictions.get_shape().as_list()[0]
729
730  # Otherwise use dynamic shape.
731  if num_predictions is None:
732    num_predictions = array_ops.shape(predictions)[0]
733  thresh_tiled = array_ops.tile(
734      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
735      array_ops.pack([1, num_predictions]))
736
737  # Tile the predictions after thresholding them across different thresholds.
738  pred_is_pos = math_ops.greater(
739      array_ops.tile(array_ops.transpose(predictions), [num_thresholds, 1]),
740      thresh_tiled)
741  pred_is_neg = math_ops.logical_not(pred_is_pos)
742
743  # Tile labels by number of thresholds
744  label_is_pos = array_ops.tile(labels, [num_thresholds, 1])
745  label_is_neg = math_ops.logical_not(label_is_pos)
746
747  true_positives = _create_local('true_positives', shape=[num_thresholds])
748  false_negatives = _create_local('false_negatives', shape=[num_thresholds])
749  true_negatives = _create_local('true_negatives', shape=[num_thresholds])
750  false_positives = _create_local('false_positives', shape=[num_thresholds])
751
752  is_true_positive = math_ops.to_float(
753      math_ops.logical_and(label_is_pos, pred_is_pos))
754  is_false_negative = math_ops.to_float(
755      math_ops.logical_and(label_is_pos, pred_is_neg))
756  is_false_positive = math_ops.to_float(
757      math_ops.logical_and(label_is_neg, pred_is_pos))
758  is_true_negative = math_ops.to_float(
759      math_ops.logical_and(label_is_neg, pred_is_neg))
760
761  if weights is not None:
762    weights = math_ops.to_float(weights)
763    weights_tiled = array_ops.tile(
764        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
765    thresh_tiled.get_shape().assert_is_compatible_with(
766        weights_tiled.get_shape())
767    is_true_positive *= weights_tiled
768    is_false_negative *= weights_tiled
769    is_false_positive *= weights_tiled
770    is_true_negative *= weights_tiled
771
772  true_positives_update_op = state_ops.assign_add(
773      true_positives, math_ops.reduce_sum(is_true_positive, 1))
774  false_negatives_update_op = state_ops.assign_add(
775      false_negatives, math_ops.reduce_sum(is_false_negative, 1))
776  true_negatives_update_op = state_ops.assign_add(
777      true_negatives, math_ops.reduce_sum(is_true_negative, 1))
778  false_positives_update_op = state_ops.assign_add(
779      false_positives, math_ops.reduce_sum(is_false_positive, 1))
780
781  return (true_positives, false_negatives, true_negatives, false_positives,
782          true_positives_update_op, false_negatives_update_op,
783          true_negatives_update_op, false_positives_update_op)
784
785
786def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
787                  metrics_collections=None, updates_collections=None,
788                  curve='ROC', name=None):
789  """Computes the approximate AUC via a Riemann sum.
790
791  The `streaming_auc` function creates four local variables, `true_positives`,
792  `true_negatives`, `false_positives` and `false_negatives` that are used to
793  compute the AUC. To discretize the AUC curve, a linearly spaced set of
794  thresholds is used to compute pairs of recall and precision values. The area
795  under the ROC-curve is therefore computed using the height of the recall
796  values by the false positive rate, while the area under the PR-curve is the
797  computed using the height of the precision values by the recall.
798
799  This value is ultimately returned as `auc`, an idempotent operation that
800  computes the area under a discretized curve of precision versus recall values
801  (computed using the afformentioned variables). The `num_thresholds` variable
802  controls the degree of discretization with larger numbers of thresholds more
803  closely approximating the true AUC.
804
805  For estimation of the metric over a stream of data, the function creates an
806  `update_op` operation that updates these variables and returns the `auc`.
807
808  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
809
810  Args:
811    predictions: A floating point `Tensor` of arbitrary shape and whose values
812      are in the range `[0, 1]`.
813    labels: A `bool` `Tensor` whose shape matches `predictions`.
814    weights: An optional `Tensor` whose shape matches `predictions`.
815    num_thresholds: The number of thresholds to use when discretizing the roc
816      curve.
817    metrics_collections: An optional list of collections that `auc` should be
818      added to.
819    updates_collections: An optional list of collections that `update_op` should
820      be added to.
821    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
822    'PR' for the Precision-Recall-curve.
823    name: An optional variable_scope name.
824
825  Returns:
826    auc: A scalar tensor representing the current area-under-curve.
827    update_op: An operation that increments the `true_positives`,
828      `true_negatives`, `false_positives` and `false_negatives` variables
829      appropriately and whose value matches `auc`.
830
831  Raises:
832    ValueError: If `predictions` and `labels` have mismatched shapes, or if
833      `weights` is not `None` and its shape doesn't match `predictions`, or if
834      either `metrics_collections` or `updates_collections` are not a list or
835      tuple.
836  """
837  with variable_scope.variable_scope(name, 'auc', [predictions, labels]):
838    if curve != 'ROC' and  curve != 'PR':
839      raise ValueError('curve must be either ROC or PR, %s unknown' %
840                       (curve))
841    kepsilon = 1e-7  # to account for floating point imprecisions
842    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
843                  for i in range(num_thresholds-2)]
844    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
845
846    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
847     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
848
849    # Add epsilons to avoid dividing by 0.
850    epsilon = 1.0e-6
851    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
852
853    def compute_auc(tp, fn, tn, fp, name):
854      """Computes the roc-auc or pr-auc based on confusion counts."""
855      recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
856      if curve == 'ROC':
857        fp_rate = math_ops.div(fp, fp + tn + epsilon)
858        x = fp_rate
859        y = recall
860      else:  # curve == 'PR'.
861        precision = math_ops.div(tp + epsilon, tp + fp + epsilon)
862        x = recall
863        y = precision
864      return math_ops.reduce_sum(math_ops.mul(
865          x[:num_thresholds - 1] - x[1:],
866          (y[:num_thresholds - 1] + y[1:]) / 2.), name=name)
867
868    # sum up the areas of all the trapeziums
869    auc = compute_auc(tp, fn, tn, fp, 'value')
870    update_op = compute_auc(
871        tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op')
872
873    if metrics_collections:
874      ops.add_to_collections(metrics_collections, auc)
875
876    if updates_collections:
877      ops.add_to_collections(updates_collections, update_op)
878
879    return auc, update_op
880
881
882def streaming_specificity_at_sensitivity(
883    predictions, labels, sensitivity, weights=None, num_thresholds=200,
884    metrics_collections=None, updates_collections=None, name=None):
885  """Computes the the specificity at a given sensitivity.
886
887  The `streaming_specificity_at_sensitivity` function creates four local
888  variables, `true_positives`, `true_negatives`, `false_positives` and
889  `false_negatives` that are used to compute the specificity at the given
890  sensitivity value. The threshold for the given sensitivity value is computed
891  and used to evaluate the corresponding specificity.
892
893  For estimation of the metric over a stream of data, the function creates an
894  `update_op` operation that updates these variables and returns the
895  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
896  `false_positives` and `false_negatives` counts with the weight of each case
897  found in the `predictions` and `labels`.
898
899  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
900
901  For additional information about specificity and sensitivity, see the
902  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
903
904  Args:
905    predictions: A floating point `Tensor` of arbitrary shape and whose values
906      are in the range `[0, 1]`.
907    labels: A `bool` `Tensor` whose shape matches `predictions`.
908    sensitivity: A scalar value in range `[0, 1]`.
909    weights: An optional `Tensor` whose shape matches `predictions`.
910    num_thresholds: The number of thresholds to use for matching the given
911      sensitivity.
912    metrics_collections: An optional list of collections that `specificity`
913      should be added to.
914    updates_collections: An optional list of collections that `update_op` should
915      be added to.
916    name: An optional variable_scope name.
917
918  Returns:
919    specificity: A scalar tensor representing the specificity at the given
920      `specificity` value.
921    update_op: An operation that increments the `true_positives`,
922      `true_negatives`, `false_positives` and `false_negatives` variables
923      appropriately and whose value matches `specificity`.
924
925  Raises:
926    ValueError: If `predictions` and `labels` have mismatched shapes, if
927      `weights` is not `None` and its shape doesn't match `predictions`, or if
928      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
929      or `updates_collections` are not a list or tuple.
930  """
931  if sensitivity < 0 or sensitivity > 1:
932    raise ValueError('`sensitivity` must be in the range [0, 1].')
933
934  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
935                                     [predictions, labels]):
936    kepsilon = 1e-7  # to account for floating point imprecisions
937    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
938                  for i in range(num_thresholds-2)]
939    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
940
941    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
942     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
943
944    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
945
946    def compute_specificity_at_sensitivity(name):
947      """Computes the specificity at the given sensitivity.
948
949      Args:
950        name: The name of the operation.
951
952      Returns:
953        The specificity using the aggregated values.
954      """
955      sensitivities = math_ops.div(tp, tp + fn + kepsilon)
956
957      # We'll need to use this trick until tf.argmax allows us to specify
958      # whether we should use the first or last index in case of ties.
959      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
960      indices_at_minval = math_ops.equal(
961          math_ops.abs(sensitivities - sensitivity), min_val)
962      indices_at_minval = math_ops.to_int64(indices_at_minval)
963      indices_at_minval = math_ops.cumsum(indices_at_minval)
964      tf_index = math_ops.argmax(indices_at_minval, 0)
965      tf_index = math_ops.cast(tf_index, dtypes.int32)
966
967      # Now, we have the implicit threshold, so compute the specificity:
968      return math_ops.div(tn[tf_index],
969                          tn[tf_index] + fp[tf_index] + kepsilon,
970                          name)
971
972    specificity = compute_specificity_at_sensitivity('value')
973    with ops.control_dependencies(
974        [tp_update_op, fn_update_op, tn_update_op, fp_update_op]):
975      update_op = compute_specificity_at_sensitivity('update_op')
976
977    if metrics_collections:
978      ops.add_to_collections(metrics_collections, specificity)
979
980    if updates_collections:
981      ops.add_to_collections(updates_collections, update_op)
982
983    return specificity, update_op
984
985
986def streaming_sensitivity_at_specificity(
987    predictions, labels, specificity, weights=None, num_thresholds=200,
988    metrics_collections=None, updates_collections=None, name=None):
989  """Computes the the specificity at a given sensitivity.
990
991  The `streaming_sensitivity_at_specificity` function creates four local
992  variables, `true_positives`, `true_negatives`, `false_positives` and
993  `false_negatives` that are used to compute the sensitivity at the given
994  specificity value. The threshold for the given specificity value is computed
995  and used to evaluate the corresponding sensitivity.
996
997  For estimation of the metric over a stream of data, the function creates an
998  `update_op` operation that updates these variables and returns the
999  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
1000  `false_positives` and `false_negatives` counts with the weight of each case
1001  found in the `predictions` and `labels`.
1002
1003  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1004
1005  For additional information about specificity and sensitivity, see the
1006  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1007
1008  Args:
1009    predictions: A floating point `Tensor` of arbitrary shape and whose values
1010      are in the range `[0, 1]`.
1011    labels: A `bool` `Tensor` whose shape matches `predictions`.
1012    specificity: A scalar value in range `[0, 1]`.
1013    weights: An optional `Tensor` whose shape matches `predictions`.
1014    num_thresholds: The number of thresholds to use for matching the given
1015      specificity.
1016    metrics_collections: An optional list of collections that `sensitivity`
1017      should be added to.
1018    updates_collections: An optional list of collections that `update_op` should
1019      be added to.
1020    name: An optional variable_scope name.
1021
1022  Returns:
1023    sensitivity: A scalar tensor representing the sensitivity at the given
1024      `specificity` value.
1025    update_op: An operation that increments the `true_positives`,
1026      `true_negatives`, `false_positives` and `false_negatives` variables
1027      appropriately and whose value matches `sensitivity`.
1028
1029  Raises:
1030    ValueError: If `predictions` and `labels` have mismatched shapes, if
1031      `weights` is not `None` and its shape doesn't match `predictions`, or if
1032      `specificity` is not between 0 and 1, or if either `metrics_collections`
1033      or `updates_collections` are not a list or tuple.
1034  """
1035  if specificity < 0 or specificity > 1:
1036    raise ValueError('`specificity` must be in the range [0, 1].')
1037
1038  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
1039                                     [predictions, labels]):
1040    kepsilon = 1e-7  # to account for floating point imprecisions
1041    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1042                  for i in range(num_thresholds-2)]
1043    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
1044
1045    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
1046     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
1047    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
1048
1049    def compute_sensitivity_at_specificity(name):
1050      specificities = math_ops.div(tn, tn + fp + kepsilon)
1051      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
1052      tf_index = math_ops.cast(tf_index, dtypes.int32)
1053
1054      # Now, we have the implicit threshold, so compute the sensitivity:
1055      return math_ops.div(tp[tf_index],
1056                          tp[tf_index] + fn[tf_index] + kepsilon,
1057                          name)
1058
1059    sensitivity = compute_sensitivity_at_specificity('value')
1060    with ops.control_dependencies(
1061        [tp_update_op, fn_update_op, tn_update_op, fp_update_op]):
1062      update_op = compute_sensitivity_at_specificity('update_op')
1063
1064    if metrics_collections:
1065      ops.add_to_collections(metrics_collections, sensitivity)
1066
1067    if updates_collections:
1068      ops.add_to_collections(updates_collections, update_op)
1069
1070    return sensitivity, update_op
1071
1072
1073def streaming_precision_at_thresholds(predictions, labels, thresholds,
1074                                      weights=None,
1075                                      metrics_collections=None,
1076                                      updates_collections=None, name=None):
1077  """Computes precision values for different `thresholds` on `predictions`.
1078
1079  The `streaming_precision_at_thresholds` function creates four local variables,
1080  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1081  for various values of thresholds. `precision[i]` is defined as the total
1082  weight of values in `predictions` above `thresholds[i]` whose corresponding
1083  entry in `labels` is `True`, divided by the total weight of values in
1084  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
1085  false_positives[i])`).
1086
1087  For estimation of the metric over a stream of data, the function creates an
1088  `update_op` operation that updates these variables and returns the
1089  `precision`.
1090
1091  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1092
1093  Args:
1094    predictions: A floating point `Tensor` of arbitrary shape and whose values
1095      are in the range `[0, 1]`.
1096    labels: A `bool` `Tensor` whose shape matches `predictions`.
1097    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1098    weights: An optional `Tensor` whose shape matches `predictions`.
1099    metrics_collections: An optional list of collections that `auc` should be
1100      added to.
1101    updates_collections: An optional list of collections that `update_op` should
1102      be added to.
1103    name: An optional variable_scope name.
1104
1105  Returns:
1106    precision: A float tensor of shape [len(thresholds)].
1107    update_op: An operation that increments the `true_positives`,
1108      `true_negatives`, `false_positives` and `false_negatives` variables that
1109      are used in the computation of `precision`.
1110
1111  Raises:
1112    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1113      `weights` is not `None` and its shape doesn't match `predictions`, or if
1114      either `metrics_collections` or `updates_collections` are not a list or
1115      tuple.
1116  """
1117  with variable_scope.variable_scope(name, 'precision_at_thresholds',
1118                                     [predictions, labels]):
1119
1120    # TODO(nsilberman): Replace with only tp and fp, this results in unnecessary
1121    # variable creation. b/30842882
1122    (true_positives, _, _, false_positives, true_positives_compute_op, _, _,
1123     false_positives_compute_op,) = _tp_fn_tn_fp(
1124         predictions, labels, thresholds, weights)
1125
1126    # avoid division by zero
1127    epsilon = 1e-7
1128    def compute_precision(name):
1129      precision = math_ops.div(true_positives,
1130                               epsilon + true_positives + false_positives,
1131                               name='precision_' + name)
1132      return precision
1133
1134    precision = compute_precision('value')
1135    with ops.control_dependencies([true_positives_compute_op,
1136                                   false_positives_compute_op]):
1137      update_op = compute_precision('update_op')
1138
1139    if metrics_collections:
1140      ops.add_to_collections(metrics_collections, precision)
1141
1142    if updates_collections:
1143      ops.add_to_collections(updates_collections, update_op)
1144
1145    return precision, update_op
1146
1147
1148def streaming_recall_at_thresholds(predictions, labels, thresholds,
1149                                   weights=None, metrics_collections=None,
1150                                   updates_collections=None, name=None):
1151  """Computes various recall values for different `thresholds` on `predictions`.
1152
1153  The `streaming_recall_at_thresholds` function creates four local variables,
1154  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1155  for various values of thresholds. `recall[i]` is defined as the total weight
1156  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1157  `labels` is `True`, divided by the total weight of `True` values in `labels`
1158  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
1159
1160  For estimation of the metric over a stream of data, the function creates an
1161  `update_op` operation that updates these variables and returns the `recall`.
1162
1163  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1164
1165  Args:
1166    predictions: A floating point `Tensor` of arbitrary shape and whose values
1167      are in the range `[0, 1]`.
1168    labels: A `bool` `Tensor` whose shape matches `predictions`.
1169    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1170    weights: An optional `Tensor` whose shape matches `predictions`.
1171    metrics_collections: An optional list of collections that `recall` should be
1172      added to.
1173    updates_collections: An optional list of collections that `update_op` should
1174      be added to.
1175    name: An optional variable_scope name.
1176
1177  Returns:
1178    recall: A float tensor of shape [len(thresholds)].
1179    update_op: An operation that increments the `true_positives`,
1180      `true_negatives`, `false_positives` and `false_negatives` variables that
1181      are used in the computation of `recall`.
1182
1183  Raises:
1184    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1185      `weights` is not `None` and its shape doesn't match `predictions`, or if
1186      either `metrics_collections` or `updates_collections` are not a list or
1187      tuple.
1188  """
1189  with variable_scope.variable_scope(name, 'recall_at_thresholds',
1190                                     [predictions, labels]):
1191    (true_positives, false_negatives, _, _, true_positives_compute_op,
1192     false_negatives_compute_op, _, _,) = _tp_fn_tn_fp(
1193         predictions, labels, thresholds, weights)
1194
1195    # avoid division by zero
1196    epsilon = 1e-7
1197    def compute_recall(name):
1198      recall = math_ops.div(true_positives,
1199                            epsilon + true_positives + false_negatives,
1200                            name='recall_' + name)
1201      return recall
1202
1203    recall = compute_recall('value')
1204    with ops.control_dependencies([true_positives_compute_op,
1205                                   false_negatives_compute_op]):
1206      update_op = compute_recall('update_op')
1207
1208    if metrics_collections:
1209      ops.add_to_collections(metrics_collections, recall)
1210
1211    if updates_collections:
1212      ops.add_to_collections(updates_collections, update_op)
1213
1214    return recall, update_op
1215
1216
1217@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
1218def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
1219                          weights=None, metrics_collections=None,
1220                          updates_collections=None, name=None):
1221  """Computes the recall@k of the predictions with respect to dense labels.
1222
1223  The `streaming_recall_at_k` function creates two local variables, `total` and
1224  `count`, that are used to compute the recall@k frequency. This frequency is
1225  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1226  divides `total` by `count`.
1227
1228  For estimation of the metric over a stream of data, the function creates an
1229  `update_op` operation that updates these variables and returns the
1230  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
1231  shape [batch_size] whose elements indicate whether or not the corresponding
1232  label is in the top `k` `predictions`. Then `update_op` increments `total`
1233  with the reduced sum of `weights` where `in_top_k` is `True`, and it
1234  increments `count` with the reduced sum of `weights`.
1235
1236  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1237  Alternatively, if `ignore_mask` is not `None`, then mask values where
1238  `ignore_mask` is `True`.
1239
1240  Args:
1241    predictions: A floating point tensor of dimension [batch_size, num_classes]
1242    labels: A tensor of dimension [batch_size] whose type is in `int32`,
1243      `int64`.
1244    k: The number of top elements to look at for computing recall.
1245    ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
1246    weights: An optional `Tensor` whose shape matches `predictions`.
1247    metrics_collections: An optional list of collections that `recall_at_k`
1248      should be added to.
1249    updates_collections: An optional list of collections `update_op` should be
1250      added to.
1251    name: An optional variable_scope name.
1252
1253  Returns:
1254    recall_at_k: A tensor representing the recall@k, the fraction of labels
1255      which fall into the top `k` predictions.
1256    update_op: An operation that increments the `total` and `count` variables
1257      appropriately and whose value matches `recall_at_k`.
1258
1259  Raises:
1260    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1261      `ignore_mask` is not `None` and its shape doesn't match `predictions`, or
1262      if `weights` is not `None` and its shape doesn't match `predictions`, or
1263      if either `metrics_collections` or `updates_collections` are not a list or
1264      tuple.
1265  """
1266  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1267  return streaming_mean(in_top_k,
1268                        _mask_weights(ignore_mask, weights),
1269                        metrics_collections,
1270                        updates_collections,
1271                        name or ('recall_at_%d' % k))
1272
1273
1274# TODO(ptucker): Validate range of values in labels?
1275@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
1276def streaming_sparse_recall_at_k(predictions,
1277                                 labels,
1278                                 k,
1279                                 class_id=None,
1280                                 ignore_mask=None,
1281                                 weights=None,
1282                                 metrics_collections=None,
1283                                 updates_collections=None,
1284                                 name=None):
1285  """Computes recall@k of the predictions with respect to sparse labels.
1286
1287  If `class_id` is specified, we calculate recall by considering only the
1288      entries in the batch for which `class_id` is in the label, and computing
1289      the fraction of them for which `class_id` is in the top-k `predictions`.
1290  If `class_id` is not specified, we'll calculate recall as how often on
1291      average a class among the labels of a batch entry is in the top-k
1292      `predictions`.
1293
1294  `streaming_sparse_recall_at_k` creates two local variables,
1295  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1296  the recall_at_k frequency. This frequency is ultimately returned as
1297  `recall_at_<k>`: an idempotent operation that simply divides
1298  `true_positive_at_<k>` by total (`true_positive_at_<k>` + `recall_at_<k>`).
1299
1300  For estimation of the metric over a stream of data, the function creates an
1301  `update_op` operation that updates these variables and returns the
1302  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1303  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1304  `labels` calculate the true positives and false negatives weighted by
1305  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1306  `false_negative_at_<k>` using these values.
1307
1308  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1309  Alternatively, if `ignore_mask` is not `None`, then mask values where
1310  `ignore_mask` is `True`.
1311
1312  Args:
1313    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1314      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1315      The final dimension contains the logit values for each class. [D1, ... DN]
1316      must match `labels`.
1317    labels: `int64` `Tensor` or `SparseTensor` with shape
1318      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1319      target classes for the associated prediction. Commonly, N=1 and `labels`
1320      has shape [batch_size, num_labels]. [D1, ... DN] must match `labels`.
1321      Values should be in range [0, num_classes], where num_classes is the last
1322      dimension of `predictions`.
1323    k: Integer, k for @k metric.
1324    class_id: Integer class ID for which we want binary metrics. This should be
1325      in range [0, num_classes], where num_classes is the last dimension of
1326      `predictions`.
1327    ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
1328      the the first [D1, ... DN] dimensions of `predictions` and `labels`.
1329    weights: An optional `Tensor` whose shape is broadcastable to the the first
1330      [D1, ... DN] dimensions of `predictions` and `labels`.
1331    metrics_collections: An optional list of collections that values should
1332      be added to.
1333    updates_collections: An optional list of collections that updates should
1334      be added to.
1335    name: Name of new update operation, and namespace for other dependant ops.
1336
1337  Returns:
1338    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1339      by the sum of `true_positives` and `false_negatives`.
1340    update_op: `Operation` that increments `true_positives` and
1341      `false_negatives` variables appropriately, and whose value matches
1342      `recall`.
1343
1344  Raises:
1345    ValueError: If `ignore_mask` is not `None` and its shape doesn't match
1346      `predictions`, or if `weights` is not `None` and its shape doesn't match
1347      `predictions`, or if either `metrics_collections` or `updates_collections`
1348      are not a list or tuple.
1349  """
1350  default_name = 'recall_at_%d' % k
1351  if class_id is not None:
1352    default_name = '%s_class%d' % (default_name, class_id)
1353
1354  with ops.name_scope(name, default_name, [predictions, labels]) as scope:
1355    _, top_k_idx = nn.top_k(predictions, k)
1356    top_k_idx = math_ops.to_int64(top_k_idx)
1357    weights = _mask_weights(ignore_mask, weights)
1358    tp, tp_update = _streaming_sparse_true_positive_at_k(
1359        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1360        weights=weights)
1361    fn, fn_update = _streaming_sparse_false_negative_at_k(
1362        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1363        weights=weights)
1364
1365    metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
1366    update = math_ops.div(
1367        tp_update, math_ops.add(tp_update, fn_update), name='update')
1368    if metrics_collections:
1369      ops.add_to_collections(metrics_collections, metric)
1370    if updates_collections:
1371      ops.add_to_collections(updates_collections, update)
1372    return metric, update
1373
1374
1375# TODO(ptucker): Validate range of values in labels?
1376@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
1377def streaming_sparse_precision_at_k(predictions,
1378                                    labels,
1379                                    k,
1380                                    class_id=None,
1381                                    ignore_mask=None,
1382                                    weights=None,
1383                                    metrics_collections=None,
1384                                    updates_collections=None,
1385                                    name=None):
1386  """Computes precision@k of the predictions with respect to sparse labels.
1387
1388  If `class_id` is specified, we calculate precision by considering only the
1389      entries in the batch for which `class_id` is in the top-k highest
1390      `predictions`, and computing the fraction of them for which `class_id` is
1391      indeed a correct label.
1392  If `class_id` is not specified, we'll calculate precision as how often on
1393      average a class among the top-k classes with the highest predicted values
1394      of a batch entry is correct and can be found in the label for that entry.
1395
1396  `streaming_sparse_precision_at_k` creates two local variables,
1397  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1398  the precision@k frequency. This frequency is ultimately returned as
1399  `precision_at_<k>`: an idempotent operation that simply divides
1400  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1401  `false_positive_at_<k>`).
1402
1403  For estimation of the metric over a stream of data, the function creates an
1404  `update_op` operation that updates these variables and returns the
1405  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1406  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1407  `labels` calculate the true positives and false positives weighted by
1408  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1409  `false_positive_at_<k>` using these values.
1410
1411  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1412  Alternatively, if `ignore_mask` is not `None`, then mask values where
1413  `ignore_mask` is `True`.
1414
1415  Args:
1416    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1417      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1418      The final dimension contains the logit values for each class. [D1, ... DN]
1419      must match `labels`.
1420    labels: `int64` `Tensor` or `SparseTensor` with shape
1421      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1422      target classes for the associated prediction. Commonly, N=1 and `labels`
1423      has shape [batch_size, num_labels]. [D1, ... DN] must match
1424      `predictions_idx`. Values should be in range [0, num_classes], where
1425      num_classes is the last dimension of `predictions`.
1426    k: Integer, k for @k metric.
1427    class_id: Integer class ID for which we want binary metrics. This should be
1428      in range [0, num_classes], where num_classes is the last dimension of
1429      `predictions`.
1430    ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
1431      the the first [D1, ... DN] dimensions of `predictions` and `labels`.
1432    weights: An optional `Tensor` whose shape is broadcastable to the the first
1433      [D1, ... DN] dimensions of `predictions` and `labels`.
1434    metrics_collections: An optional list of collections that values should
1435      be added to.
1436    updates_collections: An optional list of collections that updates should
1437      be added to.
1438    name: Name of new update operation, and namespace for other dependant ops.
1439
1440  Returns:
1441    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1442      divided by the sum of `true_positives` and `false_positives`.
1443    update_op: `Operation` that increments `true_positives` and
1444      `false_positives` variables appropriately, and whose value matches
1445      `precision`.
1446
1447  Raises:
1448    ValueError: If `ignore_mask` is not `None` and its shape doesn't match
1449      `predictions`, or if `weights` is not `None` and its shape doesn't match
1450      `predictions`, or if either `metrics_collections` or `updates_collections`
1451      are not a list or tuple.
1452  """
1453  default_name = 'precision_at_%d' % k
1454  if class_id is not None:
1455    default_name = '%s_class%d' % (default_name, class_id)
1456  with ops.name_scope(name, default_name, [predictions, labels]) as scope:
1457    _, top_k_idx = nn.top_k(predictions, k)
1458    top_k_idx = math_ops.to_int64(top_k_idx)
1459    weights = _mask_weights(ignore_mask, weights)
1460    tp, tp_update = _streaming_sparse_true_positive_at_k(
1461        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1462        weights=weights)
1463    fp, fp_update = _streaming_sparse_false_positive_at_k(
1464        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1465        weights=weights)
1466
1467    metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
1468    update = math_ops.div(
1469        tp_update, math_ops.add(tp_update, fp_update), name='update')
1470    if metrics_collections:
1471      ops.add_to_collections(metrics_collections, metric)
1472    if updates_collections:
1473      ops.add_to_collections(updates_collections, update)
1474    return metric, update
1475
1476
1477def _select_class_id(ids, selected_id):
1478  """Filter all but `selected_id` out of `ids`.
1479
1480  Args:
1481    ids: `int64` `Tensor` or `SparseTensor` of IDs.
1482    selected_id: Int id to select.
1483
1484  Returns:
1485    `SparseTensor` of same dimensions as `ids`, except for the last dimension,
1486    which might be smaller. This contains only the entries equal to
1487    `selected_id`.
1488  """
1489  if isinstance(ids, ops.SparseTensor):
1490    return sparse_ops.sparse_retain(
1491        ids, math_ops.equal(ids.values, selected_id))
1492
1493  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
1494  # tf.equal and tf.reduce_any?
1495
1496  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
1497  ids_shape = array_ops.shape(ids)
1498  ids_last_dim = array_ops.size(ids_shape) - 1
1499  filled_selected_id_shape = math_ops.reduced_shape(
1500      ids_shape, array_ops.reshape(ids_last_dim, [1]))
1501
1502  # Intersect `ids` with the selected ID.
1503  filled_selected_id = array_ops.fill(
1504      filled_selected_id_shape, math_ops.to_int64(selected_id))
1505  return set_ops.set_intersection(filled_selected_id, ids)
1506
1507
1508def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
1509  """If class ID is specified, filter all other classes.
1510
1511  Args:
1512    labels: `int64` `Tensor` or `SparseTensor` with shape
1513      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1514      target classes for the associated prediction. Commonly, N=1 and `labels`
1515      has shape [batch_size, num_labels]. [D1, ... DN] must match
1516      `predictions_idx`.
1517    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
1518      where N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
1519    selected_id: Int id to select.
1520
1521  Returns:
1522    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
1523  """
1524  if selected_id is None:
1525    return labels, predictions_idx
1526  return (_select_class_id(labels, selected_id),
1527          _select_class_id(predictions_idx, selected_id))
1528
1529
1530def _streaming_sparse_true_positive_at_k(predictions_idx,
1531                                         labels,
1532                                         k,
1533                                         class_id=None,
1534                                         weights=None,
1535                                         name=None):
1536  """Calculates weighted per step true positives for recall@k and precision@k.
1537
1538  If `class_id` is specified, calculate binary true positives for `class_id`
1539      only.
1540  If `class_id` is not specified, calculate metrics for `k` predicted vs
1541      `n` label classes, where `n` is the 2nd dimension of `labels`.
1542
1543  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1544
1545  Args:
1546    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1547      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1548      match `labels`.
1549    labels: `int64` `Tensor` or `SparseTensor` with shape
1550      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1551      target classes for the associated prediction. Commonly, N=1 and `labels`
1552      has shape [batch_size, num_labels]. [D1, ... DN] must match
1553      `predictions_idx`.
1554    k: Integer, k for @k metric. This is only used for default op name.
1555    class_id: Class for which we want binary metrics.
1556    weights: `Tensor` whose shape is broadcastable to the the first [D1, ... DN]
1557      dimensions of `predictions_idx` and `labels`.
1558    name: Name of new variable, and namespace for other dependant ops.
1559
1560  Returns:
1561    A tuple of `Variable` and update `Operation`.
1562
1563  Raises:
1564    ValueError: If `weights` is not `None` and has an incomptable shape.
1565  """
1566  default_name = 'true_positive_at_%d' % k
1567  if class_id is not None:
1568    default_name = '%s_class%d' % (default_name, class_id)
1569  with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope:
1570    labels, predictions_idx = _maybe_select_class_id(labels,
1571                                                     predictions_idx,
1572                                                     class_id)
1573    tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels))
1574    tp = math_ops.to_double(tp)
1575    if weights is not None:
1576      tp.get_shape().assert_is_compatible_with(weights.get_shape())
1577      weights = math_ops.to_double(weights)
1578      tp = math_ops.mul(tp, weights)
1579    batch_total_tp = math_ops.reduce_sum(tp)
1580
1581    var = contrib_variables.local_variable(
1582        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1583    return var, state_ops.assign_add(var, batch_total_tp, name='update')
1584
1585
1586def _streaming_sparse_false_positive_at_k(predictions_idx,
1587                                          labels,
1588                                          k,
1589                                          class_id=None,
1590                                          weights=None,
1591                                          name=None):
1592  """Calculates weighted per step false positives for precision@k.
1593
1594  If `class_id` is specified, calculate binary true positives for `class_id`
1595      only.
1596  If `class_id` is not specified, calculate metrics for `k` predicted vs
1597      `n` label classes, where `n` is the 2nd dimension of `labels`.
1598
1599  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1600
1601  Args:
1602    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1603      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1604      match `labels`.
1605    labels: `int64` `Tensor` or `SparseTensor` with shape
1606      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1607      target classes for the associated prediction. Commonly, N=1 and `labels`
1608      has shape [batch_size, num_labels]. [D1, ... DN] must match
1609      `predictions_idx`.
1610    k: Integer, k for @k metric. This is only used for default op name.
1611    class_id: Class for which we want binary metrics.
1612    weights: `Tensor` whose shape is broadcastable to the the first [D1, ... DN]
1613      dimensions of `predictions_idx` and `labels`.
1614    name: Name of new variable, and namespace for other dependant ops.
1615
1616  Returns:
1617    A tuple of `Variable` and update `Operation`.
1618
1619  Raises:
1620    ValueError: If `weights` is not `None` and has an incomptable shape.
1621  """
1622  default_name = 'false_positive_at_%d' % k
1623  if class_id is not None:
1624    default_name = '%s_class%d' % (default_name, class_id)
1625  with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope:
1626    labels, predictions_idx = _maybe_select_class_id(labels,
1627                                                     predictions_idx,
1628                                                     class_id)
1629    fp = set_ops.set_size(set_ops.set_difference(predictions_idx,
1630                                                 labels,
1631                                                 aminusb=True))
1632    fp = math_ops.to_double(fp)
1633    if weights is not None:
1634      fp.get_shape().assert_is_compatible_with(weights.get_shape())
1635      weights = math_ops.to_double(weights)
1636      fp = math_ops.mul(fp, weights)
1637    batch_total_fp = math_ops.reduce_sum(fp)
1638
1639    var = contrib_variables.local_variable(
1640        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1641    return var, state_ops.assign_add(var, batch_total_fp, name='update')
1642
1643
1644def _streaming_sparse_false_negative_at_k(predictions_idx,
1645                                          labels,
1646                                          k,
1647                                          class_id=None,
1648                                          weights=None,
1649                                          name=None):
1650  """Calculates weighted per step false negatives for recall@k.
1651
1652  If `class_id` is specified, calculate binary true positives for `class_id`
1653      only.
1654  If `class_id` is not specified, calculate metrics for `k` predicted vs
1655      `n` label classes, where `n` is the 2nd dimension of `labels`.
1656
1657  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1658
1659  Args:
1660    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1661      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1662      match `labels`.
1663    labels: `int64` `Tensor` or `SparseTensor` with shape
1664      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1665      target classes for the associated prediction. Commonly, N=1 and `labels`
1666      has shape [batch_size, num_labels]. [D1, ... DN] must match
1667      `predictions_idx`.
1668    k: Integer, k for @k metric. This is only used for default op name.
1669    class_id: Class for which we want binary metrics.
1670    weights: `Tensor` whose shape is broadcastable to the the first [D1, ... DN]
1671      dimensions of `predictions_idx` and `labels`.
1672    name: Name of new variable, and namespace for other dependant ops.
1673
1674  Returns:
1675    A tuple of `Variable` and update `Operation`.
1676
1677  Raises:
1678    ValueError: If `weights` is not `None` and has an incomptable shape.
1679  """
1680  default_name = 'false_negative_at_%d' % k
1681  if class_id is not None:
1682    default_name = '%s_class%d' % (default_name, class_id)
1683  with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope:
1684    labels, predictions_idx = _maybe_select_class_id(labels,
1685                                                     predictions_idx,
1686                                                     class_id)
1687    fn = set_ops.set_size(set_ops.set_difference(predictions_idx,
1688                                                 labels,
1689                                                 aminusb=False))
1690    fn = math_ops.to_double(fn)
1691    if weights is not None:
1692      fn.get_shape().assert_is_compatible_with(weights.get_shape())
1693      weights = math_ops.to_double(weights)
1694      fn = math_ops.mul(fn, weights)
1695    batch_total_fn = math_ops.reduce_sum(fn)
1696
1697    var = contrib_variables.local_variable(
1698        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1699    return var, state_ops.assign_add(var, batch_total_fn, name='update')
1700
1701
1702def streaming_mean_absolute_error(predictions, labels, weights=None,
1703                                  metrics_collections=None,
1704                                  updates_collections=None,
1705                                  name=None):
1706  """Computes the mean absolute error between the labels and predictions.
1707
1708  The `streaming_mean_absolute_error` function creates two local variables,
1709  `total` and `count` that are used to compute the mean absolute error. This
1710  average is weighted by `weights`, and it is ultimately returned as
1711  `mean_absolute_error`: an idempotent operation that simply divides `total` by
1712  `count`.
1713
1714  For estimation of the metric over a stream of data, the function creates an
1715  `update_op` operation that updates these variables and returns the
1716  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
1717  absolute value of the differences between `predictions` and `labels`. Then
1718  `update_op` increments `total` with the reduced sum of the product of
1719  `weights` and `absolute_errors`, and it increments `count` with the reduced
1720  sum of `weights`
1721
1722  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1723
1724  Args:
1725    predictions: A `Tensor` of arbitrary shape.
1726    labels: A `Tensor` of the same shape as `predictions`.
1727    weights: An optional `Tensor` whose shape matches `predictions`.
1728    metrics_collections: An optional list of collections that
1729      `mean_absolute_error` should be added to.
1730    updates_collections: An optional list of collections that `update_op` should
1731      be added to.
1732    name: An optional variable_scope name.
1733
1734  Returns:
1735    mean_absolute_error: A tensor representing the current mean, the value of
1736      `total` divided by `count`.
1737    update_op: An operation that increments the `total` and `count` variables
1738      appropriately and whose value matches `mean_absolute_error`.
1739
1740  Raises:
1741    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1742      `weights` is not `None` and its shape doesn't match `predictions`, or if
1743      either `metrics_collections` or `updates_collections` are not a list or
1744      tuple.
1745  """
1746  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1747      predictions, labels)
1748  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1749  absolute_errors = math_ops.abs(predictions - labels)
1750  return streaming_mean(absolute_errors, weights, metrics_collections,
1751                        updates_collections, name or 'mean_absolute_error')
1752
1753
1754def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
1755                                  metrics_collections=None,
1756                                  updates_collections=None,
1757                                  name=None):
1758  """Computes the mean relative error by normalizing with the given values.
1759
1760  The `streaming_mean_relative_error` function creates two local variables,
1761  `total` and `count` that are used to compute the mean relative absolute error.
1762  This average is weighted by `weights`, and it is ultimately returned as
1763  `mean_relative_error`: an idempotent operation that simply divides `total` by
1764  `count`.
1765
1766  For estimation of the metric over a stream of data, the function creates an
1767  `update_op` operation that updates these variables and returns the
1768  `mean_reative_error`. Internally, a `relative_errors` operation divides the
1769  absolute value of the differences between `predictions` and `labels` by the
1770  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1771  product of `weights` and `relative_errors`, and it increments `count` with the
1772  reduced sum of `weights`.
1773
1774  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1775
1776  Args:
1777    predictions: A `Tensor` of arbitrary shape.
1778    labels: A `Tensor` of the same shape as `predictions`.
1779    normalizer: A `Tensor` of the same shape as `predictions`.
1780    weights: An optional `Tensor` whose shape matches `predictions`.
1781    metrics_collections: An optional list of collections that
1782      `mean_relative_error` should be added to.
1783    updates_collections: An optional list of collections that `update_op` should
1784      be added to.
1785    name: An optional variable_scope name.
1786
1787  Returns:
1788    mean_relative_error: A tensor representing the current mean, the value of
1789      `total` divided by `count`.
1790    update_op: An operation that increments the `total` and `count` variables
1791      appropriately and whose value matches `mean_relative_error`.
1792
1793  Raises:
1794    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1795      `weights` is not `None` and its shape doesn't match `predictions`, or if
1796      either `metrics_collections` or `updates_collections` are not a list or
1797      tuple.
1798  """
1799  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1800      predictions, labels)
1801  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1802
1803  predictions, normalizer = metric_ops_util.remove_squeezable_dimensions(
1804      predictions, normalizer)
1805  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1806  relative_errors = math_ops.select(
1807      math_ops.equal(normalizer, 0.0),
1808      array_ops.zeros_like(labels),
1809      math_ops.div(math_ops.abs(labels - predictions), normalizer))
1810  return streaming_mean(relative_errors, weights, metrics_collections,
1811                        updates_collections, name or 'mean_relative_error')
1812
1813
1814def streaming_mean_squared_error(predictions, labels, weights=None,
1815                                 metrics_collections=None,
1816                                 updates_collections=None,
1817                                 name=None):
1818  """Computes the mean squared error between the labels and predictions.
1819
1820  The `streaming_mean_squared_error` function creates two local variables,
1821  `total` and `count` that are used to compute the mean squared error.
1822  This average is weighted by `weights`, and it is ultimately returned as
1823  `mean_squared_error`: an idempotent operation that simply divides `total` by
1824  `count`.
1825
1826  For estimation of the metric over a stream of data, the function creates an
1827  `update_op` operation that updates these variables and returns the
1828  `mean_squared_error`. Internally, a `squared_error` operation computes the
1829  element-wise square of the difference between `predictions` and `labels`. Then
1830  `update_op` increments `total` with the reduced sum of the product of
1831  `weights` and `squared_error`, and it increments `count` with the reduced sum
1832  of `weights`.
1833
1834  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1835
1836  Args:
1837    predictions: A `Tensor` of arbitrary shape.
1838    labels: A `Tensor` of the same shape as `predictions`.
1839    weights: An optional `Tensor` whose shape matches `predictions`.
1840    metrics_collections: An optional list of collections that
1841      `mean_squared_error` should be added to.
1842    updates_collections: An optional list of collections that `update_op` should
1843      be added to.
1844    name: An optional variable_scope name.
1845
1846  Returns:
1847    mean_squared_error: A tensor representing the current mean, the value of
1848      `total` divided by `count`.
1849    update_op: An operation that increments the `total` and `count` variables
1850      appropriately and whose value matches `mean_squared_error`.
1851
1852  Raises:
1853    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1854      `weights` is not `None` and its shape doesn't match `predictions`, or if
1855      either `metrics_collections` or `updates_collections` are not a list or
1856      tuple.
1857  """
1858  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1859      predictions, labels)
1860  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1861  squared_error = math_ops.square(labels - predictions)
1862  return streaming_mean(squared_error, weights, metrics_collections,
1863                        updates_collections, name or 'mean_squared_error')
1864
1865
1866def streaming_root_mean_squared_error(predictions, labels, weights=None,
1867                                      metrics_collections=None,
1868                                      updates_collections=None,
1869                                      name=None):
1870  """Computes the root mean squared error between the labels and predictions.
1871
1872  The `streaming_root_mean_squared_error` function creates two local variables,
1873  `total` and `count` that are used to compute the root mean squared error.
1874  This average is weighted by `weights`, and it is ultimately returned as
1875  `root_mean_squared_error`: an idempotent operation that takes the square root
1876  of the division of `total` by `count`.
1877
1878  For estimation of the metric over a stream of data, the function creates an
1879  `update_op` operation that updates these variables and returns the
1880  `root_mean_squared_error`. Internally, a `squared_error` operation computes
1881  the element-wise square of the difference between `predictions` and `labels`.
1882  Then `update_op` increments `total` with the reduced sum of the product of
1883  `weights` and `squared_error`, and it increments `count` with the reduced sum
1884  of `weights`.
1885
1886  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1887
1888  Args:
1889    predictions: A `Tensor` of arbitrary shape.
1890    labels: A `Tensor` of the same shape as `predictions`.
1891    weights: An optional `Tensor` whose shape matches `predictions`.
1892    metrics_collections: An optional list of collections that
1893      `root_mean_squared_error` should be added to.
1894    updates_collections: An optional list of collections that `update_op` should
1895      be added to.
1896    name: An optional variable_scope name.
1897
1898  Returns:
1899    root_mean_squared_error: A tensor representing the current mean, the value
1900      of `total` divided by `count`.
1901    update_op: An operation that increments the `total` and `count` variables
1902      appropriately and whose value matches `root_mean_squared_error`.
1903
1904  Raises:
1905    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1906      `weights` is not `None` and its shape doesn't match `predictions`, or if
1907      either `metrics_collections` or `updates_collections` are not a list or
1908      tuple.
1909  """
1910  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1911      predictions, labels)
1912  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1913  value_tensor, update_op = streaming_mean_squared_error(
1914      predictions, labels, weights, None, None,
1915      name or 'root_mean_squared_error')
1916
1917  root_mean_squared_error = math_ops.sqrt(value_tensor)
1918  with ops.control_dependencies([update_op]):
1919    update_op = math_ops.sqrt(update_op)
1920
1921  if metrics_collections:
1922    ops.add_to_collections(metrics_collections, root_mean_squared_error)
1923
1924  if updates_collections:
1925    ops.add_to_collections(updates_collections, update_op)
1926
1927  return root_mean_squared_error, update_op
1928
1929
1930def streaming_covariance(predictions,
1931                         labels,
1932                         weights=None,
1933                         metrics_collections=None,
1934                         updates_collections=None,
1935                         name=None):
1936  """Computes the unbiased sample covariance between `predictions` and `labels`.
1937
1938  The `streaming_covariance` function creates four local variables,
1939  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
1940  compute the sample covariance between predictions and labels across multiple
1941  batches of data. The covariance is ultimately returned as an idempotent
1942  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
1943  in order to get an unbiased estimate.
1944
1945  The algorithm used for this online computation is described in
1946  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
1947  Specifically, the formula used to combine two sample comoments is
1948  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
1949  The comoment for a single batch of data is simply
1950  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
1951
1952  If `weights` is not None, then it is used to compute weighted comoments,
1953  means, and count. NOTE: these weights are treated as "frequency weights", as
1954  opposed to "reliability weights". See discussion of the difference on
1955  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
1956
1957  To facilitate the computation of covariance across multiple batches of data,
1958  the function creates an `update_op` operation, which updates underlying
1959  variables and returns the updated covariance.
1960
1961  Args:
1962    predictions: A `Tensor` of arbitrary size.
1963    labels: A `Tensor` of the same size as `predictions`.
1964    weights: An optional set of weights which indicates the frequency with which
1965      an example is sampled. Must be broadcastable with `labels`.
1966    metrics_collections: An optional list of collections that the metric
1967      value variable should be added to.
1968    updates_collections: An optional list of collections that the metric update
1969      ops should be added to.
1970    name: An optional variable_scope name.
1971
1972  Returns:
1973    covariance: A `Tensor` representing the current unbiased sample covariance,
1974      `comoment` / (`count` - 1).
1975    update_op: An operation that updates the local variables appropriately.
1976
1977  Raises:
1978    ValueError: If labels and predictions are of different sizes or if either
1979      `metrics_collections` or `updates_collections` are not a list or tuple.
1980  """
1981  with variable_scope.variable_scope(name, 'covariance', [predictions, labels]):
1982    predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1983        predictions, labels)
1984    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1985    count = _create_local('count', [])
1986    mean_prediction = _create_local('mean_prediction', [])
1987    mean_label = _create_local('mean_label', [])
1988    comoment = _create_local('comoment', [])  # C_A in update equation
1989
1990    if weights is None:
1991      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
1992      weighted_predictions = predictions
1993      weighted_labels = labels
1994    else:
1995      batch_count = math_ops.reduce_sum(weights)  # n_B in update equation
1996      weighted_predictions = predictions * weights
1997      weighted_labels = labels * weights
1998
1999    update_count = state_ops.assign_add(count, batch_count)  # n_AB in eqn
2000    prev_count = update_count - batch_count  # n_A in update equation
2001
2002    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
2003    # batch_mean_prediction is E[x_B] in the update equation
2004    batch_mean_prediction = _safe_div(
2005        math_ops.reduce_sum(weighted_predictions), batch_count,
2006        'batch_mean_prediction')
2007    delta_mean_prediction = _safe_div(
2008        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
2009        'delta_mean_prediction')
2010    update_mean_prediction = state_ops.assign_add(mean_prediction,
2011                                                  delta_mean_prediction)
2012    # prev_mean_prediction is E[x_A] in the update equation
2013    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
2014
2015    # batch_mean_label is E[y_B] in the update equation
2016    batch_mean_label = _safe_div(
2017        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
2018    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
2019                                 update_count, 'delta_mean_label')
2020    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
2021    # prev_mean_label is E[y_A] in the update equation
2022    prev_mean_label = update_mean_label - delta_mean_label
2023
2024    unweighted_batch_coresiduals = (
2025        (predictions - batch_mean_prediction) * (labels - batch_mean_label))
2026    # batch_comoment is C_B in the update equation
2027    if weights is None:
2028      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
2029    else:
2030      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals *
2031                                           weights)
2032
2033    # View delta_comoment as = C_AB - C_A in the update equation above.
2034    # Since C_A is stored in a var, by how much do we need to increment that var
2035    # to make the var = C_AB?
2036    delta_comoment = (batch_comoment +
2037                      (prev_mean_prediction - batch_mean_prediction) *
2038                      (prev_mean_label - batch_mean_label) *
2039                      (prev_count * batch_count / update_count))
2040    update_comoment = state_ops.assign_add(comoment, delta_comoment)
2041
2042    covariance = _safe_div(comoment, count - 1, 'covariance')
2043    with ops.control_dependencies([update_comoment]):
2044      update_op = _safe_div(comoment, count - 1, 'update_op')
2045
2046  if metrics_collections:
2047    ops.add_to_collections(metrics_collections, covariance)
2048
2049  if updates_collections:
2050    ops.add_to_collections(updates_collections, update_op)
2051
2052  return covariance, update_op
2053
2054
2055def streaming_pearson_correlation(predictions,
2056                                  labels,
2057                                  weights=None,
2058                                  metrics_collections=None,
2059                                  updates_collections=None,
2060                                  name=None):
2061  """Computes pearson correlation coefficient between `predictions`, `labels`.
2062
2063  The `streaming_pearson_correlation` function delegates to
2064  `streaming_covariance` the tracking of three [co]variances:
2065  - streaming_covariance(predictions, labels), i.e. covariance
2066  - streaming_covariance(predictions, predictions), i.e. variance
2067  - streaming_covariance(labels, labels), i.e. variance
2068
2069  The product-moment correlation ultimately returned is an idempotent operation
2070  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
2071  facilitate correlation computation across multiple batches, the function
2072  groups the `update_op`s of the underlying streaming_covariance and returns an
2073  `update_op`.
2074
2075  If `weights` is not None, then it is used to compute a weighted correlation.
2076  NOTE: these weights are treated as "frequency weights", as opposed to
2077  "reliability weights". See discussion of the difference on
2078  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2079
2080  Args:
2081    predictions: A `Tensor` of arbitrary size.
2082    labels: A `Tensor` of the same size as predictions.
2083    weights: An optional set of weights which indicates the frequency with which
2084      an example is sampled. Must be broadcastable with `labels`.
2085    metrics_collections: An optional list of collections that the metric
2086      value variable should be added to.
2087    updates_collections: An optional list of collections that the metric update
2088      ops should be added to.
2089    name: An optional variable_scope name.
2090
2091  Returns:
2092    pearson_r: A tensor representing the current pearson product-moment
2093      correlation coefficient, the value of
2094      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
2095    update_op: An operation that updates the underlying variables appropriately.
2096
2097  Raises:
2098    ValueError: If labels and predictions are of different sizes or if the
2099      ignore_mask is of the wrong size or if either `metrics_collections` or
2100      `updates_collections` are not a list or tuple.
2101  """
2102  with variable_scope.variable_scope(name, 'pearson_r', [predictions, labels]):
2103    predictions, labels = metric_ops_util.remove_squeezable_dimensions(
2104        predictions, labels)
2105    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2106    cov, update_cov = streaming_covariance(
2107        predictions, labels, weights=weights, name='covariance')
2108    var_predictions, update_var_predictions = streaming_covariance(
2109        predictions, predictions, weights=weights, name='variance_predictions')
2110    var_labels, update_var_labels = streaming_covariance(
2111        labels, labels, weights=weights, name='variance_labels')
2112
2113    pearson_r = _safe_div(
2114        cov,
2115        math_ops.mul(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)),
2116        'pearson_r')
2117    with ops.control_dependencies(
2118        [update_cov, update_var_predictions, update_var_labels]):
2119      update_op = _safe_div(update_cov, math_ops.mul(
2120          math_ops.sqrt(update_var_predictions),
2121          math_ops.sqrt(update_var_labels)), 'update_op')
2122
2123  if metrics_collections:
2124    ops.add_to_collections(metrics_collections, pearson_r)
2125
2126  if updates_collections:
2127    ops.add_to_collections(updates_collections, update_op)
2128
2129  return pearson_r, update_op
2130
2131
2132# TODO(nsilberman): add a 'normalized' flag so that the user can request
2133# normalization if the inputs are not normalized.
2134def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
2135                                   metrics_collections=None,
2136                                   updates_collections=None,
2137                                   name=None):
2138  """Computes the cosine distance between the labels and predictions.
2139
2140  The `streaming_mean_cosine_distance` function creates two local variables,
2141  `total` and `count` that are used to compute the average cosine distance
2142  between `predictions` and `labels`. This average is weighted by `weights`,
2143  and it is ultimately returned as `mean_distance`, which is an idempotent
2144  operation that simply divides `total` by `count`.
2145
2146  For estimation of the metric over a stream of data, the function creates an
2147  `update_op` operation that updates these variables and returns the
2148  `mean_distance`.
2149
2150  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2151
2152  Args:
2153    predictions: A `Tensor` of the same shape as `labels`.
2154    labels: A `Tensor` of arbitrary shape.
2155    dim: The dimension along which the cosine distance is computed.
2156    weights: An optional `Tensor` whose shape matches `predictions`, and whose
2157      dimension `dim` is 1.
2158    metrics_collections: An optional list of collections that the metric
2159      value variable should be added to.
2160    updates_collections: An optional list of collections that the metric update
2161      ops should be added to.
2162    name: An optional variable_scope name.
2163
2164  Returns:
2165    mean_distance: A tensor representing the current mean, the value of `total`
2166      divided by `count`.
2167    update_op: An operation that increments the `total` and `count` variables
2168      appropriately.
2169
2170  Raises:
2171    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2172      `weights` is not `None` and its shape doesn't match `predictions`, or if
2173      either `metrics_collections` or `updates_collections` are not a list or
2174      tuple.
2175  """
2176  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
2177      predictions, labels)
2178  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2179  radial_diffs = math_ops.mul(predictions, labels)
2180  radial_diffs = math_ops.reduce_sum(radial_diffs,
2181                                     reduction_indices=[dim,],
2182                                     keep_dims=True)
2183  mean_distance, update_op = streaming_mean(radial_diffs, weights,
2184                                            None,
2185                                            None,
2186                                            name or 'mean_cosine_distance')
2187  mean_distance = math_ops.sub(1.0, mean_distance)
2188  update_op = math_ops.sub(1.0, update_op)
2189
2190  if metrics_collections:
2191    ops.add_to_collections(metrics_collections, mean_distance)
2192
2193  if updates_collections:
2194    ops.add_to_collections(updates_collections, update_op)
2195
2196  return mean_distance, update_op
2197
2198
2199@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
2200def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
2201                              metrics_collections=None,
2202                              updates_collections=None,
2203                              name=None):
2204  """Computes the percentage of values less than the given threshold.
2205
2206  The `streaming_percentage_less` function creates two local variables,
2207  `total` and `count` that are used to compute the percentage of `values` that
2208  fall below `threshold`. This rate is weighted by `weights`, and it is
2209  ultimately returned as `percentage` which is an idempotent operation that
2210  simply divides `total` by `count`.
2211
2212  For estimation of the metric over a stream of data, the function creates an
2213  `update_op` operation that updates these variables and returns the
2214  `percentage`.
2215
2216  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2217  Alternatively, if `ignore_mask` is not `None`, then mask values where
2218  `ignore_mask` is `True`.
2219
2220  Args:
2221    values: A numeric `Tensor` of arbitrary size.
2222    threshold: A scalar threshold.
2223    ignore_mask: An optional, `bool` `Tensor` whose shape matches `values`.
2224    weights: An optional `Tensor` whose shape matches `values`.
2225    metrics_collections: An optional list of collections that the metric
2226      value variable should be added to.
2227    updates_collections: An optional list of collections that the metric update
2228      ops should be added to.
2229    name: An optional variable_scope name.
2230
2231  Returns:
2232    percentage: A tensor representing the current mean, the value of `total`
2233      divided by `count`.
2234    update_op: An operation that increments the `total` and `count` variables
2235      appropriately.
2236
2237  Raises:
2238    ValueError: If `ignore_mask` is not `None` and its shape doesn't match
2239      `values`, or if `weights` is not `None` and its shape doesn't match
2240      `values`, or if either `metrics_collections` or `updates_collections` are
2241      not a list or tuple.
2242  """
2243  is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
2244  return streaming_mean(is_below_threshold, _mask_weights(ignore_mask, weights),
2245                        metrics_collections,
2246                        updates_collections,
2247                        name or 'percentage_below_threshold')
2248
2249
2250@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
2251def streaming_mean_iou(predictions,
2252                       labels,
2253                       num_classes,
2254                       ignore_mask=None,
2255                       weights=None,
2256                       metrics_collections=None,
2257                       updates_collections=None,
2258                       name=None):
2259  """Calculate per-step mean Intersection-Over-Union (mIOU).
2260
2261  Mean Intersection-Over-Union is a common evaluation metric for
2262  semantic image segmentation, which first computes the IOU for each
2263  semantic class and then computes the average over classes.
2264  IOU is defined as follows:
2265    IOU = true_positive / (true_positive + false_positive + false_negative).
2266  The predictions are accumulated in a confusion matrix, weighted by `weights`,
2267  and mIOU is then calculated from it.
2268
2269  For estimation of the metric over a stream of data, the function creates an
2270  `update_op` operation that updates these variables and returns the `mean_iou`.
2271
2272  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2273  Alternatively, if `ignore_mask` is not `None`, then mask values where
2274  `ignore_mask` is `True`.
2275
2276  Args:
2277    predictions: A tensor of prediction results for semantic labels, whose
2278      shape is [batch size] and type `int32` or `int64`. The tensor will be
2279      flattened, if its rank > 1.
2280    labels: A tensor of ground truth labels with shape [batch size] and of
2281      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
2282    num_classes: The possible number of labels the prediction task can
2283      have. This value must be provided, since a confusion matrix of
2284      dimension = [num_classes, num_classes] will be allocated.
2285    ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
2286    weights: An optional `Tensor` whose shape matches `predictions`.
2287    metrics_collections: An optional list of collections that `mean_iou`
2288      should be added to.
2289    updates_collections: An optional list of collections `update_op` should be
2290      added to.
2291    name: An optional variable_scope name.
2292
2293  Returns:
2294    mean_iou: A tensor representing the mean intersection-over-union.
2295    update_op: An operation that increments the confusion matrix.
2296
2297  Raises:
2298    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2299      `ignore_mask` is not `None` and its shape doesn't match `predictions`, or
2300      if `weights` is not `None` and its shape doesn't match `predictions`, or
2301      if either `metrics_collections` or `updates_collections` are not a list or
2302      tuple.
2303  """
2304  with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]):
2305    # Check if shape is compatible.
2306    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2307
2308    # Local variable to accumulate the predictions in the confusion matrix.
2309    cm_dtype = dtypes.int64 if weights is not None else dtypes.float64
2310    total_cm = _create_local('total_confusion_matrix',
2311                             shape=[num_classes, num_classes], dtype=cm_dtype)
2312
2313    # Cast the type to int64 required by confusion_matrix_ops.
2314    predictions = math_ops.to_int64(predictions)
2315    labels = math_ops.to_int64(labels)
2316    num_classes = math_ops.to_int64(num_classes)
2317
2318    # Flatten the input if its rank > 1.
2319    predictions_rank = predictions.get_shape().ndims
2320    if predictions_rank > 1:
2321      predictions = array_ops.reshape(predictions, [-1])
2322
2323    labels_rank = labels.get_shape().ndims
2324    if labels_rank > 1:
2325      labels = array_ops.reshape(labels, [-1])
2326
2327    weights = _mask_weights(ignore_mask, weights)
2328    if weights is not None:
2329      weights_rank = weights.get_shape().ndims
2330      if weights_rank > 1:
2331        weights = array_ops.reshape(weights, [-1])
2332
2333    # Accumulate the prediction to current confusion matrix.
2334    current_cm = confusion_matrix_ops.confusion_matrix(
2335        predictions, labels, num_classes, weights=weights, dtype=cm_dtype)
2336    update_op = state_ops.assign_add(total_cm, current_cm)
2337
2338    def compute_mean_iou(name):
2339      """Compute the mean intersection-over-union via the confusion matrix."""
2340      sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
2341      sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
2342      cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
2343      denominator = sum_over_row + sum_over_col - cm_diag
2344
2345      # If the value of the denominator is 0, set it to 1 to avoid
2346      # zero division.
2347      denominator = math_ops.select(
2348          math_ops.greater(denominator, 0),
2349          denominator,
2350          array_ops.ones_like(denominator))
2351      iou = math_ops.div(cm_diag, denominator)
2352      return math_ops.reduce_mean(iou, name=name)
2353
2354    mean_iou = compute_mean_iou('mean_iou')
2355
2356    if metrics_collections:
2357      ops.add_to_collections(metrics_collections, mean_iou)
2358
2359    if updates_collections:
2360      ops.add_to_collections(updates_collections, update_op)
2361
2362    return mean_iou, update_op
2363
2364
2365def aggregate_metrics(*value_update_tuples):
2366  """Aggregates the metric value tensors and update ops into two lists.
2367
2368  Args:
2369    *value_update_tuples: a variable number of tuples, each of which contain the
2370      pair of (value_tensor, update_op) from a streaming metric.
2371
2372  Returns:
2373    a list of value tensors and a list of update ops.
2374
2375  Raises:
2376    ValueError: if `value_update_tuples` is empty.
2377  """
2378  if not value_update_tuples:
2379    raise ValueError('Expected at least one value_tensor/update_op pair')
2380  value_ops, update_ops = zip(*value_update_tuples)
2381  return list(value_ops), list(update_ops)
2382
2383
2384def aggregate_metric_map(names_to_tuples):
2385  """Aggregates the metric names to tuple dictionary.
2386
2387  This function is useful for pairing metric names with their associated value
2388  and update ops when the list of metrics is long. For example:
2389
2390    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
2391        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
2392            predictions, labels, weights),
2393        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
2394            predictions, labels, labels, weights),
2395        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
2396            predictions, labels, weights),
2397        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
2398            predictions, labels, weights),
2399    })
2400
2401  Args:
2402    names_to_tuples: a map of metric names to tuples, each of which contain the
2403      pair of (value_tensor, update_op) from a streaming metric.
2404
2405  Returns:
2406    A dictionary from metric names to value ops and a dictionary from metric
2407    names to update ops.
2408  """
2409  metric_names = names_to_tuples.keys()
2410  value_ops, update_ops = zip(*names_to_tuples.values())
2411  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
2412
2413
2414__all__ = make_all(__name__)
2415