metric_ops.py revision 6e7c8b76bf2931754d28928623196c8090e87dc0
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.framework.python.ops import variables as contrib_variables
26
27from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
28from tensorflow.contrib.metrics.python.ops import metric_ops_util
29from tensorflow.contrib.metrics.python.ops import set_ops
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn
36from tensorflow.python.ops import sparse_ops
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.ops import variables
40from tensorflow.python.util.all_util import make_all
41
42
43def _mask_to_weights(mask=None):
44  """Converts a binary mask to a set of weights.
45
46  Args:
47    mask: A binary `Tensor`.
48
49  Returns:
50    The corresponding set of weights if `mask` is not `None`, otherwise `None`.
51  """
52  if mask is not None:
53    check_ops.assert_type(mask, dtypes.bool)
54    weights = math_ops.logical_not(mask)
55  else:
56    weights = None
57  return weights
58
59
60def _create_local(name, shape=None, collections=None, dtype=dtypes.float32):
61  """Creates a new local variable.
62
63  Args:
64    name: The name of the new or existing variable.
65    shape: Shape of the new or existing variable.
66    collections: A list of collection names to which the Variable will be added.
67    dtype: Data type of the variables.
68
69  Returns:
70    The created variable.
71  """
72  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
73  collections = list(collections or [])
74  collections += [ops.GraphKeys.LOCAL_VARIABLES]
75  return variables.Variable(
76      initial_value=array_ops.zeros(shape, dtype=dtype),
77      name=name,
78      trainable=False,
79      collections=collections)
80
81
82def _count_condition(values, ignore_mask=None, metrics_collections=None,
83                     updates_collections=None):
84  """Computes the total number of cases where the given values are True.
85
86  Args:
87    values: A binary `Tensor` of arbitrary size.
88    ignore_mask: An optional, binary tensor whose size matches 'values'.
89    metrics_collections: An optional list of collections that the metric
90      value variable should be added to.
91    updates_collections: An optional list of collections that the metric update
92      ops should be added to.
93
94  Returns:
95    value_tensor: A tensor representing the current value of the metric.
96    update_op: An operation that accumulates the error from a batch of data.
97
98  Raises:
99    ValueError: If either `metrics_collections` or `updates_collections` are not
100      a list or tuple.
101  """
102  check_ops.assert_type(values, dtypes.bool)
103  count = _create_local('count', shape=[])
104
105  if ignore_mask is not None:
106    values.get_shape().assert_is_compatible_with(ignore_mask.get_shape())
107    check_ops.assert_type(ignore_mask, dtypes.bool)
108    values = math_ops.select(
109        ignore_mask,
110        array_ops.zeros_like(values),
111        values)
112  values = math_ops.to_float(values)
113
114  value_tensor = array_ops.identity(count)
115  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
116
117  if metrics_collections:
118    ops.add_to_collections(metrics_collections, value_tensor)
119
120  if updates_collections:
121    ops.add_to_collections(updates_collections, update_op)
122
123  return value_tensor, update_op
124
125
126def _streaming_true_negatives(predictions, labels, ignore_mask=None,
127                              metrics_collections=None,
128                              updates_collections=None,
129                              name=None):
130  """Computes the total number of true_negatives.
131
132  Args:
133    predictions: The predicted values, a binary `Tensor` of arbitrary
134      dimensions.
135    labels: The ground truth values, a binary `Tensor` whose dimensions must
136      match `predictions`.
137    ignore_mask: An optional, binary tensor whose size matches 'predictions'.
138    metrics_collections: An optional list of collections that the metric
139      value variable should be added to.
140    updates_collections: An optional list of collections that the metric update
141      ops should be added to.
142    name: An optional variable_scope name.
143
144  Returns:
145    value_tensor: A tensor representing the current value of the metric.
146    update_op: An operation that accumulates the error from a batch of data.
147
148  Raises:
149    ValueError: If either `metrics_collections` or `updates_collections` are not
150      a list or tuple.
151  """
152  with variable_scope.variable_scope(
153      [predictions, labels], name, 'true_negatives'):
154
155    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
156    is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0),
157                                            math_ops.equal(predictions, 0))
158    return _count_condition(is_true_negative, ignore_mask, metrics_collections,
159                            updates_collections)
160
161
162def _streaming_true_positives(predictions, labels, ignore_mask=None,
163                              metrics_collections=None,
164                              updates_collections=None,
165                              name=None):
166  """Computes the total number of true_positives.
167
168  Args:
169    predictions: The predicted values, a binary `Tensor` of arbitrary
170      dimensions.
171    labels: The ground truth values, a binary `Tensor` whose dimensions must
172      match `predictions`.
173    ignore_mask: An optional, binary tensor whose size matches 'predictions'.
174    metrics_collections: An optional list of collections that the metric
175      value variable should be added to.
176    updates_collections: An optional list of collections that the metric update
177      ops should be added to.
178    name: An optional variable_scope name.
179
180  Returns:
181    value_tensor: A tensor representing the current value of the metric.
182    update_op: An operation that accumulates the error from a batch of data.
183
184  Raises:
185    ValueError: If either `metrics_collections` or `updates_collections` are not
186      a list or tuple.
187  """
188  with variable_scope.variable_scope(
189      name, 'true_positives', [predictions, labels]):
190
191    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
192    is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1),
193                                            math_ops.equal(predictions, 1))
194    return _count_condition(is_true_positive, ignore_mask, metrics_collections,
195                            updates_collections)
196
197
198def _streaming_false_positives(predictions, labels, ignore_mask=None,
199                               metrics_collections=None,
200                               updates_collections=None,
201                               name=None):
202  """Computes the total number of false positives.
203
204  Args:
205    predictions: The predicted values, a binary `Tensor` of arbitrary
206      dimensions.
207    labels: The ground truth values, a binary `Tensor` whose dimensions must
208      match `predictions`.
209    ignore_mask: An optional, binary tensor whose size matches 'predictions'.
210    metrics_collections: An optional list of collections that the metric
211      value variable should be added to.
212    updates_collections: An optional list of collections that the metric update
213      ops should be added to.
214    name: An optional variable_scope name.
215
216  Returns:
217    value_tensor: A tensor representing the current value of the metric.
218    update_op: An operation that accumulates the error from a batch of data.
219
220  Raises:
221    ValueError: If either `metrics_collections` or `updates_collections` are not
222      a list or tuple.
223  """
224  with variable_scope.variable_scope(
225      name, 'false_positives', [predictions, labels]):
226
227    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
228    is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0),
229                                             math_ops.equal(predictions, 1))
230    return _count_condition(is_false_positive, ignore_mask,
231                            metrics_collections, updates_collections)
232
233
234def _streaming_false_negatives(predictions, labels, ignore_mask=None,
235                               metrics_collections=None,
236                               updates_collections=None,
237                               name=None):
238  """Computes the total number of false positives.
239
240  Args:
241    predictions: The predicted values, a binary `Tensor` of arbitrary
242      dimensions.
243    labels: The ground truth values, a binary `Tensor` whose dimensions must
244      match `predictions`.
245    ignore_mask: An optional, binary tensor whose size matches 'predictions'.
246    metrics_collections: An optional list of collections that the metric
247      value variable should be added to.
248    updates_collections: An optional list of collections that the metric update
249      ops should be added to.
250    name: An optional variable_scope name.
251
252  Returns:
253    value_tensor: A tensor representing the current value of the metric.
254    update_op: An operation that accumulates the error from a batch of data.
255
256  Raises:
257    ValueError: If either `metrics_collections` or `updates_collections` are not
258      a list or tuple.
259  """
260  with variable_scope.variable_scope(
261      name, 'false_negatives', [predictions, labels]):
262
263    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
264    is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1),
265                                             math_ops.equal(predictions, 0))
266    return _count_condition(is_false_negative, ignore_mask,
267                            metrics_collections, updates_collections)
268
269
270def streaming_mean(values, weights=None, metrics_collections=None,
271                   updates_collections=None, name=None):
272  """Computes the (weighted) mean of the given values.
273
274  The `streaming_mean` function creates two local variables, `total` and `count`
275  that are used to compute the average of `values`. This average is ultimately
276  returned as `mean` which is an idempotent operation that simply divides
277  `total` by `count`. To facilitate the estimation of a mean over a stream
278  of data, the function creates an `update_op` operation whose behavior is
279  dependent on the value of `weights`. If `weights` is None, then `update_op`
280  increments `total` with the reduced sum of `values` and increments `count`
281  with the number of elements in `values`. If `weights` is not `None`, then
282  `update_op` increments `total` with the reduced sum of the product of `values`
283  and `weights` and increments `count` with the reduced sum of weights.
284  In addition to performing the updates, `update_op` also returns the
285  `mean`.
286
287  Args:
288    values: A `Tensor` of arbitrary dimensions.
289    weights: An optional set of weights of the same shape as `values`. If
290      `weights` is not None, the function computes a weighted mean.
291    metrics_collections: An optional list of collections that `mean`
292      should be added to.
293    updates_collections: An optional list of collections that `update_op`
294      should be added to.
295    name: An optional variable_scope name.
296
297  Returns:
298    mean: A tensor representing the current mean, the value of `total` divided
299      by `count`.
300    update_op: An operation that increments the `total` and `count` variables
301      appropriately and whose value matches `mean_value`.
302
303  Raises:
304    ValueError: If `weights` is not `None` and its shape doesn't match `values`
305      or if either `metrics_collections` or `updates_collections` are not a list
306      or tuple.
307  """
308  with variable_scope.variable_scope(name, 'mean', [values, weights]):
309    values = math_ops.to_float(values)
310
311    total = _create_local('total', shape=[])
312    count = _create_local('count', shape=[])
313
314    if weights is not None:
315      values.get_shape().assert_is_compatible_with(weights.get_shape())
316      weights = math_ops.to_float(weights)
317      values = math_ops.mul(values, weights)
318      num_values = math_ops.reduce_sum(weights)
319    else:
320      num_values = math_ops.to_float(array_ops.size(values))
321
322    total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
323    count_compute_op = state_ops.assign_add(count, num_values)
324
325    def compute_mean(total, count, name):
326      return math_ops.select(math_ops.greater(count, 0),
327                             math_ops.div(total, count),
328                             0, name)
329
330    mean = compute_mean(total, count, 'value')
331    with ops.control_dependencies([total_compute_op, count_compute_op]):
332      update_op = compute_mean(total, count, 'update_op')
333
334    if metrics_collections:
335      ops.add_to_collections(metrics_collections, mean)
336
337    if updates_collections:
338      ops.add_to_collections(updates_collections, update_op)
339
340    return mean, update_op
341
342
343def streaming_mean_tensor(values, weights=None, metrics_collections=None,
344                          updates_collections=None, name=None):
345  """Computes the element-wise (weighted) mean of the given tensors.
346
347  In contrast to the `streaming_mean` function which returns a scalar with the
348  mean,  this function returns an average tensor with the same shape as the
349  input tensors.
350
351  The `streaming_mean_tensor` function creates two local variables,
352  `total_tensor` and `count_tensor` that are used to compute the average of
353  `values`. This average is ultimately returned as `mean` which is an idempotent
354  operation that simply divides `total` by `count`. To facilitate the estimation
355  of a mean over a stream of data, the function creates an `update_op` operation
356  whose behavior is dependent on the value of `weights`. If `weights` is None,
357  then `update_op` increments `total` with the reduced sum of `values` and
358  increments `count` with the number of elements in `values`. If `weights` is
359  not `None`, then `update_op` increments `total` with the reduced sum of the
360  product of `values` and `weights` and increments `count` with the reduced sum
361  of weights. In addition to performing the updates, `update_op` also returns
362  the `mean`.
363
364  Args:
365    values: A `Tensor` of arbitrary dimensions.
366    weights: An optional set of weights of the same shape as `values`. If
367      `weights` is not None, the function computes a weighted mean.
368    metrics_collections: An optional list of collections that `mean`
369      should be added to.
370    updates_collections: An optional list of collections that `update_op`
371      should be added to.
372    name: An optional variable_scope name.
373
374  Returns:
375    mean: A float tensor representing the current mean, the value of `total`
376      divided by `count`.
377    update_op: An operation that increments the `total` and `count` variables
378      appropriately and whose value matches `mean_value`.
379
380  Raises:
381    ValueError: If `weights` is not `None` and its shape doesn't match `values`
382      or if either `metrics_collections` or `updates_collections` are not a list
383      or tuple.
384  """
385  with variable_scope.variable_scope(name, 'mean', [values, weights]):
386    total = _create_local('total_tensor', shape=values.get_shape())
387    count = _create_local('count_tensor', shape=values.get_shape())
388
389    if weights is not None:
390      values.get_shape().assert_is_compatible_with(weights.get_shape())
391      weights = math_ops.to_float(weights)
392      values = math_ops.mul(values, weights)
393      num_values = weights
394    else:
395      num_values = array_ops.ones_like(values)
396
397    total_compute_op = state_ops.assign_add(total, values)
398    count_compute_op = state_ops.assign_add(count, num_values)
399
400    def compute_mean(total, count, name):
401      non_zero_count = math_ops.maximum(count,
402                                        array_ops.ones_like(count),
403                                        name=name)
404      return math_ops.truediv(total, non_zero_count, name=name)
405
406    mean = compute_mean(total, count, 'value')
407    with ops.control_dependencies([total_compute_op, count_compute_op]):
408      update_op = compute_mean(total, count, 'update_op')
409
410    if metrics_collections:
411      ops.add_to_collections(metrics_collections, mean)
412
413    if updates_collections:
414      ops.add_to_collections(updates_collections, update_op)
415
416    return mean, update_op
417
418
419def streaming_accuracy(predictions, labels, weights=None,
420                       metrics_collections=None, updates_collections=None,
421                       name=None):
422  """Calculates how often `predictions` matches `labels`.
423
424  The `streaming_accuracy` function creates two local variables, `total` and
425  `count` that are used to compute the frequency with which `predictions`
426  matches `labels`. This frequency is ultimately returned as `accuracy`: an
427  idempotent operation that simply divides `total` by `count`.
428  To facilitate the estimation of the accuracy over a stream of data, the
429  function utilizes two operations. First, an `is_correct` operation that
430  computes a tensor whose shape matches `predictions` and whose elements are
431  set to 1.0 when the corresponding values of `predictions` and `labels match
432  and 0.0 otherwise. Second, an `update_op` operation whose behavior is
433  dependent on the value of `weights`. If `weights` is None, then `update_op`
434  increments `total` with the number of elements of `predictions` that match
435  `labels` and increments `count` with the number of elements in `values`. If
436  `weights` is not `None`, then `update_op` increments `total` with the reduced
437  sum of the product of `weights` and `is_correct` and increments `count` with
438  the reduced sum of `weights`. In addition to performing the updates,
439  `update_op` also returns the `accuracy` value.
440
441  Args:
442    predictions: The predicted values, a `Tensor` of any shape.
443    labels: The ground truth values, a `Tensor` whose shape matches
444      `predictions`.
445    weights: An optional set of weights whose shape matches `predictions`
446      which, when not `None`, produces a weighted mean accuracy.
447    metrics_collections: An optional list of collections that `accuracy` should
448      be added to.
449    updates_collections: An optional list of collections that `update_op` should
450      be added to.
451    name: An optional variable_scope name.
452
453  Returns:
454    accuracy: A tensor representing the accuracy, the value of `total` divided
455      by `count`.
456    update_op: An operation that increments the `total` and `count` variables
457      appropriately and whose value matches `accuracy`.
458
459  Raises:
460    ValueError: If the dimensions of `predictions` and `labels` don't match or
461      if `weight` is not `None` and its shape doesn't match `predictions` or
462      if either `metrics_collections` or `updates_collections` are not
463      a list or tuple.
464  """
465  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
466      predictions, labels)
467  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
468  is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
469  return streaming_mean(is_correct, weights, metrics_collections,
470                        updates_collections, name or 'accuracy')
471
472
473def streaming_precision(predictions, labels, ignore_mask=None,
474                        metrics_collections=None, updates_collections=None,
475                        name=None):
476  """Computes the precision of the predictions with respect to the labels.
477
478  The `streaming_precision` function creates two local variables,
479  `true_positives` and `false_positives`, that are used to compute the
480  precision. This value is ultimately returned as `precision`, an idempotent
481  operation that simply divides `true_positives` by the sum of `true_positives`
482  and `false_positives`. To facilitate the calculation of the precision over a
483  stream of data, the function creates an `update_op` operation whose behavior
484  is dependent on the value of `ignore_mask`. If `ignore_mask` is None, then
485  `update_op` increments `true_positives` with the number of elements of
486  `predictions` and `labels` that are both `True` and increments
487  `false_positives` with the number of elements of `predictions` that are `True`
488  whose corresponding `labels` element is `False`. If `ignore_mask` is not
489  `None`, then the increments for `true_positives` and `false_positives` are
490  only computed using elements of `predictions` and `labels` whose corresponding
491  values in `ignore_mask` are `False`. In addition to performing the updates,
492  `update_op` also returns the value of `precision`.
493
494  Args:
495    predictions: The predicted values, a binary `Tensor` of arbitrary shape.
496    labels: The ground truth values, a binary `Tensor` whose dimensions must
497      match `predictions`.
498    ignore_mask: An optional, binary tensor whose size matches `predictions`.
499    metrics_collections: An optional list of collections that `precision` should
500      be added to.
501    updates_collections: An optional list of collections that `update_op` should
502      be added to.
503    name: An optional variable_scope name.
504
505  Returns:
506    precision: Scalar float `Tensor` with the value of `true_positives`
507      divided by the sum of `true_positives` and `false_positives`.
508    update_op: `Operation` that increments `true_positives` and
509      `false_positives` variables appropriately and whose value matches
510      `precision`.
511
512  Raises:
513    ValueError: If the dimensions of `predictions` and `labels` don't match or
514      if `ignore_mask` is not `None` and its shape doesn't match `predictions`
515      or if either `metrics_collections` or `updates_collections` are not a list
516      or tuple.
517  """
518  with variable_scope.variable_scope(
519      name, 'precision', [predictions, labels]):
520
521    predictions, labels = metric_ops_util.remove_squeezable_dimensions(
522        predictions, labels)
523    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
524
525    true_positives, true_positives_update_op = _streaming_true_positives(
526        predictions, labels, ignore_mask, metrics_collections=None,
527        updates_collections=None, name=None)
528    false_positives, false_positives_update_op = _streaming_false_positives(
529        predictions, labels, ignore_mask, metrics_collections=None,
530        updates_collections=None, name=None)
531
532    def compute_precision(name):
533      return math_ops.select(
534          math_ops.greater(true_positives + false_positives, 0),
535          math_ops.div(true_positives, true_positives + false_positives),
536          0,
537          name)
538
539    precision = compute_precision('value')
540    with ops.control_dependencies([true_positives_update_op,
541                                   false_positives_update_op]):
542      update_op = compute_precision('update_op')
543
544    if metrics_collections:
545      ops.add_to_collections(metrics_collections, precision)
546
547    if updates_collections:
548      ops.add_to_collections(updates_collections, update_op)
549
550    return precision, update_op
551
552
553def streaming_recall(predictions, labels, ignore_mask=None,
554                     metrics_collections=None, updates_collections=None,
555                     name=None):
556  """Computes the recall of the predictions with respect to the labels.
557
558  The `streaming_recall` function creates two local variables,
559  `true_positives` and `false_negatives`, that are used to compute the
560  recall. This value is ultimately returned as `recall`, an idempotent
561  operation that simply divides `true_positives` by the sum of `true_positives`
562  and `false_negatives`. To facilitate the calculation of the recall over a
563  stream of data, the function creates an `update_op` operation whose behavior
564  is dependent on the value of `ignore_mask`. If `ignore_mask` is None, then
565  `update_op` increments `true_positives` with the number of elements of
566  `predictions` and `labels` that are both `True` and increments
567  `false_negatives` with the number of elements of `predictions` that are
568  `False` whose corresponding `labels` element is `False`. If `ignore_mask` is
569  not `None`, then the increments for `true_positives` and `false_negatives` are
570  only computed using elements of `predictions` and `labels` whose corresponding
571  values in `ignore_mask` are `False`. In addition to performing the updates,
572  `update_op` also returns the value of `recall`.
573
574  Args:
575    predictions: The predicted values, a binary `Tensor` of arbitrary shape.
576    labels: The ground truth values, a binary `Tensor` whose dimensions must
577      match `predictions`.
578    ignore_mask: An optional, binary tensor whose size matches `predictions`.
579    metrics_collections: An optional list of collections that `recall` should
580      be added to.
581    updates_collections: An optional list of collections that `update_op` should
582      be added to.
583    name: An optional variable_scope name.
584
585  Returns:
586    recall: Scalar float `Tensor` with the value of `true_positives` divided
587      by the sum of `true_positives` and `false_negatives`.
588    update_op: `Operation` that increments `true_positives` and
589      `false_negatives` variables appropriately and whose value matches
590      `recall`.
591
592  Raises:
593    ValueError: If the dimensions of `predictions` and `labels` don't match or
594      if `ignore_mask` is not `None` and its shape doesn't match `predictions`
595      or if either `metrics_collections` or `updates_collections` are not a list
596      or tuple.
597  """
598  with variable_scope.variable_scope(name, 'recall', [predictions, labels]):
599    predictions, labels = metric_ops_util.remove_squeezable_dimensions(
600        predictions, labels)
601    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
602
603    true_positives, true_positives_update_op = _streaming_true_positives(
604        predictions, labels, ignore_mask, metrics_collections=None,
605        updates_collections=None, name=None)
606    false_negatives, false_negatives_update_op = _streaming_false_negatives(
607        predictions, labels, ignore_mask, metrics_collections=None,
608        updates_collections=None, name=None)
609
610    def compute_recall(true_positives, false_negatives, name):
611      return math_ops.select(
612          math_ops.greater(true_positives + false_negatives, 0),
613          math_ops.div(true_positives, true_positives + false_negatives),
614          0,
615          name)
616
617    recall = compute_recall(true_positives, false_negatives, 'value')
618    with ops.control_dependencies([true_positives_update_op,
619                                   false_negatives_update_op]):
620      update_op = compute_recall(true_positives, false_negatives, 'update_op')
621
622    if metrics_collections:
623      ops.add_to_collections(metrics_collections, recall)
624
625    if updates_collections:
626      ops.add_to_collections(updates_collections, update_op)
627
628    return recall, update_op
629
630
631def _tp_fn_tn_fp(predictions, labels, thresholds, weights):
632  """Computes true_positives, false_negatives, true_negatives, false_positives.
633
634  The `_tp_fn_tn_fp` function creates four local variables, `true_positives`,
635  `true_negatives`, `false_positives` and `false_negatives`.
636  `true_positive[i]` is defined as the total weight of values in `predictions`
637  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
638  `false_negatives[i]` is defined as the total weight of values in `predictions`
639  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
640  `true_negatives[i]` is defined as the total weight of values in `predictions`
641  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
642  `false_positives[i]` is defined as the total weight of values in `predictions`
643  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
644
645  These four variables are updated through the `update_op`.
646  The streaming behavior is that the values of the variables after a few
647  `update_op`s is the same as if the inputs had been concatenated and a single
648  `update_op` had been performed.
649
650  If `weights` is `None`, all entries are assumed to have weight 1. Note that
651  a weight of 0 effectively discards an entry from consideration.
652
653  Args:
654    predictions: A floating point `Tensor` of arbitrary shape and whose values
655      are in the range `[0, 1]`.
656    labels: A binary `Tensor` whose shape matches `predictions`.
657    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
658    weights: An optional, floating point `Tensor` with the same shape as
659      `predictions`.
660
661  Returns:
662    true_positive: A variable of shape [len(thresholds)].
663    false_negative: A variable of shape [len(thresholds)].
664    true_negatives: A variable of shape [len(thresholds)].
665    false_positives: A variable of shape [len(thresholds)].
666    true_positives_update_op: An operation that increments the `true_positives`.
667    false_negative_update_op: An operation that increments the `false_negative`.
668    true_negatives_update_op: An operation that increments the `true_negatives`.
669    false_positives_update_op: An operation that increments the
670      `false_positives`.
671
672  Raises:
673    ValueError: If the shape of `predictions` and `labels` do not match or if
674      `weights` is not `None` and its shape doesn't match `predictions`
675      or if either `metrics_collections` or `updates_collections` are not a list
676      or tuple.
677  """
678  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
679      predictions, labels)
680  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
681
682  num_thresholds = len(thresholds)
683
684  # Reshape predictions and labels
685  predictions = array_ops.reshape(predictions, [-1, 1])
686  labels = array_ops.reshape(math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
687
688  # Use static shape if known.
689  num_predictions = predictions.get_shape().as_list()[0]
690
691  # Otherwise use dynamic shape.
692  if num_predictions is None:
693    num_predictions = array_ops.shape(predictions)[0]
694  thresh_tiled = array_ops.tile(
695      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
696      array_ops.pack([1, num_predictions]))
697
698  # Tile the predictions after thresholding them across different thresholds.
699  pred_is_pos = math_ops.greater(
700      array_ops.tile(array_ops.transpose(predictions), [num_thresholds, 1]),
701      thresh_tiled)
702  pred_is_neg = math_ops.logical_not(pred_is_pos)
703
704  # Tile labels by number of thresholds
705  label_is_pos = array_ops.tile(labels, [num_thresholds, 1])
706  label_is_neg = math_ops.logical_not(label_is_pos)
707
708  true_positives = _create_local('true_positives', shape=[num_thresholds])
709  false_negatives = _create_local('false_negatives', shape=[num_thresholds])
710  true_negatives = _create_local('true_negatives', shape=[num_thresholds])
711  false_positives = _create_local('false_positives', shape=[num_thresholds])
712
713  is_true_positive = math_ops.to_float(
714      math_ops.logical_and(label_is_pos, pred_is_pos))
715  is_false_negative = math_ops.to_float(
716      math_ops.logical_and(label_is_pos, pred_is_neg))
717  is_false_positive = math_ops.to_float(
718      math_ops.logical_and(label_is_neg, pred_is_pos))
719  is_true_negative = math_ops.to_float(
720      math_ops.logical_and(label_is_neg, pred_is_neg))
721
722  if weights is not None:
723    weights_tiled = array_ops.tile(
724        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
725    thresh_tiled.get_shape().assert_is_compatible_with(
726        weights_tiled.get_shape())
727    check_ops.assert_type(weights_tiled, dtypes.float32)
728    is_true_positive *= weights_tiled
729    is_false_negative *= weights_tiled
730    is_false_positive *= weights_tiled
731    is_true_negative *= weights_tiled
732
733  true_positives_update_op = state_ops.assign_add(
734      true_positives, math_ops.reduce_sum(is_true_positive, 1))
735  false_negatives_update_op = state_ops.assign_add(
736      false_negatives, math_ops.reduce_sum(is_false_negative, 1))
737  true_negatives_update_op = state_ops.assign_add(
738      true_negatives, math_ops.reduce_sum(is_true_negative, 1))
739  false_positives_update_op = state_ops.assign_add(
740      false_positives, math_ops.reduce_sum(is_false_positive, 1))
741
742  return (true_positives, false_negatives, true_negatives, false_positives,
743          true_positives_update_op, false_negatives_update_op,
744          true_negatives_update_op, false_positives_update_op)
745
746
747def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
748                  metrics_collections=None, updates_collections=None,
749                  curve='ROC', name=None):
750  """Computes the approximate AUC via a Riemann sum.
751
752  The `streaming_auc` function creates four local variables, `true_positives`,
753  `true_negatives`, `false_positives` and `false_negatives` that are used to
754  compute the AUC. To discretize the AUC curve, a linearly spaced set of
755  thresholds is used to compute pairs of recall and precision values. The area
756  under the ROC-curve is therefore computed using the height of the recall
757  values by the false positive rate, while the area under the PR-curve is the
758  computed using the height of the precision values by the recall.
759
760  This value is ultimately returned as `auc`, an idempotent
761  operation the computes the area under a discretized curve of precision versus
762  recall values (computed using the afformentioned variables). The
763  `num_thresholds` variable controls the degree of discretization with larger
764  numbers of thresholds more closely approximating the true AUC.
765
766  To faciliate the estimation of the AUC over a stream of data, the function
767  creates an `update_op` operation. `update_op` increments the
768  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
769  counts with the weighted number of each found in the current `predictions`
770  and `labels` `Tensors`. If `weights` is `None`, it is assumed that all
771  entries have weight 1. Note that a weight of 0 can be used to effectively
772  mask out and ignore specific entries. In addition to performing the updates,
773  `update_op` also returns the `auc`.
774
775  Args:
776    predictions: A floating point `Tensor` of arbitrary shape and whose values
777      are in the range `[0, 1]`.
778    labels: A binary `Tensor` whose shape matches `predictions`.
779    weights: An optional, floating point `Tensor` of same shape as
780      `predictions`.
781    num_thresholds: The number of thresholds to use when discretizing the roc
782      curve.
783    metrics_collections: An optional list of collections that `auc` should be
784      added to.
785    updates_collections: An optional list of collections that `update_op` should
786      be added to.
787    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
788    'PR' for the Precision-Recall-curve.
789    name: An optional variable_scope name.
790
791  Returns:
792    auc: A scalar tensor representing the current area-under-curve.
793    update_op: An operation that increments the `true_positives`,
794      `true_negatives`, `false_positives` and `false_negatives` variables
795      appropriately and whose value matches `auc`.
796
797  Raises:
798    ValueError: If the shape of `predictions` and `labels` do not match or if
799      `weights` is not `None` and its shape doesn't match `predictions` or
800      if either `metrics_collections` or `updates_collections` are not a list or
801      tuple.
802  """
803  with variable_scope.variable_scope(name, 'auc', [predictions, labels]):
804    if curve != 'ROC' and  curve != 'PR':
805      raise ValueError('curve must be either ROC or PR, %s unknown' %
806                       (curve))
807    kepsilon = 1e-7  # to account for floating point imprecisions
808    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
809                  for i in range(num_thresholds-2)]
810    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
811
812    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
813     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
814
815    # Add epsilons to avoid dividing by 0.
816    epsilon = 1.0e-6
817    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
818
819    def compute_auc(tp, fn, tn, fp, name):
820      """Computes the roc-auc or pr-auc based on confusion counts."""
821      recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
822      if curve == 'ROC':
823        fp_rate = math_ops.div(fp, fp + tn + epsilon)
824        x = fp_rate
825        y = recall
826      else:  # curve == 'PR'.
827        precision = math_ops.div(tp + epsilon, tp + fp + epsilon)
828        x = recall
829        y = precision
830      return math_ops.reduce_sum(math_ops.mul(
831          x[:num_thresholds - 1] - x[1:],
832          (y[:num_thresholds - 1] + y[1:]) / 2.), name=name)
833
834    # sum up the areas of all the trapeziums
835    auc = compute_auc(tp, fn, tn, fp, 'value')
836    update_op = compute_auc(
837        tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op')
838
839    if metrics_collections:
840      ops.add_to_collections(metrics_collections, auc)
841
842    if updates_collections:
843      ops.add_to_collections(updates_collections, update_op)
844
845    return auc, update_op
846
847
848def streaming_specificity_at_sensitivity(
849    predictions, labels, sensitivity, weights=None, num_thresholds=200,
850    metrics_collections=None, updates_collections=None, name=None):
851  """Computes the the specificity at a given sensitivity.
852
853  The `streaming_specificity_at_sensitivity` function creates four local
854  variables, `true_positives`, `true_negatives`, `false_positives` and
855  `false_negatives` that are used to compute the specificity at the given
856  sensitivity value. The threshold for the given sensitivity value is computed
857  and used to evaluate the corresponding specificity.
858
859  To faciliate the estimation of the metric over a stream of data, the function
860  creates an `update_op` operation. `update_op` increments the
861  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
862  counts with the weighted number of each found in the current `predictions`
863  and `labels` `Tensors`. If `weights` is `None`, it is assumed that all
864  entries have weight 1. Note that a weight of 0 can be used to effectively
865  mask out and ignore specific entries. In addition to performing the updates,
866  `update_op` also returns the `specificity`.
867
868  For additional information about specificity and sensitivity, see the
869  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
870
871  Args:
872    predictions: A floating point `Tensor` of arbitrary shape and whose values
873      are in the range `[0, 1]`.
874    labels: A binary `Tensor` whose shape matches `predictions`.
875    sensitivity: A scalar value in range `[0, 1]`.
876    weights: An optional, floating point `Tensor` of same shape as
877      `predictions`.
878    num_thresholds: The number of thresholds to use for matching the given
879      sensitivity.
880    metrics_collections: An optional list of collections that `specificity`
881      should be added to.
882    updates_collections: An optional list of collections that `update_op` should
883      be added to.
884    name: An optional variable_scope name.
885
886  Returns:
887    specificity: A scalar tensor representing the specificity at the given
888      `specificity` value.
889    update_op: An operation that increments the `true_positives`,
890      `true_negatives`, `false_positives` and `false_negatives` variables
891      appropriately and whose value matches `specificity`.
892
893  Raises:
894    ValueError: If the shape of `predictions` and `labels` do not match or if
895      `weights` is not `None` and its shape doesn't match `predictions` or
896      `sensitivity` is not between 0 and 1 or if either `metrics_collections` or
897      `updates_collections` are not a list or tuple.
898  """
899  if sensitivity < 0 or sensitivity > 1:
900    raise ValueError('`sensitivity` must be in the range [0, 1].')
901
902  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
903                                     [predictions, labels]):
904    kepsilon = 1e-7  # to account for floating point imprecisions
905    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
906                  for i in range(num_thresholds-2)]
907    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
908
909    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
910     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
911
912    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
913
914    def compute_specificity_at_sensitivity(name):
915      """Computes the specificity at the given sensitivity.
916
917      Args:
918        name: The name of the operation.
919
920      Returns:
921        The specificity using the aggregated values.
922      """
923      sensitivities = math_ops.div(tp, tp + fn + kepsilon)
924
925      # We'll need to use this trick until tf.argmax allows us to specify
926      # whether we should use the first or last index in case of ties.
927      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
928      indices_at_minval = math_ops.equal(
929          math_ops.abs(sensitivities - sensitivity), min_val)
930      indices_at_minval = math_ops.to_int64(indices_at_minval)
931      indices_at_minval = math_ops.cumsum(indices_at_minval)
932      tf_index = math_ops.argmax(indices_at_minval, 0)
933      tf_index = math_ops.cast(tf_index, dtypes.int32)
934
935      # Now, we have the implicit threshold, so compute the specificity:
936      return math_ops.div(tn[tf_index],
937                          tn[tf_index] + fp[tf_index] + kepsilon,
938                          name)
939
940    specificity = compute_specificity_at_sensitivity('value')
941    with ops.control_dependencies(
942        [tp_update_op, fn_update_op, tn_update_op, fp_update_op]):
943      update_op = compute_specificity_at_sensitivity('update_op')
944
945    if metrics_collections:
946      ops.add_to_collections(metrics_collections, specificity)
947
948    if updates_collections:
949      ops.add_to_collections(updates_collections, update_op)
950
951    return specificity, update_op
952
953
954def streaming_sensitivity_at_specificity(
955    predictions, labels, specificity, weights=None, num_thresholds=200,
956    metrics_collections=None, updates_collections=None, name=None):
957  """Computes the the specificity at a given sensitivity.
958
959  The `streaming_sensitivity_at_specificity` function creates four local
960  variables, `true_positives`, `true_negatives`, `false_positives` and
961  `false_negatives` that are used to compute the sensitivity at the given
962  specificity value. The threshold for the given specificity value is computed
963  and used to evaluate the corresponding sensitivity.
964
965  To faciliate the estimation of the metric over a stream of data, the function
966  creates an `update_op` operation. `update_op` increments the
967  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
968  counts with the weighted number of each found in the current `predictions`
969  and `labels` `Tensors`. If `weights` is `None`, it is assumed that all
970  entries have weight 1. Note that a weight of 0 can be used to effectively
971  mask out and ignore specific entries. In addition to performing the updates,
972  `update_op` also returns the `sensitivity`.
973
974  For additional information about specificity and sensitivity, see the
975  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
976
977  Args:
978    predictions: A floating point `Tensor` of arbitrary shape and whose values
979      are in the range `[0, 1]`.
980    labels: A binary `Tensor` whose shape matches `predictions`.
981    specificity: A scalar value in range `[0, 1]`.
982    weights: An optional, floating point `Tensor` of same shape as
983      `predictions`.
984    num_thresholds: The number of thresholds to use for matching the given
985      specificity.
986    metrics_collections: An optional list of collections that `sensitivity`
987      should be added to.
988    updates_collections: An optional list of collections that `update_op` should
989      be added to.
990    name: An optional variable_scope name.
991
992  Returns:
993    sensitivity: A scalar tensor representing the sensitivity at the given
994      `specificity` value.
995    update_op: An operation that increments the `true_positives`,
996      `true_negatives`, `false_positives` and `false_negatives` variables
997      appropriately and whose value matches `sensitivity`.
998
999  Raises:
1000    ValueError: If the shape of `predictions` and `labels` do not match or if
1001      `weights` is not `None` and its shape doesn't match `predictions` or
1002      `specificity` is not between 0 and 1 or if either `metrics_collections` or
1003      `updates_collections` are not a list or tuple.
1004  """
1005  if specificity < 0 or specificity > 1:
1006    raise ValueError('`specificity` must be in the range [0, 1].')
1007
1008  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
1009                                     [predictions, labels]):
1010    kepsilon = 1e-7  # to account for floating point imprecisions
1011    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1012                  for i in range(num_thresholds-2)]
1013    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
1014
1015    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
1016     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
1017    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
1018
1019    def compute_sensitivity_at_specificity(name):
1020      specificities = math_ops.div(tn, tn + fp + kepsilon)
1021      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
1022      tf_index = math_ops.cast(tf_index, dtypes.int32)
1023
1024      # Now, we have the implicit threshold, so compute the sensitivity:
1025      return math_ops.div(tp[tf_index],
1026                          tp[tf_index] + fn[tf_index] + kepsilon,
1027                          name)
1028
1029    sensitivity = compute_sensitivity_at_specificity('value')
1030    with ops.control_dependencies(
1031        [tp_update_op, fn_update_op, tn_update_op, fp_update_op]):
1032      update_op = compute_sensitivity_at_specificity('update_op')
1033
1034    if metrics_collections:
1035      ops.add_to_collections(metrics_collections, sensitivity)
1036
1037    if updates_collections:
1038      ops.add_to_collections(updates_collections, update_op)
1039
1040    return sensitivity, update_op
1041
1042
1043def streaming_precision_at_thresholds(predictions, labels, thresholds,
1044                                      weights=None,
1045                                      metrics_collections=None,
1046                                      updates_collections=None, name=None):
1047  """Computes precision values for different `thresholds` on `predictions`.
1048
1049  The `streaming_precision_at_thresholds` function creates four local variables,
1050  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1051  for various values of thresholds.
1052  `precision[i]` is defined as the total weight of values in `predictions` above
1053  `thresholds[i]` whose corresponding entry in `labels` is `True`
1054  (`true_positives[i]`) divided by the total weight of values in `predictions`
1055  above `thresholds[i]` (`true_positives[i] + false_positives[i]`).
1056
1057  If `weights` is `None` then all entries are assumed to have equal weight 1.
1058
1059  `precision` is returned along with an `update_op` whose value equals that of
1060  `precision`.
1061
1062  Args:
1063    predictions: A floating point `Tensor` of arbitrary shape and whose values
1064      are in the range `[0, 1]`.
1065    labels: A binary `Tensor` whose shape matches `predictions`.
1066    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1067    weights: An optional, floating point `Tensor` of same shape as
1068      `predictions`.
1069    metrics_collections: An optional list of collections that `auc` should be
1070      added to.
1071    updates_collections: An optional list of collections that `update_op` should
1072      be added to.
1073    name: An optional variable_scope name.
1074
1075  Returns:
1076    precision: A float tensor of shape [len(thresholds)].
1077    update_op: An operation that increments the `true_positives`,
1078      `true_negatives`, `false_positives` and `false_negatives` variables that
1079      are used in the computation of `precision`.
1080
1081  Raises:
1082    ValueError: If the shape of `predictions` and `labels` do not match or if
1083      `weights` is not `None` and its shape doesn't match `predictions`
1084      or if either `metrics_collections` or `updates_collections` are not a list
1085      or tuple.
1086  """
1087  with variable_scope.variable_scope(name, 'precision_at_thresholds',
1088                                     [predictions, labels]):
1089
1090    # TODO(nsilberman): Replace with only tp and fp, this results in unnecessary
1091    # variable creation. b/30842882
1092    (true_positives, _, _, false_positives, true_positives_compute_op, _, _,
1093     false_positives_compute_op,) = _tp_fn_tn_fp(
1094         predictions, labels, thresholds, weights)
1095
1096    # avoid division by zero
1097    epsilon = 1e-7
1098    def compute_precision(name):
1099      precision = math_ops.div(true_positives,
1100                               epsilon + true_positives + false_positives,
1101                               name='precision_' + name)
1102      return precision
1103
1104    precision = compute_precision('value')
1105    with ops.control_dependencies([true_positives_compute_op,
1106                                   false_positives_compute_op]):
1107      update_op = compute_precision('update_op')
1108
1109    if metrics_collections:
1110      ops.add_to_collections(metrics_collections, precision)
1111
1112    if updates_collections:
1113      ops.add_to_collections(updates_collections, update_op)
1114
1115    return precision, update_op
1116
1117
1118def streaming_recall_at_thresholds(predictions, labels, thresholds,
1119                                   weights=None, metrics_collections=None,
1120                                   updates_collections=None, name=None):
1121  """Computes various recall values for different `thresholds` on `predictions`.
1122
1123  The `streaming_recall_at_thresholds` function creates four local variables,
1124  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1125  for various values of thresholds.
1126  `recall[i]` is defined as the total weight of values in `predictions` above
1127  `thresholds[i]` whose corresponding entry in `labels` is `True`
1128  (`true_positives[i]`) divided by the total weight of True values in `labels`
1129  (`true_positives[i] + false_negatives[i]`).
1130
1131  If `weights` is `None` then all entries are assumed to have equal weight 1.
1132
1133  `recall` are returned along with an `update_op` whose value equals that of
1134  `recall`.
1135
1136  Args:
1137    predictions: A floating point `Tensor` of arbitrary shape and whose values
1138      are in the range `[0, 1]`.
1139    labels: A binary `Tensor` whose shape matches `predictions`.
1140    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1141    weights: An optional, floating point `Tensor` of same shape as
1142      `predictions`.
1143    metrics_collections: An optional list of collections that `recall` should be
1144      added to.
1145    updates_collections: An optional list of collections that `update_op` should
1146      be added to.
1147    name: An optional variable_scope name.
1148
1149  Returns:
1150    recall: A float tensor of shape [len(thresholds)].
1151    update_op: An operation that increments the `true_positives`,
1152      `true_negatives`, `false_positives` and `false_negatives` variables that
1153      are used in the computation of `recall`.
1154
1155  Raises:
1156    ValueError: If the shape of `predictions` and `labels` do not match or if
1157      `weights` is not `None` and its shape doesn't match `predictions`
1158      or if either `metrics_collections` or `updates_collections` are not a list
1159      or tuple.
1160  """
1161  with variable_scope.variable_scope(name, 'recall_at_thresholds',
1162                                     [predictions, labels]):
1163    (true_positives, false_negatives, _, _, true_positives_compute_op,
1164     false_negatives_compute_op, _, _,) = _tp_fn_tn_fp(
1165         predictions, labels, thresholds, weights)
1166
1167    # avoid division by zero
1168    epsilon = 1e-7
1169    def compute_recall(name):
1170      recall = math_ops.div(true_positives,
1171                            epsilon + true_positives + false_negatives,
1172                            name='recall_' + name)
1173      return recall
1174
1175    recall = compute_recall('value')
1176    with ops.control_dependencies([true_positives_compute_op,
1177                                   false_negatives_compute_op]):
1178      update_op = compute_recall('update_op')
1179
1180    if metrics_collections:
1181      ops.add_to_collections(metrics_collections, recall)
1182
1183    if updates_collections:
1184      ops.add_to_collections(updates_collections, update_op)
1185
1186    return recall, update_op
1187
1188
1189def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
1190                          metrics_collections=None, updates_collections=None,
1191                          name=None):
1192  """Computes the recall@k of the predictions with respect to dense labels.
1193
1194  The `streaming_recall_at_k` function creates two local variables, `total` and
1195  `count`, that are used to compute the recall@k frequency. This frequency is
1196  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1197  divides `total` by `count`. To facilitate the estimation of recall@k over a
1198  stream of data, the function utilizes two operations. First, an `in_top_k`
1199  operation computes a tensor with shape [batch_size] whose elements indicate
1200  whether or not the corresponding label is in the top `k` predictions of the
1201  `predictions` `Tensor`. Second, an `update_op` operation whose behavior is
1202  dependent on the value of `ignore_mask`. If `ignore_mask` is None, then
1203  `update_op` increments `total` with the number of elements of `in_top_k` that
1204  are set to `True` and increments `count` with the batch size. If `ignore_mask`
1205  is not `None`, then `update_op` increments `total` with the number of elements
1206  in `in_top_k` that are `True` whose corresponding element in `ignore_mask` is
1207  `False`. In addition to performing the updates, `update_op` also returns the
1208  recall value.
1209
1210  Args:
1211    predictions: A floating point tensor of dimension [batch_size, num_classes]
1212    labels: A tensor of dimension [batch_size] whose type is in `int32`,
1213      `int64`.
1214    k: The number of top elements to look at for computing recall.
1215    ignore_mask: An optional, binary tensor whose size matches `labels`. If an
1216      element of `ignore_mask` is True, the corresponding prediction and label
1217      pair is used to compute the metrics. Otherwise, the pair is ignored.
1218    metrics_collections: An optional list of collections that `recall_at_k`
1219      should be added to.
1220    updates_collections: An optional list of collections `update_op` should be
1221      added to.
1222    name: An optional variable_scope name.
1223
1224  Returns:
1225    recall_at_k: A tensor representing the recall@k, the fraction of labels
1226      which fall into the top `k` predictions.
1227    update_op: An operation that increments the `total` and `count` variables
1228      appropriately and whose value matches `recall_at_k`.
1229
1230  Raises:
1231    ValueError: If the dimensions of `predictions` and `labels` don't match or
1232      if `ignore_mask` is not `None` and its shape doesn't match `predictions`
1233      or if either `metrics_collections` or `updates_collections` are not a list
1234      or tuple.
1235  """
1236  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1237  return streaming_mean(in_top_k, _mask_to_weights(ignore_mask),
1238                        metrics_collections,
1239                        updates_collections,
1240                        name or ('recall_at_%d' % k))
1241
1242
1243# TODO(ptucker): Validate range of values in labels?
1244def streaming_sparse_recall_at_k(predictions,
1245                                 labels,
1246                                 k,
1247                                 class_id=None,
1248                                 ignore_mask=None,
1249                                 metrics_collections=None,
1250                                 updates_collections=None,
1251                                 name=None):
1252  """Computes recall@k of the predictions with respect to sparse labels.
1253
1254  If `class_id` is specified, we calculate recall by considering only the
1255      entries in the batch for which `class_id` is in the label, and computing
1256      the fraction of them for which `class_id` is in the top-k `predictions`.
1257  If `class_id` is not specified, we'll calculate recall as how often on
1258      average a class among the labels of a batch entry is in the top-k
1259      `predictions`.
1260
1261  `streaming_sparse_recall_at_k` creates two local variables,
1262  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1263  the recall_at_k frequency. This frequency is ultimately returned as
1264  `recall_at_<k>`: an idempotent operation that simply divides
1265  `true_positive_at_<k>` by total (`true_positive_at_<k>` + `recall_at_<k>`). To
1266  facilitate the estimation of recall@k over a stream of data, the function
1267  utilizes three steps.
1268  * A `top_k` operation computes a tensor whose elements indicate the top `k`
1269    predictions of the `predictions` `Tensor`.
1270  * Set operations are applied to `top_k` and `labels` to calculate true
1271    positives and false negatives.
1272  * An `update_op` operation increments `true_positive_at_<k>` and
1273    `false_negative_at_<k>`. It also returns the recall value.
1274
1275  Args:
1276    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1277      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1278      The final dimension contains the logit values for each class. [D1, ... DN]
1279      must match `labels`.
1280    labels: `int64` `Tensor` or `SparseTensor` with shape
1281      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1282      target classes for the associated prediction. Commonly, N=1 and `labels`
1283      has shape [batch_size, num_labels]. [D1, ... DN] must match `labels`.
1284      Values should be in range [0, num_classes], where num_classes is the last
1285      dimension of `predictions`.
1286    k: Integer, k for @k metric.
1287    class_id: Integer class ID for which we want binary metrics. This should be
1288      in range [0, num_classes], where num_classes is the last dimension of
1289      `predictions`.
1290    ignore_mask: An optional, binary tensor whose shape is broadcastable to the
1291      the first [D1, ... DN] dimensions of `predictions_idx` and `labels`.
1292    metrics_collections: An optional list of collections that values should
1293      be added to.
1294    updates_collections: An optional list of collections that updates should
1295      be added to.
1296    name: Name of new update operation, and namespace for other dependant ops.
1297
1298  Returns:
1299    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1300      by the sum of `true_positives` and `false_negatives`.
1301    update_op: `Operation` that increments `true_positives` and
1302      `false_negatives` variables appropriately, and whose value matches
1303      `recall`.
1304  """
1305  default_name = 'recall_at_%d' % k
1306  if class_id is not None:
1307    default_name = '%s_class%d' % (default_name, class_id)
1308
1309  with ops.name_scope(name, default_name, [predictions, labels]) as scope:
1310    _, top_k_idx = nn.top_k(predictions, k)
1311    top_k_idx = math_ops.to_int64(top_k_idx)
1312    tp, tp_update = _streaming_sparse_true_positive_at_k(
1313        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1314        ignore_mask=ignore_mask)
1315    fn, fn_update = _streaming_sparse_false_negative_at_k(
1316        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1317        ignore_mask=ignore_mask)
1318
1319    metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
1320    update = math_ops.div(
1321        tp_update, math_ops.add(tp_update, fn_update), name='update')
1322    if metrics_collections:
1323      ops.add_to_collections(metrics_collections, metric)
1324    if updates_collections:
1325      ops.add_to_collections(updates_collections, update)
1326    return metric, update
1327
1328
1329# TODO(ptucker): Validate range of values in labels?
1330def streaming_sparse_precision_at_k(predictions,
1331                                    labels,
1332                                    k,
1333                                    class_id=None,
1334                                    ignore_mask=None,
1335                                    metrics_collections=None,
1336                                    updates_collections=None,
1337                                    name=None):
1338  """Computes precision@k of the predictions with respect to sparse labels.
1339
1340  If `class_id` is specified, we calculate precision by considering only the
1341      entries in the batch for which `class_id` is in the top-k highest
1342      `predictions`, and computing the fraction of them for which `class_id` is
1343      indeed a correct label.
1344  If `class_id` is not specified, we'll calculate precision as how often on
1345      average a class among the top-k classes with the highest predicted values
1346      of a batch entry is correct and can be found in the label for that entry.
1347
1348  `streaming_sparse_precision_at_k` creates two local variables,
1349  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1350  the precision@k frequency. This frequency is ultimately returned as
1351  `precision_at_<k>`: an idempotent operation that simply divides
1352  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1353  `false_positive_at_<k>`). To facilitate the estimation of
1354  precision@k over a stream of data, the function utilizes three
1355  steps.
1356  * A `top_k` operation computes a tensor whose elements indicate the top `k`
1357    predictions of the `predictions` `Tensor`.
1358  * Set operations are applied to `top_k` and `labels` to calculate true
1359    positives and false positives.
1360  * An `update_op` operation increments `true_positive_at_<k>` and
1361    `false_positive_at_<k>`. It also returns the precision value.
1362
1363  Args:
1364    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1365      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1366      The final dimension contains the logit values for each class. [D1, ... DN]
1367      must match `labels`.
1368    labels: `int64` `Tensor` or `SparseTensor` with shape
1369      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1370      target classes for the associated prediction. Commonly, N=1 and `labels`
1371      has shape [batch_size, num_labels]. [D1, ... DN] must match
1372      `predictions_idx`. Values should be in range [0, num_classes], where
1373      num_classes is the last dimension of `predictions`.
1374    k: Integer, k for @k metric.
1375    class_id: Integer class ID for which we want binary metrics. This should be
1376      in range [0, num_classes], where num_classes is the last dimension of
1377      `predictions`.
1378    ignore_mask: An optional, binary tensor whose shape is broadcastable to the
1379      the first [D1, ... DN] dimensions of `predictions_idx` and `labels`.
1380    metrics_collections: An optional list of collections that values should
1381      be added to.
1382    updates_collections: An optional list of collections that updates should
1383      be added to.
1384    name: Name of new update operation, and namespace for other dependant ops.
1385
1386  Returns:
1387    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1388      divided by the sum of `true_positives` and `false_positives`.
1389    update_op: `Operation` that increments `true_positives` and
1390      `false_positives` variables appropriately, and whose value matches
1391      `precision`.
1392  """
1393  default_name = 'precision_at_%d' % k
1394  if class_id is not None:
1395    default_name = '%s_class%d' % (default_name, class_id)
1396  with ops.name_scope(name, default_name, [predictions, labels]) as scope:
1397    _, top_k_idx = nn.top_k(predictions, k)
1398    top_k_idx = math_ops.to_int64(top_k_idx)
1399    tp, tp_update = _streaming_sparse_true_positive_at_k(
1400        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1401        ignore_mask=ignore_mask)
1402    fp, fp_update = _streaming_sparse_false_positive_at_k(
1403        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1404        ignore_mask=ignore_mask)
1405
1406    metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
1407    update = math_ops.div(
1408        tp_update, math_ops.add(tp_update, fp_update), name='update')
1409    if metrics_collections:
1410      ops.add_to_collections(metrics_collections, metric)
1411    if updates_collections:
1412      ops.add_to_collections(updates_collections, update)
1413    return metric, update
1414
1415
1416def _select_class_id(ids, selected_id):
1417  """Filter all but `selected_id` out of `ids`.
1418
1419  Args:
1420    ids: `int64` `Tensor` or `SparseTensor` of IDs.
1421    selected_id: Int id to select.
1422
1423  Returns:
1424    `SparseTensor` of same dimensions as `ids`, except for the last dimension,
1425    which might be smaller. This contains only the entries equal to
1426    `selected_id`.
1427  """
1428  if isinstance(ids, ops.SparseTensor):
1429    return sparse_ops.sparse_retain(
1430        ids, math_ops.equal(ids.values, selected_id))
1431
1432  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
1433  # tf.equal and tf.reduce_any?
1434
1435  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
1436  ids_shape = array_ops.shape(ids)
1437  ids_last_dim = array_ops.size(ids_shape) - 1
1438  filled_selected_id_shape = math_ops.reduced_shape(
1439      ids_shape, array_ops.reshape(ids_last_dim, [1]))
1440
1441  # Intersect `ids` with the selected ID.
1442  filled_selected_id = array_ops.fill(
1443      filled_selected_id_shape, math_ops.to_int64(selected_id))
1444  return set_ops.set_intersection(filled_selected_id, ids)
1445
1446
1447def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
1448  """If class ID is specified, filter all other classes.
1449
1450  Args:
1451    labels: `int64` `Tensor` or `SparseTensor` with shape
1452      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1453      target classes for the associated prediction. Commonly, N=1 and `labels`
1454      has shape [batch_size, num_labels]. [D1, ... DN] must match
1455      `predictions_idx`.
1456    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
1457      where N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
1458    selected_id: Int id to select.
1459
1460  Returns:
1461    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
1462  """
1463  if selected_id is None:
1464    return labels, predictions_idx
1465  return (_select_class_id(labels, selected_id),
1466          _select_class_id(predictions_idx, selected_id))
1467
1468
1469def _streaming_sparse_true_positive_at_k(predictions_idx,
1470                                         labels,
1471                                         k,
1472                                         class_id=None,
1473                                         ignore_mask=None,
1474                                         name=None):
1475  """Calculates per step true positives for recall@k and precision@k.
1476
1477  If `class_id` is specified, calculate binary true positives for `class_id`
1478      only.
1479  If `class_id` is not specified, calculate metrics for `k` predicted vs
1480      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1481
1482  Args:
1483    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1484      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1485      match `labels`.
1486    labels: `int64` `Tensor` or `SparseTensor` with shape
1487      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1488      target classes for the associated prediction. Commonly, N=1 and `labels`
1489      has shape [batch_size, num_labels]. [D1, ... DN] must match
1490      `predictions_idx`.
1491    k: Integer, k for @k metric. This is only used for default op name.
1492    class_id: Class for which we want binary metrics.
1493    ignore_mask: An optional, binary tensor whose shape is broadcastable to the
1494      the first [D1, ... DN] dimensions of `predictions_idx` and `labels`.
1495    name: Name of new variable, and namespace for other dependant ops.
1496
1497  Returns:
1498    A tuple of `Variable` and update `Operation`.
1499  """
1500  default_name = 'true_positive_at_%d' % k
1501  if class_id is not None:
1502    default_name = '%s_class%d' % (default_name, class_id)
1503  with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope:
1504    labels, predictions_idx = _maybe_select_class_id(labels,
1505                                                     predictions_idx,
1506                                                     class_id)
1507    tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels))
1508    if ignore_mask is not None:
1509      tp = math_ops.select(ignore_mask, array_ops.zeros_like(tp), tp)
1510    batch_total_tp = math_ops.cast(
1511        math_ops.reduce_sum(tp), dtype=dtypes.float64)
1512
1513    var = contrib_variables.local_variable(
1514        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1515    return var, state_ops.assign_add(var, batch_total_tp, name='update')
1516
1517
1518def _streaming_sparse_false_positive_at_k(predictions_idx,
1519                                          labels,
1520                                          k,
1521                                          class_id=None,
1522                                          ignore_mask=None,
1523                                          name=None):
1524  """Calculates per step false positives for precision@k.
1525
1526  If `class_id` is specified, calculate binary true positives for `class_id`
1527      only.
1528  If `class_id` is not specified, calculate metrics for `k` predicted vs
1529      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1530
1531  Args:
1532    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1533      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1534      match `labels`.
1535    labels: `int64` `Tensor` or `SparseTensor` with shape
1536      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1537      target classes for the associated prediction. Commonly, N=1 and `labels`
1538      has shape [batch_size, num_labels]. [D1, ... DN] must match
1539      `predictions_idx`.
1540    k: Integer, k for @k metric. This is only used for default op name.
1541    class_id: Class for which we want binary metrics.
1542    ignore_mask: An optional, binary tensor whose shape is broadcastable to the
1543      the first [D1, ... DN] dimensions of `predictions_idx` and `labels`.
1544    name: Name of new variable, and namespace for other dependant ops.
1545
1546  Returns:
1547    A tuple of `Variable` and update `Operation`.
1548  """
1549  default_name = 'false_positive_at_%d' % k
1550  if class_id is not None:
1551    default_name = '%s_class%d' % (default_name, class_id)
1552  with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope:
1553    labels, predictions_idx = _maybe_select_class_id(labels,
1554                                                     predictions_idx,
1555                                                     class_id)
1556    fp = set_ops.set_size(set_ops.set_difference(predictions_idx,
1557                                                 labels,
1558                                                 aminusb=True))
1559    if ignore_mask is not None:
1560      fp = math_ops.select(ignore_mask, array_ops.zeros_like(fp), fp)
1561    batch_total_fp = math_ops.cast(
1562        math_ops.reduce_sum(fp), dtype=dtypes.float64)
1563
1564    var = contrib_variables.local_variable(
1565        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1566    return var, state_ops.assign_add(var, batch_total_fp, name='update')
1567
1568
1569def _streaming_sparse_false_negative_at_k(predictions_idx,
1570                                          labels,
1571                                          k,
1572                                          class_id=None,
1573                                          ignore_mask=None,
1574                                          name=None):
1575  """Calculates per step false negatives for recall@k.
1576
1577  If `class_id` is specified, calculate binary true positives for `class_id`
1578      only.
1579  If `class_id` is not specified, calculate metrics for `k` predicted vs
1580      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1581
1582  Args:
1583    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1584      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1585      match `labels`.
1586    labels: `int64` `Tensor` or `SparseTensor` with shape
1587      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1588      target classes for the associated prediction. Commonly, N=1 and `labels`
1589      has shape [batch_size, num_labels]. [D1, ... DN] must match
1590      `predictions_idx`.
1591    k: Integer, k for @k metric. This is only used for default op name.
1592    class_id: Class for which we want binary metrics.
1593    ignore_mask: An optional, binary tensor whose shape is broadcastable to the
1594      the first [D1, ... DN] dimensions of `predictions_idx` and `labels`.
1595    name: Name of new variable, and namespace for other dependant ops.
1596
1597  Returns:
1598    A tuple of `Variable` and update `Operation`.
1599  """
1600  default_name = 'false_negative_at_%d' % k
1601  if class_id is not None:
1602    default_name = '%s_class%d' % (default_name, class_id)
1603  with ops.name_scope(name, default_name, [predictions_idx, labels]) as scope:
1604    labels, predictions_idx = _maybe_select_class_id(labels,
1605                                                     predictions_idx,
1606                                                     class_id)
1607    fn = set_ops.set_size(set_ops.set_difference(predictions_idx,
1608                                                 labels,
1609                                                 aminusb=False))
1610    if ignore_mask is not None:
1611      fn = math_ops.select(ignore_mask, array_ops.zeros_like(fn), fn)
1612    batch_total_fn = math_ops.cast(
1613        math_ops.reduce_sum(fn), dtype=dtypes.float64)
1614
1615    var = contrib_variables.local_variable(
1616        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1617    return var, state_ops.assign_add(var, batch_total_fn, name='update')
1618
1619
1620def streaming_mean_absolute_error(predictions, labels, weights=None,
1621                                  metrics_collections=None,
1622                                  updates_collections=None,
1623                                  name=None):
1624  """Computes the mean absolute error between the labels and predictions.
1625
1626  The `streaming_mean_absolute_error` function creates two local variables,
1627  `total` and `count` that are used to compute the mean absolute error. This
1628  average is ultimately returned as `mean_absolute_error`: an idempotent
1629  operation that simply divides `total` by `count`. To facilitate the estimation
1630  of the mean absolute error over a stream of data, the function utilizes two
1631  operations. First, an `absolute_errors` operation computes the absolute value
1632  of the differences between `predictions` and `labels`. Second, an `update_op`
1633  operation whose behavior is dependent on the value of `weights`. If `weights`
1634  is None, then `update_op` increments `total` with the reduced sum of
1635  `absolute_errors` and increments `count` with the number of elements in
1636  `absolute_errors`. If `weights` is not `None`, then `update_op` increments
1637  `total` with the reduced sum of the product of `weights` and `absolute_errors`
1638  and increments `count` with the reduced sum of `weights`. In addition to
1639  performing the updates, `update_op` also returns the `mean_absolute_error`
1640  value.
1641
1642  Args:
1643    predictions: A `Tensor` of arbitrary shape.
1644    labels: A `Tensor` of the same shape as `predictions`.
1645    weights: An optional set of weights of the same shape as `predictions`. If
1646      `weights` is not None, the function computes a weighted mean.
1647    metrics_collections: An optional list of collections that
1648      `mean_absolute_error` should be added to.
1649    updates_collections: An optional list of collections that `update_op` should
1650      be added to.
1651    name: An optional variable_scope name.
1652
1653  Returns:
1654    mean_absolute_error: A tensor representing the current mean, the value of
1655      `total` divided by `count`.
1656    update_op: An operation that increments the `total` and `count` variables
1657      appropriately and whose value matches `mean_absolute_error`.
1658
1659  Raises:
1660    ValueError: If `weights` is not `None` and its shape doesn't match
1661      `predictions` or if either `metrics_collections` or `updates_collections`
1662      are not a list or tuple.
1663  """
1664  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1665      predictions, labels)
1666  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1667  absolute_errors = math_ops.abs(predictions - labels)
1668  return streaming_mean(absolute_errors, weights, metrics_collections,
1669                        updates_collections, name or 'mean_absolute_error')
1670
1671
1672def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
1673                                  metrics_collections=None,
1674                                  updates_collections=None,
1675                                  name=None):
1676  """Computes the mean relative error by normalizing with the given values.
1677
1678  The `streaming_mean_relative_error` function creates two local variables,
1679  `total` and `count` that are used to compute the mean relative absolute error.
1680  This average is ultimately returned as `mean_relative_error`: an idempotent
1681  operation that simply divides `total` by `count`. To facilitate the estimation
1682  of the mean relative error over a stream of data, the function utilizes two
1683  operations. First, a `relative_errors` operation divides the absolute value
1684  of the differences between `predictions` and `labels` by the `normalizer`.
1685  Second, an `update_op` operation whose behavior is dependent on the value of
1686  `weights`. If `weights` is None, then `update_op` increments `total` with the
1687  reduced sum of `relative_errors` and increments `count` with the number of
1688  elements in `relative_errors`. If `weights` is not `None`, then `update_op`
1689  increments `total` with the reduced sum of the product of `weights` and
1690  `relative_errors` and increments `count` with the reduced sum of `weights`. In
1691  addition to performing the updates, `update_op` also returns the
1692  `mean_relative_error` value.
1693
1694  Args:
1695    predictions: A `Tensor` of arbitrary shape.
1696    labels: A `Tensor` of the same shape as `predictions`.
1697    normalizer: A `Tensor` of the same shape as `predictions`.
1698    weights: An optional set of weights of the same shape as `predictions`. If
1699      `weights` is not None, the function computes a weighted mean.
1700    metrics_collections: An optional list of collections that
1701      `mean_relative_error` should be added to.
1702    updates_collections: An optional list of collections that `update_op` should
1703      be added to.
1704    name: An optional variable_scope name.
1705
1706  Returns:
1707    mean_relative_error: A tensor representing the current mean, the value of
1708      `total` divided by `count`.
1709    update_op: An operation that increments the `total` and `count` variables
1710      appropriately and whose value matches `mean_relative_error`.
1711
1712  Raises:
1713    ValueError: If `weights` is not `None` and its shape doesn't match
1714      `predictions` or if either `metrics_collections` or `updates_collections`
1715      are not a list or tuple.
1716  """
1717  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1718      predictions, labels)
1719  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1720
1721  predictions, normalizer = metric_ops_util.remove_squeezable_dimensions(
1722      predictions, normalizer)
1723  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1724  relative_errors = math_ops.select(
1725      math_ops.equal(normalizer, 0.0),
1726      array_ops.zeros_like(labels),
1727      math_ops.div(math_ops.abs(labels - predictions), normalizer))
1728  return streaming_mean(relative_errors, weights, metrics_collections,
1729                        updates_collections, name or 'mean_relative_error')
1730
1731
1732def streaming_mean_squared_error(predictions, labels, weights=None,
1733                                 metrics_collections=None,
1734                                 updates_collections=None,
1735                                 name=None):
1736  """Computes the mean squared error between the labels and predictions.
1737
1738  The `streaming_mean_squared_error` function creates two local variables,
1739  `total` and `count` that are used to compute the mean squared error.
1740  This average is ultimately returned as `mean_squared_error`: an idempotent
1741  operation that simply divides `total` by `count`. To facilitate the estimation
1742  of the mean squared error over a stream of data, the function utilizes two
1743  operations. First, a `squared_error` operation computes the element-wise
1744  square of the difference between `predictions` and `labels`. Second, an
1745  `update_op` operation whose behavior is dependent on the value of `weights`.
1746  If `weights` is None, then `update_op` increments `total` with the
1747  reduced sum of `squared_error` and increments `count` with the number of
1748  elements in `squared_error`. If `weights` is not `None`, then `update_op`
1749  increments `total` with the reduced sum of the product of `weights` and
1750  `squared_error` and increments `count` with the reduced sum of `weights`. In
1751  addition to performing the updates, `update_op` also returns the
1752  `mean_squared_error` value.
1753
1754  Args:
1755    predictions: A `Tensor` of arbitrary shape.
1756    labels: A `Tensor` of the same shape as `predictions`.
1757    weights: An optional set of weights of the same shape as `predictions`. If
1758      `weights` is not None, the function computes a weighted mean.
1759    metrics_collections: An optional list of collections that
1760      `mean_squared_error` should be added to.
1761    updates_collections: An optional list of collections that `update_op` should
1762      be added to.
1763    name: An optional variable_scope name.
1764
1765  Returns:
1766    mean_squared_error: A tensor representing the current mean, the value of
1767      `total` divided by `count`.
1768    update_op: An operation that increments the `total` and `count` variables
1769      appropriately and whose value matches `mean_squared_error`.
1770
1771  Raises:
1772    ValueError: If `weights` is not `None` and its shape doesn't match
1773      `predictions` or if either `metrics_collections` or `updates_collections`
1774      are not a list or tuple.
1775  """
1776  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1777      predictions, labels)
1778  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1779  squared_error = math_ops.square(labels - predictions)
1780  return streaming_mean(squared_error, weights, metrics_collections,
1781                        updates_collections, name or 'mean_squared_error')
1782
1783
1784def streaming_root_mean_squared_error(predictions, labels, weights=None,
1785                                      metrics_collections=None,
1786                                      updates_collections=None,
1787                                      name=None):
1788  """Computes the root mean squared error between the labels and predictions.
1789
1790  The `streaming_root_mean_squared_error` function creates two local variables,
1791  `total` and `count` that are used to compute the root mean squared error.
1792  This average is ultimately returned as `root_mean_squared_error`: an
1793  idempotent operation that takes the square root of the division of `total`
1794  by `count`. To facilitate the estimation of the root mean squared error over a
1795  stream of data, the function utilizes two operations. First, a `squared_error`
1796  operation computes the element-wise square of the difference between
1797  `predictions` and `labels`. Second, an `update_op` operation whose behavior is
1798  dependent on the value of `weights`. If `weights` is None, then `update_op`
1799  increments `total` with the reduced sum of `squared_error` and increments
1800  `count` with the number of elements in `squared_error`. If `weights` is not
1801  `None`, then `update_op` increments `total` with the reduced sum of the
1802  product of `weights` and `squared_error` and increments `count` with the
1803  reduced sum of `weights`. In addition to performing the updates, `update_op`
1804  also returns the `root_mean_squared_error` value.
1805
1806  Args:
1807    predictions: A `Tensor` of arbitrary shape.
1808    labels: A `Tensor` of the same shape as `predictions`.
1809    weights: An optional set of weights of the same shape as `predictions`. If
1810      `weights` is not None, the function computes a weighted mean.
1811    metrics_collections: An optional list of collections that
1812      `root_mean_squared_error` should be added to.
1813    updates_collections: An optional list of collections that `update_op` should
1814      be added to.
1815    name: An optional variable_scope name.
1816
1817  Returns:
1818    root_mean_squared_error: A tensor representing the current mean, the value
1819      of `total` divided by `count`.
1820    update_op: An operation that increments the `total` and `count` variables
1821      appropriately and whose value matches `root_mean_squared_error`.
1822
1823  Raises:
1824    ValueError: If `weights` is not `None` and its shape doesn't match
1825      `predictions` or if either `metrics_collections` or `updates_collections`
1826      are not a list or tuple.
1827  """
1828  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1829      predictions, labels)
1830  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1831  value_tensor, update_op = streaming_mean_squared_error(
1832      predictions, labels, weights, None, None,
1833      name or 'root_mean_squared_error')
1834
1835  root_mean_squared_error = math_ops.sqrt(value_tensor)
1836  with ops.control_dependencies([update_op]):
1837    update_op = math_ops.sqrt(update_op)
1838
1839  if metrics_collections:
1840    ops.add_to_collections(metrics_collections, root_mean_squared_error)
1841
1842  if updates_collections:
1843    ops.add_to_collections(updates_collections, update_op)
1844
1845  return root_mean_squared_error, update_op
1846
1847
1848# TODO(nsilberman): add a 'normalized' flag so that the user can request
1849# normalization if the inputs are not normalized.
1850def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
1851                                   metrics_collections=None,
1852                                   updates_collections=None,
1853                                   name=None):
1854  """Computes the cosine distance between the labels and predictions.
1855
1856  The `streaming_mean_cosine_distance` function creates two local variables,
1857  `total` and `count` that are used to compute the average cosine distance
1858  between `predictions` and `labels`. This average is ultimately returned as
1859  `mean_distance` which is an idempotent operation that simply divides `total`
1860  by `count. To facilitate the estimation of a mean over multiple batches
1861  of data, the function creates an `update_op` operation whose behavior is
1862  dependent on the value of `weights`. If `weights` is None, then `update_op`
1863  increments `total` with the reduced sum of `values and increments `count` with
1864  the number of elements in `values`. If `weights` is not `None`, then
1865  `update_op` increments `total` with the reduced sum of the product of `values`
1866  and `weights` and increments `count` with the reduced sum of weights.
1867
1868  Args:
1869    predictions: A tensor of the same size as labels.
1870    labels: A tensor of arbitrary size.
1871    dim: The dimension along which the cosine distance is computed.
1872    weights: An optional set of weights which indicates which predictions to
1873      ignore during metric computation. Its size matches that of labels except
1874      for the value of 'dim' which should be 1. For example if labels has
1875      dimensions [32, 100, 200, 3], then `weights` should have dimensions
1876      [32, 100, 200, 1].
1877    metrics_collections: An optional list of collections that the metric
1878      value variable should be added to.
1879    updates_collections: An optional list of collections that the metric update
1880      ops should be added to.
1881    name: An optional variable_scope name.
1882
1883  Returns:
1884    mean_distance: A tensor representing the current mean, the value of `total`
1885      divided by `count`.
1886    update_op: An operation that increments the `total` and `count` variables
1887      appropriately.
1888
1889  Raises:
1890    ValueError: If labels and predictions are of different sizes or if the
1891      ignore_mask is of the wrong size or if either `metrics_collections` or
1892      `updates_collections` are not a list or tuple.
1893  """
1894  predictions, labels = metric_ops_util.remove_squeezable_dimensions(
1895      predictions, labels)
1896  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1897  radial_diffs = math_ops.mul(predictions, labels)
1898  radial_diffs = math_ops.reduce_sum(radial_diffs,
1899                                     reduction_indices=[dim,],
1900                                     keep_dims=True)
1901  mean_distance, update_op = streaming_mean(radial_diffs, weights,
1902                                            None,
1903                                            None,
1904                                            name or 'mean_cosine_distance')
1905  mean_distance = math_ops.sub(1.0, mean_distance)
1906  update_op = math_ops.sub(1.0, update_op)
1907
1908  if metrics_collections:
1909    ops.add_to_collections(metrics_collections, mean_distance)
1910
1911  if updates_collections:
1912    ops.add_to_collections(updates_collections, update_op)
1913
1914  return mean_distance, update_op
1915
1916
1917def streaming_percentage_less(values, threshold, ignore_mask=None,
1918                              metrics_collections=None,
1919                              updates_collections=None,
1920                              name=None):
1921  """Computes the percentage of values less than the given threshold.
1922
1923  The `streaming_percentage_less` function creates two local variables,
1924  `total` and `count` that are used to compute the percentage of `values` that
1925  fall below `threshold`. This rate is ultimately returned as `percentage`
1926  which is an idempotent operation that simply divides `total` by `count.
1927  To facilitate the estimation of the percentage of values that fall under
1928  `threshold` over multiple batches of data, the function creates an
1929  `update_op` operation whose behavior is dependent on the value of
1930  `ignore_mask`. If `ignore_mask` is None, then `update_op`
1931  increments `total` with the number of elements of `values` that are less
1932  than `threshold` and `count` with the number of elements in `values`. If
1933  `ignore_mask` is not `None`, then `update_op` increments `total` with the
1934  number of elements of `values` that are less than `threshold` and whose
1935  corresponding entries in `ignore_mask` are False, and `count` is incremented
1936  with the number of elements of `ignore_mask` that are False.
1937
1938  Args:
1939    values: A numeric `Tensor` of arbitrary size.
1940    threshold: A scalar threshold.
1941    ignore_mask: An optional mask of the same shape as 'values' which indicates
1942      which elements to ignore during metric computation.
1943    metrics_collections: An optional list of collections that the metric
1944      value variable should be added to.
1945    updates_collections: An optional list of collections that the metric update
1946      ops should be added to.
1947    name: An optional variable_scope name.
1948
1949  Returns:
1950    percentage: A tensor representing the current mean, the value of `total`
1951      divided by `count`.
1952    update_op: An operation that increments the `total` and `count` variables
1953      appropriately.
1954
1955  Raises:
1956    ValueError: If `ignore_mask` is not None and its shape doesn't match `values
1957      or if either `metrics_collections` or `updates_collections` are supplied
1958      but are not a list or tuple.
1959  """
1960  is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
1961  return streaming_mean(is_below_threshold, _mask_to_weights(ignore_mask),
1962                        metrics_collections, updates_collections,
1963                        name or 'percentage_below_threshold')
1964
1965
1966def streaming_mean_iou(predictions,
1967                       labels,
1968                       num_classes,
1969                       ignore_mask=None,
1970                       metrics_collections=None,
1971                       updates_collections=None,
1972                       name=None):
1973  """Calculate per-step mean Intersection-Over-Union (mIOU).
1974
1975  Mean Intersection-Over-Union is a common evaluation metric for
1976  semantic image segmentation, which first computes the IOU for each
1977  semantic class and then computes the average over classes.
1978  IOU is defined as follows:
1979    IOU = true_positive / (true_positive + false_positive + false_negative).
1980  The predictions are accumulated in a confusion matrix, and mIOU is then
1981  calculated from it.
1982
1983  Args:
1984    predictions: A tensor of prediction results for semantic labels, whose
1985      shape is [batch size] and type `int32` or `int64`. The tensor will be
1986      flattened, if its rank > 1.
1987    labels: A tensor of ground truth labels with shape [batch size] and of
1988      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
1989    num_classes: The possible number of labels the prediction task can
1990      have. This value must be provided, since a confusion matrix of
1991      dimension = [num_classes, num_classes] will be allocated.
1992    ignore_mask: An optional, boolean tensor whose size matches `labels`. If an
1993      element of `ignore_mask` is True, the corresponding prediction and label
1994      pair is NOT used to compute the metrics. Otherwise, the pair is included.
1995    metrics_collections: An optional list of collections that `mean_iou`
1996      should be added to.
1997    updates_collections: An optional list of collections `update_op` should be
1998      added to.
1999    name: An optional variable_scope name.
2000
2001  Returns:
2002    mean_iou: A tensor representing the mean intersection-over-union.
2003    update_op: An operation that increments the confusion matrix.
2004
2005  Raises:
2006    ValueError: If the dimensions of `predictions` and `labels` don't match or
2007      if `ignore_mask` is not `None` and its shape doesn't match `labels`
2008      or if either `metrics_collections` or `updates_collections` are not a list
2009      or tuple.
2010  """
2011  with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]):
2012    # Check if shape is compatible.
2013    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2014    if ignore_mask is not None:
2015      labels.get_shape().assert_is_compatible_with(ignore_mask.get_shape())
2016
2017    # Local variable to accumulate the predictions in the confusion matrix.
2018    total_cm = _create_local('total_confusion_matrix',
2019                             shape=[num_classes, num_classes],
2020                             dtype=dtypes.int64)
2021
2022    # Cast the type to int64 required by confusion_matrix_ops.
2023    predictions = math_ops.to_int64(predictions)
2024    labels = math_ops.to_int64(labels)
2025    num_classes = math_ops.to_int64(num_classes)
2026
2027    # Flatten the input if its rank > 1.
2028    predictions_rank = predictions.get_shape().ndims
2029    if predictions_rank > 1:
2030      predictions = array_ops.reshape(predictions, [-1])
2031
2032    labels_rank = labels.get_shape().ndims
2033    if labels_rank > 1:
2034      labels = array_ops.reshape(labels, [-1])
2035
2036    if ignore_mask is not None:
2037      ignore_mask_rank = ignore_mask.get_shape().ndims
2038      if ignore_mask_rank > 1:
2039        ignore_mask = array_ops.reshape(ignore_mask, [-1])
2040
2041      check_ops.assert_type(ignore_mask, dtypes.bool)
2042      not_ignore_mask = math_ops.logical_not(ignore_mask)
2043      predictions = array_ops.boolean_mask(predictions, not_ignore_mask)
2044      labels = array_ops.boolean_mask(labels, not_ignore_mask)
2045
2046    # Accumulate the prediction to current confusion matrix.
2047    current_cm = confusion_matrix_ops.confusion_matrix(
2048        predictions, labels, num_classes, dtype=dtypes.int64)
2049    update_op = state_ops.assign_add(total_cm, current_cm)
2050
2051    def compute_mean_iou(name):
2052      """Compute the mean intersection-over-union via the confusion matrix."""
2053      sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
2054      sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
2055      cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
2056      denominator = sum_over_row + sum_over_col - cm_diag
2057
2058      # If the value of the denominator is 0, set it to 1 to avoid
2059      # zero division.
2060      denominator = math_ops.select(
2061          math_ops.greater(denominator, 0),
2062          denominator,
2063          array_ops.ones_like(denominator))
2064      iou = math_ops.div(cm_diag, denominator)
2065      return math_ops.reduce_mean(iou, name=name)
2066
2067    mean_iou = compute_mean_iou('mean_iou')
2068
2069    if metrics_collections:
2070      ops.add_to_collections(metrics_collections, mean_iou)
2071
2072    if updates_collections:
2073      ops.add_to_collections(updates_collections, update_op)
2074
2075    return mean_iou, update_op
2076
2077
2078def aggregate_metrics(*value_update_tuples):
2079  """Aggregates the metric value tensors and update ops into two lists.
2080
2081  Args:
2082    *value_update_tuples: a variable number of tuples, each of which contain the
2083      pair of (value_tensor, update_op) from a streaming metric.
2084
2085  Returns:
2086    a list of value tensors and a list of update ops.
2087
2088  Raises:
2089    ValueError: if `value_update_tuples` is empty.
2090  """
2091  if not value_update_tuples:
2092    raise ValueError('Expected at least one value_tensor/update_op pair')
2093  value_ops, update_ops = zip(*value_update_tuples)
2094  return list(value_ops), list(update_ops)
2095
2096
2097def aggregate_metric_map(names_to_tuples):
2098  """Aggregates the metric names to tuple dictionary.
2099
2100  This function is useful for pairing metric names with their associated value
2101  and update ops when the list of metrics is long. For example:
2102
2103    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
2104        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
2105            predictions, labels, weights),
2106        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
2107            predictions, labels, labels, weights),
2108        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
2109            predictions, labels, weights),
2110        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
2111            predictions, labels, weights),
2112    })
2113
2114  Args:
2115    names_to_tuples: a map of metric names to tuples, each of which contain the
2116      pair of (value_tensor, update_op) from a streaming metric.
2117
2118  Returns:
2119    A dictionary from metric names to value ops and a dictionary from metric
2120    names to update ops.
2121  """
2122  metric_names = names_to_tuples.keys()
2123  value_ops, update_ops = zip(*names_to_tuples.values())
2124  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
2125
2126
2127__all__ = make_all(__name__)
2128