metric_ops.py revision 2800e4d47be9078d2688057722f1385276aae8b7
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.framework import deprecated
26from tensorflow.contrib.framework import tensor_util
27from tensorflow.contrib.framework.python.ops import variables as contrib_variables
28from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
29from tensorflow.contrib.metrics.python.ops import set_ops
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn
38from tensorflow.python.ops import sparse_ops
39from tensorflow.python.ops import state_ops
40from tensorflow.python.ops import variable_scope
41from tensorflow.python.ops import variables
42
43
44def _safe_div(numerator, denominator, name):
45  """Divides two values, returning 0 if the denominator is <= 0.
46
47  Args:
48    numerator: A real `Tensor`.
49    denominator: A real `Tensor`, with dtype matching `numerator`.
50    name: Name for the returned op.
51
52  Returns:
53    0 if `denominator` <= 0, else `numerator` / `denominator`
54  """
55  return math_ops.select(
56      math_ops.greater(denominator, 0),
57      math_ops.truediv(numerator, denominator),
58      0,
59      name=name)
60
61
62def _safe_scalar_div(numerator, denominator, name):
63  """Divides two values, returning 0 if the denominator is 0.
64
65  Args:
66    numerator: A scalar `float64` `Tensor`.
67    denominator: A scalar `float64` `Tensor`.
68    name: Name for the returned op.
69
70  Returns:
71    0 if `denominator` == 0, else `numerator` / `denominator`
72  """
73  numerator.get_shape().with_rank_at_most(1)
74  denominator.get_shape().with_rank_at_most(1)
75  return control_flow_ops.cond(
76      math_ops.equal(
77          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
78      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
79      lambda: math_ops.div(numerator, denominator),
80      name=name)
81
82
83def _create_local(name, shape, collections=None, validate_shape=True,
84                  dtype=dtypes.float32):
85  """Creates a new local variable.
86
87  Args:
88    name: The name of the new or existing variable.
89    shape: Shape of the new or existing variable.
90    collections: A list of collection names to which the Variable will be added.
91    validate_shape: Whether to validate the shape of the variable.
92    dtype: Data type of the variables.
93
94  Returns:
95    The created variable.
96  """
97  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
98  collections = list(collections or [])
99  collections += [ops.GraphKeys.LOCAL_VARIABLES]
100  return variables.Variable(
101      initial_value=array_ops.zeros(shape, dtype=dtype),
102      name=name,
103      trainable=False,
104      collections=collections,
105      validate_shape=validate_shape)
106
107
108def _count_condition(values, weights=None, metrics_collections=None,
109                     updates_collections=None):
110  """Sums the weights of cases where the given values are True.
111
112  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
113
114  Args:
115    values: A `bool` `Tensor` of arbitrary size.
116    weights: An optional `Tensor` whose shape is broadcastable to `values`.
117    metrics_collections: An optional list of collections that the metric
118      value variable should be added to.
119    updates_collections: An optional list of collections that the metric update
120      ops should be added to.
121
122  Returns:
123    value_tensor: A tensor representing the current value of the metric.
124    update_op: An operation that accumulates the error from a batch of data.
125
126  Raises:
127    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
128      or if either `metrics_collections` or `updates_collections` are not a list
129      or tuple.
130  """
131  check_ops.assert_type(values, dtypes.bool)
132  count = _create_local('count', shape=[])
133
134  values = math_ops.to_float(values)
135  if weights is not None:
136    weights = math_ops.to_float(weights)
137    values = math_ops.mul(values, weights)
138
139  value_tensor = array_ops.identity(count)
140  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
141
142  if metrics_collections:
143    ops.add_to_collections(metrics_collections, value_tensor)
144
145  if updates_collections:
146    ops.add_to_collections(updates_collections, update_op)
147
148  return value_tensor, update_op
149
150
151def _streaming_true_positives(predictions, labels, weights=None,
152                              metrics_collections=None,
153                              updates_collections=None,
154                              name=None):
155  """Sum the weights of true_positives.
156
157  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
158
159  Args:
160    predictions: The predicted values, a `bool` `Tensor` of arbitrary
161      dimensions.
162    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
163      match `predictions`.
164    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
165    metrics_collections: An optional list of collections that the metric
166      value variable should be added to.
167    updates_collections: An optional list of collections that the metric update
168      ops should be added to.
169    name: An optional variable_scope name.
170
171  Returns:
172    value_tensor: A tensor representing the current value of the metric.
173    update_op: An operation that accumulates the error from a batch of data.
174
175  Raises:
176    ValueError: If `predictions` and `labels` have mismatched shapes, or if
177      `weights` is not `None` and its shape doesn't match `predictions`, or if
178      either `metrics_collections` or `updates_collections` are not a list or
179      tuple.
180  """
181  with variable_scope.variable_scope(
182      name, 'true_positives', [predictions, labels]):
183
184    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
185    is_true_positive = math_ops.logical_and(math_ops.equal(labels, 1),
186                                            math_ops.equal(predictions, 1))
187    return _count_condition(is_true_positive, weights, metrics_collections,
188                            updates_collections)
189
190
191def _streaming_false_positives(predictions, labels, weights=None,
192                               metrics_collections=None,
193                               updates_collections=None,
194                               name=None):
195  """Sum the weights of false positives.
196
197  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
198
199  Args:
200    predictions: The predicted values, a `bool` `Tensor` of arbitrary
201      dimensions.
202    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
203      match `predictions`.
204    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
205    metrics_collections: An optional list of collections that the metric
206      value variable should be added to.
207    updates_collections: An optional list of collections that the metric update
208      ops should be added to.
209    name: An optional variable_scope name.
210
211  Returns:
212    value_tensor: A tensor representing the current value of the metric.
213    update_op: An operation that accumulates the error from a batch of data.
214
215  Raises:
216    ValueError: If `predictions` and `labels` have mismatched shapes, or if
217      `weights` is not `None` and its shape doesn't match `predictions`, or if
218      either `metrics_collections` or `updates_collections` are not a list or
219      tuple.
220  """
221  with variable_scope.variable_scope(
222      name, 'false_positives', [predictions, labels]):
223
224    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
225    is_false_positive = math_ops.logical_and(math_ops.equal(labels, 0),
226                                             math_ops.equal(predictions, 1))
227    return _count_condition(is_false_positive, weights, metrics_collections,
228                            updates_collections)
229
230
231def _streaming_false_negatives(predictions, labels, weights=None,
232                               metrics_collections=None,
233                               updates_collections=None,
234                               name=None):
235  """Computes the total number of false positives.
236
237  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
238
239  Args:
240    predictions: The predicted values, a `bool` `Tensor` of arbitrary
241      dimensions.
242    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
243      match `predictions`.
244    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
245    metrics_collections: An optional list of collections that the metric
246      value variable should be added to.
247    updates_collections: An optional list of collections that the metric update
248      ops should be added to.
249    name: An optional variable_scope name.
250
251  Returns:
252    value_tensor: A tensor representing the current value of the metric.
253    update_op: An operation that accumulates the error from a batch of data.
254
255  Raises:
256    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
257      or if either `metrics_collections` or `updates_collections` are not a list
258      or tuple.
259  """
260  with variable_scope.variable_scope(
261      name, 'false_negatives', [predictions, labels]):
262
263    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
264    is_false_negative = math_ops.logical_and(math_ops.equal(labels, 1),
265                                             math_ops.equal(predictions, 0))
266    return _count_condition(is_false_negative, weights, metrics_collections,
267                            updates_collections)
268
269
270def _broadcast_weights(weights, values):
271  """Broadcast `weights` to the same shape as `values`.
272
273  This returns a version of `weights` following the same broadcast rules as
274  `mul(weights, values)`. When computing a weighted average, use this function
275  to broadcast `weights` before summing them; e.g.,
276  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
277
278  Args:
279    weights: `Tensor` whose shape is broadcastable to `values`.
280    values: `Tensor` of any shape.
281
282  Returns:
283    `weights` broadcast to `values` shape.
284  """
285  weights_shape = weights.get_shape()
286  values_shape = values.get_shape()
287  if (weights_shape.is_fully_defined() and
288      values_shape.is_fully_defined() and
289      weights_shape.is_compatible_with(values_shape)):
290    return weights
291  return math_ops.mul(
292      weights, array_ops.ones_like(values), name='broadcast_weights')
293
294
295def streaming_mean(values, weights=None, metrics_collections=None,
296                   updates_collections=None, name=None):
297  """Computes the (weighted) mean of the given values.
298
299  The `streaming_mean` function creates two local variables, `total` and `count`
300  that are used to compute the average of `values`. This average is ultimately
301  returned as `mean` which is an idempotent operation that simply divides
302  `total` by `count`.
303
304  For estimation of the metric  over a stream of data, the function creates an
305  `update_op` operation that updates these variables and returns the `mean`.
306  `update_op` increments `total` with the reduced sum of the product of `values`
307  and `weights`, and it increments `count` with the reduced sum of `weights`.
308
309  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
310
311  Args:
312    values: A `Tensor` of arbitrary dimensions.
313    weights: An optional `Tensor` whose shape is broadcastable to `values`.
314    metrics_collections: An optional list of collections that `mean`
315      should be added to.
316    updates_collections: An optional list of collections that `update_op`
317      should be added to.
318    name: An optional variable_scope name.
319
320  Returns:
321    mean: A tensor representing the current mean, the value of `total` divided
322      by `count`.
323    update_op: An operation that increments the `total` and `count` variables
324      appropriately and whose value matches `mean_value`.
325
326  Raises:
327    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
328      or if either `metrics_collections` or `updates_collections` are not a list
329      or tuple.
330  """
331  with variable_scope.variable_scope(name, 'mean', [values, weights]):
332    values = math_ops.to_float(values)
333
334    total = _create_local('total', shape=[])
335    count = _create_local('count', shape=[])
336
337    if weights is not None:
338      weights = math_ops.to_float(weights)
339      values = math_ops.mul(values, weights)
340      num_values = math_ops.reduce_sum(_broadcast_weights(weights, values))
341    else:
342      num_values = math_ops.to_float(array_ops.size(values))
343
344    total_compute_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
345    count_compute_op = state_ops.assign_add(count, num_values)
346
347    mean = _safe_div(total, count, 'value')
348    with ops.control_dependencies([total_compute_op, count_compute_op]):
349      update_op = _safe_div(total, count, 'update_op')
350
351    if metrics_collections:
352      ops.add_to_collections(metrics_collections, mean)
353
354    if updates_collections:
355      ops.add_to_collections(updates_collections, update_op)
356
357    return mean, update_op
358
359
360def streaming_mean_tensor(values, weights=None, metrics_collections=None,
361                          updates_collections=None, name=None):
362  """Computes the element-wise (weighted) mean of the given tensors.
363
364  In contrast to the `streaming_mean` function which returns a scalar with the
365  mean,  this function returns an average tensor with the same shape as the
366  input tensors.
367
368  The `streaming_mean_tensor` function creates two local variables,
369  `total_tensor` and `count_tensor` that are used to compute the average of
370  `values`. This average is ultimately returned as `mean` which is an idempotent
371  operation that simply divides `total` by `count`.
372
373  For estimation of the metric  over a stream of data, the function creates an
374  `update_op` operation that updates these variables and returns the `mean`.
375  `update_op` increments `total` with the reduced sum of the product of `values`
376  and `weights`, and it increments `count` with the reduced sum of `weights`.
377
378  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
379
380  Args:
381    values: A `Tensor` of arbitrary dimensions.
382    weights: An optional `Tensor` whose shape is broadcastable to `values`.
383    metrics_collections: An optional list of collections that `mean`
384      should be added to.
385    updates_collections: An optional list of collections that `update_op`
386      should be added to.
387    name: An optional variable_scope name.
388
389  Returns:
390    mean: A float tensor representing the current mean, the value of `total`
391      divided by `count`.
392    update_op: An operation that increments the `total` and `count` variables
393      appropriately and whose value matches `mean_value`.
394
395  Raises:
396    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
397      or if either `metrics_collections` or `updates_collections` are not a list
398      or tuple.
399  """
400  with variable_scope.variable_scope(name, 'mean', [values, weights]):
401    total = _create_local('total_tensor', shape=values.get_shape())
402    count = _create_local('count_tensor', shape=values.get_shape())
403
404    num_values = array_ops.ones_like(values)
405    if weights is not None:
406      weights = math_ops.to_float(weights)
407      values = math_ops.mul(values, weights)
408      num_values = math_ops.mul(num_values, weights)
409
410    total_compute_op = state_ops.assign_add(total, values)
411    count_compute_op = state_ops.assign_add(count, num_values)
412
413    def compute_mean(total, count, name):
414      non_zero_count = math_ops.maximum(count,
415                                        array_ops.ones_like(count),
416                                        name=name)
417      return math_ops.truediv(total, non_zero_count, name=name)
418
419    mean = compute_mean(total, count, 'value')
420    with ops.control_dependencies([total_compute_op, count_compute_op]):
421      update_op = compute_mean(total, count, 'update_op')
422
423    if metrics_collections:
424      ops.add_to_collections(metrics_collections, mean)
425
426    if updates_collections:
427      ops.add_to_collections(updates_collections, update_op)
428
429    return mean, update_op
430
431
432def streaming_accuracy(predictions, labels, weights=None,
433                       metrics_collections=None, updates_collections=None,
434                       name=None):
435  """Calculates how often `predictions` matches `labels`.
436
437  The `streaming_accuracy` function creates two local variables, `total` and
438  `count` that are used to compute the frequency with which `predictions`
439  matches `labels`. This frequency is ultimately returned as `accuracy`: an
440  idempotent operation that simply divides `total` by `count`.
441
442  For estimation of the metric  over a stream of data, the function creates an
443  `update_op` operation that updates these variables and returns the `accuracy`.
444  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
445  where the corresponding elements of `predictions` and `labels` match and 0.0
446  otherwise. Then `update_op` increments `total` with the reduced sum of the
447  product of `weights` and `is_correct`, and it increments `count` with the
448  reduced sum of `weights`.
449
450  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
451
452  Args:
453    predictions: The predicted values, a `Tensor` of any shape.
454    labels: The ground truth values, a `Tensor` whose shape matches
455      `predictions`.
456    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
457    metrics_collections: An optional list of collections that `accuracy` should
458      be added to.
459    updates_collections: An optional list of collections that `update_op` should
460      be added to.
461    name: An optional variable_scope name.
462
463  Returns:
464    accuracy: A tensor representing the accuracy, the value of `total` divided
465      by `count`.
466    update_op: An operation that increments the `total` and `count` variables
467      appropriately and whose value matches `accuracy`.
468
469  Raises:
470    ValueError: If `predictions` and `labels` have mismatched shapes, or if
471      `weights` is not `None` and its shape doesn't match `predictions`, or if
472      either `metrics_collections` or `updates_collections` are not a list or
473      tuple.
474  """
475  predictions, labels, weights = _remove_squeezable_dimensions(
476      predictions, labels, weights=weights)
477  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
478  if labels.dtype != predictions.dtype:
479    predictions = math_ops.cast(predictions, labels.dtype)
480  is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
481  return streaming_mean(is_correct, weights, metrics_collections,
482                        updates_collections, name or 'accuracy')
483
484
485def streaming_precision(predictions, labels, weights=None,
486                        metrics_collections=None, updates_collections=None,
487                        name=None):
488  """Computes the precision of the predictions with respect to the labels.
489
490  The `streaming_precision` function creates two local variables,
491  `true_positives` and `false_positives`, that are used to compute the
492  precision. This value is ultimately returned as `precision`, an idempotent
493  operation that simply divides `true_positives` by the sum of `true_positives`
494  and `false_positives`.
495
496  For estimation of the metric  over a stream of data, the function creates an
497  `update_op` operation that updates these variables and returns the
498  `precision`. `update_op` weights each prediction by the corresponding value in
499  `weights`.
500
501  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
502
503  Args:
504    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
505    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
506      match `predictions`.
507    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
508    metrics_collections: An optional list of collections that `precision` should
509      be added to.
510    updates_collections: An optional list of collections that `update_op` should
511      be added to.
512    name: An optional variable_scope name.
513
514  Returns:
515    precision: Scalar float `Tensor` with the value of `true_positives`
516      divided by the sum of `true_positives` and `false_positives`.
517    update_op: `Operation` that increments `true_positives` and
518      `false_positives` variables appropriately and whose value matches
519      `precision`.
520
521  Raises:
522    ValueError: If `predictions` and `labels` have mismatched shapes, or if
523      `weights` is not `None` and its shape doesn't match `predictions`, or if
524      either `metrics_collections` or `updates_collections` are not a list or
525      tuple.
526  """
527  with variable_scope.variable_scope(
528      name, 'precision', [predictions, labels]):
529
530    predictions, labels, weights = _remove_squeezable_dimensions(
531        predictions, labels, weights)
532    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
533
534    true_positives, true_positives_update_op = _streaming_true_positives(
535        predictions, labels, weights, metrics_collections=None,
536        updates_collections=None, name=None)
537    false_positives, false_positives_update_op = _streaming_false_positives(
538        predictions, labels, weights, metrics_collections=None,
539        updates_collections=None, name=None)
540
541    def compute_precision(name):
542      return math_ops.select(
543          math_ops.greater(true_positives + false_positives, 0),
544          math_ops.div(true_positives, true_positives + false_positives),
545          0,
546          name)
547
548    precision = compute_precision('value')
549    with ops.control_dependencies([true_positives_update_op,
550                                   false_positives_update_op]):
551      update_op = compute_precision('update_op')
552
553    if metrics_collections:
554      ops.add_to_collections(metrics_collections, precision)
555
556    if updates_collections:
557      ops.add_to_collections(updates_collections, update_op)
558
559    return precision, update_op
560
561
562def streaming_recall(predictions, labels, weights=None,
563                     metrics_collections=None, updates_collections=None,
564                     name=None):
565  """Computes the recall of the predictions with respect to the labels.
566
567  The `streaming_recall` function creates two local variables, `true_positives`
568  and `false_negatives`, that are used to compute the recall. This value is
569  ultimately returned as `recall`, an idempotent operation that simply divides
570  `true_positives` by the sum of `true_positives`  and `false_negatives`.
571
572  For estimation of the metric  over a stream of data, the function creates an
573  `update_op` that updates these variables and returns the `recall`. `update_op`
574  weights each prediction by the corresponding value in `weights`.
575
576  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
577
578  Args:
579    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
580    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
581      match `predictions`.
582    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
583    metrics_collections: An optional list of collections that `recall` should
584      be added to.
585    updates_collections: An optional list of collections that `update_op` should
586      be added to.
587    name: An optional variable_scope name.
588
589  Returns:
590    recall: Scalar float `Tensor` with the value of `true_positives` divided
591      by the sum of `true_positives` and `false_negatives`.
592    update_op: `Operation` that increments `true_positives` and
593      `false_negatives` variables appropriately and whose value matches
594      `recall`.
595
596  Raises:
597    ValueError: If `predictions` and `labels` have mismatched shapes, or if
598      `weights` is not `None` and its shape doesn't match `predictions`, or if
599      either `metrics_collections` or `updates_collections` are not a list or
600      tuple.
601  """
602  with variable_scope.variable_scope(name, 'recall', [predictions, labels]):
603    predictions, labels, weights = _remove_squeezable_dimensions(
604        predictions, labels, weights)
605    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
606
607    true_positives, true_positives_update_op = _streaming_true_positives(
608        predictions, labels, weights, metrics_collections=None,
609        updates_collections=None, name=None)
610    false_negatives, false_negatives_update_op = _streaming_false_negatives(
611        predictions, labels, weights, metrics_collections=None,
612        updates_collections=None, name=None)
613
614    def compute_recall(true_positives, false_negatives, name):
615      return math_ops.select(
616          math_ops.greater(true_positives + false_negatives, 0),
617          math_ops.div(true_positives, true_positives + false_negatives),
618          0,
619          name)
620
621    recall = compute_recall(true_positives, false_negatives, 'value')
622    with ops.control_dependencies([true_positives_update_op,
623                                   false_negatives_update_op]):
624      update_op = compute_recall(true_positives, false_negatives, 'update_op')
625
626    if metrics_collections:
627      ops.add_to_collections(metrics_collections, recall)
628
629    if updates_collections:
630      ops.add_to_collections(updates_collections, update_op)
631
632    return recall, update_op
633
634
635def _tp_fn_tn_fp(predictions, labels, thresholds, weights=None):
636  """Computes true_positives, false_negatives, true_negatives, false_positives.
637
638  The `_tp_fn_tn_fp` function creates four local variables, `true_positives`,
639  `true_negatives`, `false_positives` and `false_negatives`.
640  `true_positive[i]` is defined as the total weight of values in `predictions`
641  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
642  `false_negatives[i]` is defined as the total weight of values in `predictions`
643  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
644  `true_negatives[i]` is defined as the total weight of values in `predictions`
645  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
646  `false_positives[i]` is defined as the total weight of values in `predictions`
647  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
648
649  For estimation of these metrics over a stream of data, for each metric the
650  function respectively creates an `update_op` operation that updates the
651  variable and returns its value.
652
653  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
654
655  Args:
656    predictions: A floating point `Tensor` of arbitrary shape and whose values
657      are in the range `[0, 1]`.
658    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
659      to `bool`.
660    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
661    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
662
663  Returns:
664    true_positive: A variable of shape [len(thresholds)].
665    false_negative: A variable of shape [len(thresholds)].
666    true_negatives: A variable of shape [len(thresholds)].
667    false_positives: A variable of shape [len(thresholds)].
668    true_positives_update_op: An operation that increments the `true_positives`.
669    false_negative_update_op: An operation that increments the `false_negative`.
670    true_negatives_update_op: An operation that increments the `true_negatives`.
671    false_positives_update_op: An operation that increments the
672      `false_positives`.
673
674  Raises:
675    ValueError: If `predictions` and `labels` have mismatched shapes, or if
676      `weights` is not `None` and its shape doesn't match `predictions`.
677  """
678  predictions, labels, weights = _remove_squeezable_dimensions(
679      predictions, labels, weights)
680  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
681
682  num_thresholds = len(thresholds)
683
684  # Reshape predictions and labels.
685  predictions_2d = array_ops.reshape(predictions, [-1, 1])
686  labels_2d = array_ops.reshape(
687      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
688
689  # Use static shape if known.
690  num_predictions = predictions_2d.get_shape().as_list()[0]
691
692  # Otherwise use dynamic shape.
693  if num_predictions is None:
694    num_predictions = array_ops.shape(predictions_2d)[0]
695  thresh_tiled = array_ops.tile(
696      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
697      array_ops.pack([1, num_predictions]))
698
699  # Tile the predictions after thresholding them across different thresholds.
700  pred_is_pos = math_ops.greater(
701      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
702      thresh_tiled)
703  pred_is_neg = math_ops.logical_not(pred_is_pos)
704
705  # Tile labels by number of thresholds
706  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
707  label_is_neg = math_ops.logical_not(label_is_pos)
708
709  true_positives = _create_local('true_positives', shape=[num_thresholds])
710  false_negatives = _create_local('false_negatives', shape=[num_thresholds])
711  true_negatives = _create_local('true_negatives', shape=[num_thresholds])
712  false_positives = _create_local('false_positives', shape=[num_thresholds])
713
714  is_true_positive = math_ops.to_float(
715      math_ops.logical_and(label_is_pos, pred_is_pos))
716  is_false_negative = math_ops.to_float(
717      math_ops.logical_and(label_is_pos, pred_is_neg))
718  is_false_positive = math_ops.to_float(
719      math_ops.logical_and(label_is_neg, pred_is_pos))
720  is_true_negative = math_ops.to_float(
721      math_ops.logical_and(label_is_neg, pred_is_neg))
722
723  if weights is not None:
724    weights = math_ops.to_float(weights)
725    weights_tiled = array_ops.tile(array_ops.reshape(_broadcast_weights(
726        weights, predictions), [1, -1]), [num_thresholds, 1])
727
728    thresh_tiled.get_shape().assert_is_compatible_with(
729        weights_tiled.get_shape())
730    is_true_positive *= weights_tiled
731    is_false_negative *= weights_tiled
732    is_false_positive *= weights_tiled
733    is_true_negative *= weights_tiled
734
735  true_positives_update_op = state_ops.assign_add(
736      true_positives, math_ops.reduce_sum(is_true_positive, 1))
737  false_negatives_update_op = state_ops.assign_add(
738      false_negatives, math_ops.reduce_sum(is_false_negative, 1))
739  true_negatives_update_op = state_ops.assign_add(
740      true_negatives, math_ops.reduce_sum(is_true_negative, 1))
741  false_positives_update_op = state_ops.assign_add(
742      false_positives, math_ops.reduce_sum(is_false_positive, 1))
743
744  return (true_positives, false_negatives, true_negatives, false_positives,
745          true_positives_update_op, false_negatives_update_op,
746          true_negatives_update_op, false_positives_update_op)
747
748
749def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
750                  metrics_collections=None, updates_collections=None,
751                  curve='ROC', name=None):
752  """Computes the approximate AUC via a Riemann sum.
753
754  The `streaming_auc` function creates four local variables, `true_positives`,
755  `true_negatives`, `false_positives` and `false_negatives` that are used to
756  compute the AUC. To discretize the AUC curve, a linearly spaced set of
757  thresholds is used to compute pairs of recall and precision values. The area
758  under the ROC-curve is therefore computed using the height of the recall
759  values by the false positive rate, while the area under the PR-curve is the
760  computed using the height of the precision values by the recall.
761
762  This value is ultimately returned as `auc`, an idempotent operation that
763  computes the area under a discretized curve of precision versus recall values
764  (computed using the aforementioned variables). The `num_thresholds` variable
765  controls the degree of discretization with larger numbers of thresholds more
766  closely approximating the true AUC. The quality of the approximation may vary
767  dramatically depending on `num_thresholds`.
768
769  For best results, `predictions` should be distributed approximately uniformly
770  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
771  approximation may be poor if this is not the case.
772
773  For estimation of the metric over a stream of data, the function creates an
774  `update_op` operation that updates these variables and returns the `auc`.
775
776  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
777
778  Args:
779    predictions: A floating point `Tensor` of arbitrary shape and whose values
780      are in the range `[0, 1]`.
781    labels: A `bool` `Tensor` whose shape matches `predictions`.
782    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
783    num_thresholds: The number of thresholds to use when discretizing the roc
784      curve.
785    metrics_collections: An optional list of collections that `auc` should be
786      added to.
787    updates_collections: An optional list of collections that `update_op` should
788      be added to.
789    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
790    'PR' for the Precision-Recall-curve.
791    name: An optional variable_scope name.
792
793  Returns:
794    auc: A scalar tensor representing the current area-under-curve.
795    update_op: An operation that increments the `true_positives`,
796      `true_negatives`, `false_positives` and `false_negatives` variables
797      appropriately and whose value matches `auc`.
798
799  Raises:
800    ValueError: If `predictions` and `labels` have mismatched shapes, or if
801      `weights` is not `None` and its shape doesn't match `predictions`, or if
802      either `metrics_collections` or `updates_collections` are not a list or
803      tuple.
804  """
805  with variable_scope.variable_scope(name, 'auc', [predictions, labels]):
806    if curve != 'ROC' and  curve != 'PR':
807      raise ValueError('curve must be either ROC or PR, %s unknown' %
808                       (curve))
809    kepsilon = 1e-7  # to account for floating point imprecisions
810    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
811                  for i in range(num_thresholds-2)]
812    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
813
814    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
815     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
816
817    # Add epsilons to avoid dividing by 0.
818    epsilon = 1.0e-6
819    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
820
821    def compute_auc(tp, fn, tn, fp, name):
822      """Computes the roc-auc or pr-auc based on confusion counts."""
823      recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
824      if curve == 'ROC':
825        fp_rate = math_ops.div(fp, fp + tn + epsilon)
826        x = fp_rate
827        y = recall
828      else:  # curve == 'PR'.
829        precision = math_ops.div(tp + epsilon, tp + fp + epsilon)
830        x = recall
831        y = precision
832      return math_ops.reduce_sum(math_ops.mul(
833          x[:num_thresholds - 1] - x[1:],
834          (y[:num_thresholds - 1] + y[1:]) / 2.), name=name)
835
836    # sum up the areas of all the trapeziums
837    auc = compute_auc(tp, fn, tn, fp, 'value')
838    update_op = compute_auc(
839        tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op')
840
841    if metrics_collections:
842      ops.add_to_collections(metrics_collections, auc)
843
844    if updates_collections:
845      ops.add_to_collections(updates_collections, update_op)
846
847    return auc, update_op
848
849
850def streaming_specificity_at_sensitivity(
851    predictions, labels, sensitivity, weights=None, num_thresholds=200,
852    metrics_collections=None, updates_collections=None, name=None):
853  """Computes the specificity at a given sensitivity.
854
855  The `streaming_specificity_at_sensitivity` function creates four local
856  variables, `true_positives`, `true_negatives`, `false_positives` and
857  `false_negatives` that are used to compute the specificity at the given
858  sensitivity value. The threshold for the given sensitivity value is computed
859  and used to evaluate the corresponding specificity.
860
861  For estimation of the metric over a stream of data, the function creates an
862  `update_op` operation that updates these variables and returns the
863  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
864  `false_positives` and `false_negatives` counts with the weight of each case
865  found in the `predictions` and `labels`.
866
867  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
868
869  For additional information about specificity and sensitivity, see the
870  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
871
872  Args:
873    predictions: A floating point `Tensor` of arbitrary shape and whose values
874      are in the range `[0, 1]`.
875    labels: A `bool` `Tensor` whose shape matches `predictions`.
876    sensitivity: A scalar value in range `[0, 1]`.
877    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
878    num_thresholds: The number of thresholds to use for matching the given
879      sensitivity.
880    metrics_collections: An optional list of collections that `specificity`
881      should be added to.
882    updates_collections: An optional list of collections that `update_op` should
883      be added to.
884    name: An optional variable_scope name.
885
886  Returns:
887    specificity: A scalar tensor representing the specificity at the given
888      `specificity` value.
889    update_op: An operation that increments the `true_positives`,
890      `true_negatives`, `false_positives` and `false_negatives` variables
891      appropriately and whose value matches `specificity`.
892
893  Raises:
894    ValueError: If `predictions` and `labels` have mismatched shapes, if
895      `weights` is not `None` and its shape doesn't match `predictions`, or if
896      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
897      or `updates_collections` are not a list or tuple.
898  """
899  if sensitivity < 0 or sensitivity > 1:
900    raise ValueError('`sensitivity` must be in the range [0, 1].')
901
902  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
903                                     [predictions, labels]):
904    kepsilon = 1e-7  # to account for floating point imprecisions
905    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
906                  for i in range(num_thresholds-2)]
907    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
908
909    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
910     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
911
912    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
913
914    def compute_specificity_at_sensitivity(name):
915      """Computes the specificity at the given sensitivity.
916
917      Args:
918        name: The name of the operation.
919
920      Returns:
921        The specificity using the aggregated values.
922      """
923      sensitivities = math_ops.div(tp, tp + fn + kepsilon)
924
925      # We'll need to use this trick until tf.argmax allows us to specify
926      # whether we should use the first or last index in case of ties.
927      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
928      indices_at_minval = math_ops.equal(
929          math_ops.abs(sensitivities - sensitivity), min_val)
930      indices_at_minval = math_ops.to_int64(indices_at_minval)
931      indices_at_minval = math_ops.cumsum(indices_at_minval)
932      tf_index = math_ops.argmax(indices_at_minval, 0)
933      tf_index = math_ops.cast(tf_index, dtypes.int32)
934
935      # Now, we have the implicit threshold, so compute the specificity:
936      return math_ops.div(tn[tf_index],
937                          tn[tf_index] + fp[tf_index] + kepsilon,
938                          name)
939
940    specificity = compute_specificity_at_sensitivity('value')
941    with ops.control_dependencies(
942        [tp_update_op, fn_update_op, tn_update_op, fp_update_op]):
943      update_op = compute_specificity_at_sensitivity('update_op')
944
945    if metrics_collections:
946      ops.add_to_collections(metrics_collections, specificity)
947
948    if updates_collections:
949      ops.add_to_collections(updates_collections, update_op)
950
951    return specificity, update_op
952
953
954def streaming_sensitivity_at_specificity(
955    predictions, labels, specificity, weights=None, num_thresholds=200,
956    metrics_collections=None, updates_collections=None, name=None):
957  """Computes the specificity at a given sensitivity.
958
959  The `streaming_sensitivity_at_specificity` function creates four local
960  variables, `true_positives`, `true_negatives`, `false_positives` and
961  `false_negatives` that are used to compute the sensitivity at the given
962  specificity value. The threshold for the given specificity value is computed
963  and used to evaluate the corresponding sensitivity.
964
965  For estimation of the metric over a stream of data, the function creates an
966  `update_op` operation that updates these variables and returns the
967  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
968  `false_positives` and `false_negatives` counts with the weight of each case
969  found in the `predictions` and `labels`.
970
971  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
972
973  For additional information about specificity and sensitivity, see the
974  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
975
976  Args:
977    predictions: A floating point `Tensor` of arbitrary shape and whose values
978      are in the range `[0, 1]`.
979    labels: A `bool` `Tensor` whose shape matches `predictions`.
980    specificity: A scalar value in range `[0, 1]`.
981    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
982    num_thresholds: The number of thresholds to use for matching the given
983      specificity.
984    metrics_collections: An optional list of collections that `sensitivity`
985      should be added to.
986    updates_collections: An optional list of collections that `update_op` should
987      be added to.
988    name: An optional variable_scope name.
989
990  Returns:
991    sensitivity: A scalar tensor representing the sensitivity at the given
992      `specificity` value.
993    update_op: An operation that increments the `true_positives`,
994      `true_negatives`, `false_positives` and `false_negatives` variables
995      appropriately and whose value matches `sensitivity`.
996
997  Raises:
998    ValueError: If `predictions` and `labels` have mismatched shapes, if
999      `weights` is not `None` and its shape doesn't match `predictions`, or if
1000      `specificity` is not between 0 and 1, or if either `metrics_collections`
1001      or `updates_collections` are not a list or tuple.
1002  """
1003  if specificity < 0 or specificity > 1:
1004    raise ValueError('`specificity` must be in the range [0, 1].')
1005
1006  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
1007                                     [predictions, labels]):
1008    kepsilon = 1e-7  # to account for floating point imprecisions
1009    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1010                  for i in range(num_thresholds-2)]
1011    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
1012
1013    (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
1014     fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, weights)
1015    assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
1016
1017    def compute_sensitivity_at_specificity(name):
1018      specificities = math_ops.div(tn, tn + fp + kepsilon)
1019      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
1020      tf_index = math_ops.cast(tf_index, dtypes.int32)
1021
1022      # Now, we have the implicit threshold, so compute the sensitivity:
1023      return math_ops.div(tp[tf_index],
1024                          tp[tf_index] + fn[tf_index] + kepsilon,
1025                          name)
1026
1027    sensitivity = compute_sensitivity_at_specificity('value')
1028    with ops.control_dependencies(
1029        [tp_update_op, fn_update_op, tn_update_op, fp_update_op]):
1030      update_op = compute_sensitivity_at_specificity('update_op')
1031
1032    if metrics_collections:
1033      ops.add_to_collections(metrics_collections, sensitivity)
1034
1035    if updates_collections:
1036      ops.add_to_collections(updates_collections, update_op)
1037
1038    return sensitivity, update_op
1039
1040
1041def streaming_precision_at_thresholds(predictions, labels, thresholds,
1042                                      weights=None,
1043                                      metrics_collections=None,
1044                                      updates_collections=None, name=None):
1045  """Computes precision values for different `thresholds` on `predictions`.
1046
1047  The `streaming_precision_at_thresholds` function creates four local variables,
1048  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1049  for various values of thresholds. `precision[i]` is defined as the total
1050  weight of values in `predictions` above `thresholds[i]` whose corresponding
1051  entry in `labels` is `True`, divided by the total weight of values in
1052  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
1053  false_positives[i])`).
1054
1055  For estimation of the metric over a stream of data, the function creates an
1056  `update_op` operation that updates these variables and returns the
1057  `precision`.
1058
1059  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1060
1061  Args:
1062    predictions: A floating point `Tensor` of arbitrary shape and whose values
1063      are in the range `[0, 1]`.
1064    labels: A `bool` `Tensor` whose shape matches `predictions`.
1065    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1066    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
1067    metrics_collections: An optional list of collections that `auc` should be
1068      added to.
1069    updates_collections: An optional list of collections that `update_op` should
1070      be added to.
1071    name: An optional variable_scope name.
1072
1073  Returns:
1074    precision: A float tensor of shape [len(thresholds)].
1075    update_op: An operation that increments the `true_positives`,
1076      `true_negatives`, `false_positives` and `false_negatives` variables that
1077      are used in the computation of `precision`.
1078
1079  Raises:
1080    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1081      `weights` is not `None` and its shape doesn't match `predictions`, or if
1082      either `metrics_collections` or `updates_collections` are not a list or
1083      tuple.
1084  """
1085  with variable_scope.variable_scope(name, 'precision_at_thresholds',
1086                                     [predictions, labels]):
1087
1088    # TODO(nsilberman): Replace with only tp and fp, this results in unnecessary
1089    # variable creation. b/30842882
1090    (true_positives, _, _, false_positives, true_positives_compute_op, _, _,
1091     false_positives_compute_op,) = _tp_fn_tn_fp(
1092         predictions, labels, thresholds, weights)
1093
1094    # avoid division by zero
1095    epsilon = 1e-7
1096    def compute_precision(name):
1097      precision = math_ops.div(true_positives,
1098                               epsilon + true_positives + false_positives,
1099                               name='precision_' + name)
1100      return precision
1101
1102    precision = compute_precision('value')
1103    with ops.control_dependencies([true_positives_compute_op,
1104                                   false_positives_compute_op]):
1105      update_op = compute_precision('update_op')
1106
1107    if metrics_collections:
1108      ops.add_to_collections(metrics_collections, precision)
1109
1110    if updates_collections:
1111      ops.add_to_collections(updates_collections, update_op)
1112
1113    return precision, update_op
1114
1115
1116def streaming_recall_at_thresholds(predictions, labels, thresholds,
1117                                   weights=None, metrics_collections=None,
1118                                   updates_collections=None, name=None):
1119  """Computes various recall values for different `thresholds` on `predictions`.
1120
1121  The `streaming_recall_at_thresholds` function creates four local variables,
1122  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1123  for various values of thresholds. `recall[i]` is defined as the total weight
1124  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1125  `labels` is `True`, divided by the total weight of `True` values in `labels`
1126  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
1127
1128  For estimation of the metric over a stream of data, the function creates an
1129  `update_op` operation that updates these variables and returns the `recall`.
1130
1131  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1132
1133  Args:
1134    predictions: A floating point `Tensor` of arbitrary shape and whose values
1135      are in the range `[0, 1]`.
1136    labels: A `bool` `Tensor` whose shape matches `predictions`.
1137    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1138    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
1139    metrics_collections: An optional list of collections that `recall` should be
1140      added to.
1141    updates_collections: An optional list of collections that `update_op` should
1142      be added to.
1143    name: An optional variable_scope name.
1144
1145  Returns:
1146    recall: A float tensor of shape [len(thresholds)].
1147    update_op: An operation that increments the `true_positives`,
1148      `true_negatives`, `false_positives` and `false_negatives` variables that
1149      are used in the computation of `recall`.
1150
1151  Raises:
1152    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1153      `weights` is not `None` and its shape doesn't match `predictions`, or if
1154      either `metrics_collections` or `updates_collections` are not a list or
1155      tuple.
1156  """
1157  with variable_scope.variable_scope(name, 'recall_at_thresholds',
1158                                     [predictions, labels]):
1159    (true_positives, false_negatives, _, _, true_positives_compute_op,
1160     false_negatives_compute_op, _, _,) = _tp_fn_tn_fp(
1161         predictions, labels, thresholds, weights)
1162
1163    # avoid division by zero
1164    epsilon = 1e-7
1165    def compute_recall(name):
1166      recall = math_ops.div(true_positives,
1167                            epsilon + true_positives + false_negatives,
1168                            name='recall_' + name)
1169      return recall
1170
1171    recall = compute_recall('value')
1172    with ops.control_dependencies([true_positives_compute_op,
1173                                   false_negatives_compute_op]):
1174      update_op = compute_recall('update_op')
1175
1176    if metrics_collections:
1177      ops.add_to_collections(metrics_collections, recall)
1178
1179    if updates_collections:
1180      ops.add_to_collections(updates_collections, update_op)
1181
1182    return recall, update_op
1183
1184
1185def _at_k_name(name, k=None, class_id=None):
1186  if k is not None:
1187    name = '%s_at_%d' % (name, k)
1188  else:
1189    name = '%s_at_k' % (name)
1190  if class_id is not None:
1191    name = '%s_class%d' % (name, class_id)
1192  return name
1193
1194
1195@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
1196            'and reshape labels from [batch_size] to [batch_size, 1].')
1197def streaming_recall_at_k(predictions, labels, k, weights=None,
1198                          metrics_collections=None, updates_collections=None,
1199                          name=None):
1200  """Computes the recall@k of the predictions with respect to dense labels.
1201
1202  The `streaming_recall_at_k` function creates two local variables, `total` and
1203  `count`, that are used to compute the recall@k frequency. This frequency is
1204  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1205  divides `total` by `count`.
1206
1207  For estimation of the metric over a stream of data, the function creates an
1208  `update_op` operation that updates these variables and returns the
1209  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
1210  shape [batch_size] whose elements indicate whether or not the corresponding
1211  label is in the top `k` `predictions`. Then `update_op` increments `total`
1212  with the reduced sum of `weights` where `in_top_k` is `True`, and it
1213  increments `count` with the reduced sum of `weights`.
1214
1215  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1216
1217  Args:
1218    predictions: A floating point tensor of dimension [batch_size, num_classes]
1219    labels: A tensor of dimension [batch_size] whose type is in `int32`,
1220      `int64`.
1221    k: The number of top elements to look at for computing recall.
1222    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
1223    metrics_collections: An optional list of collections that `recall_at_k`
1224      should be added to.
1225    updates_collections: An optional list of collections `update_op` should be
1226      added to.
1227    name: An optional variable_scope name.
1228
1229  Returns:
1230    recall_at_k: A tensor representing the recall@k, the fraction of labels
1231      which fall into the top `k` predictions.
1232    update_op: An operation that increments the `total` and `count` variables
1233      appropriately and whose value matches `recall_at_k`.
1234
1235  Raises:
1236    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1237      `weights` is not `None` and its shape doesn't match `predictions`, or if
1238      either `metrics_collections` or `updates_collections` are not a list or
1239      tuple.
1240  """
1241  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1242  return streaming_mean(in_top_k,
1243                        weights,
1244                        metrics_collections,
1245                        updates_collections,
1246                        name or _at_k_name('recall', k))
1247
1248
1249# TODO(ptucker): Validate range of values in labels?
1250def streaming_sparse_recall_at_k(predictions,
1251                                 labels,
1252                                 k,
1253                                 class_id=None,
1254                                 weights=None,
1255                                 metrics_collections=None,
1256                                 updates_collections=None,
1257                                 name=None):
1258  """Computes recall@k of the predictions with respect to sparse labels.
1259
1260  If `class_id` is specified, we calculate recall by considering only the
1261      entries in the batch for which `class_id` is in the label, and computing
1262      the fraction of them for which `class_id` is in the top-k `predictions`.
1263  If `class_id` is not specified, we'll calculate recall as how often on
1264      average a class among the labels of a batch entry is in the top-k
1265      `predictions`.
1266
1267  `streaming_sparse_recall_at_k` creates two local variables,
1268  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1269  the recall_at_k frequency. This frequency is ultimately returned as
1270  `recall_at_<k>`: an idempotent operation that simply divides
1271  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1272  `false_negative_at_<k>`).
1273
1274  For estimation of the metric over a stream of data, the function creates an
1275  `update_op` operation that updates these variables and returns the
1276  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1277  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1278  `labels` calculate the true positives and false negatives weighted by
1279  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1280  `false_negative_at_<k>` using these values.
1281
1282  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1283
1284  Args:
1285    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1286      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1287      The final dimension contains the logit values for each class. [D1, ... DN]
1288      must match `labels`.
1289    labels: `int64` `Tensor` or `SparseTensor` with shape
1290      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1291      target classes for the associated prediction. Commonly, N=1 and `labels`
1292      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
1293      Values should be in range [0, num_classes), where num_classes is the last
1294      dimension of `predictions`. Values outside this range always count
1295      towards `false_negative_at_<k>`.
1296    k: Integer, k for @k metric.
1297    class_id: Integer class ID for which we want binary metrics. This should be
1298      in range [0, num_classes), where num_classes is the last dimension of
1299      `predictions`. If class_id is outside this range, the method returns NAN.
1300    weights: An optional `Tensor` whose shape is broadcastable to the first
1301      [D1, ... DN] dimensions of `predictions` and `labels`.
1302    metrics_collections: An optional list of collections that values should
1303      be added to.
1304    updates_collections: An optional list of collections that updates should
1305      be added to.
1306    name: Name of new update operation, and namespace for other dependent ops.
1307
1308  Returns:
1309    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1310      by the sum of `true_positives` and `false_negatives`.
1311    update_op: `Operation` that increments `true_positives` and
1312      `false_negatives` variables appropriately, and whose value matches
1313      `recall`.
1314
1315  Raises:
1316    ValueError: If `weights` is not `None` and its shape doesn't match
1317    `predictions`, or if either `metrics_collections` or `updates_collections`
1318    are not a list or tuple.
1319  """
1320  default_name = _at_k_name('recall', k, class_id=class_id)
1321  with ops.name_scope(name, default_name, (predictions, labels)) as scope:
1322    _, top_k_idx = nn.top_k(predictions, k)
1323    top_k_idx = math_ops.to_int64(top_k_idx)
1324    tp, tp_update = _streaming_sparse_true_positive_at_k(
1325        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1326        weights=weights)
1327    fn, fn_update = _streaming_sparse_false_negative_at_k(
1328        predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1329        weights=weights)
1330
1331    metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
1332    update = math_ops.div(
1333        tp_update, math_ops.add(tp_update, fn_update), name='update')
1334    if metrics_collections:
1335      ops.add_to_collections(metrics_collections, metric)
1336    if updates_collections:
1337      ops.add_to_collections(updates_collections, update)
1338    return metric, update
1339
1340
1341def _streaming_sparse_precision_at_k(top_k_idx,
1342                                     labels,
1343                                     k=None,
1344                                     class_id=None,
1345                                     weights=None,
1346                                     metrics_collections=None,
1347                                     updates_collections=None,
1348                                     name=None):
1349  """Computes precision@k of the top-k indices with respect to sparse labels.
1350
1351  This method contains the code shared by streaming_sparse_precision_at_k and
1352  streaming_sparse_precision_at_top_k. Refer to those methods for more details.
1353
1354  Args:
1355    top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where
1356      N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k].
1357      The final dimension contains the indices of top-k labels. [D1, ... DN]
1358      must match `labels`.
1359    labels: `int64` `Tensor` or `SparseTensor` with shape
1360      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1361      target classes for the associated prediction. Commonly, N=1 and `labels`
1362      has shape [batch_size, num_labels]. [D1, ... DN] must match
1363      `predictions_idx`. Values should be in range [0, num_classes), where
1364      num_classes is the last dimension of `predictions`. Values outside this
1365      range are ignored.
1366    k: Integer, k for @k metric or `None`. Only used for default op name.
1367    class_id: Integer class ID for which we want binary metrics. This should be
1368      in range [0, num_classes), where num_classes is the last dimension of
1369      `predictions`. If `class_id` is outside this range, the method returns
1370      NAN.
1371    weights: An optional `Tensor` whose shape is broadcastable to the first
1372      [D1, ... DN] dimensions of `predictions` and `labels`.
1373    metrics_collections: An optional list of collections that values should
1374      be added to.
1375    updates_collections: An optional list of collections that updates should
1376      be added to.
1377    name: Name of the metric and of the enclosing scope.
1378
1379  Returns:
1380    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1381      divided by the sum of `true_positives` and `false_positives`.
1382    update_op: `Operation` that increments `true_positives` and
1383      `false_positives` variables appropriately, and whose value matches
1384      `precision`.
1385
1386  Raises:
1387    ValueError: If `weights` is not `None` and its shape doesn't match
1388      `predictions`, or if either `metrics_collections` or `updates_collections`
1389      are not a list or tuple.
1390  """
1391  top_k_idx = math_ops.to_int64(top_k_idx)
1392  tp, tp_update = _streaming_sparse_true_positive_at_k(
1393      predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1394      weights=weights)
1395  fp, fp_update = _streaming_sparse_false_positive_at_k(
1396      predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1397      weights=weights)
1398
1399  metric = math_ops.div(tp, math_ops.add(tp, fp), name=name)
1400  update = math_ops.div(
1401      tp_update, math_ops.add(tp_update, fp_update), name='update')
1402  if metrics_collections:
1403    ops.add_to_collections(metrics_collections, metric)
1404  if updates_collections:
1405    ops.add_to_collections(updates_collections, update)
1406  return metric, update
1407
1408
1409# TODO(ptucker): Validate range of values in labels?
1410def streaming_sparse_precision_at_k(predictions,
1411                                    labels,
1412                                    k,
1413                                    class_id=None,
1414                                    weights=None,
1415                                    metrics_collections=None,
1416                                    updates_collections=None,
1417                                    name=None):
1418  """Computes precision@k of the predictions with respect to sparse labels.
1419
1420  If `class_id` is specified, we calculate precision by considering only the
1421      entries in the batch for which `class_id` is in the top-k highest
1422      `predictions`, and computing the fraction of them for which `class_id` is
1423      indeed a correct label.
1424  If `class_id` is not specified, we'll calculate precision as how often on
1425      average a class among the top-k classes with the highest predicted values
1426      of a batch entry is correct and can be found in the label for that entry.
1427
1428  `streaming_sparse_precision_at_k` creates two local variables,
1429  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1430  the precision@k frequency. This frequency is ultimately returned as
1431  `precision_at_<k>`: an idempotent operation that simply divides
1432  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1433  `false_positive_at_<k>`).
1434
1435  For estimation of the metric over a stream of data, the function creates an
1436  `update_op` operation that updates these variables and returns the
1437  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1438  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1439  `labels` calculate the true positives and false positives weighted by
1440  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1441  `false_positive_at_<k>` using these values.
1442
1443  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1444
1445  Args:
1446    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1447      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1448      The final dimension contains the logit values for each class. [D1, ... DN]
1449      must match `labels`.
1450    labels: `int64` `Tensor` or `SparseTensor` with shape
1451      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1452      target classes for the associated prediction. Commonly, N=1 and `labels`
1453      has shape [batch_size, num_labels]. [D1, ... DN] must match
1454      `predictions`. Values should be in range [0, num_classes), where
1455      num_classes is the last dimension of `predictions`. Values outside this
1456      range are ignored.
1457    k: Integer, k for @k metric.
1458    class_id: Integer class ID for which we want binary metrics. This should be
1459      in range [0, num_classes], where num_classes is the last dimension of
1460      `predictions`. If `class_id` is outside this range, the method returns
1461      NAN.
1462    weights: An optional `Tensor` whose shape is broadcastable to the first
1463      [D1, ... DN] dimensions of `predictions` and `labels`.
1464    metrics_collections: An optional list of collections that values should
1465      be added to.
1466    updates_collections: An optional list of collections that updates should
1467      be added to.
1468    name: Name of new update operation, and namespace for other dependent ops.
1469
1470  Returns:
1471    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1472      divided by the sum of `true_positives` and `false_positives`.
1473    update_op: `Operation` that increments `true_positives` and
1474      `false_positives` variables appropriately, and whose value matches
1475      `precision`.
1476
1477  Raises:
1478    ValueError: If `weights` is not `None` and its shape doesn't match
1479      `predictions`, or if either `metrics_collections` or `updates_collections`
1480      are not a list or tuple.
1481  """
1482  default_name = _at_k_name('precision', k, class_id=class_id)
1483  with ops.name_scope(name, default_name,
1484                      (predictions, labels, weights)) as scope:
1485    _, top_k_idx = nn.top_k(predictions, k)
1486    return _streaming_sparse_precision_at_k(
1487        top_k_idx=top_k_idx,
1488        labels=labels,
1489        k=k,
1490        class_id=class_id,
1491        weights=weights,
1492        metrics_collections=metrics_collections,
1493        updates_collections=updates_collections,
1494        name=scope)
1495
1496
1497# TODO(ptucker): Validate range of values in labels?
1498def streaming_sparse_precision_at_top_k(top_k_predictions,
1499                                        labels,
1500                                        class_id=None,
1501                                        weights=None,
1502                                        metrics_collections=None,
1503                                        updates_collections=None,
1504                                        name=None):
1505  """Computes precision@k of top-k predictions with respect to sparse labels.
1506
1507  If `class_id` is specified, we calculate precision by considering only the
1508      entries in the batch for which `class_id` is in the top-k highest
1509      `predictions`, and computing the fraction of them for which `class_id` is
1510      indeed a correct label.
1511  If `class_id` is not specified, we'll calculate precision as how often on
1512      average a class among the top-k classes with the highest predicted values
1513      of a batch entry is correct and can be found in the label for that entry.
1514
1515  `streaming_sparse_precision_at_top_k` creates two local variables,
1516  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
1517  the precision@k frequency. This frequency is ultimately returned as
1518  `precision_at_k`: an idempotent operation that simply divides
1519  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
1520
1521  For estimation of the metric over a stream of data, the function creates an
1522  `update_op` operation that updates these variables and returns the
1523  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
1524  and `labels` calculate the true positives and false positives weighted by
1525  `weights`. Then `update_op` increments `true_positive_at_k` and
1526  `false_positive_at_k` using these values.
1527
1528  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1529
1530  Args:
1531    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
1532      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
1533      The final dimension contains the indices of top-k labels. [D1, ... DN]
1534      must match `labels`.
1535    labels: `int64` `Tensor` or `SparseTensor` with shape
1536      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1537      target classes for the associated prediction. Commonly, N=1 and `labels`
1538      has shape [batch_size, num_labels]. [D1, ... DN] must match
1539      `top_k_predictions`. Values should be in range [0, num_classes), where
1540      num_classes is the last dimension of `predictions`. Values outside this
1541      range are ignored.
1542    class_id: Integer class ID for which we want binary metrics. This should be
1543      in range [0, num_classes), where num_classes is the last dimension of
1544      `predictions`. If `class_id` is outside this range, the method returns
1545      NAN.
1546    weights: An optional `Tensor` whose shape is broadcastable to the first
1547      [D1, ... DN] dimensions of `predictions` and `labels`.
1548    metrics_collections: An optional list of collections that values should
1549      be added to.
1550    updates_collections: An optional list of collections that updates should
1551      be added to.
1552    name: Name of new update operation, and namespace for other dependent ops.
1553
1554  Returns:
1555    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1556      divided by the sum of `true_positives` and `false_positives`.
1557    update_op: `Operation` that increments `true_positives` and
1558      `false_positives` variables appropriately, and whose value matches
1559      `precision`.
1560
1561  Raises:
1562    ValueError: If `weights` is not `None` and its shape doesn't match
1563      `predictions`, or if either `metrics_collections` or `updates_collections`
1564      are not a list or tuple.
1565    ValueError: If `top_k_predictions` has rank < 2.
1566  """
1567  default_name = _at_k_name('precision', class_id=class_id)
1568  with ops.name_scope(
1569      name, default_name,
1570      (top_k_predictions, labels, weights)) as scope:
1571    rank = array_ops.rank(top_k_predictions)
1572    check_rank_op = control_flow_ops.Assert(
1573        math_ops.greater_equal(rank, 2),
1574        ['top_k_predictions must have rank 2 or higher, e.g. [batch_size, k].'])
1575    with ops.control_dependencies([check_rank_op]):
1576      return _streaming_sparse_precision_at_k(
1577          top_k_idx=top_k_predictions,
1578          labels=labels,
1579          class_id=class_id,
1580          weights=weights,
1581          metrics_collections=metrics_collections,
1582          updates_collections=updates_collections,
1583          name=scope)
1584
1585
1586def num_relevant(labels, k):
1587  """Computes number of relevant values for each row in labels.
1588
1589  For labels with shape [D1, ... DN, num_labels], this is the minimum of
1590  `num_labels` and `k`.
1591
1592  Args:
1593    labels: `int64` `Tensor` or `SparseTensor` with shape
1594      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1595      target classes for the associated prediction. Commonly, N=1 and `labels`
1596      has shape [batch_size, num_labels].
1597    k: Integer, k for @k metric.
1598
1599  Returns:
1600    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
1601    relevant values for that row.
1602
1603  Raises:
1604    ValueError: if inputs have invalid dtypes or values.
1605  """
1606  if k < 1:
1607    raise ValueError('Invalid k=%s.' % k)
1608  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
1609    # For SparseTensor, calculate separate count for each row.
1610    if isinstance(
1611        labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
1612      labels_sizes = set_ops.set_size(labels)
1613      return math_ops.minimum(labels_sizes, k, name=scope)
1614
1615    # For dense Tensor, calculate scalar count based on last dimension, and
1616    # tile across labels shape.
1617    labels_shape = array_ops.shape(labels)
1618    labels_size = labels_shape[-1]
1619    num_relevant_scalar = math_ops.minimum(labels_size, k)
1620    return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
1621
1622
1623def expand_and_tile(tensor, multiple, dim=0, name=None):
1624  """Slice `tensor` shape in 2, then tile along the sliced dimension.
1625
1626  A new dimension is inserted in shape of `tensor` before `dim`, then values are
1627  tiled `multiple` times along the new dimension.
1628
1629  Args:
1630    tensor: Input `Tensor` or `SparseTensor`.
1631    multiple: Integer, number of times to tile.
1632    dim: Integer, dimension along which to tile.
1633    name: Name of operation.
1634
1635  Returns:
1636    `Tensor` result of expanding and tiling `tensor`.
1637
1638  Raises:
1639    ValueError: if `multiple` is less than 1, or `dim` is not in
1640    `[-rank(tensor), rank(tensor)]`.
1641  """
1642  if multiple < 1:
1643    raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
1644  with ops.name_scope(
1645      name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
1646    # Sparse.
1647    if isinstance(tensor, sparse_tensor.SparseTensorValue):
1648      tensor = sparse_tensor.SparseTensor.from_value(tensor)
1649    if isinstance(tensor, sparse_tensor.SparseTensor):
1650      if dim < 0:
1651        expand_dims = array_ops.reshape(
1652            array_ops.size(tensor.shape) + dim, [1])
1653      else:
1654        expand_dims = [dim]
1655      expanded_shape = array_ops.concat(
1656          0, (array_ops.slice(tensor.shape, [0], expand_dims), [1],
1657              array_ops.slice(tensor.shape, expand_dims, [-1])),
1658          name='expanded_shape')
1659      expanded = sparse_ops.sparse_reshape(
1660          tensor, shape=expanded_shape, name='expand')
1661      if multiple == 1:
1662        return expanded
1663      return sparse_ops.sparse_concat(
1664          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
1665
1666    # Dense.
1667    expanded = array_ops.expand_dims(
1668        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
1669    if multiple == 1:
1670      return expanded
1671    ones = array_ops.ones_like(array_ops.shape(tensor))
1672    tile_multiples = array_ops.concat(
1673        0, (ones[:dim], (multiple,), ones[dim:]), name='multiples')
1674    return array_ops.tile(expanded, tile_multiples, name=scope)
1675
1676
1677def sparse_average_precision_at_k(predictions, labels, k):
1678  """Computes average precision@k of predictions with respect to sparse labels.
1679
1680  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
1681  for each row is:
1682
1683    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
1684
1685  A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`,
1686  and the result `Tensors`. In the common case, this is [batch_size]. Each row
1687  of the results contains the average precision for that row.
1688
1689  Internally, a `top_k` operation computes a `Tensor` indicating the top `k`
1690  `predictions`. Set operations applied to `top_k` and `labels` calculate the
1691  true positives, which are used to calculate the precision ("P_{i}" term,
1692  above).
1693
1694  Args:
1695    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1696      N >= 1. Commonly, N=1 and `predictions` has shape
1697      [batch size, num_classes]. The final dimension contains the logit values
1698      for each class. [D1, ... DN] must match `labels`.
1699    labels: `int64` `Tensor` or `SparseTensor` with shape
1700      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1701      target classes for the associated prediction. Commonly, N=1 and `labels`
1702      has shape [batch_size, num_labels]. [D1, ... DN] must match
1703      `predictions`. Values should be in range [0, num_classes), where
1704      num_classes is the last dimension of `predictions`. Values outside this
1705      range are ignored.
1706    k: Integer, k for @k metric. This will calculate an average precision for
1707      range `[1,k]`, as documented above.
1708
1709  Returns:
1710    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
1711    precision for that row.
1712
1713  Raises:
1714    ValueError: if k is invalid.
1715  """
1716  if k < 1:
1717    raise ValueError('Invalid k=%s.' % k)
1718  with ops.name_scope(
1719      None, 'average_precision', (predictions, labels, k)) as scope:
1720    # Calculate top k indices to produce [D1, ... DN, k] tensor.
1721    _, predictions_idx = nn.top_k(predictions, k)
1722    predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
1723
1724    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
1725    # prediction for each k, so we can calculate separate true positive values
1726    # for each k.
1727    predictions_idx_per_k = array_ops.expand_dims(
1728        predictions_idx, -1, name='predictions_idx_per_k')
1729
1730    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
1731    labels_per_k = expand_and_tile(
1732        labels, multiple=k, dim=-1, name='labels_per_k')
1733
1734    # The following tensors are all of shape [D1, ... DN, k], containing values
1735    # per row, per k value.
1736    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
1737    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
1738    #     the formula above.
1739    # `tp_per_k` (int32) - True positive counts.
1740    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
1741    #     the precision denominator.
1742    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
1743    #     term from the formula above.
1744    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
1745    #     precisions at all k for which relevance indicator is true.
1746    relevant_per_k = _sparse_true_positive_at_k(
1747        predictions_idx_per_k, labels_per_k, name='relevant_per_k')
1748    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
1749    retrieved_per_k = math_ops.cumsum(
1750        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
1751    precision_per_k = math_ops.div(
1752        math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
1753        name='precision_per_k')
1754    relevant_precision_per_k = math_ops.mul(
1755        precision_per_k, math_ops.to_double(relevant_per_k),
1756        name='relevant_precision_per_k')
1757
1758    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
1759    precision_sum = math_ops.reduce_sum(
1760        relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
1761
1762    # Divide by number of relevant items to get average precision. These are
1763    # the "num_relevant_items" and "AveP" terms from the formula above.
1764    num_relevant_items = math_ops.to_double(num_relevant(labels, k))
1765    return math_ops.div(precision_sum, num_relevant_items, name=scope)
1766
1767
1768def streaming_sparse_average_precision_at_k(predictions,
1769                                            labels,
1770                                            k,
1771                                            weights=None,
1772                                            metrics_collections=None,
1773                                            updates_collections=None,
1774                                            name=None):
1775  """Computes average precision@k of predictions with respect to sparse labels.
1776
1777  See `sparse_average_precision_at_k` for details on formula. `weights` are
1778  applied to the result of `sparse_average_precision_at_k`
1779
1780  `streaming_sparse_average_precision_at_k` creates two local variables,
1781  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
1782  are used to compute the frequency. This frequency is ultimately returned as
1783  `average_precision_at_<k>`: an idempotent operation that simply divides
1784  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
1785
1786  For estimation of the metric over a stream of data, the function creates an
1787  `update_op` operation that updates these variables and returns the
1788  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1789  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1790  `labels` calculate the true positives and false positives weighted by
1791  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1792  `false_positive_at_<k>` using these values.
1793
1794  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1795
1796  Args:
1797    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1798      N >= 1. Commonly, N=1 and `predictions` has shape
1799      [batch size, num_classes]. The final dimension contains the logit values
1800      for each class. [D1, ... DN] must match `labels`.
1801    labels: `int64` `Tensor` or `SparseTensor` with shape
1802      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1803      target classes for the associated prediction. Commonly, N=1 and `labels`
1804      has shape [batch_size, num_labels]. [D1, ... DN] must match
1805      `predictions_`. Values should be in range [0, num_classes), where
1806      num_classes is the last dimension of `predictions`. Values outside this
1807      range are ignored.
1808    k: Integer, k for @k metric. This will calculate an average precision for
1809      range `[1,k]`, as documented above.
1810    weights: An optional `Tensor` whose shape is broadcastable to the first
1811      [D1, ... DN] dimensions of `predictions` and `labels`.
1812    metrics_collections: An optional list of collections that values should
1813      be added to.
1814    updates_collections: An optional list of collections that updates should
1815      be added to.
1816    name: Name of new update operation, and namespace for other dependent ops.
1817
1818  Returns:
1819    mean_average_precision: Scalar `float64` `Tensor` with the mean average
1820      precision values.
1821    update: `Operation` that increments  variables appropriately, and whose
1822      value matches `metric`.
1823  """
1824  default_name = _at_k_name('average_precision', k)
1825  with ops.name_scope(name, default_name, (predictions, labels)) as scope:
1826    # Calculate per-example average precision, and apply weights.
1827    average_precision = sparse_average_precision_at_k(
1828        predictions=predictions, labels=labels, k=k)
1829    if weights is not None:
1830      weights = math_ops.to_double(weights)
1831      average_precision = math_ops.mul(average_precision, weights)
1832
1833    # Create accumulation variables and update ops for max average precision and
1834    # total average precision.
1835    with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
1836      # `max` is the max possible precision. Since max for any row is 1.0:
1837      # - For the unweighted case, this is just the number of rows.
1838      # - For the weighted case, it's the sum of the weights broadcast across
1839      #   `average_precision` rows.
1840      max_var = contrib_variables.local_variable(
1841          array_ops.zeros([], dtype=dtypes.float64), name=max_scope)
1842      if weights is None:
1843        batch_max = math_ops.to_double(
1844            array_ops.size(average_precision, name='batch_max'))
1845      else:
1846        # TODO(ptucker): More efficient way to broadcast?
1847        broadcast_weights = math_ops.mul(
1848            weights, array_ops.ones_like(average_precision),
1849            name='broadcast_weights')
1850        batch_max = math_ops.reduce_sum(broadcast_weights, name='batch_max')
1851      max_update = state_ops.assign_add(max_var, batch_max, name='update')
1852    with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
1853      total_var = contrib_variables.local_variable(
1854          array_ops.zeros([], dtype=dtypes.float64), name=total_scope)
1855      batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
1856      total_update = state_ops.assign_add(total_var, batch_total, name='update')
1857
1858    # Divide total by max to get mean, for both vars and the update ops.
1859    mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
1860    update = _safe_scalar_div(total_update, max_update, name=scope)
1861
1862    if metrics_collections:
1863      ops.add_to_collections(metrics_collections, mean_average_precision)
1864    if updates_collections:
1865      ops.add_to_collections(updates_collections, update)
1866
1867    return mean_average_precision, update
1868
1869
1870def _select_class_id(ids, selected_id):
1871  """Filter all but `selected_id` out of `ids`.
1872
1873  Args:
1874    ids: `int64` `Tensor` or `SparseTensor` of IDs.
1875    selected_id: Int id to select.
1876
1877  Returns:
1878    `SparseTensor` of same dimensions as `ids`. This contains only the entries
1879    equal to `selected_id`.
1880  """
1881  if isinstance(
1882      ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
1883    return sparse_ops.sparse_retain(
1884        ids, math_ops.equal(ids.values, selected_id))
1885
1886  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
1887  # tf.equal and tf.reduce_any?
1888
1889  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
1890  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
1891  ids_last_dim = array_ops.size(ids_shape) - 1
1892  filled_selected_id_shape = math_ops.reduced_shape(
1893      ids_shape, array_ops.reshape(ids_last_dim, [1]))
1894
1895  # Intersect `ids` with the selected ID.
1896  filled_selected_id = array_ops.fill(
1897      filled_selected_id_shape, math_ops.to_int64(selected_id))
1898  result = set_ops.set_intersection(filled_selected_id, ids)
1899  return sparse_tensor.SparseTensor(
1900      indices=result.indices, values=result.values, shape=ids_shape)
1901
1902
1903def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
1904  """If class ID is specified, filter all other classes.
1905
1906  Args:
1907    labels: `int64` `Tensor` or `SparseTensor` with shape
1908      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1909      target classes for the associated prediction. Commonly, N=1 and `labels`
1910      has shape [batch_size, num_labels]. [D1, ... DN] must match
1911      `predictions_idx`.
1912    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
1913      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
1914      [batch size, k].
1915    selected_id: Int id to select.
1916
1917  Returns:
1918    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
1919  """
1920  if selected_id is None:
1921    return labels, predictions_idx
1922  return (_select_class_id(labels, selected_id),
1923          _select_class_id(predictions_idx, selected_id))
1924
1925
1926def _sparse_true_positive_at_k(predictions_idx,
1927                               labels,
1928                               class_id=None,
1929                               weights=None,
1930                               name=None):
1931  """Calculates true positives for recall@k and precision@k.
1932
1933  If `class_id` is specified, calculate binary true positives for `class_id`
1934      only.
1935  If `class_id` is not specified, calculate metrics for `k` predicted vs
1936      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1937
1938  Args:
1939    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1940      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1941      match `labels`.
1942    labels: `int64` `Tensor` or `SparseTensor` with shape
1943      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1944      target classes for the associated prediction. Commonly, N=1 and `labels`
1945      has shape [batch_size, num_labels]. [D1, ... DN] must match
1946      `predictions_idx`.
1947    class_id: Class for which we want binary metrics.
1948    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1949      dimensions of `predictions_idx` and `labels`.
1950    name: Name of operation.
1951
1952  Returns:
1953    A [D1, ... DN] `Tensor` of true positive counts.
1954  """
1955  with ops.name_scope(name, 'true_positives', (predictions_idx, labels)):
1956    labels, predictions_idx = _maybe_select_class_id(
1957        labels, predictions_idx, class_id)
1958    tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels))
1959    tp = math_ops.to_double(tp)
1960    if weights is not None:
1961      weights = math_ops.to_double(weights)
1962      tp = math_ops.mul(tp, weights)
1963    return tp
1964
1965
1966def _streaming_sparse_true_positive_at_k(predictions_idx,
1967                                         labels,
1968                                         k=None,
1969                                         class_id=None,
1970                                         weights=None,
1971                                         name=None):
1972  """Calculates weighted per step true positives for recall@k and precision@k.
1973
1974  If `class_id` is specified, calculate binary true positives for `class_id`
1975      only.
1976  If `class_id` is not specified, calculate metrics for `k` predicted vs
1977      `n` label classes, where `n` is the 2nd dimension of `labels`.
1978
1979  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1980
1981  Args:
1982    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1983      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1984      match `labels`.
1985    labels: `int64` `Tensor` or `SparseTensor` with shape
1986      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1987      target classes for the associated prediction. Commonly, N=1 and `labels`
1988      has shape [batch_size, num_labels]. [D1, ... DN] must match
1989      `predictions_idx`.
1990    k: Integer, k for @k metric. This is only used for default op name.
1991    class_id: Class for which we want binary metrics.
1992    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1993      dimensions of `predictions_idx` and `labels`.
1994    name: Name of new variable, and namespace for other dependent ops.
1995
1996  Returns:
1997    A tuple of `Variable` and update `Operation`.
1998
1999  Raises:
2000    ValueError: If `weights` is not `None` and has an incomptable shape.
2001  """
2002  default_name = _at_k_name('true_positive', k, class_id=class_id)
2003  with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope:
2004    tp = _sparse_true_positive_at_k(
2005        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
2006        weights=weights)
2007    batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
2008
2009    var = contrib_variables.local_variable(
2010        array_ops.zeros([], dtype=dtypes.float64), name=scope)
2011    return var, state_ops.assign_add(var, batch_total_tp, name='update')
2012
2013
2014def _sparse_false_positive_at_k(predictions_idx,
2015                                labels,
2016                                class_id=None,
2017                                weights=None):
2018  """Calculates false positives for precision@k.
2019
2020  If `class_id` is specified, calculate binary true positives for `class_id`
2021      only.
2022  If `class_id` is not specified, calculate metrics for `k` predicted vs
2023      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2024
2025  Args:
2026    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2027      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2028      match `labels`.
2029    labels: `int64` `Tensor` or `SparseTensor` with shape
2030      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2031      target classes for the associated prediction. Commonly, N=1 and `labels`
2032      has shape [batch_size, num_labels]. [D1, ... DN] must match
2033      `predictions_idx`.
2034    class_id: Class for which we want binary metrics.
2035    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
2036      dimensions of `predictions_idx` and `labels`.
2037
2038  Returns:
2039    A [D1, ... DN] `Tensor` of false positive counts.
2040  """
2041  with ops.name_scope(None, 'false_positives', (predictions_idx, labels)):
2042    labels, predictions_idx = _maybe_select_class_id(labels,
2043                                                     predictions_idx,
2044                                                     class_id)
2045    fp = set_ops.set_size(set_ops.set_difference(
2046        predictions_idx, labels, aminusb=True))
2047    fp = math_ops.to_double(fp)
2048    if weights is not None:
2049      weights = math_ops.to_double(weights)
2050      fp = math_ops.mul(fp, weights)
2051    return fp
2052
2053
2054def _streaming_sparse_false_positive_at_k(predictions_idx,
2055                                          labels,
2056                                          k=None,
2057                                          class_id=None,
2058                                          weights=None,
2059                                          name=None):
2060  """Calculates weighted per step false positives for precision@k.
2061
2062  If `class_id` is specified, calculate binary true positives for `class_id`
2063      only.
2064  If `class_id` is not specified, calculate metrics for `k` predicted vs
2065      `n` label classes, where `n` is the 2nd dimension of `labels`.
2066
2067  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2068
2069  Args:
2070    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2071      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2072      match `labels`.
2073    labels: `int64` `Tensor` or `SparseTensor` with shape
2074      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2075      target classes for the associated prediction. Commonly, N=1 and `labels`
2076      has shape [batch_size, num_labels]. [D1, ... DN] must match
2077      `predictions_idx`.
2078    k: Integer, k for @k metric. This is only used for default op name.
2079    class_id: Class for which we want binary metrics.
2080    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
2081      dimensions of `predictions_idx` and `labels`.
2082    name: Name of new variable, and namespace for other dependent ops.
2083
2084  Returns:
2085    A tuple of `Variable` and update `Operation`.
2086
2087  Raises:
2088    ValueError: If `weights` is not `None` and has an incomptable shape.
2089  """
2090  default_name = _at_k_name('false_positive', k, class_id=class_id)
2091  with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope:
2092    fp = _sparse_false_positive_at_k(
2093        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
2094        weights=weights)
2095    batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
2096
2097    var = contrib_variables.local_variable(
2098        array_ops.zeros([], dtype=dtypes.float64), name=scope)
2099    return var, state_ops.assign_add(var, batch_total_fp, name='update')
2100
2101
2102def _sparse_false_negative_at_k(predictions_idx,
2103                                labels,
2104                                class_id=None,
2105                                weights=None):
2106  """Calculates false negatives for recall@k.
2107
2108  If `class_id` is specified, calculate binary true positives for `class_id`
2109      only.
2110  If `class_id` is not specified, calculate metrics for `k` predicted vs
2111      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2112
2113  Args:
2114    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2115      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2116      match `labels`.
2117    labels: `int64` `Tensor` or `SparseTensor` with shape
2118      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2119      target classes for the associated prediction. Commonly, N=1 and `labels`
2120      has shape [batch_size, num_labels]. [D1, ... DN] must match
2121      `predictions_idx`.
2122    class_id: Class for which we want binary metrics.
2123    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
2124      dimensions of `predictions_idx` and `labels`.
2125
2126  Returns:
2127    A [D1, ... DN] `Tensor` of false negative counts.
2128  """
2129  with ops.name_scope(None, 'false_negatives', (predictions_idx, labels)):
2130    labels, predictions_idx = _maybe_select_class_id(labels,
2131                                                     predictions_idx,
2132                                                     class_id)
2133    fn = set_ops.set_size(set_ops.set_difference(predictions_idx,
2134                                                 labels,
2135                                                 aminusb=False))
2136    fn = math_ops.to_double(fn)
2137    if weights is not None:
2138      weights = math_ops.to_double(weights)
2139      fn = math_ops.mul(fn, weights)
2140    return fn
2141
2142
2143def _streaming_sparse_false_negative_at_k(predictions_idx,
2144                                          labels,
2145                                          k,
2146                                          class_id=None,
2147                                          weights=None,
2148                                          name=None):
2149  """Calculates weighted per step false negatives for recall@k.
2150
2151  If `class_id` is specified, calculate binary true positives for `class_id`
2152      only.
2153  If `class_id` is not specified, calculate metrics for `k` predicted vs
2154      `n` label classes, where `n` is the 2nd dimension of `labels`.
2155
2156  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2157
2158  Args:
2159    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2160      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2161      match `labels`.
2162    labels: `int64` `Tensor` or `SparseTensor` with shape
2163      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2164      target classes for the associated prediction. Commonly, N=1 and `labels`
2165      has shape [batch_size, num_labels]. [D1, ... DN] must match
2166      `predictions_idx`.
2167    k: Integer, k for @k metric. This is only used for default op name.
2168    class_id: Class for which we want binary metrics.
2169    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
2170      dimensions of `predictions_idx` and `labels`.
2171    name: Name of new variable, and namespace for other dependent ops.
2172
2173  Returns:
2174    A tuple of `Variable` and update `Operation`.
2175
2176  Raises:
2177    ValueError: If `weights` is not `None` and has an incomptable shape.
2178  """
2179  default_name = _at_k_name('false_negative', k, class_id=class_id)
2180  with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope:
2181    fn = _sparse_false_negative_at_k(
2182        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
2183        weights=weights)
2184    batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
2185
2186    var = contrib_variables.local_variable(
2187        array_ops.zeros([], dtype=dtypes.float64), name=scope)
2188    return var, state_ops.assign_add(var, batch_total_fn, name='update')
2189
2190
2191def streaming_mean_absolute_error(predictions, labels, weights=None,
2192                                  metrics_collections=None,
2193                                  updates_collections=None,
2194                                  name=None):
2195  """Computes the mean absolute error between the labels and predictions.
2196
2197  The `streaming_mean_absolute_error` function creates two local variables,
2198  `total` and `count` that are used to compute the mean absolute error. This
2199  average is weighted by `weights`, and it is ultimately returned as
2200  `mean_absolute_error`: an idempotent operation that simply divides `total` by
2201  `count`.
2202
2203  For estimation of the metric over a stream of data, the function creates an
2204  `update_op` operation that updates these variables and returns the
2205  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
2206  absolute value of the differences between `predictions` and `labels`. Then
2207  `update_op` increments `total` with the reduced sum of the product of
2208  `weights` and `absolute_errors`, and it increments `count` with the reduced
2209  sum of `weights`
2210
2211  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2212
2213  Args:
2214    predictions: A `Tensor` of arbitrary shape.
2215    labels: A `Tensor` of the same shape as `predictions`.
2216    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2217    metrics_collections: An optional list of collections that
2218      `mean_absolute_error` should be added to.
2219    updates_collections: An optional list of collections that `update_op` should
2220      be added to.
2221    name: An optional variable_scope name.
2222
2223  Returns:
2224    mean_absolute_error: A tensor representing the current mean, the value of
2225      `total` divided by `count`.
2226    update_op: An operation that increments the `total` and `count` variables
2227      appropriately and whose value matches `mean_absolute_error`.
2228
2229  Raises:
2230    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2231      `weights` is not `None` and its shape doesn't match `predictions`, or if
2232      either `metrics_collections` or `updates_collections` are not a list or
2233      tuple.
2234  """
2235  predictions, labels, weights = _remove_squeezable_dimensions(
2236      predictions, labels, weights)
2237  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2238  absolute_errors = math_ops.abs(predictions - labels)
2239  return streaming_mean(absolute_errors, weights, metrics_collections,
2240                        updates_collections, name or 'mean_absolute_error')
2241
2242
2243def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
2244                                  metrics_collections=None,
2245                                  updates_collections=None,
2246                                  name=None):
2247  """Computes the mean relative error by normalizing with the given values.
2248
2249  The `streaming_mean_relative_error` function creates two local variables,
2250  `total` and `count` that are used to compute the mean relative absolute error.
2251  This average is weighted by `weights`, and it is ultimately returned as
2252  `mean_relative_error`: an idempotent operation that simply divides `total` by
2253  `count`.
2254
2255  For estimation of the metric over a stream of data, the function creates an
2256  `update_op` operation that updates these variables and returns the
2257  `mean_reative_error`. Internally, a `relative_errors` operation divides the
2258  absolute value of the differences between `predictions` and `labels` by the
2259  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
2260  product of `weights` and `relative_errors`, and it increments `count` with the
2261  reduced sum of `weights`.
2262
2263  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2264
2265  Args:
2266    predictions: A `Tensor` of arbitrary shape.
2267    labels: A `Tensor` of the same shape as `predictions`.
2268    normalizer: A `Tensor` of the same shape as `predictions`.
2269    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2270    metrics_collections: An optional list of collections that
2271      `mean_relative_error` should be added to.
2272    updates_collections: An optional list of collections that `update_op` should
2273      be added to.
2274    name: An optional variable_scope name.
2275
2276  Returns:
2277    mean_relative_error: A tensor representing the current mean, the value of
2278      `total` divided by `count`.
2279    update_op: An operation that increments the `total` and `count` variables
2280      appropriately and whose value matches `mean_relative_error`.
2281
2282  Raises:
2283    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2284      `weights` is not `None` and its shape doesn't match `predictions`, or if
2285      either `metrics_collections` or `updates_collections` are not a list or
2286      tuple.
2287  """
2288  predictions, labels, weights = _remove_squeezable_dimensions(
2289      predictions, labels, weights)
2290  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2291
2292  predictions, normalizer = tensor_util.remove_squeezable_dimensions(
2293      predictions, normalizer)
2294  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
2295  relative_errors = math_ops.select(
2296      math_ops.equal(normalizer, 0.0),
2297      array_ops.zeros_like(labels),
2298      math_ops.div(math_ops.abs(labels - predictions), normalizer))
2299  return streaming_mean(relative_errors, weights, metrics_collections,
2300                        updates_collections, name or 'mean_relative_error')
2301
2302
2303def streaming_mean_squared_error(predictions, labels, weights=None,
2304                                 metrics_collections=None,
2305                                 updates_collections=None,
2306                                 name=None):
2307  """Computes the mean squared error between the labels and predictions.
2308
2309  The `streaming_mean_squared_error` function creates two local variables,
2310  `total` and `count` that are used to compute the mean squared error.
2311  This average is weighted by `weights`, and it is ultimately returned as
2312  `mean_squared_error`: an idempotent operation that simply divides `total` by
2313  `count`.
2314
2315  For estimation of the metric over a stream of data, the function creates an
2316  `update_op` operation that updates these variables and returns the
2317  `mean_squared_error`. Internally, a `squared_error` operation computes the
2318  element-wise square of the difference between `predictions` and `labels`. Then
2319  `update_op` increments `total` with the reduced sum of the product of
2320  `weights` and `squared_error`, and it increments `count` with the reduced sum
2321  of `weights`.
2322
2323  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2324
2325  Args:
2326    predictions: A `Tensor` of arbitrary shape.
2327    labels: A `Tensor` of the same shape as `predictions`.
2328    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2329    metrics_collections: An optional list of collections that
2330      `mean_squared_error` should be added to.
2331    updates_collections: An optional list of collections that `update_op` should
2332      be added to.
2333    name: An optional variable_scope name.
2334
2335  Returns:
2336    mean_squared_error: A tensor representing the current mean, the value of
2337      `total` divided by `count`.
2338    update_op: An operation that increments the `total` and `count` variables
2339      appropriately and whose value matches `mean_squared_error`.
2340
2341  Raises:
2342    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2343      `weights` is not `None` and its shape doesn't match `predictions`, or if
2344      either `metrics_collections` or `updates_collections` are not a list or
2345      tuple.
2346  """
2347  predictions, labels, weights = _remove_squeezable_dimensions(
2348      predictions, labels, weights)
2349  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2350  squared_error = math_ops.square(labels - predictions)
2351  return streaming_mean(squared_error, weights, metrics_collections,
2352                        updates_collections, name or 'mean_squared_error')
2353
2354
2355def streaming_root_mean_squared_error(predictions, labels, weights=None,
2356                                      metrics_collections=None,
2357                                      updates_collections=None,
2358                                      name=None):
2359  """Computes the root mean squared error between the labels and predictions.
2360
2361  The `streaming_root_mean_squared_error` function creates two local variables,
2362  `total` and `count` that are used to compute the root mean squared error.
2363  This average is weighted by `weights`, and it is ultimately returned as
2364  `root_mean_squared_error`: an idempotent operation that takes the square root
2365  of the division of `total` by `count`.
2366
2367  For estimation of the metric over a stream of data, the function creates an
2368  `update_op` operation that updates these variables and returns the
2369  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2370  the element-wise square of the difference between `predictions` and `labels`.
2371  Then `update_op` increments `total` with the reduced sum of the product of
2372  `weights` and `squared_error`, and it increments `count` with the reduced sum
2373  of `weights`.
2374
2375  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2376
2377  Args:
2378    predictions: A `Tensor` of arbitrary shape.
2379    labels: A `Tensor` of the same shape as `predictions`.
2380    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2381    metrics_collections: An optional list of collections that
2382      `root_mean_squared_error` should be added to.
2383    updates_collections: An optional list of collections that `update_op` should
2384      be added to.
2385    name: An optional variable_scope name.
2386
2387  Returns:
2388    root_mean_squared_error: A tensor representing the current mean, the value
2389      of `total` divided by `count`.
2390    update_op: An operation that increments the `total` and `count` variables
2391      appropriately and whose value matches `root_mean_squared_error`.
2392
2393  Raises:
2394    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2395      `weights` is not `None` and its shape doesn't match `predictions`, or if
2396      either `metrics_collections` or `updates_collections` are not a list or
2397      tuple.
2398  """
2399  predictions, labels, weights = _remove_squeezable_dimensions(
2400      predictions, labels, weights)
2401  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2402  value_tensor, update_op = streaming_mean_squared_error(
2403      predictions, labels, weights, None, None,
2404      name or 'root_mean_squared_error')
2405
2406  root_mean_squared_error = math_ops.sqrt(value_tensor)
2407  with ops.control_dependencies([update_op]):
2408    update_op = math_ops.sqrt(update_op)
2409
2410  if metrics_collections:
2411    ops.add_to_collections(metrics_collections, root_mean_squared_error)
2412
2413  if updates_collections:
2414    ops.add_to_collections(updates_collections, update_op)
2415
2416  return root_mean_squared_error, update_op
2417
2418
2419def streaming_covariance(predictions,
2420                         labels,
2421                         weights=None,
2422                         metrics_collections=None,
2423                         updates_collections=None,
2424                         name=None):
2425  """Computes the unbiased sample covariance between `predictions` and `labels`.
2426
2427  The `streaming_covariance` function creates four local variables,
2428  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
2429  compute the sample covariance between predictions and labels across multiple
2430  batches of data. The covariance is ultimately returned as an idempotent
2431  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
2432  in order to get an unbiased estimate.
2433
2434  The algorithm used for this online computation is described in
2435  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
2436  Specifically, the formula used to combine two sample comoments is
2437  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
2438  The comoment for a single batch of data is simply
2439  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
2440
2441  If `weights` is not None, then it is used to compute weighted comoments,
2442  means, and count. NOTE: these weights are treated as "frequency weights", as
2443  opposed to "reliability weights". See discussion of the difference on
2444  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2445
2446  To facilitate the computation of covariance across multiple batches of data,
2447  the function creates an `update_op` operation, which updates underlying
2448  variables and returns the updated covariance.
2449
2450  Args:
2451    predictions: A `Tensor` of arbitrary size.
2452    labels: A `Tensor` of the same size as `predictions`.
2453    weights: An optional set of weights which indicates the frequency with which
2454      an example is sampled. Must be broadcastable with `labels`.
2455    metrics_collections: An optional list of collections that the metric
2456      value variable should be added to.
2457    updates_collections: An optional list of collections that the metric update
2458      ops should be added to.
2459    name: An optional variable_scope name.
2460
2461  Returns:
2462    covariance: A `Tensor` representing the current unbiased sample covariance,
2463      `comoment` / (`count` - 1).
2464    update_op: An operation that updates the local variables appropriately.
2465
2466  Raises:
2467    ValueError: If labels and predictions are of different sizes or if either
2468      `metrics_collections` or `updates_collections` are not a list or tuple.
2469  """
2470  with variable_scope.variable_scope(name, 'covariance', [predictions, labels]):
2471    predictions, labels, weights = _remove_squeezable_dimensions(
2472        predictions, labels, weights)
2473    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2474    count = _create_local('count', [])
2475    mean_prediction = _create_local('mean_prediction', [])
2476    mean_label = _create_local('mean_label', [])
2477    comoment = _create_local('comoment', [])  # C_A in update equation
2478
2479    if weights is None:
2480      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
2481      weighted_predictions = predictions
2482      weighted_labels = labels
2483    else:
2484      batch_count = math_ops.reduce_sum(
2485          _broadcast_weights(weights, labels))  # n_B in eqn
2486      weighted_predictions = predictions * weights
2487      weighted_labels = labels * weights
2488
2489    update_count = state_ops.assign_add(count, batch_count)  # n_AB in eqn
2490    prev_count = update_count - batch_count  # n_A in update equation
2491
2492    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
2493    # batch_mean_prediction is E[x_B] in the update equation
2494    batch_mean_prediction = _safe_div(
2495        math_ops.reduce_sum(weighted_predictions), batch_count,
2496        'batch_mean_prediction')
2497    delta_mean_prediction = _safe_div(
2498        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
2499        'delta_mean_prediction')
2500    update_mean_prediction = state_ops.assign_add(mean_prediction,
2501                                                  delta_mean_prediction)
2502    # prev_mean_prediction is E[x_A] in the update equation
2503    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
2504
2505    # batch_mean_label is E[y_B] in the update equation
2506    batch_mean_label = _safe_div(
2507        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
2508    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
2509                                 update_count, 'delta_mean_label')
2510    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
2511    # prev_mean_label is E[y_A] in the update equation
2512    prev_mean_label = update_mean_label - delta_mean_label
2513
2514    unweighted_batch_coresiduals = (
2515        (predictions - batch_mean_prediction) * (labels - batch_mean_label))
2516    # batch_comoment is C_B in the update equation
2517    if weights is None:
2518      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
2519    else:
2520      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals *
2521                                           weights)
2522
2523    # View delta_comoment as = C_AB - C_A in the update equation above.
2524    # Since C_A is stored in a var, by how much do we need to increment that var
2525    # to make the var = C_AB?
2526    delta_comoment = (batch_comoment +
2527                      (prev_mean_prediction - batch_mean_prediction) *
2528                      (prev_mean_label - batch_mean_label) *
2529                      (prev_count * batch_count / update_count))
2530    update_comoment = state_ops.assign_add(comoment, delta_comoment)
2531
2532    covariance = _safe_div(comoment, count - 1, 'covariance')
2533    with ops.control_dependencies([update_comoment]):
2534      update_op = _safe_div(comoment, count - 1, 'update_op')
2535
2536  if metrics_collections:
2537    ops.add_to_collections(metrics_collections, covariance)
2538
2539  if updates_collections:
2540    ops.add_to_collections(updates_collections, update_op)
2541
2542  return covariance, update_op
2543
2544
2545def streaming_pearson_correlation(predictions,
2546                                  labels,
2547                                  weights=None,
2548                                  metrics_collections=None,
2549                                  updates_collections=None,
2550                                  name=None):
2551  """Computes Pearson correlation coefficient between `predictions`, `labels`.
2552
2553  The `streaming_pearson_correlation` function delegates to
2554  `streaming_covariance` the tracking of three [co]variances:
2555
2556  - `streaming_covariance(predictions, labels)`, i.e. covariance
2557  - `streaming_covariance(predictions, predictions)`, i.e. variance
2558  - `streaming_covariance(labels, labels)`, i.e. variance
2559
2560  The product-moment correlation ultimately returned is an idempotent operation
2561  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
2562  facilitate correlation computation across multiple batches, the function
2563  groups the `update_op`s of the underlying streaming_covariance and returns an
2564  `update_op`.
2565
2566  If `weights` is not None, then it is used to compute a weighted correlation.
2567  NOTE: these weights are treated as "frequency weights", as opposed to
2568  "reliability weights". See discussion of the difference on
2569  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2570
2571  Args:
2572    predictions: A `Tensor` of arbitrary size.
2573    labels: A `Tensor` of the same size as predictions.
2574    weights: An optional set of weights which indicates the frequency with which
2575      an example is sampled. Must be broadcastable with `labels`.
2576    metrics_collections: An optional list of collections that the metric
2577      value variable should be added to.
2578    updates_collections: An optional list of collections that the metric update
2579      ops should be added to.
2580    name: An optional variable_scope name.
2581
2582  Returns:
2583    pearson_r: A tensor representing the current Pearson product-moment
2584      correlation coefficient, the value of
2585      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
2586    update_op: An operation that updates the underlying variables appropriately.
2587
2588  Raises:
2589    ValueError: If `labels` and `predictions` are of different sizes, or if
2590      `weights` is the wrong size, or if either `metrics_collections` or
2591      `updates_collections` are not a `list` or `tuple`.
2592  """
2593  with variable_scope.variable_scope(name, 'pearson_r', [predictions, labels]):
2594    predictions, labels, weights = _remove_squeezable_dimensions(
2595        predictions, labels, weights)
2596    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2597    cov, update_cov = streaming_covariance(
2598        predictions, labels, weights=weights, name='covariance')
2599    var_predictions, update_var_predictions = streaming_covariance(
2600        predictions, predictions, weights=weights, name='variance_predictions')
2601    var_labels, update_var_labels = streaming_covariance(
2602        labels, labels, weights=weights, name='variance_labels')
2603
2604    pearson_r = _safe_div(
2605        cov,
2606        math_ops.mul(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)),
2607        'pearson_r')
2608    with ops.control_dependencies(
2609        [update_cov, update_var_predictions, update_var_labels]):
2610      update_op = _safe_div(update_cov, math_ops.mul(
2611          math_ops.sqrt(update_var_predictions),
2612          math_ops.sqrt(update_var_labels)), 'update_op')
2613
2614  if metrics_collections:
2615    ops.add_to_collections(metrics_collections, pearson_r)
2616
2617  if updates_collections:
2618    ops.add_to_collections(updates_collections, update_op)
2619
2620  return pearson_r, update_op
2621
2622
2623# TODO(nsilberman): add a 'normalized' flag so that the user can request
2624# normalization if the inputs are not normalized.
2625def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
2626                                   metrics_collections=None,
2627                                   updates_collections=None,
2628                                   name=None):
2629  """Computes the cosine distance between the labels and predictions.
2630
2631  The `streaming_mean_cosine_distance` function creates two local variables,
2632  `total` and `count` that are used to compute the average cosine distance
2633  between `predictions` and `labels`. This average is weighted by `weights`,
2634  and it is ultimately returned as `mean_distance`, which is an idempotent
2635  operation that simply divides `total` by `count`.
2636
2637  For estimation of the metric over a stream of data, the function creates an
2638  `update_op` operation that updates these variables and returns the
2639  `mean_distance`.
2640
2641  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2642
2643  Args:
2644    predictions: A `Tensor` of the same shape as `labels`.
2645    labels: A `Tensor` of arbitrary shape.
2646    dim: The dimension along which the cosine distance is computed.
2647    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
2648      and whose dimension `dim` is 1.
2649    metrics_collections: An optional list of collections that the metric
2650      value variable should be added to.
2651    updates_collections: An optional list of collections that the metric update
2652      ops should be added to.
2653    name: An optional variable_scope name.
2654
2655  Returns:
2656    mean_distance: A tensor representing the current mean, the value of `total`
2657      divided by `count`.
2658    update_op: An operation that increments the `total` and `count` variables
2659      appropriately.
2660
2661  Raises:
2662    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2663      `weights` is not `None` and its shape doesn't match `predictions`, or if
2664      either `metrics_collections` or `updates_collections` are not a list or
2665      tuple.
2666  """
2667  predictions, labels, weights = _remove_squeezable_dimensions(
2668      predictions, labels, weights)
2669  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2670  radial_diffs = math_ops.mul(predictions, labels)
2671  radial_diffs = math_ops.reduce_sum(radial_diffs,
2672                                     reduction_indices=[dim,],
2673                                     keep_dims=True)
2674  mean_distance, update_op = streaming_mean(radial_diffs, weights,
2675                                            None,
2676                                            None,
2677                                            name or 'mean_cosine_distance')
2678  mean_distance = math_ops.sub(1.0, mean_distance)
2679  update_op = math_ops.sub(1.0, update_op)
2680
2681  if metrics_collections:
2682    ops.add_to_collections(metrics_collections, mean_distance)
2683
2684  if updates_collections:
2685    ops.add_to_collections(updates_collections, update_op)
2686
2687  return mean_distance, update_op
2688
2689
2690def streaming_percentage_less(values, threshold, weights=None,
2691                              metrics_collections=None,
2692                              updates_collections=None,
2693                              name=None):
2694  """Computes the percentage of values less than the given threshold.
2695
2696  The `streaming_percentage_less` function creates two local variables,
2697  `total` and `count` that are used to compute the percentage of `values` that
2698  fall below `threshold`. This rate is weighted by `weights`, and it is
2699  ultimately returned as `percentage` which is an idempotent operation that
2700  simply divides `total` by `count`.
2701
2702  For estimation of the metric over a stream of data, the function creates an
2703  `update_op` operation that updates these variables and returns the
2704  `percentage`.
2705
2706  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2707
2708  Args:
2709    values: A numeric `Tensor` of arbitrary size.
2710    threshold: A scalar threshold.
2711    weights: An optional `Tensor` whose shape is broadcastable to `values`.
2712    metrics_collections: An optional list of collections that the metric
2713      value variable should be added to.
2714    updates_collections: An optional list of collections that the metric update
2715      ops should be added to.
2716    name: An optional variable_scope name.
2717
2718  Returns:
2719    percentage: A tensor representing the current mean, the value of `total`
2720      divided by `count`.
2721    update_op: An operation that increments the `total` and `count` variables
2722      appropriately.
2723
2724  Raises:
2725    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
2726      or if either `metrics_collections` or `updates_collections` are not a list
2727      or tuple.
2728  """
2729  is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
2730  return streaming_mean(is_below_threshold,
2731                        weights,
2732                        metrics_collections,
2733                        updates_collections,
2734                        name or 'percentage_below_threshold')
2735
2736
2737def streaming_mean_iou(predictions,
2738                       labels,
2739                       num_classes,
2740                       weights=None,
2741                       metrics_collections=None,
2742                       updates_collections=None,
2743                       name=None):
2744  """Calculate per-step mean Intersection-Over-Union (mIOU).
2745
2746  Mean Intersection-Over-Union is a common evaluation metric for
2747  semantic image segmentation, which first computes the IOU for each
2748  semantic class and then computes the average over classes.
2749  IOU is defined as follows:
2750    IOU = true_positive / (true_positive + false_positive + false_negative).
2751  The predictions are accumulated in a confusion matrix, weighted by `weights`,
2752  and mIOU is then calculated from it.
2753
2754  For estimation of the metric over a stream of data, the function creates an
2755  `update_op` operation that updates these variables and returns the `mean_iou`.
2756
2757  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2758
2759  Args:
2760    predictions: A tensor of prediction results for semantic labels, whose
2761      shape is [batch size] and type `int32` or `int64`. The tensor will be
2762      flattened, if its rank > 1.
2763    labels: A tensor of ground truth labels with shape [batch size] and of
2764      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
2765    num_classes: The possible number of labels the prediction task can
2766      have. This value must be provided, since a confusion matrix of
2767      dimension = [num_classes, num_classes] will be allocated.
2768    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2769    metrics_collections: An optional list of collections that `mean_iou`
2770      should be added to.
2771    updates_collections: An optional list of collections `update_op` should be
2772      added to.
2773    name: An optional variable_scope name.
2774
2775  Returns:
2776    mean_iou: A tensor representing the mean intersection-over-union.
2777    update_op: An operation that increments the confusion matrix.
2778
2779  Raises:
2780    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2781      `weights` is not `None` and its shape doesn't match `predictions`, or if
2782      either `metrics_collections` or `updates_collections` are not a list or
2783      tuple.
2784  """
2785  with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]):
2786    # Check if shape is compatible.
2787    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2788
2789    # Local variable to accumulate the predictions in the confusion matrix.
2790    cm_dtype = dtypes.int64 if weights is not None else dtypes.float64
2791    total_cm = _create_local('total_confusion_matrix',
2792                             shape=[num_classes, num_classes], dtype=cm_dtype)
2793
2794    # Cast the type to int64 required by confusion_matrix_ops.
2795    predictions = math_ops.to_int64(predictions)
2796    labels = math_ops.to_int64(labels)
2797    num_classes = math_ops.to_int64(num_classes)
2798
2799    # Flatten the input if its rank > 1.
2800    predictions_rank = predictions.get_shape().ndims
2801    if predictions_rank > 1:
2802      predictions = array_ops.reshape(predictions, [-1])
2803
2804    labels_rank = labels.get_shape().ndims
2805    if labels_rank > 1:
2806      labels = array_ops.reshape(labels, [-1])
2807
2808    if weights is not None:
2809      weights_rank = weights.get_shape().ndims
2810      if weights_rank > 1:
2811        weights = array_ops.reshape(weights, [-1])
2812
2813    # Accumulate the prediction to current confusion matrix.
2814    current_cm = confusion_matrix_ops.confusion_matrix(
2815        predictions, labels, num_classes, weights=weights, dtype=cm_dtype)
2816    update_op = state_ops.assign_add(total_cm, current_cm)
2817
2818    def compute_mean_iou(name):
2819      """Compute the mean intersection-over-union via the confusion matrix."""
2820      sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
2821      sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
2822      cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
2823      denominator = sum_over_row + sum_over_col - cm_diag
2824
2825      # If the value of the denominator is 0, set it to 1 to avoid
2826      # zero division.
2827      denominator = math_ops.select(
2828          math_ops.greater(denominator, 0),
2829          denominator,
2830          array_ops.ones_like(denominator))
2831      iou = math_ops.div(cm_diag, denominator)
2832      return math_ops.reduce_mean(iou, name=name)
2833
2834    mean_iou = compute_mean_iou('mean_iou')
2835
2836    if metrics_collections:
2837      ops.add_to_collections(metrics_collections, mean_iou)
2838
2839    if updates_collections:
2840      ops.add_to_collections(updates_collections, update_op)
2841
2842    return mean_iou, update_op
2843
2844
2845def _next_array_size(required_size, growth_factor=1.5):
2846  """Calculate the next size for reallocating a dynamic array.
2847
2848  Args:
2849    required_size: number or tf.Tensor specifying required array capacity.
2850    growth_factor: optional number or tf.Tensor specifying the growth factor
2851      between subsequent allocations.
2852
2853  Returns:
2854    tf.Tensor with dtype=int32 giving the next array size.
2855  """
2856  exponent = math_ops.ceil(
2857      math_ops.log(math_ops.cast(required_size, dtypes.float32))
2858      / math_ops.log(math_ops.cast(growth_factor, dtypes.float32)))
2859  return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32)
2860
2861
2862def streaming_concat(values,
2863                     axis=0,
2864                     max_size=None,
2865                     metrics_collections=None,
2866                     updates_collections=None,
2867                     name=None):
2868  """Concatenate values along an axis across batches.
2869
2870  The function `streaming_concat` creates two local variables, `array` and
2871  `size`, that are used to store concatenated values. Internally, `array` is
2872  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
2873  that updates can be run in amortized constant time.
2874
2875  For estimation of the metric over a stream of data, the function creates an
2876  `update_op` operation that appends the values of a tensor and returns the
2877  `value` of the concatenated tensors.
2878
2879  This op allows for evaluating metrics that cannot be updated incrementally
2880  using the same framework as other streaming metrics.
2881
2882  Args:
2883    values: tensor to concatenate. Rank and the shape along all axes other than
2884      the axis to concatenate along must be statically known.
2885    axis: optional integer axis to concatenate along.
2886    max_size: optional integer maximum size of `value` along the given axis.
2887      Once the maximum size is reached, further updates are no-ops. By default,
2888      there is no maximum size: the array is resized as necessary.
2889    metrics_collections: An optional list of collections that `value`
2890      should be added to.
2891    updates_collections: An optional list of collections `update_op` should be
2892      added to.
2893    name: An optional variable_scope name.
2894
2895  Returns:
2896    value: A tensor representing the concatenated values.
2897    update_op: An operation that concatenates the next values.
2898
2899  Raises:
2900    ValueError: if `values` does not have a statically known rank, `axis` is
2901      not in the valid range or the size of `values` is not statically known
2902      along any axis other than `axis`.
2903  """
2904  with variable_scope.variable_scope(name, 'streaming_concat', [values]):
2905    # pylint: disable=invalid-slice-index
2906    values_shape = values.get_shape()
2907    if values_shape.dims is None:
2908      raise ValueError('`values` must have known statically known rank')
2909
2910    ndim = len(values_shape)
2911    if axis < 0:
2912      axis += ndim
2913    if not 0 <= axis < ndim:
2914      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
2915
2916    fixed_shape = [dim.value for n, dim in enumerate(values_shape)
2917                   if n != axis]
2918    if any(value is None for value in fixed_shape):
2919      raise ValueError('all dimensions of `values` other than the dimension to '
2920                       'concatenate along must have statically known size')
2921
2922    # We move `axis` to the front of the internal array so assign ops can be
2923    # applied to contiguous slices
2924    init_size = 0 if max_size is None else max_size
2925    init_shape = [init_size] + fixed_shape
2926    array = _create_local(
2927        'array', shape=init_shape, validate_shape=False, dtype=values.dtype)
2928    size = _create_local('size', shape=[], dtype=dtypes.int32)
2929
2930    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
2931    valid_array = array[:size]
2932    valid_array.set_shape([None] + fixed_shape)
2933    value = array_ops.transpose(valid_array, perm, name='concat')
2934
2935    values_size = array_ops.shape(values)[axis]
2936    if max_size is None:
2937      batch_size = values_size
2938    else:
2939      batch_size = math_ops.minimum(values_size, max_size - size)
2940
2941    perm = [axis] + [n for n in range(ndim) if n != axis]
2942    batch_values = array_ops.transpose(values, perm)[:batch_size]
2943
2944    def reallocate():
2945      next_size = _next_array_size(new_size)
2946      next_shape = array_ops.pack([next_size] + fixed_shape)
2947      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
2948      old_value = array.value()
2949      assign_op = state_ops.assign(array, new_value, validate_shape=False)
2950      with ops.control_dependencies([assign_op]):
2951        copy_op = array[:size].assign(old_value[:size])
2952      # return value needs to be the same dtype as no_op() for cond
2953      with ops.control_dependencies([copy_op]):
2954        return control_flow_ops.no_op()
2955
2956    new_size = size + batch_size
2957    array_size = array_ops.shape_internal(array, optimize=False)[0]
2958    maybe_reallocate_op = control_flow_ops.cond(
2959        new_size > array_size, reallocate, control_flow_ops.no_op)
2960    with ops.control_dependencies([maybe_reallocate_op]):
2961      append_values_op = array[size:new_size].assign(batch_values)
2962    with ops.control_dependencies([append_values_op]):
2963      update_op = size.assign(new_size)
2964
2965    if metrics_collections:
2966      ops.add_to_collections(metrics_collections, value)
2967
2968    if updates_collections:
2969      ops.add_to_collections(updates_collections, update_op)
2970
2971    return value, update_op
2972    # pylint: enable=invalid-slice-index
2973
2974
2975def aggregate_metrics(*value_update_tuples):
2976  """Aggregates the metric value tensors and update ops into two lists.
2977
2978  Args:
2979    *value_update_tuples: a variable number of tuples, each of which contain the
2980      pair of (value_tensor, update_op) from a streaming metric.
2981
2982  Returns:
2983    a list of value tensors and a list of update ops.
2984
2985  Raises:
2986    ValueError: if `value_update_tuples` is empty.
2987  """
2988  if not value_update_tuples:
2989    raise ValueError('Expected at least one value_tensor/update_op pair')
2990  value_ops, update_ops = zip(*value_update_tuples)
2991  return list(value_ops), list(update_ops)
2992
2993
2994def aggregate_metric_map(names_to_tuples):
2995  """Aggregates the metric names to tuple dictionary.
2996
2997  This function is useful for pairing metric names with their associated value
2998  and update ops when the list of metrics is long. For example:
2999
3000  ```python
3001    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
3002        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
3003            predictions, labels, weights),
3004        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
3005            predictions, labels, labels, weights),
3006        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
3007            predictions, labels, weights),
3008        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
3009            predictions, labels, weights),
3010    })
3011  ```
3012
3013  Args:
3014    names_to_tuples: a map of metric names to tuples, each of which contain the
3015      pair of (value_tensor, update_op) from a streaming metric.
3016
3017  Returns:
3018    A dictionary from metric names to value ops and a dictionary from metric
3019    names to update ops.
3020  """
3021  metric_names = names_to_tuples.keys()
3022  value_ops, update_ops = zip(*names_to_tuples.values())
3023  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
3024
3025
3026def _remove_squeezable_dimensions(predictions, labels, weights):
3027  """Squeeze last dim if needed.
3028
3029  Squeezes `predictions` and `labels` if their rank differs by 1.
3030  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
3031
3032  This will use static shape if available. Otherwise, it will add graph
3033  operations, which could result in a performance hit.
3034
3035  Args:
3036    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
3037    labels: Label values, a `Tensor` whose dimensions match `predictions`.
3038    weights: optional `weights` tensor. It will be squeezed if its rank is 1
3039      more than the new rank of `predictions`
3040
3041  Returns:
3042    Tuple of `predictions`, `labels` and `weights`, possibly with the last
3043    dimension squeezed.
3044  """
3045  predictions, labels = tensor_util.remove_squeezable_dimensions(
3046      predictions, labels)
3047  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3048
3049  if weights is not None:
3050    predictions_shape = predictions.get_shape()
3051    predictions_rank = predictions_shape.ndims
3052    weights_shape = weights.get_shape()
3053    weights_rank = weights_shape.ndims
3054
3055    if (predictions_rank is not None) and (weights_rank is not None):
3056      # Use static rank.
3057      if weights_rank - predictions_rank == 1:
3058        weights = array_ops.squeeze(weights, [-1])
3059    elif (weights_rank is None) or (
3060        weights_shape.dims[-1].is_compatible_with(1)):
3061      # Use dynamic rank
3062      weights = control_flow_ops.cond(
3063          math_ops.equal(array_ops.rank(weights),
3064                         math_ops.add(array_ops.rank(predictions), 1)),
3065          lambda: array_ops.squeeze(weights, [-1]),
3066          lambda: weights)
3067  return predictions, labels, weights
3068
3069
3070__all__ = [
3071    'aggregate_metric_map',
3072    'aggregate_metrics',
3073    'streaming_accuracy',
3074    'streaming_auc',
3075    'streaming_mean',
3076    'streaming_mean_absolute_error',
3077    'streaming_mean_cosine_distance',
3078    'streaming_mean_iou',
3079    'streaming_mean_relative_error',
3080    'streaming_mean_squared_error',
3081    'streaming_mean_tensor',
3082    'streaming_percentage_less',
3083    'streaming_precision',
3084    'streaming_precision_at_thresholds',
3085    'streaming_recall',
3086    'streaming_recall_at_k',
3087    'streaming_recall_at_thresholds',
3088    'streaming_root_mean_squared_error',
3089    'streaming_sensitivity_at_specificity',
3090    'streaming_sparse_average_precision_at_k',
3091    'streaming_sparse_precision_at_k',
3092    'streaming_sparse_recall_at_k',
3093    'streaming_specificity_at_sensitivity',
3094]
3095