metric_ops.py revision d4eb834824d79c6a64a3c4a1c4a88b434b73e63e
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.framework import deprecated
26from tensorflow.contrib.framework import tensor_util
27from tensorflow.contrib.framework.python.ops import variables as contrib_variables
28from tensorflow.contrib.metrics.python.ops import set_ops
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import metrics
37from tensorflow.python.ops import nn
38from tensorflow.python.ops import sparse_ops
39from tensorflow.python.ops import state_ops
40from tensorflow.python.ops import variable_scope
41from tensorflow.python.ops import variables
42
43
44def _safe_div(numerator, denominator, name):
45  """Divides two values, returning 0 if the denominator is <= 0.
46
47  Args:
48    numerator: A real `Tensor`.
49    denominator: A real `Tensor`, with dtype matching `numerator`.
50    name: Name for the returned op.
51
52  Returns:
53    0 if `denominator` <= 0, else `numerator` / `denominator`
54  """
55  return array_ops.where(
56      math_ops.greater(denominator, 0),
57      math_ops.truediv(numerator, denominator),
58      0,
59      name=name)
60
61
62def _safe_scalar_div(numerator, denominator, name):
63  """Divides two values, returning 0 if the denominator is 0.
64
65  Args:
66    numerator: A scalar `float64` `Tensor`.
67    denominator: A scalar `float64` `Tensor`.
68    name: Name for the returned op.
69
70  Returns:
71    0 if `denominator` == 0, else `numerator` / `denominator`
72  """
73  numerator.get_shape().with_rank_at_most(1)
74  denominator.get_shape().with_rank_at_most(1)
75  return control_flow_ops.cond(
76      math_ops.equal(
77          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
78      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
79      lambda: math_ops.div(numerator, denominator),
80      name=name)
81
82
83def _create_local(name, shape, collections=None, validate_shape=True,
84                  dtype=dtypes.float32):
85  """Creates a new local variable.
86
87  Args:
88    name: The name of the new or existing variable.
89    shape: Shape of the new or existing variable.
90    collections: A list of collection names to which the Variable will be added.
91    validate_shape: Whether to validate the shape of the variable.
92    dtype: Data type of the variables.
93
94  Returns:
95    The created variable.
96  """
97  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
98  collections = list(collections or [])
99  collections += [ops.GraphKeys.LOCAL_VARIABLES]
100  return variables.Variable(
101      initial_value=array_ops.zeros(shape, dtype=dtype),
102      name=name,
103      trainable=False,
104      collections=collections,
105      validate_shape=validate_shape)
106
107
108def _count_condition(values, weights=None, metrics_collections=None,
109                     updates_collections=None):
110  """Sums the weights of cases where the given values are True.
111
112  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
113
114  Args:
115    values: A `bool` `Tensor` of arbitrary size.
116    weights: An optional `Tensor` whose shape is broadcastable to `values`.
117    metrics_collections: An optional list of collections that the metric
118      value variable should be added to.
119    updates_collections: An optional list of collections that the metric update
120      ops should be added to.
121
122  Returns:
123    value_tensor: A `Tensor` representing the current value of the metric.
124    update_op: An operation that accumulates the error from a batch of data.
125
126  Raises:
127    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
128      or if either `metrics_collections` or `updates_collections` are not a list
129      or tuple.
130  """
131  check_ops.assert_type(values, dtypes.bool)
132  count = _create_local('count', shape=[])
133
134  values = math_ops.to_float(values)
135  if weights is not None:
136    weights = math_ops.to_float(weights)
137    values = math_ops.mul(values, weights)
138
139  value_tensor = array_ops.identity(count)
140  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
141
142  if metrics_collections:
143    ops.add_to_collections(metrics_collections, value_tensor)
144
145  if updates_collections:
146    ops.add_to_collections(updates_collections, update_op)
147
148  return value_tensor, update_op
149
150
151def streaming_true_positives(predictions, labels, weights=None,
152                             metrics_collections=None,
153                             updates_collections=None,
154                             name=None):
155  """Sum the weights of true_positives.
156
157  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
158
159  Args:
160    predictions: The predicted values, a `bool` `Tensor` of arbitrary
161      dimensions.
162    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
163      match `predictions`.
164    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
165    metrics_collections: An optional list of collections that the metric
166      value variable should be added to.
167    updates_collections: An optional list of collections that the metric update
168      ops should be added to.
169    name: An optional variable_scope name.
170
171  Returns:
172    value_tensor: A `Tensor` representing the current value of the metric.
173    update_op: An operation that accumulates the error from a batch of data.
174
175  Raises:
176    ValueError: If `predictions` and `labels` have mismatched shapes, or if
177      `weights` is not `None` and its shape doesn't match `predictions`, or if
178      either `metrics_collections` or `updates_collections` are not a list or
179      tuple.
180  """
181  return metrics.true_positives(
182      predictions=predictions, labels=labels, weights=weights,
183      metrics_collections=metrics_collections,
184      updates_collections=updates_collections, name=name)
185
186
187def streaming_true_negatives(predictions, labels, weights=None,
188                             metrics_collections=None,
189                             updates_collections=None,
190                             name=None):
191  """Sum the weights of true_negatives.
192
193  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
194
195  Args:
196    predictions: The predicted values, a `bool` `Tensor` of arbitrary
197      dimensions.
198    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
199      match `predictions`.
200    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
201    metrics_collections: An optional list of collections that the metric
202      value variable should be added to.
203    updates_collections: An optional list of collections that the metric update
204      ops should be added to.
205    name: An optional variable_scope name.
206
207  Returns:
208    value_tensor: A `Tensor` representing the current value of the metric.
209    update_op: An operation that accumulates the error from a batch of data.
210
211  Raises:
212    ValueError: If `predictions` and `labels` have mismatched shapes, or if
213      `weights` is not `None` and its shape doesn't match `predictions`, or if
214      either `metrics_collections` or `updates_collections` are not a list or
215      tuple.
216  """
217  with variable_scope.variable_scope(
218      name, 'true_negatives', (predictions, labels, weights)):
219
220    predictions = ops.convert_to_tensor(predictions)
221    labels = ops.convert_to_tensor(labels)
222    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
223    is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0),
224                                            math_ops.equal(predictions, 0))
225    return _count_condition(is_true_negative, weights, metrics_collections,
226                            updates_collections)
227
228
229def streaming_false_positives(predictions, labels, weights=None,
230                              metrics_collections=None,
231                              updates_collections=None,
232                              name=None):
233  """Sum the weights of false positives.
234
235  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
236
237  Args:
238    predictions: The predicted values, a `bool` `Tensor` of arbitrary
239      dimensions.
240    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
241      match `predictions`.
242    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
243    metrics_collections: An optional list of collections that the metric
244      value variable should be added to.
245    updates_collections: An optional list of collections that the metric update
246      ops should be added to.
247    name: An optional variable_scope name.
248
249  Returns:
250    value_tensor: A `Tensor` representing the current value of the metric.
251    update_op: An operation that accumulates the error from a batch of data.
252
253  Raises:
254    ValueError: If `predictions` and `labels` have mismatched shapes, or if
255      `weights` is not `None` and its shape doesn't match `predictions`, or if
256      either `metrics_collections` or `updates_collections` are not a list or
257      tuple.
258  """
259  return metrics.false_positives(
260      predictions=predictions, labels=labels, weights=weights,
261      metrics_collections=metrics_collections,
262      updates_collections=updates_collections, name=name)
263
264
265def streaming_false_negatives(predictions, labels, weights=None,
266                              metrics_collections=None,
267                              updates_collections=None,
268                              name=None):
269  """Computes the total number of false positives.
270
271  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
272
273  Args:
274    predictions: The predicted values, a `bool` `Tensor` of arbitrary
275      dimensions.
276    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
277      match `predictions`.
278    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
279    metrics_collections: An optional list of collections that the metric
280      value variable should be added to.
281    updates_collections: An optional list of collections that the metric update
282      ops should be added to.
283    name: An optional variable_scope name.
284
285  Returns:
286    value_tensor: A `Tensor` representing the current value of the metric.
287    update_op: An operation that accumulates the error from a batch of data.
288
289  Raises:
290    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
291      or if either `metrics_collections` or `updates_collections` are not a list
292      or tuple.
293  """
294  return metrics.false_negatives(
295      predictions=predictions, labels=labels, weights=weights,
296      metrics_collections=metrics_collections,
297      updates_collections=updates_collections, name=name)
298
299
300def _broadcast_weights(weights, values):
301  """Broadcast `weights` to the same shape as `values`.
302
303  This returns a version of `weights` following the same broadcast rules as
304  `mul(weights, values)`. When computing a weighted average, use this function
305  to broadcast `weights` before summing them; e.g.,
306  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
307
308  Args:
309    weights: `Tensor` whose shape is broadcastable to `values`.
310    values: `Tensor` of any shape.
311
312  Returns:
313    `weights` broadcast to `values` shape.
314  """
315  weights_shape = weights.get_shape()
316  values_shape = values.get_shape()
317  if (weights_shape.is_fully_defined() and
318      values_shape.is_fully_defined() and
319      weights_shape.is_compatible_with(values_shape)):
320    return weights
321  return math_ops.mul(
322      weights, array_ops.ones_like(values), name='broadcast_weights')
323
324
325def streaming_mean(values, weights=None, metrics_collections=None,
326                   updates_collections=None, name=None):
327  """Computes the (weighted) mean of the given values.
328
329  The `streaming_mean` function creates two local variables, `total` and `count`
330  that are used to compute the average of `values`. This average is ultimately
331  returned as `mean` which is an idempotent operation that simply divides
332  `total` by `count`.
333
334  For estimation of the metric  over a stream of data, the function creates an
335  `update_op` operation that updates these variables and returns the `mean`.
336  `update_op` increments `total` with the reduced sum of the product of `values`
337  and `weights`, and it increments `count` with the reduced sum of `weights`.
338
339  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
340
341  Args:
342    values: A `Tensor` of arbitrary dimensions.
343    weights: An optional `Tensor` whose shape is broadcastable to `values`.
344    metrics_collections: An optional list of collections that `mean`
345      should be added to.
346    updates_collections: An optional list of collections that `update_op`
347      should be added to.
348    name: An optional variable_scope name.
349
350  Returns:
351    mean: A `Tensor` representing the current mean, the value of `total` divided
352      by `count`.
353    update_op: An operation that increments the `total` and `count` variables
354      appropriately and whose value matches `mean_value`.
355
356  Raises:
357    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
358      or if either `metrics_collections` or `updates_collections` are not a list
359      or tuple.
360  """
361  return metrics.mean(
362      values=values, weights=weights, metrics_collections=metrics_collections,
363      updates_collections=updates_collections, name=name)
364
365
366def streaming_mean_tensor(values, weights=None, metrics_collections=None,
367                          updates_collections=None, name=None):
368  """Computes the element-wise (weighted) mean of the given tensors.
369
370  In contrast to the `streaming_mean` function which returns a scalar with the
371  mean,  this function returns an average tensor with the same shape as the
372  input tensors.
373
374  The `streaming_mean_tensor` function creates two local variables,
375  `total_tensor` and `count_tensor` that are used to compute the average of
376  `values`. This average is ultimately returned as `mean` which is an idempotent
377  operation that simply divides `total` by `count`.
378
379  For estimation of the metric  over a stream of data, the function creates an
380  `update_op` operation that updates these variables and returns the `mean`.
381  `update_op` increments `total` with the reduced sum of the product of `values`
382  and `weights`, and it increments `count` with the reduced sum of `weights`.
383
384  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
385
386  Args:
387    values: A `Tensor` of arbitrary dimensions.
388    weights: An optional `Tensor` whose shape is broadcastable to `values`.
389    metrics_collections: An optional list of collections that `mean`
390      should be added to.
391    updates_collections: An optional list of collections that `update_op`
392      should be added to.
393    name: An optional variable_scope name.
394
395  Returns:
396    mean: A float `Tensor` representing the current mean, the value of `total`
397      divided by `count`.
398    update_op: An operation that increments the `total` and `count` variables
399      appropriately and whose value matches `mean_value`.
400
401  Raises:
402    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
403      or if either `metrics_collections` or `updates_collections` are not a list
404      or tuple.
405  """
406  return metrics.mean_tensor(
407      values=values, weights=weights, metrics_collections=metrics_collections,
408      updates_collections=updates_collections, name=name)
409
410
411def streaming_accuracy(predictions, labels, weights=None,
412                       metrics_collections=None, updates_collections=None,
413                       name=None):
414  """Calculates how often `predictions` matches `labels`.
415
416  The `streaming_accuracy` function creates two local variables, `total` and
417  `count` that are used to compute the frequency with which `predictions`
418  matches `labels`. This frequency is ultimately returned as `accuracy`: an
419  idempotent operation that simply divides `total` by `count`.
420
421  For estimation of the metric  over a stream of data, the function creates an
422  `update_op` operation that updates these variables and returns the `accuracy`.
423  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
424  where the corresponding elements of `predictions` and `labels` match and 0.0
425  otherwise. Then `update_op` increments `total` with the reduced sum of the
426  product of `weights` and `is_correct`, and it increments `count` with the
427  reduced sum of `weights`.
428
429  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
430
431  Args:
432    predictions: The predicted values, a `Tensor` of any shape.
433    labels: The ground truth values, a `Tensor` whose shape matches
434      `predictions`.
435    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
436    metrics_collections: An optional list of collections that `accuracy` should
437      be added to.
438    updates_collections: An optional list of collections that `update_op` should
439      be added to.
440    name: An optional variable_scope name.
441
442  Returns:
443    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
444      by `count`.
445    update_op: An operation that increments the `total` and `count` variables
446      appropriately and whose value matches `accuracy`.
447
448  Raises:
449    ValueError: If `predictions` and `labels` have mismatched shapes, or if
450      `weights` is not `None` and its shape doesn't match `predictions`, or if
451      either `metrics_collections` or `updates_collections` are not a list or
452      tuple.
453  """
454  return metrics.accuracy(
455      predictions=predictions, labels=labels, weights=weights,
456      metrics_collections=metrics_collections,
457      updates_collections=updates_collections, name=name)
458
459
460def streaming_precision(predictions, labels, weights=None,
461                        metrics_collections=None, updates_collections=None,
462                        name=None):
463  """Computes the precision of the predictions with respect to the labels.
464
465  The `streaming_precision` function creates two local variables,
466  `true_positives` and `false_positives`, that are used to compute the
467  precision. This value is ultimately returned as `precision`, an idempotent
468  operation that simply divides `true_positives` by the sum of `true_positives`
469  and `false_positives`.
470
471  For estimation of the metric  over a stream of data, the function creates an
472  `update_op` operation that updates these variables and returns the
473  `precision`. `update_op` weights each prediction by the corresponding value in
474  `weights`.
475
476  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
477
478  Args:
479    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
480    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
481      match `predictions`.
482    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
483    metrics_collections: An optional list of collections that `precision` should
484      be added to.
485    updates_collections: An optional list of collections that `update_op` should
486      be added to.
487    name: An optional variable_scope name.
488
489  Returns:
490    precision: Scalar float `Tensor` with the value of `true_positives`
491      divided by the sum of `true_positives` and `false_positives`.
492    update_op: `Operation` that increments `true_positives` and
493      `false_positives` variables appropriately and whose value matches
494      `precision`.
495
496  Raises:
497    ValueError: If `predictions` and `labels` have mismatched shapes, or if
498      `weights` is not `None` and its shape doesn't match `predictions`, or if
499      either `metrics_collections` or `updates_collections` are not a list or
500      tuple.
501  """
502  return metrics.precision(
503      predictions=predictions, labels=labels, weights=weights,
504      metrics_collections=metrics_collections,
505      updates_collections=updates_collections, name=name)
506
507
508def streaming_recall(predictions, labels, weights=None,
509                     metrics_collections=None, updates_collections=None,
510                     name=None):
511  """Computes the recall of the predictions with respect to the labels.
512
513  The `streaming_recall` function creates two local variables, `true_positives`
514  and `false_negatives`, that are used to compute the recall. This value is
515  ultimately returned as `recall`, an idempotent operation that simply divides
516  `true_positives` by the sum of `true_positives`  and `false_negatives`.
517
518  For estimation of the metric  over a stream of data, the function creates an
519  `update_op` that updates these variables and returns the `recall`. `update_op`
520  weights each prediction by the corresponding value in `weights`.
521
522  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
523
524  Args:
525    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
526    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
527      match `predictions`.
528    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
529    metrics_collections: An optional list of collections that `recall` should
530      be added to.
531    updates_collections: An optional list of collections that `update_op` should
532      be added to.
533    name: An optional variable_scope name.
534
535  Returns:
536    recall: Scalar float `Tensor` with the value of `true_positives` divided
537      by the sum of `true_positives` and `false_negatives`.
538    update_op: `Operation` that increments `true_positives` and
539      `false_negatives` variables appropriately and whose value matches
540      `recall`.
541
542  Raises:
543    ValueError: If `predictions` and `labels` have mismatched shapes, or if
544      `weights` is not `None` and its shape doesn't match `predictions`, or if
545      either `metrics_collections` or `updates_collections` are not a list or
546      tuple.
547  """
548  return metrics.recall(
549      predictions=predictions, labels=labels, weights=weights,
550      metrics_collections=metrics_collections,
551      updates_collections=updates_collections, name=name)
552
553
554def _streaming_confusion_matrix_at_thresholds(
555    predictions, labels, thresholds, weights=None, includes=None):
556  """Computes true_positives, false_negatives, true_negatives, false_positives.
557
558  This function creates up to four local variables, `true_positives`,
559  `true_negatives`, `false_positives` and `false_negatives`.
560  `true_positive[i]` is defined as the total weight of values in `predictions`
561  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
562  `false_negatives[i]` is defined as the total weight of values in `predictions`
563  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
564  `true_negatives[i]` is defined as the total weight of values in `predictions`
565  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
566  `false_positives[i]` is defined as the total weight of values in `predictions`
567  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
568
569  For estimation of these metrics over a stream of data, for each metric the
570  function respectively creates an `update_op` operation that updates the
571  variable and returns its value.
572
573  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
574
575  Args:
576    predictions: A floating point `Tensor` of arbitrary shape and whose values
577      are in the range `[0, 1]`.
578    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
579      to `bool`.
580    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
581    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
582    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
583        default to all four.
584
585  Returns:
586    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
587        `includes`.
588    update_ops: Dict of operations that increments the `values`. Keys are from
589        `includes`.
590
591  Raises:
592    ValueError: If `predictions` and `labels` have mismatched shapes, or if
593      `weights` is not `None` and its shape doesn't match `predictions`, or if
594      `includes` contains invalid keys.
595  """
596  all_includes = ('tp', 'fn', 'tn', 'fp')
597  if includes is None:
598    includes = all_includes
599  else:
600    for include in includes:
601      if include not in all_includes:
602        raise ValueError('Invaild key: %s.' % include)
603
604  predictions, labels, weights = _remove_squeezable_dimensions(
605      predictions, labels, weights)
606  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
607
608  num_thresholds = len(thresholds)
609
610  # Reshape predictions and labels.
611  predictions_2d = array_ops.reshape(predictions, [-1, 1])
612  labels_2d = array_ops.reshape(
613      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
614
615  # Use static shape if known.
616  num_predictions = predictions_2d.get_shape().as_list()[0]
617
618  # Otherwise use dynamic shape.
619  if num_predictions is None:
620    num_predictions = array_ops.shape(predictions_2d)[0]
621  thresh_tiled = array_ops.tile(
622      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
623      array_ops.pack([1, num_predictions]))
624
625  # Tile the predictions after thresholding them across different thresholds.
626  pred_is_pos = math_ops.greater(
627      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
628      thresh_tiled)
629  if ('fn' in includes) or ('tn' in includes):
630    pred_is_neg = math_ops.logical_not(pred_is_pos)
631
632  # Tile labels by number of thresholds
633  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
634  if ('fp' in includes) or ('tn' in includes):
635    label_is_neg = math_ops.logical_not(label_is_pos)
636
637  if weights is not None:
638    weights = math_ops.to_float(weights)
639    weights_tiled = array_ops.tile(array_ops.reshape(_broadcast_weights(
640        weights, predictions), [1, -1]), [num_thresholds, 1])
641    thresh_tiled.get_shape().assert_is_compatible_with(
642        weights_tiled.get_shape())
643  else:
644    weights_tiled = None
645
646  values = {}
647  update_ops = {}
648
649  if 'tp' in includes:
650    true_positives = _create_local('true_positives', shape=[num_thresholds])
651    is_true_positive = math_ops.to_float(
652        math_ops.logical_and(label_is_pos, pred_is_pos))
653    if weights_tiled is not None:
654      is_true_positive *= weights_tiled
655    update_ops['tp'] = state_ops.assign_add(
656        true_positives, math_ops.reduce_sum(is_true_positive, 1))
657    values['tp'] = true_positives
658
659  if 'fn' in includes:
660    false_negatives = _create_local('false_negatives', shape=[num_thresholds])
661    is_false_negative = math_ops.to_float(
662        math_ops.logical_and(label_is_pos, pred_is_neg))
663    if weights_tiled is not None:
664      is_false_negative *= weights_tiled
665    update_ops['fn'] = state_ops.assign_add(
666        false_negatives, math_ops.reduce_sum(is_false_negative, 1))
667    values['fn'] = false_negatives
668
669  if 'tn' in includes:
670    true_negatives = _create_local('true_negatives', shape=[num_thresholds])
671    is_true_negative = math_ops.to_float(
672        math_ops.logical_and(label_is_neg, pred_is_neg))
673    if weights_tiled is not None:
674      is_true_negative *= weights_tiled
675    update_ops['tn'] = state_ops.assign_add(
676        true_negatives, math_ops.reduce_sum(is_true_negative, 1))
677    values['tn'] = true_negatives
678
679  if 'fp' in includes:
680    false_positives = _create_local('false_positives', shape=[num_thresholds])
681    is_false_positive = math_ops.to_float(
682        math_ops.logical_and(label_is_neg, pred_is_pos))
683    if weights_tiled is not None:
684      is_false_positive *= weights_tiled
685    update_ops['fp'] = state_ops.assign_add(
686        false_positives, math_ops.reduce_sum(is_false_positive, 1))
687    values['fp'] = false_positives
688
689  return values, update_ops
690
691
692def streaming_true_positives_at_thresholds(
693    predictions, labels, thresholds, weights=None):
694  values, update_ops = _streaming_confusion_matrix_at_thresholds(
695      predictions, labels, thresholds, weights=weights, includes=('tp',))
696  return values['tp'], update_ops['tp']
697
698
699def streaming_false_negatives_at_thresholds(
700    predictions, labels, thresholds, weights=None):
701  values, update_ops = _streaming_confusion_matrix_at_thresholds(
702      predictions, labels, thresholds, weights=weights, includes=('fn',))
703  return values['fn'], update_ops['fn']
704
705
706def streaming_false_positives_at_thresholds(
707    predictions, labels, thresholds, weights=None):
708  values, update_ops = _streaming_confusion_matrix_at_thresholds(
709      predictions, labels, thresholds, weights=weights, includes=('fp',))
710  return values['fp'], update_ops['fp']
711
712
713def streaming_true_negatives_at_thresholds(
714    predictions, labels, thresholds, weights=None):
715  values, update_ops = _streaming_confusion_matrix_at_thresholds(
716      predictions, labels, thresholds, weights=weights, includes=('tn',))
717  return values['tn'], update_ops['tn']
718
719
720def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
721                  metrics_collections=None, updates_collections=None,
722                  curve='ROC', name=None):
723  """Computes the approximate AUC via a Riemann sum.
724
725  The `streaming_auc` function creates four local variables, `true_positives`,
726  `true_negatives`, `false_positives` and `false_negatives` that are used to
727  compute the AUC. To discretize the AUC curve, a linearly spaced set of
728  thresholds is used to compute pairs of recall and precision values. The area
729  under the ROC-curve is therefore computed using the height of the recall
730  values by the false positive rate, while the area under the PR-curve is the
731  computed using the height of the precision values by the recall.
732
733  This value is ultimately returned as `auc`, an idempotent operation that
734  computes the area under a discretized curve of precision versus recall values
735  (computed using the aforementioned variables). The `num_thresholds` variable
736  controls the degree of discretization with larger numbers of thresholds more
737  closely approximating the true AUC. The quality of the approximation may vary
738  dramatically depending on `num_thresholds`.
739
740  For best results, `predictions` should be distributed approximately uniformly
741  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
742  approximation may be poor if this is not the case.
743
744  For estimation of the metric over a stream of data, the function creates an
745  `update_op` operation that updates these variables and returns the `auc`.
746
747  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
748
749  Args:
750    predictions: A floating point `Tensor` of arbitrary shape and whose values
751      are in the range `[0, 1]`.
752    labels: A `bool` `Tensor` whose shape matches `predictions`.
753    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
754    num_thresholds: The number of thresholds to use when discretizing the roc
755      curve.
756    metrics_collections: An optional list of collections that `auc` should be
757      added to.
758    updates_collections: An optional list of collections that `update_op` should
759      be added to.
760    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
761    'PR' for the Precision-Recall-curve.
762    name: An optional variable_scope name.
763
764  Returns:
765    auc: A scalar `Tensor` representing the current area-under-curve.
766    update_op: An operation that increments the `true_positives`,
767      `true_negatives`, `false_positives` and `false_negatives` variables
768      appropriately and whose value matches `auc`.
769
770  Raises:
771    ValueError: If `predictions` and `labels` have mismatched shapes, or if
772      `weights` is not `None` and its shape doesn't match `predictions`, or if
773      either `metrics_collections` or `updates_collections` are not a list or
774      tuple.
775  """
776  return metrics.auc(
777      predictions=predictions, labels=labels, weights=weights,
778      metrics_collections=metrics_collections, num_thresholds=num_thresholds,
779      curve=curve, updates_collections=updates_collections, name=name)
780
781
782def streaming_specificity_at_sensitivity(
783    predictions, labels, sensitivity, weights=None, num_thresholds=200,
784    metrics_collections=None, updates_collections=None, name=None):
785  """Computes the specificity at a given sensitivity.
786
787  The `streaming_specificity_at_sensitivity` function creates four local
788  variables, `true_positives`, `true_negatives`, `false_positives` and
789  `false_negatives` that are used to compute the specificity at the given
790  sensitivity value. The threshold for the given sensitivity value is computed
791  and used to evaluate the corresponding specificity.
792
793  For estimation of the metric over a stream of data, the function creates an
794  `update_op` operation that updates these variables and returns the
795  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
796  `false_positives` and `false_negatives` counts with the weight of each case
797  found in the `predictions` and `labels`.
798
799  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
800
801  For additional information about specificity and sensitivity, see the
802  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
803
804  Args:
805    predictions: A floating point `Tensor` of arbitrary shape and whose values
806      are in the range `[0, 1]`.
807    labels: A `bool` `Tensor` whose shape matches `predictions`.
808    sensitivity: A scalar value in range `[0, 1]`.
809    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
810    num_thresholds: The number of thresholds to use for matching the given
811      sensitivity.
812    metrics_collections: An optional list of collections that `specificity`
813      should be added to.
814    updates_collections: An optional list of collections that `update_op` should
815      be added to.
816    name: An optional variable_scope name.
817
818  Returns:
819    specificity: A scalar `Tensor` representing the specificity at the given
820      `specificity` value.
821    update_op: An operation that increments the `true_positives`,
822      `true_negatives`, `false_positives` and `false_negatives` variables
823      appropriately and whose value matches `specificity`.
824
825  Raises:
826    ValueError: If `predictions` and `labels` have mismatched shapes, if
827      `weights` is not `None` and its shape doesn't match `predictions`, or if
828      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
829      or `updates_collections` are not a list or tuple.
830  """
831  return metrics.specificity_at_sensitivity(
832      sensitivity=sensitivity, num_thresholds=num_thresholds,
833      predictions=predictions, labels=labels, weights=weights,
834      metrics_collections=metrics_collections,
835      updates_collections=updates_collections, name=name)
836
837
838def streaming_sensitivity_at_specificity(
839    predictions, labels, specificity, weights=None, num_thresholds=200,
840    metrics_collections=None, updates_collections=None, name=None):
841  """Computes the specificity at a given sensitivity.
842
843  The `streaming_sensitivity_at_specificity` function creates four local
844  variables, `true_positives`, `true_negatives`, `false_positives` and
845  `false_negatives` that are used to compute the sensitivity at the given
846  specificity value. The threshold for the given specificity value is computed
847  and used to evaluate the corresponding sensitivity.
848
849  For estimation of the metric over a stream of data, the function creates an
850  `update_op` operation that updates these variables and returns the
851  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
852  `false_positives` and `false_negatives` counts with the weight of each case
853  found in the `predictions` and `labels`.
854
855  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
856
857  For additional information about specificity and sensitivity, see the
858  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
859
860  Args:
861    predictions: A floating point `Tensor` of arbitrary shape and whose values
862      are in the range `[0, 1]`.
863    labels: A `bool` `Tensor` whose shape matches `predictions`.
864    specificity: A scalar value in range `[0, 1]`.
865    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
866    num_thresholds: The number of thresholds to use for matching the given
867      specificity.
868    metrics_collections: An optional list of collections that `sensitivity`
869      should be added to.
870    updates_collections: An optional list of collections that `update_op` should
871      be added to.
872    name: An optional variable_scope name.
873
874  Returns:
875    sensitivity: A scalar `Tensor` representing the sensitivity at the given
876      `specificity` value.
877    update_op: An operation that increments the `true_positives`,
878      `true_negatives`, `false_positives` and `false_negatives` variables
879      appropriately and whose value matches `sensitivity`.
880
881  Raises:
882    ValueError: If `predictions` and `labels` have mismatched shapes, if
883      `weights` is not `None` and its shape doesn't match `predictions`, or if
884      `specificity` is not between 0 and 1, or if either `metrics_collections`
885      or `updates_collections` are not a list or tuple.
886  """
887  return metrics.sensitivity_at_specificity(
888      specificity=specificity, num_thresholds=num_thresholds,
889      predictions=predictions, labels=labels, weights=weights,
890      metrics_collections=metrics_collections,
891      updates_collections=updates_collections, name=name)
892
893
894def streaming_precision_at_thresholds(predictions, labels, thresholds,
895                                      weights=None,
896                                      metrics_collections=None,
897                                      updates_collections=None, name=None):
898  """Computes precision values for different `thresholds` on `predictions`.
899
900  The `streaming_precision_at_thresholds` function creates four local variables,
901  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
902  for various values of thresholds. `precision[i]` is defined as the total
903  weight of values in `predictions` above `thresholds[i]` whose corresponding
904  entry in `labels` is `True`, divided by the total weight of values in
905  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
906  false_positives[i])`).
907
908  For estimation of the metric over a stream of data, the function creates an
909  `update_op` operation that updates these variables and returns the
910  `precision`.
911
912  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
913
914  Args:
915    predictions: A floating point `Tensor` of arbitrary shape and whose values
916      are in the range `[0, 1]`.
917    labels: A `bool` `Tensor` whose shape matches `predictions`.
918    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
919    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
920    metrics_collections: An optional list of collections that `auc` should be
921      added to.
922    updates_collections: An optional list of collections that `update_op` should
923      be added to.
924    name: An optional variable_scope name.
925
926  Returns:
927    precision: A float `Tensor` of shape `[len(thresholds)]`.
928    update_op: An operation that increments the `true_positives`,
929      `true_negatives`, `false_positives` and `false_negatives` variables that
930      are used in the computation of `precision`.
931
932  Raises:
933    ValueError: If `predictions` and `labels` have mismatched shapes, or if
934      `weights` is not `None` and its shape doesn't match `predictions`, or if
935      either `metrics_collections` or `updates_collections` are not a list or
936      tuple.
937  """
938  return metrics.precision_at_thresholds(
939      thresholds=thresholds,
940      predictions=predictions, labels=labels, weights=weights,
941      metrics_collections=metrics_collections,
942      updates_collections=updates_collections, name=name)
943
944
945def streaming_recall_at_thresholds(predictions, labels, thresholds,
946                                   weights=None, metrics_collections=None,
947                                   updates_collections=None, name=None):
948  """Computes various recall values for different `thresholds` on `predictions`.
949
950  The `streaming_recall_at_thresholds` function creates four local variables,
951  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
952  for various values of thresholds. `recall[i]` is defined as the total weight
953  of values in `predictions` above `thresholds[i]` whose corresponding entry in
954  `labels` is `True`, divided by the total weight of `True` values in `labels`
955  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
956
957  For estimation of the metric over a stream of data, the function creates an
958  `update_op` operation that updates these variables and returns the `recall`.
959
960  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
961
962  Args:
963    predictions: A floating point `Tensor` of arbitrary shape and whose values
964      are in the range `[0, 1]`.
965    labels: A `bool` `Tensor` whose shape matches `predictions`.
966    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
967    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
968    metrics_collections: An optional list of collections that `recall` should be
969      added to.
970    updates_collections: An optional list of collections that `update_op` should
971      be added to.
972    name: An optional variable_scope name.
973
974  Returns:
975    recall: A float `Tensor` of shape `[len(thresholds)]`.
976    update_op: An operation that increments the `true_positives`,
977      `true_negatives`, `false_positives` and `false_negatives` variables that
978      are used in the computation of `recall`.
979
980  Raises:
981    ValueError: If `predictions` and `labels` have mismatched shapes, or if
982      `weights` is not `None` and its shape doesn't match `predictions`, or if
983      either `metrics_collections` or `updates_collections` are not a list or
984      tuple.
985  """
986  return metrics.recall_at_thresholds(
987      thresholds=thresholds,
988      predictions=predictions, labels=labels, weights=weights,
989      metrics_collections=metrics_collections,
990      updates_collections=updates_collections, name=name)
991
992
993def _at_k_name(name, k=None, class_id=None):
994  if k is not None:
995    name = '%s_at_%d' % (name, k)
996  else:
997    name = '%s_at_k' % (name)
998  if class_id is not None:
999    name = '%s_class%d' % (name, class_id)
1000  return name
1001
1002
1003@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
1004            'and reshape labels from [batch_size] to [batch_size, 1].')
1005def streaming_recall_at_k(predictions, labels, k, weights=None,
1006                          metrics_collections=None, updates_collections=None,
1007                          name=None):
1008  """Computes the recall@k of the predictions with respect to dense labels.
1009
1010  The `streaming_recall_at_k` function creates two local variables, `total` and
1011  `count`, that are used to compute the recall@k frequency. This frequency is
1012  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1013  divides `total` by `count`.
1014
1015  For estimation of the metric over a stream of data, the function creates an
1016  `update_op` operation that updates these variables and returns the
1017  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
1018  shape [batch_size] whose elements indicate whether or not the corresponding
1019  label is in the top `k` `predictions`. Then `update_op` increments `total`
1020  with the reduced sum of `weights` where `in_top_k` is `True`, and it
1021  increments `count` with the reduced sum of `weights`.
1022
1023  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1024
1025  Args:
1026    predictions: A float `Tensor` of dimension [batch_size, num_classes].
1027    labels: A `Tensor` of dimension [batch_size] whose type is in `int32`,
1028      `int64`.
1029    k: The number of top elements to look at for computing recall.
1030    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
1031    metrics_collections: An optional list of collections that `recall_at_k`
1032      should be added to.
1033    updates_collections: An optional list of collections `update_op` should be
1034      added to.
1035    name: An optional variable_scope name.
1036
1037  Returns:
1038    recall_at_k: A `Tensor` representing the recall@k, the fraction of labels
1039      which fall into the top `k` predictions.
1040    update_op: An operation that increments the `total` and `count` variables
1041      appropriately and whose value matches `recall_at_k`.
1042
1043  Raises:
1044    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1045      `weights` is not `None` and its shape doesn't match `predictions`, or if
1046      either `metrics_collections` or `updates_collections` are not a list or
1047      tuple.
1048  """
1049  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1050  return streaming_mean(in_top_k,
1051                        weights,
1052                        metrics_collections,
1053                        updates_collections,
1054                        name or _at_k_name('recall', k))
1055
1056
1057# TODO(ptucker): Validate range of values in labels?
1058def streaming_sparse_recall_at_k(predictions,
1059                                 labels,
1060                                 k,
1061                                 class_id=None,
1062                                 weights=None,
1063                                 metrics_collections=None,
1064                                 updates_collections=None,
1065                                 name=None):
1066  """Computes recall@k of the predictions with respect to sparse labels.
1067
1068  If `class_id` is specified, we calculate recall by considering only the
1069      entries in the batch for which `class_id` is in the label, and computing
1070      the fraction of them for which `class_id` is in the top-k `predictions`.
1071  If `class_id` is not specified, we'll calculate recall as how often on
1072      average a class among the labels of a batch entry is in the top-k
1073      `predictions`.
1074
1075  `streaming_sparse_recall_at_k` creates two local variables,
1076  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1077  the recall_at_k frequency. This frequency is ultimately returned as
1078  `recall_at_<k>`: an idempotent operation that simply divides
1079  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1080  `false_negative_at_<k>`).
1081
1082  For estimation of the metric over a stream of data, the function creates an
1083  `update_op` operation that updates these variables and returns the
1084  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1085  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1086  `labels` calculate the true positives and false negatives weighted by
1087  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1088  `false_negative_at_<k>` using these values.
1089
1090  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1091
1092  Args:
1093    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1094      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1095      The final dimension contains the logit values for each class. [D1, ... DN]
1096      must match `labels`.
1097    labels: `int64` `Tensor` or `SparseTensor` with shape
1098      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1099      target classes for the associated prediction. Commonly, N=1 and `labels`
1100      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
1101      Values should be in range [0, num_classes), where num_classes is the last
1102      dimension of `predictions`. Values outside this range always count
1103      towards `false_negative_at_<k>`.
1104    k: Integer, k for @k metric.
1105    class_id: Integer class ID for which we want binary metrics. This should be
1106      in range [0, num_classes), where num_classes is the last dimension of
1107      `predictions`. If class_id is outside this range, the method returns NAN.
1108    weights: An optional `Tensor` whose shape is broadcastable to the first
1109      [D1, ... DN] dimensions of `predictions` and `labels`.
1110    metrics_collections: An optional list of collections that values should
1111      be added to.
1112    updates_collections: An optional list of collections that updates should
1113      be added to.
1114    name: Name of new update operation, and namespace for other dependent ops.
1115
1116  Returns:
1117    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1118      by the sum of `true_positives` and `false_negatives`.
1119    update_op: `Operation` that increments `true_positives` and
1120      `false_negatives` variables appropriately, and whose value matches
1121      `recall`.
1122
1123  Raises:
1124    ValueError: If `weights` is not `None` and its shape doesn't match
1125    `predictions`, or if either `metrics_collections` or `updates_collections`
1126    are not a list or tuple.
1127  """
1128  return metrics.recall_at_k(
1129      k=k, class_id=class_id,
1130      predictions=predictions, labels=labels, weights=weights,
1131      metrics_collections=metrics_collections,
1132      updates_collections=updates_collections, name=name)
1133
1134
1135def _streaming_sparse_precision_at_k(top_k_idx,
1136                                     labels,
1137                                     k=None,
1138                                     class_id=None,
1139                                     weights=None,
1140                                     metrics_collections=None,
1141                                     updates_collections=None,
1142                                     name=None):
1143  """Computes precision@k of the top-k indices with respect to sparse labels.
1144
1145  This method contains the code shared by streaming_sparse_precision_at_k and
1146  streaming_sparse_precision_at_top_k. Refer to those methods for more details.
1147
1148  Args:
1149    top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where
1150      N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k].
1151      The final dimension contains the indices of top-k labels. [D1, ... DN]
1152      must match `labels`.
1153    labels: `int64` `Tensor` or `SparseTensor` with shape
1154      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1155      target classes for the associated prediction. Commonly, N=1 and `labels`
1156      has shape [batch_size, num_labels]. [D1, ... DN] must match
1157      `predictions_idx`. Values should be in range [0, num_classes), where
1158      num_classes is the last dimension of `predictions`. Values outside this
1159      range are ignored.
1160    k: Integer, k for @k metric or `None`. Only used for default op name.
1161    class_id: Integer class ID for which we want binary metrics. This should be
1162      in range [0, num_classes), where num_classes is the last dimension of
1163      `predictions`. If `class_id` is outside this range, the method returns
1164      NAN.
1165    weights: An optional `Tensor` whose shape is broadcastable to the first
1166      [D1, ... DN] dimensions of `predictions` and `labels`.
1167    metrics_collections: An optional list of collections that values should
1168      be added to.
1169    updates_collections: An optional list of collections that updates should
1170      be added to.
1171    name: Name of the metric and of the enclosing scope.
1172
1173  Returns:
1174    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1175      divided by the sum of `true_positives` and `false_positives`.
1176    update_op: `Operation` that increments `true_positives` and
1177      `false_positives` variables appropriately, and whose value matches
1178      `precision`.
1179
1180  Raises:
1181    ValueError: If `weights` is not `None` and its shape doesn't match
1182      `predictions`, or if either `metrics_collections` or `updates_collections`
1183      are not a list or tuple.
1184  """
1185  top_k_idx = math_ops.to_int64(top_k_idx)
1186  tp, tp_update = _streaming_sparse_true_positive_at_k(
1187      predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1188      weights=weights)
1189  fp, fp_update = _streaming_sparse_false_positive_at_k(
1190      predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1191      weights=weights)
1192
1193  metric = math_ops.div(tp, math_ops.add(tp, fp), name=name)
1194  update = math_ops.div(
1195      tp_update, math_ops.add(tp_update, fp_update), name='update')
1196  if metrics_collections:
1197    ops.add_to_collections(metrics_collections, metric)
1198  if updates_collections:
1199    ops.add_to_collections(updates_collections, update)
1200  return metric, update
1201
1202
1203# TODO(ptucker): Validate range of values in labels?
1204def streaming_sparse_precision_at_k(predictions,
1205                                    labels,
1206                                    k,
1207                                    class_id=None,
1208                                    weights=None,
1209                                    metrics_collections=None,
1210                                    updates_collections=None,
1211                                    name=None):
1212  """Computes precision@k of the predictions with respect to sparse labels.
1213
1214  If `class_id` is specified, we calculate precision by considering only the
1215      entries in the batch for which `class_id` is in the top-k highest
1216      `predictions`, and computing the fraction of them for which `class_id` is
1217      indeed a correct label.
1218  If `class_id` is not specified, we'll calculate precision as how often on
1219      average a class among the top-k classes with the highest predicted values
1220      of a batch entry is correct and can be found in the label for that entry.
1221
1222  `streaming_sparse_precision_at_k` creates two local variables,
1223  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1224  the precision@k frequency. This frequency is ultimately returned as
1225  `precision_at_<k>`: an idempotent operation that simply divides
1226  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1227  `false_positive_at_<k>`).
1228
1229  For estimation of the metric over a stream of data, the function creates an
1230  `update_op` operation that updates these variables and returns the
1231  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1232  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1233  `labels` calculate the true positives and false positives weighted by
1234  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1235  `false_positive_at_<k>` using these values.
1236
1237  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1238
1239  Args:
1240    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1241      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1242      The final dimension contains the logit values for each class. [D1, ... DN]
1243      must match `labels`.
1244    labels: `int64` `Tensor` or `SparseTensor` with shape
1245      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1246      target classes for the associated prediction. Commonly, N=1 and `labels`
1247      has shape [batch_size, num_labels]. [D1, ... DN] must match
1248      `predictions`. Values should be in range [0, num_classes), where
1249      num_classes is the last dimension of `predictions`. Values outside this
1250      range are ignored.
1251    k: Integer, k for @k metric.
1252    class_id: Integer class ID for which we want binary metrics. This should be
1253      in range [0, num_classes], where num_classes is the last dimension of
1254      `predictions`. If `class_id` is outside this range, the method returns
1255      NAN.
1256    weights: An optional `Tensor` whose shape is broadcastable to the first
1257      [D1, ... DN] dimensions of `predictions` and `labels`.
1258    metrics_collections: An optional list of collections that values should
1259      be added to.
1260    updates_collections: An optional list of collections that updates should
1261      be added to.
1262    name: Name of new update operation, and namespace for other dependent ops.
1263
1264  Returns:
1265    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1266      divided by the sum of `true_positives` and `false_positives`.
1267    update_op: `Operation` that increments `true_positives` and
1268      `false_positives` variables appropriately, and whose value matches
1269      `precision`.
1270
1271  Raises:
1272    ValueError: If `weights` is not `None` and its shape doesn't match
1273      `predictions`, or if either `metrics_collections` or `updates_collections`
1274      are not a list or tuple.
1275  """
1276  return metrics.sparse_precision_at_k(
1277      k=k, class_id=class_id,
1278      predictions=predictions, labels=labels, weights=weights,
1279      metrics_collections=metrics_collections,
1280      updates_collections=updates_collections, name=name)
1281
1282
1283# TODO(ptucker): Validate range of values in labels?
1284def streaming_sparse_precision_at_top_k(top_k_predictions,
1285                                        labels,
1286                                        class_id=None,
1287                                        weights=None,
1288                                        metrics_collections=None,
1289                                        updates_collections=None,
1290                                        name=None):
1291  """Computes precision@k of top-k predictions with respect to sparse labels.
1292
1293  If `class_id` is specified, we calculate precision by considering only the
1294      entries in the batch for which `class_id` is in the top-k highest
1295      `predictions`, and computing the fraction of them for which `class_id` is
1296      indeed a correct label.
1297  If `class_id` is not specified, we'll calculate precision as how often on
1298      average a class among the top-k classes with the highest predicted values
1299      of a batch entry is correct and can be found in the label for that entry.
1300
1301  `streaming_sparse_precision_at_top_k` creates two local variables,
1302  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
1303  the precision@k frequency. This frequency is ultimately returned as
1304  `precision_at_k`: an idempotent operation that simply divides
1305  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
1306
1307  For estimation of the metric over a stream of data, the function creates an
1308  `update_op` operation that updates these variables and returns the
1309  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
1310  and `labels` calculate the true positives and false positives weighted by
1311  `weights`. Then `update_op` increments `true_positive_at_k` and
1312  `false_positive_at_k` using these values.
1313
1314  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1315
1316  Args:
1317    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
1318      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
1319      The final dimension contains the indices of top-k labels. [D1, ... DN]
1320      must match `labels`.
1321    labels: `int64` `Tensor` or `SparseTensor` with shape
1322      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1323      target classes for the associated prediction. Commonly, N=1 and `labels`
1324      has shape [batch_size, num_labels]. [D1, ... DN] must match
1325      `top_k_predictions`. Values should be in range [0, num_classes), where
1326      num_classes is the last dimension of `predictions`. Values outside this
1327      range are ignored.
1328    class_id: Integer class ID for which we want binary metrics. This should be
1329      in range [0, num_classes), where num_classes is the last dimension of
1330      `predictions`. If `class_id` is outside this range, the method returns
1331      NAN.
1332    weights: An optional `Tensor` whose shape is broadcastable to the first
1333      [D1, ... DN] dimensions of `predictions` and `labels`.
1334    metrics_collections: An optional list of collections that values should
1335      be added to.
1336    updates_collections: An optional list of collections that updates should
1337      be added to.
1338    name: Name of new update operation, and namespace for other dependent ops.
1339
1340  Returns:
1341    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1342      divided by the sum of `true_positives` and `false_positives`.
1343    update_op: `Operation` that increments `true_positives` and
1344      `false_positives` variables appropriately, and whose value matches
1345      `precision`.
1346
1347  Raises:
1348    ValueError: If `weights` is not `None` and its shape doesn't match
1349      `predictions`, or if either `metrics_collections` or `updates_collections`
1350      are not a list or tuple.
1351    ValueError: If `top_k_predictions` has rank < 2.
1352  """
1353  default_name = _at_k_name('precision', class_id=class_id)
1354  with ops.name_scope(
1355      name, default_name,
1356      (top_k_predictions, labels, weights)) as scope:
1357    rank = array_ops.rank(top_k_predictions)
1358    check_rank_op = control_flow_ops.Assert(
1359        math_ops.greater_equal(rank, 2),
1360        ['top_k_predictions must have rank 2 or higher, e.g. [batch_size, k].'])
1361    with ops.control_dependencies([check_rank_op]):
1362      return _streaming_sparse_precision_at_k(
1363          top_k_idx=top_k_predictions,
1364          labels=labels,
1365          class_id=class_id,
1366          weights=weights,
1367          metrics_collections=metrics_collections,
1368          updates_collections=updates_collections,
1369          name=scope)
1370
1371
1372def num_relevant(labels, k):
1373  """Computes number of relevant values for each row in labels.
1374
1375  For labels with shape [D1, ... DN, num_labels], this is the minimum of
1376  `num_labels` and `k`.
1377
1378  Args:
1379    labels: `int64` `Tensor` or `SparseTensor` with shape
1380      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1381      target classes for the associated prediction. Commonly, N=1 and `labels`
1382      has shape [batch_size, num_labels].
1383    k: Integer, k for @k metric.
1384
1385  Returns:
1386    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
1387    relevant values for that row.
1388
1389  Raises:
1390    ValueError: if inputs have invalid dtypes or values.
1391  """
1392  if k < 1:
1393    raise ValueError('Invalid k=%s.' % k)
1394  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
1395    # For SparseTensor, calculate separate count for each row.
1396    if isinstance(
1397        labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
1398      labels_sizes = set_ops.set_size(labels)
1399      return math_ops.minimum(labels_sizes, k, name=scope)
1400
1401    # For dense Tensor, calculate scalar count based on last dimension, and
1402    # tile across labels shape.
1403    labels_shape = array_ops.shape(labels)
1404    labels_size = labels_shape[-1]
1405    num_relevant_scalar = math_ops.minimum(labels_size, k)
1406    return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
1407
1408
1409def expand_and_tile(tensor, multiple, dim=0, name=None):
1410  """Slice `tensor` shape in 2, then tile along the sliced dimension.
1411
1412  A new dimension is inserted in shape of `tensor` before `dim`, then values are
1413  tiled `multiple` times along the new dimension.
1414
1415  Args:
1416    tensor: Input `Tensor` or `SparseTensor`.
1417    multiple: Integer, number of times to tile.
1418    dim: Integer, dimension along which to tile.
1419    name: Name of operation.
1420
1421  Returns:
1422    `Tensor` result of expanding and tiling `tensor`.
1423
1424  Raises:
1425    ValueError: if `multiple` is less than 1, or `dim` is not in
1426    `[-rank(tensor), rank(tensor)]`.
1427  """
1428  if multiple < 1:
1429    raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
1430  with ops.name_scope(
1431      name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
1432    # Sparse.
1433    if isinstance(tensor, sparse_tensor.SparseTensorValue):
1434      tensor = sparse_tensor.SparseTensor.from_value(tensor)
1435    if isinstance(tensor, sparse_tensor.SparseTensor):
1436      if dim < 0:
1437        expand_dims = array_ops.reshape(
1438            array_ops.size(tensor.dense_shape) + dim, [1])
1439      else:
1440        expand_dims = [dim]
1441      expanded_shape = array_ops.concat_v2(
1442          (array_ops.strided_slice(tensor.dense_shape, [0], expand_dims),
1443           [1],
1444           array_ops.strided_slice(
1445               tensor.dense_shape, expand_dims, [-1], end_mask=1 << 0)),
1446          0, name='expanded_shape')
1447      expanded = sparse_ops.sparse_reshape(
1448          tensor, shape=expanded_shape, name='expand')
1449      if multiple == 1:
1450        return expanded
1451      return sparse_ops.sparse_concat(
1452          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
1453
1454    # Dense.
1455    expanded = array_ops.expand_dims(
1456        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
1457    if multiple == 1:
1458      return expanded
1459    ones = array_ops.ones_like(array_ops.shape(tensor))
1460    tile_multiples = array_ops.concat_v2(
1461        (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
1462    return array_ops.tile(expanded, tile_multiples, name=scope)
1463
1464
1465def sparse_average_precision_at_k(predictions, labels, k):
1466  """Computes average precision@k of predictions with respect to sparse labels.
1467
1468  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
1469  for each row is:
1470
1471    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
1472
1473  A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`,
1474  and the result `Tensors`. In the common case, this is [batch_size]. Each row
1475  of the results contains the average precision for that row.
1476
1477  Internally, a `top_k` operation computes a `Tensor` indicating the top `k`
1478  `predictions`. Set operations applied to `top_k` and `labels` calculate the
1479  true positives, which are used to calculate the precision ("P_{i}" term,
1480  above).
1481
1482  Args:
1483    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1484      N >= 1. Commonly, N=1 and `predictions` has shape
1485      [batch size, num_classes]. The final dimension contains the logit values
1486      for each class. [D1, ... DN] must match `labels`.
1487    labels: `int64` `Tensor` or `SparseTensor` with shape
1488      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1489      target classes for the associated prediction. Commonly, N=1 and `labels`
1490      has shape [batch_size, num_labels]. [D1, ... DN] must match
1491      `predictions`. Values should be in range [0, num_classes), where
1492      num_classes is the last dimension of `predictions`. Values outside this
1493      range are ignored.
1494    k: Integer, k for @k metric. This will calculate an average precision for
1495      range `[1,k]`, as documented above.
1496
1497  Returns:
1498    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
1499    precision for that row.
1500
1501  Raises:
1502    ValueError: if k is invalid.
1503  """
1504  if k < 1:
1505    raise ValueError('Invalid k=%s.' % k)
1506  with ops.name_scope(
1507      None, 'average_precision', (predictions, labels, k)) as scope:
1508    # Calculate top k indices to produce [D1, ... DN, k] tensor.
1509    _, predictions_idx = nn.top_k(predictions, k)
1510    predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
1511
1512    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
1513    # prediction for each k, so we can calculate separate true positive values
1514    # for each k.
1515    predictions_idx_per_k = array_ops.expand_dims(
1516        predictions_idx, -1, name='predictions_idx_per_k')
1517
1518    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
1519    labels_per_k = expand_and_tile(
1520        labels, multiple=k, dim=-1, name='labels_per_k')
1521
1522    # The following tensors are all of shape [D1, ... DN, k], containing values
1523    # per row, per k value.
1524    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
1525    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
1526    #     the formula above.
1527    # `tp_per_k` (int32) - True positive counts.
1528    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
1529    #     the precision denominator.
1530    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
1531    #     term from the formula above.
1532    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
1533    #     precisions at all k for which relevance indicator is true.
1534    relevant_per_k = _sparse_true_positive_at_k(
1535        predictions_idx_per_k, labels_per_k, name='relevant_per_k')
1536    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
1537    retrieved_per_k = math_ops.cumsum(
1538        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
1539    precision_per_k = math_ops.div(
1540        math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
1541        name='precision_per_k')
1542    relevant_precision_per_k = math_ops.mul(
1543        precision_per_k, math_ops.to_double(relevant_per_k),
1544        name='relevant_precision_per_k')
1545
1546    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
1547    precision_sum = math_ops.reduce_sum(
1548        relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
1549
1550    # Divide by number of relevant items to get average precision. These are
1551    # the "num_relevant_items" and "AveP" terms from the formula above.
1552    num_relevant_items = math_ops.to_double(num_relevant(labels, k))
1553    return math_ops.div(precision_sum, num_relevant_items, name=scope)
1554
1555
1556def streaming_sparse_average_precision_at_k(predictions,
1557                                            labels,
1558                                            k,
1559                                            weights=None,
1560                                            metrics_collections=None,
1561                                            updates_collections=None,
1562                                            name=None):
1563  """Computes average precision@k of predictions with respect to sparse labels.
1564
1565  See `sparse_average_precision_at_k` for details on formula. `weights` are
1566  applied to the result of `sparse_average_precision_at_k`
1567
1568  `streaming_sparse_average_precision_at_k` creates two local variables,
1569  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
1570  are used to compute the frequency. This frequency is ultimately returned as
1571  `average_precision_at_<k>`: an idempotent operation that simply divides
1572  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
1573
1574  For estimation of the metric over a stream of data, the function creates an
1575  `update_op` operation that updates these variables and returns the
1576  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1577  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1578  `labels` calculate the true positives and false positives weighted by
1579  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1580  `false_positive_at_<k>` using these values.
1581
1582  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1583
1584  Args:
1585    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1586      N >= 1. Commonly, N=1 and `predictions` has shape
1587      [batch size, num_classes]. The final dimension contains the logit values
1588      for each class. [D1, ... DN] must match `labels`.
1589    labels: `int64` `Tensor` or `SparseTensor` with shape
1590      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1591      target classes for the associated prediction. Commonly, N=1 and `labels`
1592      has shape [batch_size, num_labels]. [D1, ... DN] must match
1593      `predictions_`. Values should be in range [0, num_classes), where
1594      num_classes is the last dimension of `predictions`. Values outside this
1595      range are ignored.
1596    k: Integer, k for @k metric. This will calculate an average precision for
1597      range `[1,k]`, as documented above.
1598    weights: An optional `Tensor` whose shape is broadcastable to the first
1599      [D1, ... DN] dimensions of `predictions` and `labels`.
1600    metrics_collections: An optional list of collections that values should
1601      be added to.
1602    updates_collections: An optional list of collections that updates should
1603      be added to.
1604    name: Name of new update operation, and namespace for other dependent ops.
1605
1606  Returns:
1607    mean_average_precision: Scalar `float64` `Tensor` with the mean average
1608      precision values.
1609    update: `Operation` that increments  variables appropriately, and whose
1610      value matches `metric`.
1611  """
1612  return metrics.sparse_average_precision_at_k(
1613      k=k, predictions=predictions, labels=labels, weights=weights,
1614      metrics_collections=metrics_collections,
1615      updates_collections=updates_collections, name=name)
1616
1617
1618def _select_class_id(ids, selected_id):
1619  """Filter all but `selected_id` out of `ids`.
1620
1621  Args:
1622    ids: `int64` `Tensor` or `SparseTensor` of IDs.
1623    selected_id: Int id to select.
1624
1625  Returns:
1626    `SparseTensor` of same dimensions as `ids`. This contains only the entries
1627    equal to `selected_id`.
1628  """
1629  if isinstance(
1630      ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
1631    return sparse_ops.sparse_retain(
1632        ids, math_ops.equal(ids.values, selected_id))
1633
1634  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
1635  # tf.equal and tf.reduce_any?
1636
1637  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
1638  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
1639  ids_last_dim = array_ops.size(ids_shape) - 1
1640  filled_selected_id_shape = math_ops.reduced_shape(
1641      ids_shape, array_ops.reshape(ids_last_dim, [1]))
1642
1643  # Intersect `ids` with the selected ID.
1644  filled_selected_id = array_ops.fill(
1645      filled_selected_id_shape, math_ops.to_int64(selected_id))
1646  result = set_ops.set_intersection(filled_selected_id, ids)
1647  return sparse_tensor.SparseTensor(
1648      indices=result.indices, values=result.values, dense_shape=ids_shape)
1649
1650
1651def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
1652  """If class ID is specified, filter all other classes.
1653
1654  Args:
1655    labels: `int64` `Tensor` or `SparseTensor` with shape
1656      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1657      target classes for the associated prediction. Commonly, N=1 and `labels`
1658      has shape [batch_size, num_labels]. [D1, ... DN] must match
1659      `predictions_idx`.
1660    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
1661      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
1662      [batch size, k].
1663    selected_id: Int id to select.
1664
1665  Returns:
1666    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
1667  """
1668  if selected_id is None:
1669    return labels, predictions_idx
1670  return (_select_class_id(labels, selected_id),
1671          _select_class_id(predictions_idx, selected_id))
1672
1673
1674def _sparse_true_positive_at_k(predictions_idx,
1675                               labels,
1676                               class_id=None,
1677                               weights=None,
1678                               name=None):
1679  """Calculates true positives for recall@k and precision@k.
1680
1681  If `class_id` is specified, calculate binary true positives for `class_id`
1682      only.
1683  If `class_id` is not specified, calculate metrics for `k` predicted vs
1684      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1685
1686  Args:
1687    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1688      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1689      match `labels`.
1690    labels: `int64` `Tensor` or `SparseTensor` with shape
1691      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1692      target classes for the associated prediction. Commonly, N=1 and `labels`
1693      has shape [batch_size, num_labels]. [D1, ... DN] must match
1694      `predictions_idx`.
1695    class_id: Class for which we want binary metrics.
1696    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1697      dimensions of `predictions_idx` and `labels`.
1698    name: Name of operation.
1699
1700  Returns:
1701    A [D1, ... DN] `Tensor` of true positive counts.
1702  """
1703  with ops.name_scope(name, 'true_positives', (predictions_idx, labels)):
1704    labels, predictions_idx = _maybe_select_class_id(
1705        labels, predictions_idx, class_id)
1706    tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels))
1707    tp = math_ops.to_double(tp)
1708    if weights is not None:
1709      weights = math_ops.to_double(weights)
1710      tp = math_ops.mul(tp, weights)
1711    return tp
1712
1713
1714def _streaming_sparse_true_positive_at_k(predictions_idx,
1715                                         labels,
1716                                         k=None,
1717                                         class_id=None,
1718                                         weights=None,
1719                                         name=None):
1720  """Calculates weighted per step true positives for recall@k and precision@k.
1721
1722  If `class_id` is specified, calculate binary true positives for `class_id`
1723      only.
1724  If `class_id` is not specified, calculate metrics for `k` predicted vs
1725      `n` label classes, where `n` is the 2nd dimension of `labels`.
1726
1727  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1728
1729  Args:
1730    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1731      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1732      match `labels`.
1733    labels: `int64` `Tensor` or `SparseTensor` with shape
1734      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1735      target classes for the associated prediction. Commonly, N=1 and `labels`
1736      has shape [batch_size, num_labels]. [D1, ... DN] must match
1737      `predictions_idx`.
1738    k: Integer, k for @k metric. This is only used for default op name.
1739    class_id: Class for which we want binary metrics.
1740    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1741      dimensions of `predictions_idx` and `labels`.
1742    name: Name of new variable, and namespace for other dependent ops.
1743
1744  Returns:
1745    A tuple of `Variable` and update `Operation`.
1746
1747  Raises:
1748    ValueError: If `weights` is not `None` and has an incomptable shape.
1749  """
1750  default_name = _at_k_name('true_positive', k, class_id=class_id)
1751  with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope:
1752    tp = _sparse_true_positive_at_k(
1753        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
1754        weights=weights)
1755    batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
1756
1757    var = contrib_variables.local_variable(
1758        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1759    return var, state_ops.assign_add(var, batch_total_tp, name='update')
1760
1761
1762def _sparse_false_positive_at_k(predictions_idx,
1763                                labels,
1764                                class_id=None,
1765                                weights=None):
1766  """Calculates false positives for precision@k.
1767
1768  If `class_id` is specified, calculate binary true positives for `class_id`
1769      only.
1770  If `class_id` is not specified, calculate metrics for `k` predicted vs
1771      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1772
1773  Args:
1774    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1775      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1776      match `labels`.
1777    labels: `int64` `Tensor` or `SparseTensor` with shape
1778      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1779      target classes for the associated prediction. Commonly, N=1 and `labels`
1780      has shape [batch_size, num_labels]. [D1, ... DN] must match
1781      `predictions_idx`.
1782    class_id: Class for which we want binary metrics.
1783    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1784      dimensions of `predictions_idx` and `labels`.
1785
1786  Returns:
1787    A [D1, ... DN] `Tensor` of false positive counts.
1788  """
1789  with ops.name_scope(None, 'false_positives', (predictions_idx, labels)):
1790    labels, predictions_idx = _maybe_select_class_id(labels,
1791                                                     predictions_idx,
1792                                                     class_id)
1793    fp = set_ops.set_size(set_ops.set_difference(
1794        predictions_idx, labels, aminusb=True))
1795    fp = math_ops.to_double(fp)
1796    if weights is not None:
1797      weights = math_ops.to_double(weights)
1798      fp = math_ops.mul(fp, weights)
1799    return fp
1800
1801
1802def _streaming_sparse_false_positive_at_k(predictions_idx,
1803                                          labels,
1804                                          k=None,
1805                                          class_id=None,
1806                                          weights=None,
1807                                          name=None):
1808  """Calculates weighted per step false positives for precision@k.
1809
1810  If `class_id` is specified, calculate binary true positives for `class_id`
1811      only.
1812  If `class_id` is not specified, calculate metrics for `k` predicted vs
1813      `n` label classes, where `n` is the 2nd dimension of `labels`.
1814
1815  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1816
1817  Args:
1818    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1819      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1820      match `labels`.
1821    labels: `int64` `Tensor` or `SparseTensor` with shape
1822      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1823      target classes for the associated prediction. Commonly, N=1 and `labels`
1824      has shape [batch_size, num_labels]. [D1, ... DN] must match
1825      `predictions_idx`.
1826    k: Integer, k for @k metric. This is only used for default op name.
1827    class_id: Class for which we want binary metrics.
1828    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1829      dimensions of `predictions_idx` and `labels`.
1830    name: Name of new variable, and namespace for other dependent ops.
1831
1832  Returns:
1833    A tuple of `Variable` and update `Operation`.
1834
1835  Raises:
1836    ValueError: If `weights` is not `None` and has an incomptable shape.
1837  """
1838  default_name = _at_k_name('false_positive', k, class_id=class_id)
1839  with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope:
1840    fp = _sparse_false_positive_at_k(
1841        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
1842        weights=weights)
1843    batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
1844
1845    var = contrib_variables.local_variable(
1846        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1847    return var, state_ops.assign_add(var, batch_total_fp, name='update')
1848
1849
1850def _sparse_false_negative_at_k(predictions_idx,
1851                                labels,
1852                                class_id=None,
1853                                weights=None):
1854  """Calculates false negatives for recall@k.
1855
1856  If `class_id` is specified, calculate binary true positives for `class_id`
1857      only.
1858  If `class_id` is not specified, calculate metrics for `k` predicted vs
1859      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1860
1861  Args:
1862    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1863      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1864      match `labels`.
1865    labels: `int64` `Tensor` or `SparseTensor` with shape
1866      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1867      target classes for the associated prediction. Commonly, N=1 and `labels`
1868      has shape [batch_size, num_labels]. [D1, ... DN] must match
1869      `predictions_idx`.
1870    class_id: Class for which we want binary metrics.
1871    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1872      dimensions of `predictions_idx` and `labels`.
1873
1874  Returns:
1875    A [D1, ... DN] `Tensor` of false negative counts.
1876  """
1877  with ops.name_scope(None, 'false_negatives', (predictions_idx, labels)):
1878    labels, predictions_idx = _maybe_select_class_id(labels,
1879                                                     predictions_idx,
1880                                                     class_id)
1881    fn = set_ops.set_size(set_ops.set_difference(predictions_idx,
1882                                                 labels,
1883                                                 aminusb=False))
1884    fn = math_ops.to_double(fn)
1885    if weights is not None:
1886      weights = math_ops.to_double(weights)
1887      fn = math_ops.mul(fn, weights)
1888    return fn
1889
1890
1891def _streaming_sparse_false_negative_at_k(predictions_idx,
1892                                          labels,
1893                                          k,
1894                                          class_id=None,
1895                                          weights=None,
1896                                          name=None):
1897  """Calculates weighted per step false negatives for recall@k.
1898
1899  If `class_id` is specified, calculate binary true positives for `class_id`
1900      only.
1901  If `class_id` is not specified, calculate metrics for `k` predicted vs
1902      `n` label classes, where `n` is the 2nd dimension of `labels`.
1903
1904  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1905
1906  Args:
1907    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1908      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1909      match `labels`.
1910    labels: `int64` `Tensor` or `SparseTensor` with shape
1911      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1912      target classes for the associated prediction. Commonly, N=1 and `labels`
1913      has shape [batch_size, num_labels]. [D1, ... DN] must match
1914      `predictions_idx`.
1915    k: Integer, k for @k metric. This is only used for default op name.
1916    class_id: Class for which we want binary metrics.
1917    weights: `Tensor` whose shape is broadcastable to the first [D1, ... DN]
1918      dimensions of `predictions_idx` and `labels`.
1919    name: Name of new variable, and namespace for other dependent ops.
1920
1921  Returns:
1922    A tuple of `Variable` and update `Operation`.
1923
1924  Raises:
1925    ValueError: If `weights` is not `None` and has an incomptable shape.
1926  """
1927  default_name = _at_k_name('false_negative', k, class_id=class_id)
1928  with ops.name_scope(name, default_name, (predictions_idx, labels)) as scope:
1929    fn = _sparse_false_negative_at_k(
1930        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
1931        weights=weights)
1932    batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
1933
1934    var = contrib_variables.local_variable(
1935        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1936    return var, state_ops.assign_add(var, batch_total_fn, name='update')
1937
1938
1939def streaming_mean_absolute_error(predictions, labels, weights=None,
1940                                  metrics_collections=None,
1941                                  updates_collections=None,
1942                                  name=None):
1943  """Computes the mean absolute error between the labels and predictions.
1944
1945  The `streaming_mean_absolute_error` function creates two local variables,
1946  `total` and `count` that are used to compute the mean absolute error. This
1947  average is weighted by `weights`, and it is ultimately returned as
1948  `mean_absolute_error`: an idempotent operation that simply divides `total` by
1949  `count`.
1950
1951  For estimation of the metric over a stream of data, the function creates an
1952  `update_op` operation that updates these variables and returns the
1953  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
1954  absolute value of the differences between `predictions` and `labels`. Then
1955  `update_op` increments `total` with the reduced sum of the product of
1956  `weights` and `absolute_errors`, and it increments `count` with the reduced
1957  sum of `weights`
1958
1959  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1960
1961  Args:
1962    predictions: A `Tensor` of arbitrary shape.
1963    labels: A `Tensor` of the same shape as `predictions`.
1964    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
1965    metrics_collections: An optional list of collections that
1966      `mean_absolute_error` should be added to.
1967    updates_collections: An optional list of collections that `update_op` should
1968      be added to.
1969    name: An optional variable_scope name.
1970
1971  Returns:
1972    mean_absolute_error: A `Tensor` representing the current mean, the value of
1973      `total` divided by `count`.
1974    update_op: An operation that increments the `total` and `count` variables
1975      appropriately and whose value matches `mean_absolute_error`.
1976
1977  Raises:
1978    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1979      `weights` is not `None` and its shape doesn't match `predictions`, or if
1980      either `metrics_collections` or `updates_collections` are not a list or
1981      tuple.
1982  """
1983  return metrics.mean_absolute_error(
1984      predictions=predictions, labels=labels, weights=weights,
1985      metrics_collections=metrics_collections,
1986      updates_collections=updates_collections, name=name)
1987
1988
1989def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
1990                                  metrics_collections=None,
1991                                  updates_collections=None,
1992                                  name=None):
1993  """Computes the mean relative error by normalizing with the given values.
1994
1995  The `streaming_mean_relative_error` function creates two local variables,
1996  `total` and `count` that are used to compute the mean relative absolute error.
1997  This average is weighted by `weights`, and it is ultimately returned as
1998  `mean_relative_error`: an idempotent operation that simply divides `total` by
1999  `count`.
2000
2001  For estimation of the metric over a stream of data, the function creates an
2002  `update_op` operation that updates these variables and returns the
2003  `mean_reative_error`. Internally, a `relative_errors` operation divides the
2004  absolute value of the differences between `predictions` and `labels` by the
2005  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
2006  product of `weights` and `relative_errors`, and it increments `count` with the
2007  reduced sum of `weights`.
2008
2009  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2010
2011  Args:
2012    predictions: A `Tensor` of arbitrary shape.
2013    labels: A `Tensor` of the same shape as `predictions`.
2014    normalizer: A `Tensor` of the same shape as `predictions`.
2015    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2016    metrics_collections: An optional list of collections that
2017      `mean_relative_error` should be added to.
2018    updates_collections: An optional list of collections that `update_op` should
2019      be added to.
2020    name: An optional variable_scope name.
2021
2022  Returns:
2023    mean_relative_error: A `Tensor` representing the current mean, the value of
2024      `total` divided by `count`.
2025    update_op: An operation that increments the `total` and `count` variables
2026      appropriately and whose value matches `mean_relative_error`.
2027
2028  Raises:
2029    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2030      `weights` is not `None` and its shape doesn't match `predictions`, or if
2031      either `metrics_collections` or `updates_collections` are not a list or
2032      tuple.
2033  """
2034  return metrics.mean_relative_error(
2035      normalizer=normalizer, predictions=predictions, labels=labels,
2036      weights=weights, metrics_collections=metrics_collections,
2037      updates_collections=updates_collections, name=name)
2038
2039
2040def streaming_mean_squared_error(predictions, labels, weights=None,
2041                                 metrics_collections=None,
2042                                 updates_collections=None,
2043                                 name=None):
2044  """Computes the mean squared error between the labels and predictions.
2045
2046  The `streaming_mean_squared_error` function creates two local variables,
2047  `total` and `count` that are used to compute the mean squared error.
2048  This average is weighted by `weights`, and it is ultimately returned as
2049  `mean_squared_error`: an idempotent operation that simply divides `total` by
2050  `count`.
2051
2052  For estimation of the metric over a stream of data, the function creates an
2053  `update_op` operation that updates these variables and returns the
2054  `mean_squared_error`. Internally, a `squared_error` operation computes the
2055  element-wise square of the difference between `predictions` and `labels`. Then
2056  `update_op` increments `total` with the reduced sum of the product of
2057  `weights` and `squared_error`, and it increments `count` with the reduced sum
2058  of `weights`.
2059
2060  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2061
2062  Args:
2063    predictions: A `Tensor` of arbitrary shape.
2064    labels: A `Tensor` of the same shape as `predictions`.
2065    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2066    metrics_collections: An optional list of collections that
2067      `mean_squared_error` should be added to.
2068    updates_collections: An optional list of collections that `update_op` should
2069      be added to.
2070    name: An optional variable_scope name.
2071
2072  Returns:
2073    mean_squared_error: A `Tensor` representing the current mean, the value of
2074      `total` divided by `count`.
2075    update_op: An operation that increments the `total` and `count` variables
2076      appropriately and whose value matches `mean_squared_error`.
2077
2078  Raises:
2079    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2080      `weights` is not `None` and its shape doesn't match `predictions`, or if
2081      either `metrics_collections` or `updates_collections` are not a list or
2082      tuple.
2083  """
2084  return metrics.mean_squared_error(
2085      predictions=predictions, labels=labels, weights=weights,
2086      metrics_collections=metrics_collections,
2087      updates_collections=updates_collections, name=name)
2088
2089
2090def streaming_root_mean_squared_error(predictions, labels, weights=None,
2091                                      metrics_collections=None,
2092                                      updates_collections=None,
2093                                      name=None):
2094  """Computes the root mean squared error between the labels and predictions.
2095
2096  The `streaming_root_mean_squared_error` function creates two local variables,
2097  `total` and `count` that are used to compute the root mean squared error.
2098  This average is weighted by `weights`, and it is ultimately returned as
2099  `root_mean_squared_error`: an idempotent operation that takes the square root
2100  of the division of `total` by `count`.
2101
2102  For estimation of the metric over a stream of data, the function creates an
2103  `update_op` operation that updates these variables and returns the
2104  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2105  the element-wise square of the difference between `predictions` and `labels`.
2106  Then `update_op` increments `total` with the reduced sum of the product of
2107  `weights` and `squared_error`, and it increments `count` with the reduced sum
2108  of `weights`.
2109
2110  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2111
2112  Args:
2113    predictions: A `Tensor` of arbitrary shape.
2114    labels: A `Tensor` of the same shape as `predictions`.
2115    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2116    metrics_collections: An optional list of collections that
2117      `root_mean_squared_error` should be added to.
2118    updates_collections: An optional list of collections that `update_op` should
2119      be added to.
2120    name: An optional variable_scope name.
2121
2122  Returns:
2123    root_mean_squared_error: A `Tensor` representing the current mean, the value
2124      of `total` divided by `count`.
2125    update_op: An operation that increments the `total` and `count` variables
2126      appropriately and whose value matches `root_mean_squared_error`.
2127
2128  Raises:
2129    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2130      `weights` is not `None` and its shape doesn't match `predictions`, or if
2131      either `metrics_collections` or `updates_collections` are not a list or
2132      tuple.
2133  """
2134  return metrics.root_mean_squared_error(
2135      predictions=predictions, labels=labels, weights=weights,
2136      metrics_collections=metrics_collections,
2137      updates_collections=updates_collections, name=name)
2138
2139
2140def streaming_covariance(predictions,
2141                         labels,
2142                         weights=None,
2143                         metrics_collections=None,
2144                         updates_collections=None,
2145                         name=None):
2146  """Computes the unbiased sample covariance between `predictions` and `labels`.
2147
2148  The `streaming_covariance` function creates four local variables,
2149  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
2150  compute the sample covariance between predictions and labels across multiple
2151  batches of data. The covariance is ultimately returned as an idempotent
2152  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
2153  in order to get an unbiased estimate.
2154
2155  The algorithm used for this online computation is described in
2156  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
2157  Specifically, the formula used to combine two sample comoments is
2158  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
2159  The comoment for a single batch of data is simply
2160  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
2161
2162  If `weights` is not None, then it is used to compute weighted comoments,
2163  means, and count. NOTE: these weights are treated as "frequency weights", as
2164  opposed to "reliability weights". See discussion of the difference on
2165  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2166
2167  To facilitate the computation of covariance across multiple batches of data,
2168  the function creates an `update_op` operation, which updates underlying
2169  variables and returns the updated covariance.
2170
2171  Args:
2172    predictions: A `Tensor` of arbitrary size.
2173    labels: A `Tensor` of the same size as `predictions`.
2174    weights: An optional set of weights which indicates the frequency with which
2175      an example is sampled. Must be broadcastable with `labels`.
2176    metrics_collections: An optional list of collections that the metric
2177      value variable should be added to.
2178    updates_collections: An optional list of collections that the metric update
2179      ops should be added to.
2180    name: An optional variable_scope name.
2181
2182  Returns:
2183    covariance: A `Tensor` representing the current unbiased sample covariance,
2184      `comoment` / (`count` - 1).
2185    update_op: An operation that updates the local variables appropriately.
2186
2187  Raises:
2188    ValueError: If labels and predictions are of different sizes or if either
2189      `metrics_collections` or `updates_collections` are not a list or tuple.
2190  """
2191  with variable_scope.variable_scope(
2192      name, 'covariance', (predictions, labels, weights)):
2193    predictions, labels, weights = _remove_squeezable_dimensions(
2194        predictions, labels, weights)
2195    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2196    count = _create_local('count', [])
2197    mean_prediction = _create_local('mean_prediction', [])
2198    mean_label = _create_local('mean_label', [])
2199    comoment = _create_local('comoment', [])  # C_A in update equation
2200
2201    if weights is None:
2202      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
2203      weighted_predictions = predictions
2204      weighted_labels = labels
2205    else:
2206      batch_count = math_ops.reduce_sum(
2207          _broadcast_weights(weights, labels))  # n_B in eqn
2208      weighted_predictions = predictions * weights
2209      weighted_labels = labels * weights
2210
2211    update_count = state_ops.assign_add(count, batch_count)  # n_AB in eqn
2212    prev_count = update_count - batch_count  # n_A in update equation
2213
2214    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
2215    # batch_mean_prediction is E[x_B] in the update equation
2216    batch_mean_prediction = _safe_div(
2217        math_ops.reduce_sum(weighted_predictions), batch_count,
2218        'batch_mean_prediction')
2219    delta_mean_prediction = _safe_div(
2220        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
2221        'delta_mean_prediction')
2222    update_mean_prediction = state_ops.assign_add(mean_prediction,
2223                                                  delta_mean_prediction)
2224    # prev_mean_prediction is E[x_A] in the update equation
2225    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
2226
2227    # batch_mean_label is E[y_B] in the update equation
2228    batch_mean_label = _safe_div(
2229        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
2230    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
2231                                 update_count, 'delta_mean_label')
2232    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
2233    # prev_mean_label is E[y_A] in the update equation
2234    prev_mean_label = update_mean_label - delta_mean_label
2235
2236    unweighted_batch_coresiduals = (
2237        (predictions - batch_mean_prediction) * (labels - batch_mean_label))
2238    # batch_comoment is C_B in the update equation
2239    if weights is None:
2240      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
2241    else:
2242      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals *
2243                                           weights)
2244
2245    # View delta_comoment as = C_AB - C_A in the update equation above.
2246    # Since C_A is stored in a var, by how much do we need to increment that var
2247    # to make the var = C_AB?
2248    delta_comoment = (batch_comoment +
2249                      (prev_mean_prediction - batch_mean_prediction) *
2250                      (prev_mean_label - batch_mean_label) *
2251                      (prev_count * batch_count / update_count))
2252    update_comoment = state_ops.assign_add(comoment, delta_comoment)
2253
2254    covariance = _safe_div(comoment, count - 1, 'covariance')
2255    with ops.control_dependencies([update_comoment]):
2256      update_op = _safe_div(comoment, count - 1, 'update_op')
2257
2258  if metrics_collections:
2259    ops.add_to_collections(metrics_collections, covariance)
2260
2261  if updates_collections:
2262    ops.add_to_collections(updates_collections, update_op)
2263
2264  return covariance, update_op
2265
2266
2267def streaming_pearson_correlation(predictions,
2268                                  labels,
2269                                  weights=None,
2270                                  metrics_collections=None,
2271                                  updates_collections=None,
2272                                  name=None):
2273  """Computes Pearson correlation coefficient between `predictions`, `labels`.
2274
2275  The `streaming_pearson_correlation` function delegates to
2276  `streaming_covariance` the tracking of three [co]variances:
2277
2278  - `streaming_covariance(predictions, labels)`, i.e. covariance
2279  - `streaming_covariance(predictions, predictions)`, i.e. variance
2280  - `streaming_covariance(labels, labels)`, i.e. variance
2281
2282  The product-moment correlation ultimately returned is an idempotent operation
2283  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
2284  facilitate correlation computation across multiple batches, the function
2285  groups the `update_op`s of the underlying streaming_covariance and returns an
2286  `update_op`.
2287
2288  If `weights` is not None, then it is used to compute a weighted correlation.
2289  NOTE: these weights are treated as "frequency weights", as opposed to
2290  "reliability weights". See discussion of the difference on
2291  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2292
2293  Args:
2294    predictions: A `Tensor` of arbitrary size.
2295    labels: A `Tensor` of the same size as predictions.
2296    weights: An optional set of weights which indicates the frequency with which
2297      an example is sampled. Must be broadcastable with `labels`.
2298    metrics_collections: An optional list of collections that the metric
2299      value variable should be added to.
2300    updates_collections: An optional list of collections that the metric update
2301      ops should be added to.
2302    name: An optional variable_scope name.
2303
2304  Returns:
2305    pearson_r: A `Tensor` representing the current Pearson product-moment
2306      correlation coefficient, the value of
2307      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
2308    update_op: An operation that updates the underlying variables appropriately.
2309
2310  Raises:
2311    ValueError: If `labels` and `predictions` are of different sizes, or if
2312      `weights` is the wrong size, or if either `metrics_collections` or
2313      `updates_collections` are not a `list` or `tuple`.
2314  """
2315  with variable_scope.variable_scope(
2316      name, 'pearson_r', (predictions, labels, weights)):
2317    predictions, labels, weights = _remove_squeezable_dimensions(
2318        predictions, labels, weights)
2319    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2320    cov, update_cov = streaming_covariance(
2321        predictions, labels, weights=weights, name='covariance')
2322    var_predictions, update_var_predictions = streaming_covariance(
2323        predictions, predictions, weights=weights, name='variance_predictions')
2324    var_labels, update_var_labels = streaming_covariance(
2325        labels, labels, weights=weights, name='variance_labels')
2326
2327    pearson_r = _safe_div(
2328        cov,
2329        math_ops.mul(math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)),
2330        'pearson_r')
2331    with ops.control_dependencies(
2332        [update_cov, update_var_predictions, update_var_labels]):
2333      update_op = _safe_div(update_cov, math_ops.mul(
2334          math_ops.sqrt(update_var_predictions),
2335          math_ops.sqrt(update_var_labels)), 'update_op')
2336
2337  if metrics_collections:
2338    ops.add_to_collections(metrics_collections, pearson_r)
2339
2340  if updates_collections:
2341    ops.add_to_collections(updates_collections, update_op)
2342
2343  return pearson_r, update_op
2344
2345
2346# TODO(nsilberman): add a 'normalized' flag so that the user can request
2347# normalization if the inputs are not normalized.
2348def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
2349                                   metrics_collections=None,
2350                                   updates_collections=None,
2351                                   name=None):
2352  """Computes the cosine distance between the labels and predictions.
2353
2354  The `streaming_mean_cosine_distance` function creates two local variables,
2355  `total` and `count` that are used to compute the average cosine distance
2356  between `predictions` and `labels`. This average is weighted by `weights`,
2357  and it is ultimately returned as `mean_distance`, which is an idempotent
2358  operation that simply divides `total` by `count`.
2359
2360  For estimation of the metric over a stream of data, the function creates an
2361  `update_op` operation that updates these variables and returns the
2362  `mean_distance`.
2363
2364  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2365
2366  Args:
2367    predictions: A `Tensor` of the same shape as `labels`.
2368    labels: A `Tensor` of arbitrary shape.
2369    dim: The dimension along which the cosine distance is computed.
2370    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
2371      and whose dimension `dim` is 1.
2372    metrics_collections: An optional list of collections that the metric
2373      value variable should be added to.
2374    updates_collections: An optional list of collections that the metric update
2375      ops should be added to.
2376    name: An optional variable_scope name.
2377
2378  Returns:
2379    mean_distance: A `Tensor` representing the current mean, the value of `total`
2380      divided by `count`.
2381    update_op: An operation that increments the `total` and `count` variables
2382      appropriately.
2383
2384  Raises:
2385    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2386      `weights` is not `None` and its shape doesn't match `predictions`, or if
2387      either `metrics_collections` or `updates_collections` are not a list or
2388      tuple.
2389  """
2390  predictions, labels, weights = _remove_squeezable_dimensions(
2391      predictions, labels, weights)
2392  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2393  radial_diffs = math_ops.mul(predictions, labels)
2394  radial_diffs = math_ops.reduce_sum(radial_diffs,
2395                                     reduction_indices=[dim,],
2396                                     keep_dims=True)
2397  mean_distance, update_op = streaming_mean(radial_diffs, weights,
2398                                            None,
2399                                            None,
2400                                            name or 'mean_cosine_distance')
2401  mean_distance = math_ops.sub(1.0, mean_distance)
2402  update_op = math_ops.sub(1.0, update_op)
2403
2404  if metrics_collections:
2405    ops.add_to_collections(metrics_collections, mean_distance)
2406
2407  if updates_collections:
2408    ops.add_to_collections(updates_collections, update_op)
2409
2410  return mean_distance, update_op
2411
2412
2413def streaming_percentage_less(values, threshold, weights=None,
2414                              metrics_collections=None,
2415                              updates_collections=None,
2416                              name=None):
2417  """Computes the percentage of values less than the given threshold.
2418
2419  The `streaming_percentage_less` function creates two local variables,
2420  `total` and `count` that are used to compute the percentage of `values` that
2421  fall below `threshold`. This rate is weighted by `weights`, and it is
2422  ultimately returned as `percentage` which is an idempotent operation that
2423  simply divides `total` by `count`.
2424
2425  For estimation of the metric over a stream of data, the function creates an
2426  `update_op` operation that updates these variables and returns the
2427  `percentage`.
2428
2429  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2430
2431  Args:
2432    values: A numeric `Tensor` of arbitrary size.
2433    threshold: A scalar threshold.
2434    weights: An optional `Tensor` whose shape is broadcastable to `values`.
2435    metrics_collections: An optional list of collections that the metric
2436      value variable should be added to.
2437    updates_collections: An optional list of collections that the metric update
2438      ops should be added to.
2439    name: An optional variable_scope name.
2440
2441  Returns:
2442    percentage: A `Tensor` representing the current mean, the value of `total`
2443      divided by `count`.
2444    update_op: An operation that increments the `total` and `count` variables
2445      appropriately.
2446
2447  Raises:
2448    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
2449      or if either `metrics_collections` or `updates_collections` are not a list
2450      or tuple.
2451  """
2452  return metrics.percentage_below(
2453      values=values, threshold=threshold, weights=weights,
2454      metrics_collections=metrics_collections,
2455      updates_collections=updates_collections, name=name)
2456
2457
2458def streaming_mean_iou(predictions,
2459                       labels,
2460                       num_classes,
2461                       weights=None,
2462                       metrics_collections=None,
2463                       updates_collections=None,
2464                       name=None):
2465  """Calculate per-step mean Intersection-Over-Union (mIOU).
2466
2467  Mean Intersection-Over-Union is a common evaluation metric for
2468  semantic image segmentation, which first computes the IOU for each
2469  semantic class and then computes the average over classes.
2470  IOU is defined as follows:
2471    IOU = true_positive / (true_positive + false_positive + false_negative).
2472  The predictions are accumulated in a confusion matrix, weighted by `weights`,
2473  and mIOU is then calculated from it.
2474
2475  For estimation of the metric over a stream of data, the function creates an
2476  `update_op` operation that updates these variables and returns the `mean_iou`.
2477
2478  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2479
2480  Args:
2481    predictions: A `Tensor` of prediction results for semantic labels, whose
2482      shape is [batch size] and type `int32` or `int64`. The tensor will be
2483      flattened, if its rank > 1.
2484    labels: A `Tensor` of ground truth labels with shape [batch size] and of
2485      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
2486    num_classes: The possible number of labels the prediction task can
2487      have. This value must be provided, since a confusion matrix of
2488      dimension = [num_classes, num_classes] will be allocated.
2489    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2490    metrics_collections: An optional list of collections that `mean_iou`
2491      should be added to.
2492    updates_collections: An optional list of collections `update_op` should be
2493      added to.
2494    name: An optional variable_scope name.
2495
2496  Returns:
2497    mean_iou: A `Tensor` representing the mean intersection-over-union.
2498    update_op: An operation that increments the confusion matrix.
2499
2500  Raises:
2501    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2502      `weights` is not `None` and its shape doesn't match `predictions`, or if
2503      either `metrics_collections` or `updates_collections` are not a list or
2504      tuple.
2505  """
2506  return metrics.mean_iou(
2507      num_classes=num_classes, predictions=predictions, labels=labels,
2508      weights=weights, metrics_collections=metrics_collections,
2509      updates_collections=updates_collections, name=name)
2510
2511
2512def _next_array_size(required_size, growth_factor=1.5):
2513  """Calculate the next size for reallocating a dynamic array.
2514
2515  Args:
2516    required_size: number or tf.Tensor specifying required array capacity.
2517    growth_factor: optional number or tf.Tensor specifying the growth factor
2518      between subsequent allocations.
2519
2520  Returns:
2521    tf.Tensor with dtype=int32 giving the next array size.
2522  """
2523  exponent = math_ops.ceil(
2524      math_ops.log(math_ops.cast(required_size, dtypes.float32))
2525      / math_ops.log(math_ops.cast(growth_factor, dtypes.float32)))
2526  return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32)
2527
2528
2529def streaming_concat(values,
2530                     axis=0,
2531                     max_size=None,
2532                     metrics_collections=None,
2533                     updates_collections=None,
2534                     name=None):
2535  """Concatenate values along an axis across batches.
2536
2537  The function `streaming_concat` creates two local variables, `array` and
2538  `size`, that are used to store concatenated values. Internally, `array` is
2539  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
2540  that updates can be run in amortized constant time.
2541
2542  For estimation of the metric over a stream of data, the function creates an
2543  `update_op` operation that appends the values of a tensor and returns the
2544  `value` of the concatenated tensors.
2545
2546  This op allows for evaluating metrics that cannot be updated incrementally
2547  using the same framework as other streaming metrics.
2548
2549  Args:
2550    values: `Tensor` to concatenate. Rank and the shape along all axes other
2551      than the axis to concatenate along must be statically known.
2552    axis: optional integer axis to concatenate along.
2553    max_size: optional integer maximum size of `value` along the given axis.
2554      Once the maximum size is reached, further updates are no-ops. By default,
2555      there is no maximum size: the array is resized as necessary.
2556    metrics_collections: An optional list of collections that `value`
2557      should be added to.
2558    updates_collections: An optional list of collections `update_op` should be
2559      added to.
2560    name: An optional variable_scope name.
2561
2562  Returns:
2563    value: A `Tensor` representing the concatenated values.
2564    update_op: An operation that concatenates the next values.
2565
2566  Raises:
2567    ValueError: if `values` does not have a statically known rank, `axis` is
2568      not in the valid range or the size of `values` is not statically known
2569      along any axis other than `axis`.
2570  """
2571  with variable_scope.variable_scope(name, 'streaming_concat', (values,)):
2572    # pylint: disable=invalid-slice-index
2573    values_shape = values.get_shape()
2574    if values_shape.dims is None:
2575      raise ValueError('`values` must have known statically known rank')
2576
2577    ndim = len(values_shape)
2578    if axis < 0:
2579      axis += ndim
2580    if not 0 <= axis < ndim:
2581      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
2582
2583    fixed_shape = [dim.value for n, dim in enumerate(values_shape)
2584                   if n != axis]
2585    if any(value is None for value in fixed_shape):
2586      raise ValueError('all dimensions of `values` other than the dimension to '
2587                       'concatenate along must have statically known size')
2588
2589    # We move `axis` to the front of the internal array so assign ops can be
2590    # applied to contiguous slices
2591    init_size = 0 if max_size is None else max_size
2592    init_shape = [init_size] + fixed_shape
2593    array = _create_local(
2594        'array', shape=init_shape, validate_shape=False, dtype=values.dtype)
2595    size = _create_local('size', shape=[], dtype=dtypes.int32)
2596
2597    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
2598    valid_array = array[:size]
2599    valid_array.set_shape([None] + fixed_shape)
2600    value = array_ops.transpose(valid_array, perm, name='concat')
2601
2602    values_size = array_ops.shape(values)[axis]
2603    if max_size is None:
2604      batch_size = values_size
2605    else:
2606      batch_size = math_ops.minimum(values_size, max_size - size)
2607
2608    perm = [axis] + [n for n in range(ndim) if n != axis]
2609    batch_values = array_ops.transpose(values, perm)[:batch_size]
2610
2611    def reallocate():
2612      next_size = _next_array_size(new_size)
2613      next_shape = array_ops.pack([next_size] + fixed_shape)
2614      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
2615      old_value = array.value()
2616      assign_op = state_ops.assign(array, new_value, validate_shape=False)
2617      with ops.control_dependencies([assign_op]):
2618        copy_op = array[:size].assign(old_value[:size])
2619      # return value needs to be the same dtype as no_op() for cond
2620      with ops.control_dependencies([copy_op]):
2621        return control_flow_ops.no_op()
2622
2623    new_size = size + batch_size
2624    array_size = array_ops.shape_internal(array, optimize=False)[0]
2625    maybe_reallocate_op = control_flow_ops.cond(
2626        new_size > array_size, reallocate, control_flow_ops.no_op)
2627    with ops.control_dependencies([maybe_reallocate_op]):
2628      append_values_op = array[size:new_size].assign(batch_values)
2629    with ops.control_dependencies([append_values_op]):
2630      update_op = size.assign(new_size)
2631
2632    if metrics_collections:
2633      ops.add_to_collections(metrics_collections, value)
2634
2635    if updates_collections:
2636      ops.add_to_collections(updates_collections, update_op)
2637
2638    return value, update_op
2639    # pylint: enable=invalid-slice-index
2640
2641
2642def aggregate_metrics(*value_update_tuples):
2643  """Aggregates the metric value tensors and update ops into two lists.
2644
2645  Args:
2646    *value_update_tuples: a variable number of tuples, each of which contain the
2647      pair of (value_tensor, update_op) from a streaming metric.
2648
2649  Returns:
2650    A list of value `Tensor` objects and a list of update ops.
2651
2652  Raises:
2653    ValueError: if `value_update_tuples` is empty.
2654  """
2655  if not value_update_tuples:
2656    raise ValueError('Expected at least one value_tensor/update_op pair')
2657  value_ops, update_ops = zip(*value_update_tuples)
2658  return list(value_ops), list(update_ops)
2659
2660
2661def aggregate_metric_map(names_to_tuples):
2662  """Aggregates the metric names to tuple dictionary.
2663
2664  This function is useful for pairing metric names with their associated value
2665  and update ops when the list of metrics is long. For example:
2666
2667  ```python
2668    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
2669        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
2670            predictions, labels, weights),
2671        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
2672            predictions, labels, labels, weights),
2673        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
2674            predictions, labels, weights),
2675        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
2676            predictions, labels, weights),
2677    })
2678  ```
2679
2680  Args:
2681    names_to_tuples: a map of metric names to tuples, each of which contain the
2682      pair of (value_tensor, update_op) from a streaming metric.
2683
2684  Returns:
2685    A dictionary from metric names to value ops and a dictionary from metric
2686    names to update ops.
2687  """
2688  metric_names = names_to_tuples.keys()
2689  value_ops, update_ops = zip(*names_to_tuples.values())
2690  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
2691
2692
2693def _remove_squeezable_dimensions(predictions, labels, weights):
2694  """Squeeze last dim if needed.
2695
2696  Squeezes `predictions` and `labels` if their rank differs by 1.
2697  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
2698
2699  This will use static shape if available. Otherwise, it will add graph
2700  operations, which could result in a performance hit.
2701
2702  Args:
2703    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
2704    labels: Label values, a `Tensor` whose dimensions match `predictions`.
2705    weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
2706      more than the new rank of `predictions`
2707
2708  Returns:
2709    Tuple of `predictions`, `labels` and `weights`, possibly with the last
2710    dimension squeezed.
2711  """
2712  predictions, labels = tensor_util.remove_squeezable_dimensions(
2713      predictions, labels)
2714  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2715
2716  if weights is not None:
2717    weights = ops.convert_to_tensor(weights)
2718    predictions_shape = predictions.get_shape()
2719    predictions_rank = predictions_shape.ndims
2720    weights_shape = weights.get_shape()
2721    weights_rank = weights_shape.ndims
2722
2723    if (predictions_rank is not None) and (weights_rank is not None):
2724      # Use static rank.
2725      if weights_rank - predictions_rank == 1:
2726        weights = array_ops.squeeze(weights, [-1])
2727    elif (weights_rank is None) or (
2728        weights_shape.dims[-1].is_compatible_with(1)):
2729      # Use dynamic rank
2730      weights = control_flow_ops.cond(
2731          math_ops.equal(array_ops.rank(weights),
2732                         math_ops.add(array_ops.rank(predictions), 1)),
2733          lambda: array_ops.squeeze(weights, [-1]),
2734          lambda: weights)
2735  return predictions, labels, weights
2736
2737
2738__all__ = [
2739    'aggregate_metric_map',
2740    'aggregate_metrics',
2741    'streaming_accuracy',
2742    'streaming_auc',
2743    'streaming_false_negatives',
2744    'streaming_false_negatives_at_thresholds',
2745    'streaming_false_positives',
2746    'streaming_false_positives_at_thresholds',
2747    'streaming_mean',
2748    'streaming_mean_absolute_error',
2749    'streaming_mean_cosine_distance',
2750    'streaming_mean_iou',
2751    'streaming_mean_relative_error',
2752    'streaming_mean_squared_error',
2753    'streaming_mean_tensor',
2754    'streaming_percentage_less',
2755    'streaming_precision',
2756    'streaming_precision_at_thresholds',
2757    'streaming_recall',
2758    'streaming_recall_at_k',
2759    'streaming_recall_at_thresholds',
2760    'streaming_root_mean_squared_error',
2761    'streaming_sensitivity_at_specificity',
2762    'streaming_sparse_average_precision_at_k',
2763    'streaming_sparse_precision_at_k',
2764    'streaming_sparse_recall_at_k',
2765    'streaming_specificity_at_sensitivity',
2766    'streaming_true_negatives',
2767    'streaming_true_negatives_at_thresholds',
2768    'streaming_true_positives',
2769    'streaming_true_positives_at_thresholds',
2770]
2771