metric_ops.py revision be270ecb79c8548a0caddf67908189e6169b1472
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.framework import deprecated
26from tensorflow.contrib.framework import tensor_util
27from tensorflow.contrib.framework.python.ops import variables as contrib_variables
28from tensorflow.contrib.metrics.python.ops import set_ops
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import metrics
37from tensorflow.python.ops import nn
38from tensorflow.python.ops import sparse_ops
39from tensorflow.python.ops import state_ops
40from tensorflow.python.ops import variable_scope
41from tensorflow.python.ops import variables
42
43
44def _safe_div(numerator, denominator, name):
45  """Divides two values, returning 0 if the denominator is <= 0.
46
47  Args:
48    numerator: A real `Tensor`.
49    denominator: A real `Tensor`, with dtype matching `numerator`.
50    name: Name for the returned op.
51
52  Returns:
53    0 if `denominator` <= 0, else `numerator` / `denominator`
54  """
55  return array_ops.where(
56      math_ops.greater(denominator, 0),
57      math_ops.truediv(numerator, denominator),
58      0,
59      name=name)
60
61
62def _safe_scalar_div(numerator, denominator, name):
63  """Divides two values, returning 0 if the denominator is 0.
64
65  Args:
66    numerator: A scalar `float64` `Tensor`.
67    denominator: A scalar `float64` `Tensor`.
68    name: Name for the returned op.
69
70  Returns:
71    0 if `denominator` == 0, else `numerator` / `denominator`
72  """
73  numerator.get_shape().with_rank_at_most(1)
74  denominator.get_shape().with_rank_at_most(1)
75  return control_flow_ops.cond(
76      math_ops.equal(
77          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
78      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
79      lambda: math_ops.div(numerator, denominator),
80      name=name)
81
82
83def _create_local(name, shape, collections=None, validate_shape=True,
84                  dtype=dtypes.float32):
85  """Creates a new local variable.
86
87  Args:
88    name: The name of the new or existing variable.
89    shape: Shape of the new or existing variable.
90    collections: A list of collection names to which the Variable will be added.
91    validate_shape: Whether to validate the shape of the variable.
92    dtype: Data type of the variables.
93
94  Returns:
95    The created variable.
96  """
97  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
98  collections = list(collections or [])
99  collections += [ops.GraphKeys.LOCAL_VARIABLES]
100  return variables.Variable(
101      initial_value=array_ops.zeros(shape, dtype=dtype),
102      name=name,
103      trainable=False,
104      collections=collections,
105      validate_shape=validate_shape)
106
107
108# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py.
109def _assert_weights_rank(weights, values):
110  """`weights` rank must be either `0`, or the same as 'values'."""
111  return check_ops.assert_rank_in(weights, (0, array_ops.rank(values)))
112
113
114def _count_condition(values, weights=None, metrics_collections=None,
115                     updates_collections=None):
116  """Sums the weights of cases where the given values are True.
117
118  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
119
120  Args:
121    values: A `bool` `Tensor` of arbitrary size.
122    weights: Optional `Tensor` whose rank is either 0, or the same rank as
123      `values`, and must be broadcastable to `values` (i.e., all dimensions
124      must be either `1`, or the same as the corresponding `values`
125      dimension).
126    metrics_collections: An optional list of collections that the metric
127      value variable should be added to.
128    updates_collections: An optional list of collections that the metric update
129      ops should be added to.
130
131  Returns:
132    value_tensor: A `Tensor` representing the current value of the metric.
133    update_op: An operation that accumulates the error from a batch of data.
134
135  Raises:
136    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
137      or if either `metrics_collections` or `updates_collections` are not a list
138      or tuple.
139  """
140  check_ops.assert_type(values, dtypes.bool)
141  count = _create_local('count', shape=[])
142
143  values = math_ops.to_float(values)
144  if weights is not None:
145    weights = math_ops.to_float(weights)
146    with ops.control_dependencies((_assert_weights_rank(weights, values),)):
147      values = math_ops.multiply(values, weights)
148
149  value_tensor = array_ops.identity(count)
150  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
151
152  if metrics_collections:
153    ops.add_to_collections(metrics_collections, value_tensor)
154
155  if updates_collections:
156    ops.add_to_collections(updates_collections, update_op)
157
158  return value_tensor, update_op
159
160
161def streaming_true_positives(predictions, labels, weights=None,
162                             metrics_collections=None,
163                             updates_collections=None,
164                             name=None):
165  """Sum the weights of true_positives.
166
167  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
168
169  Args:
170    predictions: The predicted values, a `bool` `Tensor` of arbitrary
171      dimensions.
172    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
173      match `predictions`.
174    weights: Optional `Tensor` whose rank is either 0, or the same rank as
175      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
176      must be either `1`, or the same as the corresponding `labels`
177      dimension).
178    metrics_collections: An optional list of collections that the metric
179      value variable should be added to.
180    updates_collections: An optional list of collections that the metric update
181      ops should be added to.
182    name: An optional variable_scope name.
183
184  Returns:
185    value_tensor: A `Tensor` representing the current value of the metric.
186    update_op: An operation that accumulates the error from a batch of data.
187
188  Raises:
189    ValueError: If `predictions` and `labels` have mismatched shapes, or if
190      `weights` is not `None` and its shape doesn't match `predictions`, or if
191      either `metrics_collections` or `updates_collections` are not a list or
192      tuple.
193  """
194  return metrics.true_positives(
195      predictions=predictions, labels=labels, weights=weights,
196      metrics_collections=metrics_collections,
197      updates_collections=updates_collections, name=name)
198
199
200def streaming_true_negatives(predictions, labels, weights=None,
201                             metrics_collections=None,
202                             updates_collections=None,
203                             name=None):
204  """Sum the weights of true_negatives.
205
206  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
207
208  Args:
209    predictions: The predicted values, a `bool` `Tensor` of arbitrary
210      dimensions.
211    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
212      match `predictions`.
213    weights: Optional `Tensor` whose rank is either 0, or the same rank as
214      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
215      must be either `1`, or the same as the corresponding `labels`
216      dimension).
217    metrics_collections: An optional list of collections that the metric
218      value variable should be added to.
219    updates_collections: An optional list of collections that the metric update
220      ops should be added to.
221    name: An optional variable_scope name.
222
223  Returns:
224    value_tensor: A `Tensor` representing the current value of the metric.
225    update_op: An operation that accumulates the error from a batch of data.
226
227  Raises:
228    ValueError: If `predictions` and `labels` have mismatched shapes, or if
229      `weights` is not `None` and its shape doesn't match `predictions`, or if
230      either `metrics_collections` or `updates_collections` are not a list or
231      tuple.
232  """
233  with variable_scope.variable_scope(
234      name, 'true_negatives', (predictions, labels, weights)):
235
236    predictions = ops.convert_to_tensor(predictions)
237    labels = ops.convert_to_tensor(labels)
238    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
239    is_true_negative = math_ops.logical_and(math_ops.equal(labels, 0),
240                                            math_ops.equal(predictions, 0))
241    return _count_condition(is_true_negative, weights, metrics_collections,
242                            updates_collections)
243
244
245def streaming_false_positives(predictions, labels, weights=None,
246                              metrics_collections=None,
247                              updates_collections=None,
248                              name=None):
249  """Sum the weights of false positives.
250
251  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
252
253  Args:
254    predictions: The predicted values, a `bool` `Tensor` of arbitrary
255      dimensions.
256    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
257      match `predictions`.
258    weights: Optional `Tensor` whose rank is either 0, or the same rank as
259      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
260      must be either `1`, or the same as the corresponding `labels`
261      dimension).
262    metrics_collections: An optional list of collections that the metric
263      value variable should be added to.
264    updates_collections: An optional list of collections that the metric update
265      ops should be added to.
266    name: An optional variable_scope name.
267
268  Returns:
269    value_tensor: A `Tensor` representing the current value of the metric.
270    update_op: An operation that accumulates the error from a batch of data.
271
272  Raises:
273    ValueError: If `predictions` and `labels` have mismatched shapes, or if
274      `weights` is not `None` and its shape doesn't match `predictions`, or if
275      either `metrics_collections` or `updates_collections` are not a list or
276      tuple.
277  """
278  return metrics.false_positives(
279      predictions=predictions, labels=labels, weights=weights,
280      metrics_collections=metrics_collections,
281      updates_collections=updates_collections, name=name)
282
283
284def streaming_false_negatives(predictions, labels, weights=None,
285                              metrics_collections=None,
286                              updates_collections=None,
287                              name=None):
288  """Computes the total number of false positives.
289
290  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
291
292  Args:
293    predictions: The predicted values, a `bool` `Tensor` of arbitrary
294      dimensions.
295    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
296      match `predictions`.
297    weights: Optional `Tensor` whose rank is either 0, or the same rank as
298      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
299      must be either `1`, or the same as the corresponding `labels`
300      dimension).
301    metrics_collections: An optional list of collections that the metric
302      value variable should be added to.
303    updates_collections: An optional list of collections that the metric update
304      ops should be added to.
305    name: An optional variable_scope name.
306
307  Returns:
308    value_tensor: A `Tensor` representing the current value of the metric.
309    update_op: An operation that accumulates the error from a batch of data.
310
311  Raises:
312    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
313      or if either `metrics_collections` or `updates_collections` are not a list
314      or tuple.
315  """
316  return metrics.false_negatives(
317      predictions=predictions, labels=labels, weights=weights,
318      metrics_collections=metrics_collections,
319      updates_collections=updates_collections, name=name)
320
321
322# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py.
323def _broadcast_weights(weights, values):
324  """Broadcast `weights` to the same shape as `values`.
325
326  This returns a version of `weights` following the same broadcast rules as
327  `mul(weights, values)`. When computing a weighted average, use this function
328  to broadcast `weights` before summing them; e.g.,
329  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
330
331  Args:
332    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
333      must be broadcastable to `values` (i.e., all dimensions must be either
334      `1`, or the same as the corresponding `values` dimension).
335    values: `Tensor` of any shape.
336
337  Returns:
338    `weights` broadcast to `values` shape.
339  """
340  with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope:
341    weights_shape = weights.get_shape()
342    values_shape = values.get_shape()
343    if (weights_shape.is_fully_defined() and
344        values_shape.is_fully_defined() and
345        weights_shape.is_compatible_with(values_shape)):
346      return weights
347    with ops.control_dependencies((_assert_weights_rank(weights, values),)):
348      return math_ops.multiply(
349          weights, array_ops.ones_like(values), name=scope)
350
351
352def streaming_mean(values, weights=None, metrics_collections=None,
353                   updates_collections=None, name=None):
354  """Computes the (weighted) mean of the given values.
355
356  The `streaming_mean` function creates two local variables, `total` and `count`
357  that are used to compute the average of `values`. This average is ultimately
358  returned as `mean` which is an idempotent operation that simply divides
359  `total` by `count`.
360
361  For estimation of the metric  over a stream of data, the function creates an
362  `update_op` operation that updates these variables and returns the `mean`.
363  `update_op` increments `total` with the reduced sum of the product of `values`
364  and `weights`, and it increments `count` with the reduced sum of `weights`.
365
366  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
367
368  Args:
369    values: A `Tensor` of arbitrary dimensions.
370    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
371      must be broadcastable to `values` (i.e., all dimensions must be either
372      `1`, or the same as the corresponding `values` dimension).
373    metrics_collections: An optional list of collections that `mean`
374      should be added to.
375    updates_collections: An optional list of collections that `update_op`
376      should be added to.
377    name: An optional variable_scope name.
378
379  Returns:
380    mean: A `Tensor` representing the current mean, the value of `total` divided
381      by `count`.
382    update_op: An operation that increments the `total` and `count` variables
383      appropriately and whose value matches `mean_value`.
384
385  Raises:
386    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
387      or if either `metrics_collections` or `updates_collections` are not a list
388      or tuple.
389  """
390  return metrics.mean(
391      values=values, weights=weights, metrics_collections=metrics_collections,
392      updates_collections=updates_collections, name=name)
393
394
395def streaming_mean_tensor(values, weights=None, metrics_collections=None,
396                          updates_collections=None, name=None):
397  """Computes the element-wise (weighted) mean of the given tensors.
398
399  In contrast to the `streaming_mean` function which returns a scalar with the
400  mean,  this function returns an average tensor with the same shape as the
401  input tensors.
402
403  The `streaming_mean_tensor` function creates two local variables,
404  `total_tensor` and `count_tensor` that are used to compute the average of
405  `values`. This average is ultimately returned as `mean` which is an idempotent
406  operation that simply divides `total` by `count`.
407
408  For estimation of the metric  over a stream of data, the function creates an
409  `update_op` operation that updates these variables and returns the `mean`.
410  `update_op` increments `total` with the reduced sum of the product of `values`
411  and `weights`, and it increments `count` with the reduced sum of `weights`.
412
413  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
414
415  Args:
416    values: A `Tensor` of arbitrary dimensions.
417    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
418      must be broadcastable to `values` (i.e., all dimensions must be either
419      `1`, or the same as the corresponding `values` dimension).
420    metrics_collections: An optional list of collections that `mean`
421      should be added to.
422    updates_collections: An optional list of collections that `update_op`
423      should be added to.
424    name: An optional variable_scope name.
425
426  Returns:
427    mean: A float `Tensor` representing the current mean, the value of `total`
428      divided by `count`.
429    update_op: An operation that increments the `total` and `count` variables
430      appropriately and whose value matches `mean_value`.
431
432  Raises:
433    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
434      or if either `metrics_collections` or `updates_collections` are not a list
435      or tuple.
436  """
437  return metrics.mean_tensor(
438      values=values, weights=weights, metrics_collections=metrics_collections,
439      updates_collections=updates_collections, name=name)
440
441
442def streaming_accuracy(predictions, labels, weights=None,
443                       metrics_collections=None, updates_collections=None,
444                       name=None):
445  """Calculates how often `predictions` matches `labels`.
446
447  The `streaming_accuracy` function creates two local variables, `total` and
448  `count` that are used to compute the frequency with which `predictions`
449  matches `labels`. This frequency is ultimately returned as `accuracy`: an
450  idempotent operation that simply divides `total` by `count`.
451
452  For estimation of the metric  over a stream of data, the function creates an
453  `update_op` operation that updates these variables and returns the `accuracy`.
454  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
455  where the corresponding elements of `predictions` and `labels` match and 0.0
456  otherwise. Then `update_op` increments `total` with the reduced sum of the
457  product of `weights` and `is_correct`, and it increments `count` with the
458  reduced sum of `weights`.
459
460  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
461
462  Args:
463    predictions: The predicted values, a `Tensor` of any shape.
464    labels: The ground truth values, a `Tensor` whose shape matches
465      `predictions`.
466    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
467      must be broadcastable to `labels` (i.e., all dimensions must be either
468      `1`, or the same as the corresponding `labels` dimension).
469    metrics_collections: An optional list of collections that `accuracy` should
470      be added to.
471    updates_collections: An optional list of collections that `update_op` should
472      be added to.
473    name: An optional variable_scope name.
474
475  Returns:
476    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
477      by `count`.
478    update_op: An operation that increments the `total` and `count` variables
479      appropriately and whose value matches `accuracy`.
480
481  Raises:
482    ValueError: If `predictions` and `labels` have mismatched shapes, or if
483      `weights` is not `None` and its shape doesn't match `predictions`, or if
484      either `metrics_collections` or `updates_collections` are not a list or
485      tuple.
486  """
487  return metrics.accuracy(
488      predictions=predictions, labels=labels, weights=weights,
489      metrics_collections=metrics_collections,
490      updates_collections=updates_collections, name=name)
491
492
493def streaming_precision(predictions, labels, weights=None,
494                        metrics_collections=None, updates_collections=None,
495                        name=None):
496  """Computes the precision of the predictions with respect to the labels.
497
498  The `streaming_precision` function creates two local variables,
499  `true_positives` and `false_positives`, that are used to compute the
500  precision. This value is ultimately returned as `precision`, an idempotent
501  operation that simply divides `true_positives` by the sum of `true_positives`
502  and `false_positives`.
503
504  For estimation of the metric  over a stream of data, the function creates an
505  `update_op` operation that updates these variables and returns the
506  `precision`. `update_op` weights each prediction by the corresponding value in
507  `weights`.
508
509  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
510
511  Args:
512    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
513    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
514      match `predictions`.
515    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
516      must be broadcastable to `labels` (i.e., all dimensions must be either
517      `1`, or the same as the corresponding `labels` dimension).
518    metrics_collections: An optional list of collections that `precision` should
519      be added to.
520    updates_collections: An optional list of collections that `update_op` should
521      be added to.
522    name: An optional variable_scope name.
523
524  Returns:
525    precision: Scalar float `Tensor` with the value of `true_positives`
526      divided by the sum of `true_positives` and `false_positives`.
527    update_op: `Operation` that increments `true_positives` and
528      `false_positives` variables appropriately and whose value matches
529      `precision`.
530
531  Raises:
532    ValueError: If `predictions` and `labels` have mismatched shapes, or if
533      `weights` is not `None` and its shape doesn't match `predictions`, or if
534      either `metrics_collections` or `updates_collections` are not a list or
535      tuple.
536  """
537  return metrics.precision(
538      predictions=predictions, labels=labels, weights=weights,
539      metrics_collections=metrics_collections,
540      updates_collections=updates_collections, name=name)
541
542
543def streaming_recall(predictions, labels, weights=None,
544                     metrics_collections=None, updates_collections=None,
545                     name=None):
546  """Computes the recall of the predictions with respect to the labels.
547
548  The `streaming_recall` function creates two local variables, `true_positives`
549  and `false_negatives`, that are used to compute the recall. This value is
550  ultimately returned as `recall`, an idempotent operation that simply divides
551  `true_positives` by the sum of `true_positives`  and `false_negatives`.
552
553  For estimation of the metric  over a stream of data, the function creates an
554  `update_op` that updates these variables and returns the `recall`. `update_op`
555  weights each prediction by the corresponding value in `weights`.
556
557  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
558
559  Args:
560    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
561    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
562      match `predictions`.
563    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
564      must be broadcastable to `labels` (i.e., all dimensions must be either
565      `1`, or the same as the corresponding `labels` dimension).
566    metrics_collections: An optional list of collections that `recall` should
567      be added to.
568    updates_collections: An optional list of collections that `update_op` should
569      be added to.
570    name: An optional variable_scope name.
571
572  Returns:
573    recall: Scalar float `Tensor` with the value of `true_positives` divided
574      by the sum of `true_positives` and `false_negatives`.
575    update_op: `Operation` that increments `true_positives` and
576      `false_negatives` variables appropriately and whose value matches
577      `recall`.
578
579  Raises:
580    ValueError: If `predictions` and `labels` have mismatched shapes, or if
581      `weights` is not `None` and its shape doesn't match `predictions`, or if
582      either `metrics_collections` or `updates_collections` are not a list or
583      tuple.
584  """
585  return metrics.recall(
586      predictions=predictions, labels=labels, weights=weights,
587      metrics_collections=metrics_collections,
588      updates_collections=updates_collections, name=name)
589
590
591def _streaming_confusion_matrix_at_thresholds(
592    predictions, labels, thresholds, weights=None, includes=None):
593  """Computes true_positives, false_negatives, true_negatives, false_positives.
594
595  This function creates up to four local variables, `true_positives`,
596  `true_negatives`, `false_positives` and `false_negatives`.
597  `true_positive[i]` is defined as the total weight of values in `predictions`
598  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
599  `false_negatives[i]` is defined as the total weight of values in `predictions`
600  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
601  `true_negatives[i]` is defined as the total weight of values in `predictions`
602  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
603  `false_positives[i]` is defined as the total weight of values in `predictions`
604  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
605
606  For estimation of these metrics over a stream of data, for each metric the
607  function respectively creates an `update_op` operation that updates the
608  variable and returns its value.
609
610  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
611
612  Args:
613    predictions: A floating point `Tensor` of arbitrary shape and whose values
614      are in the range `[0, 1]`.
615    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
616      to `bool`.
617    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
618    weights: Optional `Tensor` whose rank is either 0, or the same rank as
619      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
620      must be either `1`, or the same as the corresponding `labels`
621      dimension).
622    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
623      default to all four.
624
625  Returns:
626    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
627        `includes`.
628    update_ops: Dict of operations that increments the `values`. Keys are from
629        `includes`.
630
631  Raises:
632    ValueError: If `predictions` and `labels` have mismatched shapes, or if
633      `weights` is not `None` and its shape doesn't match `predictions`, or if
634      `includes` contains invalid keys.
635  """
636  all_includes = ('tp', 'fn', 'tn', 'fp')
637  if includes is None:
638    includes = all_includes
639  else:
640    for include in includes:
641      if include not in all_includes:
642        raise ValueError('Invaild key: %s.' % include)
643
644  predictions, labels, weights = _remove_squeezable_dimensions(
645      predictions, labels, weights)
646  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
647
648  num_thresholds = len(thresholds)
649
650  # Reshape predictions and labels.
651  predictions_2d = array_ops.reshape(predictions, [-1, 1])
652  labels_2d = array_ops.reshape(
653      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
654
655  # Use static shape if known.
656  num_predictions = predictions_2d.get_shape().as_list()[0]
657
658  # Otherwise use dynamic shape.
659  if num_predictions is None:
660    num_predictions = array_ops.shape(predictions_2d)[0]
661  thresh_tiled = array_ops.tile(
662      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
663      array_ops.stack([1, num_predictions]))
664
665  # Tile the predictions after thresholding them across different thresholds.
666  pred_is_pos = math_ops.greater(
667      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
668      thresh_tiled)
669  if ('fn' in includes) or ('tn' in includes):
670    pred_is_neg = math_ops.logical_not(pred_is_pos)
671
672  # Tile labels by number of thresholds
673  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
674  if ('fp' in includes) or ('tn' in includes):
675    label_is_neg = math_ops.logical_not(label_is_pos)
676
677  if weights is not None:
678    broadcast_weights = _broadcast_weights(
679        math_ops.to_float(weights), predictions)
680    weights_tiled = array_ops.tile(array_ops.reshape(
681        broadcast_weights, [1, -1]), [num_thresholds, 1])
682    thresh_tiled.get_shape().assert_is_compatible_with(
683        weights_tiled.get_shape())
684  else:
685    weights_tiled = None
686
687  values = {}
688  update_ops = {}
689
690  if 'tp' in includes:
691    true_positives = _create_local('true_positives', shape=[num_thresholds])
692    is_true_positive = math_ops.to_float(
693        math_ops.logical_and(label_is_pos, pred_is_pos))
694    if weights_tiled is not None:
695      is_true_positive *= weights_tiled
696    update_ops['tp'] = state_ops.assign_add(
697        true_positives, math_ops.reduce_sum(is_true_positive, 1))
698    values['tp'] = true_positives
699
700  if 'fn' in includes:
701    false_negatives = _create_local('false_negatives', shape=[num_thresholds])
702    is_false_negative = math_ops.to_float(
703        math_ops.logical_and(label_is_pos, pred_is_neg))
704    if weights_tiled is not None:
705      is_false_negative *= weights_tiled
706    update_ops['fn'] = state_ops.assign_add(
707        false_negatives, math_ops.reduce_sum(is_false_negative, 1))
708    values['fn'] = false_negatives
709
710  if 'tn' in includes:
711    true_negatives = _create_local('true_negatives', shape=[num_thresholds])
712    is_true_negative = math_ops.to_float(
713        math_ops.logical_and(label_is_neg, pred_is_neg))
714    if weights_tiled is not None:
715      is_true_negative *= weights_tiled
716    update_ops['tn'] = state_ops.assign_add(
717        true_negatives, math_ops.reduce_sum(is_true_negative, 1))
718    values['tn'] = true_negatives
719
720  if 'fp' in includes:
721    false_positives = _create_local('false_positives', shape=[num_thresholds])
722    is_false_positive = math_ops.to_float(
723        math_ops.logical_and(label_is_neg, pred_is_pos))
724    if weights_tiled is not None:
725      is_false_positive *= weights_tiled
726    update_ops['fp'] = state_ops.assign_add(
727        false_positives, math_ops.reduce_sum(is_false_positive, 1))
728    values['fp'] = false_positives
729
730  return values, update_ops
731
732
733def streaming_true_positives_at_thresholds(
734    predictions, labels, thresholds, weights=None):
735  values, update_ops = _streaming_confusion_matrix_at_thresholds(
736      predictions, labels, thresholds, weights=weights, includes=('tp',))
737  return values['tp'], update_ops['tp']
738
739
740def streaming_false_negatives_at_thresholds(
741    predictions, labels, thresholds, weights=None):
742  values, update_ops = _streaming_confusion_matrix_at_thresholds(
743      predictions, labels, thresholds, weights=weights, includes=('fn',))
744  return values['fn'], update_ops['fn']
745
746
747def streaming_false_positives_at_thresholds(
748    predictions, labels, thresholds, weights=None):
749  values, update_ops = _streaming_confusion_matrix_at_thresholds(
750      predictions, labels, thresholds, weights=weights, includes=('fp',))
751  return values['fp'], update_ops['fp']
752
753
754def streaming_true_negatives_at_thresholds(
755    predictions, labels, thresholds, weights=None):
756  values, update_ops = _streaming_confusion_matrix_at_thresholds(
757      predictions, labels, thresholds, weights=weights, includes=('tn',))
758  return values['tn'], update_ops['tn']
759
760
761def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
762                  metrics_collections=None, updates_collections=None,
763                  curve='ROC', name=None):
764  """Computes the approximate AUC via a Riemann sum.
765
766  The `streaming_auc` function creates four local variables, `true_positives`,
767  `true_negatives`, `false_positives` and `false_negatives` that are used to
768  compute the AUC. To discretize the AUC curve, a linearly spaced set of
769  thresholds is used to compute pairs of recall and precision values. The area
770  under the ROC-curve is therefore computed using the height of the recall
771  values by the false positive rate, while the area under the PR-curve is the
772  computed using the height of the precision values by the recall.
773
774  This value is ultimately returned as `auc`, an idempotent operation that
775  computes the area under a discretized curve of precision versus recall values
776  (computed using the aforementioned variables). The `num_thresholds` variable
777  controls the degree of discretization with larger numbers of thresholds more
778  closely approximating the true AUC. The quality of the approximation may vary
779  dramatically depending on `num_thresholds`.
780
781  For best results, `predictions` should be distributed approximately uniformly
782  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
783  approximation may be poor if this is not the case.
784
785  For estimation of the metric over a stream of data, the function creates an
786  `update_op` operation that updates these variables and returns the `auc`.
787
788  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
789
790  Args:
791    predictions: A floating point `Tensor` of arbitrary shape and whose values
792      are in the range `[0, 1]`.
793    labels: A `bool` `Tensor` whose shape matches `predictions`.
794    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
795      must be broadcastable to `labels` (i.e., all dimensions must be either
796      `1`, or the same as the corresponding `labels` dimension).
797    num_thresholds: The number of thresholds to use when discretizing the roc
798      curve.
799    metrics_collections: An optional list of collections that `auc` should be
800      added to.
801    updates_collections: An optional list of collections that `update_op` should
802      be added to.
803    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
804    'PR' for the Precision-Recall-curve.
805    name: An optional variable_scope name.
806
807  Returns:
808    auc: A scalar `Tensor` representing the current area-under-curve.
809    update_op: An operation that increments the `true_positives`,
810      `true_negatives`, `false_positives` and `false_negatives` variables
811      appropriately and whose value matches `auc`.
812
813  Raises:
814    ValueError: If `predictions` and `labels` have mismatched shapes, or if
815      `weights` is not `None` and its shape doesn't match `predictions`, or if
816      either `metrics_collections` or `updates_collections` are not a list or
817      tuple.
818  """
819  return metrics.auc(
820      predictions=predictions, labels=labels, weights=weights,
821      metrics_collections=metrics_collections, num_thresholds=num_thresholds,
822      curve=curve, updates_collections=updates_collections, name=name)
823
824
825def streaming_specificity_at_sensitivity(
826    predictions, labels, sensitivity, weights=None, num_thresholds=200,
827    metrics_collections=None, updates_collections=None, name=None):
828  """Computes the specificity at a given sensitivity.
829
830  The `streaming_specificity_at_sensitivity` function creates four local
831  variables, `true_positives`, `true_negatives`, `false_positives` and
832  `false_negatives` that are used to compute the specificity at the given
833  sensitivity value. The threshold for the given sensitivity value is computed
834  and used to evaluate the corresponding specificity.
835
836  For estimation of the metric over a stream of data, the function creates an
837  `update_op` operation that updates these variables and returns the
838  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
839  `false_positives` and `false_negatives` counts with the weight of each case
840  found in the `predictions` and `labels`.
841
842  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
843
844  For additional information about specificity and sensitivity, see the
845  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
846
847  Args:
848    predictions: A floating point `Tensor` of arbitrary shape and whose values
849      are in the range `[0, 1]`.
850    labels: A `bool` `Tensor` whose shape matches `predictions`.
851    sensitivity: A scalar value in range `[0, 1]`.
852    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
853      must be broadcastable to `labels` (i.e., all dimensions must be either
854      `1`, or the same as the corresponding `labels` dimension).
855    num_thresholds: The number of thresholds to use for matching the given
856      sensitivity.
857    metrics_collections: An optional list of collections that `specificity`
858      should be added to.
859    updates_collections: An optional list of collections that `update_op` should
860      be added to.
861    name: An optional variable_scope name.
862
863  Returns:
864    specificity: A scalar `Tensor` representing the specificity at the given
865      `specificity` value.
866    update_op: An operation that increments the `true_positives`,
867      `true_negatives`, `false_positives` and `false_negatives` variables
868      appropriately and whose value matches `specificity`.
869
870  Raises:
871    ValueError: If `predictions` and `labels` have mismatched shapes, if
872      `weights` is not `None` and its shape doesn't match `predictions`, or if
873      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
874      or `updates_collections` are not a list or tuple.
875  """
876  return metrics.specificity_at_sensitivity(
877      sensitivity=sensitivity, num_thresholds=num_thresholds,
878      predictions=predictions, labels=labels, weights=weights,
879      metrics_collections=metrics_collections,
880      updates_collections=updates_collections, name=name)
881
882
883def streaming_sensitivity_at_specificity(
884    predictions, labels, specificity, weights=None, num_thresholds=200,
885    metrics_collections=None, updates_collections=None, name=None):
886  """Computes the specificity at a given sensitivity.
887
888  The `streaming_sensitivity_at_specificity` function creates four local
889  variables, `true_positives`, `true_negatives`, `false_positives` and
890  `false_negatives` that are used to compute the sensitivity at the given
891  specificity value. The threshold for the given specificity value is computed
892  and used to evaluate the corresponding sensitivity.
893
894  For estimation of the metric over a stream of data, the function creates an
895  `update_op` operation that updates these variables and returns the
896  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
897  `false_positives` and `false_negatives` counts with the weight of each case
898  found in the `predictions` and `labels`.
899
900  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
901
902  For additional information about specificity and sensitivity, see the
903  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
904
905  Args:
906    predictions: A floating point `Tensor` of arbitrary shape and whose values
907      are in the range `[0, 1]`.
908    labels: A `bool` `Tensor` whose shape matches `predictions`.
909    specificity: A scalar value in range `[0, 1]`.
910    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
911      must be broadcastable to `labels` (i.e., all dimensions must be either
912      `1`, or the same as the corresponding `labels` dimension).
913    num_thresholds: The number of thresholds to use for matching the given
914      specificity.
915    metrics_collections: An optional list of collections that `sensitivity`
916      should be added to.
917    updates_collections: An optional list of collections that `update_op` should
918      be added to.
919    name: An optional variable_scope name.
920
921  Returns:
922    sensitivity: A scalar `Tensor` representing the sensitivity at the given
923      `specificity` value.
924    update_op: An operation that increments the `true_positives`,
925      `true_negatives`, `false_positives` and `false_negatives` variables
926      appropriately and whose value matches `sensitivity`.
927
928  Raises:
929    ValueError: If `predictions` and `labels` have mismatched shapes, if
930      `weights` is not `None` and its shape doesn't match `predictions`, or if
931      `specificity` is not between 0 and 1, or if either `metrics_collections`
932      or `updates_collections` are not a list or tuple.
933  """
934  return metrics.sensitivity_at_specificity(
935      specificity=specificity, num_thresholds=num_thresholds,
936      predictions=predictions, labels=labels, weights=weights,
937      metrics_collections=metrics_collections,
938      updates_collections=updates_collections, name=name)
939
940
941def streaming_precision_at_thresholds(predictions, labels, thresholds,
942                                      weights=None,
943                                      metrics_collections=None,
944                                      updates_collections=None, name=None):
945  """Computes precision values for different `thresholds` on `predictions`.
946
947  The `streaming_precision_at_thresholds` function creates four local variables,
948  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
949  for various values of thresholds. `precision[i]` is defined as the total
950  weight of values in `predictions` above `thresholds[i]` whose corresponding
951  entry in `labels` is `True`, divided by the total weight of values in
952  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
953  false_positives[i])`).
954
955  For estimation of the metric over a stream of data, the function creates an
956  `update_op` operation that updates these variables and returns the
957  `precision`.
958
959  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
960
961  Args:
962    predictions: A floating point `Tensor` of arbitrary shape and whose values
963      are in the range `[0, 1]`.
964    labels: A `bool` `Tensor` whose shape matches `predictions`.
965    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
966    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
967      must be broadcastable to `labels` (i.e., all dimensions must be either
968      `1`, or the same as the corresponding `labels` dimension).
969    metrics_collections: An optional list of collections that `auc` should be
970      added to.
971    updates_collections: An optional list of collections that `update_op` should
972      be added to.
973    name: An optional variable_scope name.
974
975  Returns:
976    precision: A float `Tensor` of shape `[len(thresholds)]`.
977    update_op: An operation that increments the `true_positives`,
978      `true_negatives`, `false_positives` and `false_negatives` variables that
979      are used in the computation of `precision`.
980
981  Raises:
982    ValueError: If `predictions` and `labels` have mismatched shapes, or if
983      `weights` is not `None` and its shape doesn't match `predictions`, or if
984      either `metrics_collections` or `updates_collections` are not a list or
985      tuple.
986  """
987  return metrics.precision_at_thresholds(
988      thresholds=thresholds,
989      predictions=predictions, labels=labels, weights=weights,
990      metrics_collections=metrics_collections,
991      updates_collections=updates_collections, name=name)
992
993
994def streaming_recall_at_thresholds(predictions, labels, thresholds,
995                                   weights=None, metrics_collections=None,
996                                   updates_collections=None, name=None):
997  """Computes various recall values for different `thresholds` on `predictions`.
998
999  The `streaming_recall_at_thresholds` function creates four local variables,
1000  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1001  for various values of thresholds. `recall[i]` is defined as the total weight
1002  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1003  `labels` is `True`, divided by the total weight of `True` values in `labels`
1004  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
1005
1006  For estimation of the metric over a stream of data, the function creates an
1007  `update_op` operation that updates these variables and returns the `recall`.
1008
1009  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1010
1011  Args:
1012    predictions: A floating point `Tensor` of arbitrary shape and whose values
1013      are in the range `[0, 1]`.
1014    labels: A `bool` `Tensor` whose shape matches `predictions`.
1015    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1016    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1017      must be broadcastable to `labels` (i.e., all dimensions must be either
1018      `1`, or the same as the corresponding `labels` dimension).
1019    metrics_collections: An optional list of collections that `recall` should be
1020      added to.
1021    updates_collections: An optional list of collections that `update_op` should
1022      be added to.
1023    name: An optional variable_scope name.
1024
1025  Returns:
1026    recall: A float `Tensor` of shape `[len(thresholds)]`.
1027    update_op: An operation that increments the `true_positives`,
1028      `true_negatives`, `false_positives` and `false_negatives` variables that
1029      are used in the computation of `recall`.
1030
1031  Raises:
1032    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1033      `weights` is not `None` and its shape doesn't match `predictions`, or if
1034      either `metrics_collections` or `updates_collections` are not a list or
1035      tuple.
1036  """
1037  return metrics.recall_at_thresholds(
1038      thresholds=thresholds,
1039      predictions=predictions, labels=labels, weights=weights,
1040      metrics_collections=metrics_collections,
1041      updates_collections=updates_collections, name=name)
1042
1043
1044def _at_k_name(name, k=None, class_id=None):
1045  if k is not None:
1046    name = '%s_at_%d' % (name, k)
1047  else:
1048    name = '%s_at_k' % (name)
1049  if class_id is not None:
1050    name = '%s_class%d' % (name, class_id)
1051  return name
1052
1053
1054@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
1055            'and reshape labels from [batch_size] to [batch_size, 1].')
1056def streaming_recall_at_k(predictions, labels, k, weights=None,
1057                          metrics_collections=None, updates_collections=None,
1058                          name=None):
1059  """Computes the recall@k of the predictions with respect to dense labels.
1060
1061  The `streaming_recall_at_k` function creates two local variables, `total` and
1062  `count`, that are used to compute the recall@k frequency. This frequency is
1063  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1064  divides `total` by `count`.
1065
1066  For estimation of the metric over a stream of data, the function creates an
1067  `update_op` operation that updates these variables and returns the
1068  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
1069  shape [batch_size] whose elements indicate whether or not the corresponding
1070  label is in the top `k` `predictions`. Then `update_op` increments `total`
1071  with the reduced sum of `weights` where `in_top_k` is `True`, and it
1072  increments `count` with the reduced sum of `weights`.
1073
1074  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1075
1076  Args:
1077    predictions: A float `Tensor` of dimension [batch_size, num_classes].
1078    labels: A `Tensor` of dimension [batch_size] whose type is in `int32`,
1079      `int64`.
1080    k: The number of top elements to look at for computing recall.
1081    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1082      must be broadcastable to `labels` (i.e., all dimensions must be either
1083      `1`, or the same as the corresponding `labels` dimension).
1084    metrics_collections: An optional list of collections that `recall_at_k`
1085      should be added to.
1086    updates_collections: An optional list of collections `update_op` should be
1087      added to.
1088    name: An optional variable_scope name.
1089
1090  Returns:
1091    recall_at_k: A `Tensor` representing the recall@k, the fraction of labels
1092      which fall into the top `k` predictions.
1093    update_op: An operation that increments the `total` and `count` variables
1094      appropriately and whose value matches `recall_at_k`.
1095
1096  Raises:
1097    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1098      `weights` is not `None` and its shape doesn't match `predictions`, or if
1099      either `metrics_collections` or `updates_collections` are not a list or
1100      tuple.
1101  """
1102  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1103  return streaming_mean(in_top_k,
1104                        weights,
1105                        metrics_collections,
1106                        updates_collections,
1107                        name or _at_k_name('recall', k))
1108
1109
1110# TODO(ptucker): Validate range of values in labels?
1111def streaming_sparse_recall_at_k(predictions,
1112                                 labels,
1113                                 k,
1114                                 class_id=None,
1115                                 weights=None,
1116                                 metrics_collections=None,
1117                                 updates_collections=None,
1118                                 name=None):
1119  """Computes recall@k of the predictions with respect to sparse labels.
1120
1121  If `class_id` is not specified, we'll calculate recall as the ratio of true
1122      positives (i.e., correct predictions, items in the top `k` highest
1123      `predictions` that are found in the corresponding row in `labels`) to
1124      actual positives (the full `labels` row).
1125  If `class_id` is specified, we calculate recall by considering only the rows
1126      in the batch for which `class_id` is in `labels`, and computing the
1127      fraction of them for which `class_id` is in the corresponding row in
1128      `labels`.
1129
1130  `streaming_sparse_recall_at_k` creates two local variables,
1131  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1132  the recall_at_k frequency. This frequency is ultimately returned as
1133  `recall_at_<k>`: an idempotent operation that simply divides
1134  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1135  `false_negative_at_<k>`).
1136
1137  For estimation of the metric over a stream of data, the function creates an
1138  `update_op` operation that updates these variables and returns the
1139  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1140  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1141  `labels` calculate the true positives and false negatives weighted by
1142  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1143  `false_negative_at_<k>` using these values.
1144
1145  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1146
1147  Args:
1148    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1149      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1150      The final dimension contains the logit values for each class. [D1, ... DN]
1151      must match `labels`.
1152    labels: `int64` `Tensor` or `SparseTensor` with shape
1153      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1154      target classes for the associated prediction. Commonly, N=1 and `labels`
1155      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
1156      Values should be in range [0, num_classes), where num_classes is the last
1157      dimension of `predictions`. Values outside this range always count
1158      towards `false_negative_at_<k>`.
1159    k: Integer, k for @k metric.
1160    class_id: Integer class ID for which we want binary metrics. This should be
1161      in range [0, num_classes), where num_classes is the last dimension of
1162      `predictions`. If class_id is outside this range, the method returns NAN.
1163    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1164      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1165      dimensions must be either `1`, or the same as the corresponding `labels`
1166      dimension).
1167    metrics_collections: An optional list of collections that values should
1168      be added to.
1169    updates_collections: An optional list of collections that updates should
1170      be added to.
1171    name: Name of new update operation, and namespace for other dependent ops.
1172
1173  Returns:
1174    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1175      by the sum of `true_positives` and `false_negatives`.
1176    update_op: `Operation` that increments `true_positives` and
1177      `false_negatives` variables appropriately, and whose value matches
1178      `recall`.
1179
1180  Raises:
1181    ValueError: If `weights` is not `None` and its shape doesn't match
1182    `predictions`, or if either `metrics_collections` or `updates_collections`
1183    are not a list or tuple.
1184  """
1185  return metrics.recall_at_k(
1186      k=k, class_id=class_id,
1187      predictions=predictions, labels=labels, weights=weights,
1188      metrics_collections=metrics_collections,
1189      updates_collections=updates_collections, name=name)
1190
1191
1192def _streaming_sparse_precision_at_k(top_k_idx,
1193                                     labels,
1194                                     k=None,
1195                                     class_id=None,
1196                                     weights=None,
1197                                     metrics_collections=None,
1198                                     updates_collections=None,
1199                                     name=None):
1200  """Computes precision@k of the top-k indices with respect to sparse labels.
1201
1202  This method contains the code shared by streaming_sparse_precision_at_k and
1203  streaming_sparse_precision_at_top_k. Refer to those methods for more details.
1204
1205  Args:
1206    top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where
1207      N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k].
1208      The final dimension contains the indices of top-k labels. [D1, ... DN]
1209      must match `labels`.
1210    labels: `int64` `Tensor` or `SparseTensor` with shape
1211      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1212      target classes for the associated prediction. Commonly, N=1 and `labels`
1213      has shape [batch_size, num_labels]. [D1, ... DN] must match
1214      `predictions_idx`. Values should be in range [0, num_classes), where
1215      num_classes is the last dimension of `predictions`. Values outside this
1216      range are ignored.
1217    k: Integer, k for @k metric or `None`. Only used for default op name.
1218    class_id: Integer class ID for which we want binary metrics. This should be
1219      in range [0, num_classes), where num_classes is the last dimension of
1220      `predictions`. If `class_id` is outside this range, the method returns
1221      NAN.
1222    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1223      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1224      dimensions must be either `1`, or the same as the corresponding `labels`
1225      dimension).
1226    metrics_collections: An optional list of collections that values should
1227      be added to.
1228    updates_collections: An optional list of collections that updates should
1229      be added to.
1230    name: Name of the metric and of the enclosing scope.
1231
1232  Returns:
1233    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1234      divided by the sum of `true_positives` and `false_positives`.
1235    update_op: `Operation` that increments `true_positives` and
1236      `false_positives` variables appropriately, and whose value matches
1237      `precision`.
1238
1239  Raises:
1240    ValueError: If `weights` is not `None` and its shape doesn't match
1241      `predictions`, or if either `metrics_collections` or `updates_collections`
1242      are not a list or tuple.
1243  """
1244  top_k_idx = math_ops.to_int64(top_k_idx)
1245  tp, tp_update = _streaming_sparse_true_positive_at_k(
1246      predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1247      weights=weights)
1248  fp, fp_update = _streaming_sparse_false_positive_at_k(
1249      predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
1250      weights=weights)
1251
1252  metric = math_ops.div(tp, math_ops.add(tp, fp), name=name)
1253  update = math_ops.div(
1254      tp_update, math_ops.add(tp_update, fp_update), name='update')
1255  if metrics_collections:
1256    ops.add_to_collections(metrics_collections, metric)
1257  if updates_collections:
1258    ops.add_to_collections(updates_collections, update)
1259  return metric, update
1260
1261
1262# TODO(ptucker): Validate range of values in labels?
1263def streaming_sparse_precision_at_k(predictions,
1264                                    labels,
1265                                    k,
1266                                    class_id=None,
1267                                    weights=None,
1268                                    metrics_collections=None,
1269                                    updates_collections=None,
1270                                    name=None):
1271  """Computes precision@k of the predictions with respect to sparse labels.
1272
1273  If `class_id` is not specified, we calculate precision as the ratio of true
1274      positives (i.e., correct predictions, items in the top `k` highest
1275      `predictions` that are found in the corresponding row in `labels`) to
1276      positives (all top `k` `predictions`).
1277  If `class_id` is specified, we calculate precision by considering only the
1278      rows in the batch for which `class_id` is in the top `k` highest
1279      `predictions`, and computing the fraction of them for which `class_id` is
1280      in the corresponding row in `labels`.
1281
1282  We expect precision to decrease as `k` increases.
1283
1284  `streaming_sparse_precision_at_k` creates two local variables,
1285  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1286  the precision@k frequency. This frequency is ultimately returned as
1287  `precision_at_<k>`: an idempotent operation that simply divides
1288  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1289  `false_positive_at_<k>`).
1290
1291  For estimation of the metric over a stream of data, the function creates an
1292  `update_op` operation that updates these variables and returns the
1293  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1294  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1295  `labels` calculate the true positives and false positives weighted by
1296  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1297  `false_positive_at_<k>` using these values.
1298
1299  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1300
1301  Args:
1302    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1303      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1304      The final dimension contains the logit values for each class. [D1, ... DN]
1305      must match `labels`.
1306    labels: `int64` `Tensor` or `SparseTensor` with shape
1307      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1308      target classes for the associated prediction. Commonly, N=1 and `labels`
1309      has shape [batch_size, num_labels]. [D1, ... DN] must match
1310      `predictions`. Values should be in range [0, num_classes), where
1311      num_classes is the last dimension of `predictions`. Values outside this
1312      range are ignored.
1313    k: Integer, k for @k metric.
1314    class_id: Integer class ID for which we want binary metrics. This should be
1315      in range [0, num_classes], where num_classes is the last dimension of
1316      `predictions`. If `class_id` is outside this range, the method returns
1317      NAN.
1318    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1319      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1320      dimensions must be either `1`, or the same as the corresponding `labels`
1321      dimension).
1322    metrics_collections: An optional list of collections that values should
1323      be added to.
1324    updates_collections: An optional list of collections that updates should
1325      be added to.
1326    name: Name of new update operation, and namespace for other dependent ops.
1327
1328  Returns:
1329    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1330      divided by the sum of `true_positives` and `false_positives`.
1331    update_op: `Operation` that increments `true_positives` and
1332      `false_positives` variables appropriately, and whose value matches
1333      `precision`.
1334
1335  Raises:
1336    ValueError: If `weights` is not `None` and its shape doesn't match
1337      `predictions`, or if either `metrics_collections` or `updates_collections`
1338      are not a list or tuple.
1339  """
1340  return metrics.sparse_precision_at_k(
1341      k=k, class_id=class_id,
1342      predictions=predictions, labels=labels, weights=weights,
1343      metrics_collections=metrics_collections,
1344      updates_collections=updates_collections, name=name)
1345
1346
1347# TODO(ptucker): Validate range of values in labels?
1348def streaming_sparse_precision_at_top_k(top_k_predictions,
1349                                        labels,
1350                                        class_id=None,
1351                                        weights=None,
1352                                        metrics_collections=None,
1353                                        updates_collections=None,
1354                                        name=None):
1355  """Computes precision@k of top-k predictions with respect to sparse labels.
1356
1357  If `class_id` is not specified, we calculate precision as the ratio of
1358      true positives (i.e., correct predictions, items in `top_k_predictions`
1359      that are found in the corresponding row in `labels`) to positives (all
1360      `top_k_predictions`).
1361  If `class_id` is specified, we calculate precision by considering only the
1362      rows in the batch for which `class_id` is in the top `k` highest
1363      `predictions`, and computing the fraction of them for which `class_id` is
1364      in the corresponding row in `labels`.
1365
1366  We expect precision to decrease as `k` increases.
1367
1368  `streaming_sparse_precision_at_top_k` creates two local variables,
1369  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
1370  the precision@k frequency. This frequency is ultimately returned as
1371  `precision_at_k`: an idempotent operation that simply divides
1372  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
1373
1374  For estimation of the metric over a stream of data, the function creates an
1375  `update_op` operation that updates these variables and returns the
1376  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
1377  and `labels` calculate the true positives and false positives weighted by
1378  `weights`. Then `update_op` increments `true_positive_at_k` and
1379  `false_positive_at_k` using these values.
1380
1381  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1382
1383  Args:
1384    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
1385      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
1386      The final dimension contains the indices of top-k labels. [D1, ... DN]
1387      must match `labels`.
1388    labels: `int64` `Tensor` or `SparseTensor` with shape
1389      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1390      target classes for the associated prediction. Commonly, N=1 and `labels`
1391      has shape [batch_size, num_labels]. [D1, ... DN] must match
1392      `top_k_predictions`. Values should be in range [0, num_classes), where
1393      num_classes is the last dimension of `predictions`. Values outside this
1394      range are ignored.
1395    class_id: Integer class ID for which we want binary metrics. This should be
1396      in range [0, num_classes), where num_classes is the last dimension of
1397      `predictions`. If `class_id` is outside this range, the method returns
1398      NAN.
1399    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1400      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1401      dimensions must be either `1`, or the same as the corresponding `labels`
1402      dimension).
1403    metrics_collections: An optional list of collections that values should
1404      be added to.
1405    updates_collections: An optional list of collections that updates should
1406      be added to.
1407    name: Name of new update operation, and namespace for other dependent ops.
1408
1409  Returns:
1410    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1411      divided by the sum of `true_positives` and `false_positives`.
1412    update_op: `Operation` that increments `true_positives` and
1413      `false_positives` variables appropriately, and whose value matches
1414      `precision`.
1415
1416  Raises:
1417    ValueError: If `weights` is not `None` and its shape doesn't match
1418      `predictions`, or if either `metrics_collections` or `updates_collections`
1419      are not a list or tuple.
1420    ValueError: If `top_k_predictions` has rank < 2.
1421  """
1422  default_name = _at_k_name('precision', class_id=class_id)
1423  with ops.name_scope(
1424      name, default_name,
1425      (top_k_predictions, labels, weights)) as name_scope:
1426    return _streaming_sparse_precision_at_k(
1427        top_k_idx=top_k_predictions,
1428        labels=labels,
1429        class_id=class_id,
1430        weights=weights,
1431        metrics_collections=metrics_collections,
1432        updates_collections=updates_collections,
1433        name=name_scope)
1434
1435
1436def num_relevant(labels, k):
1437  """Computes number of relevant values for each row in labels.
1438
1439  For labels with shape [D1, ... DN, num_labels], this is the minimum of
1440  `num_labels` and `k`.
1441
1442  Args:
1443    labels: `int64` `Tensor` or `SparseTensor` with shape
1444      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1445      target classes for the associated prediction. Commonly, N=1 and `labels`
1446      has shape [batch_size, num_labels].
1447    k: Integer, k for @k metric.
1448
1449  Returns:
1450    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
1451    relevant values for that row.
1452
1453  Raises:
1454    ValueError: if inputs have invalid dtypes or values.
1455  """
1456  if k < 1:
1457    raise ValueError('Invalid k=%s.' % k)
1458  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
1459    # For SparseTensor, calculate separate count for each row.
1460    if isinstance(
1461        labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
1462      labels_sizes = set_ops.set_size(labels)
1463      return math_ops.minimum(labels_sizes, k, name=scope)
1464
1465    # For dense Tensor, calculate scalar count based on last dimension, and
1466    # tile across labels shape.
1467    labels_shape = array_ops.shape(labels)
1468    labels_size = labels_shape[-1]
1469    num_relevant_scalar = math_ops.minimum(labels_size, k)
1470    return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
1471
1472
1473def expand_and_tile(tensor, multiple, dim=0, name=None):
1474  """Slice `tensor` shape in 2, then tile along the sliced dimension.
1475
1476  A new dimension is inserted in shape of `tensor` before `dim`, then values are
1477  tiled `multiple` times along the new dimension.
1478
1479  Args:
1480    tensor: Input `Tensor` or `SparseTensor`.
1481    multiple: Integer, number of times to tile.
1482    dim: Integer, dimension along which to tile.
1483    name: Name of operation.
1484
1485  Returns:
1486    `Tensor` result of expanding and tiling `tensor`.
1487
1488  Raises:
1489    ValueError: if `multiple` is less than 1, or `dim` is not in
1490    `[-rank(tensor), rank(tensor)]`.
1491  """
1492  if multiple < 1:
1493    raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
1494  with ops.name_scope(
1495      name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
1496    # Sparse.
1497    if isinstance(tensor, sparse_tensor.SparseTensorValue):
1498      tensor = sparse_tensor.SparseTensor.from_value(tensor)
1499    if isinstance(tensor, sparse_tensor.SparseTensor):
1500      if dim < 0:
1501        expand_dims = array_ops.reshape(
1502            array_ops.size(tensor.dense_shape) + dim, [1])
1503      else:
1504        expand_dims = [dim]
1505      expanded_shape = array_ops.concat(
1506          (array_ops.strided_slice(tensor.dense_shape, [0], expand_dims), [1],
1507           array_ops.strided_slice(
1508               tensor.dense_shape, expand_dims, [-1], end_mask=1 << 0)),
1509          0,
1510          name='expanded_shape')
1511      expanded = sparse_ops.sparse_reshape(
1512          tensor, shape=expanded_shape, name='expand')
1513      if multiple == 1:
1514        return expanded
1515      return sparse_ops.sparse_concat(
1516          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
1517
1518    # Dense.
1519    expanded = array_ops.expand_dims(
1520        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
1521    if multiple == 1:
1522      return expanded
1523    ones = array_ops.ones_like(array_ops.shape(tensor))
1524    tile_multiples = array_ops.concat(
1525        (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
1526    return array_ops.tile(expanded, tile_multiples, name=scope)
1527
1528
1529def sparse_average_precision_at_k(predictions, labels, k):
1530  """Computes average precision@k of predictions with respect to sparse labels.
1531
1532  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
1533  for each row is:
1534
1535    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
1536
1537  A "row" is the elements in dimension [D1, ... DN] of `predictions`, `labels`,
1538  and the result `Tensors`. In the common case, this is [batch_size]. Each row
1539  of the results contains the average precision for that row.
1540
1541  Internally, a `top_k` operation computes a `Tensor` indicating the top `k`
1542  `predictions`. Set operations applied to `top_k` and `labels` calculate the
1543  true positives, which are used to calculate the precision ("P_{i}" term,
1544  above).
1545
1546  Args:
1547    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1548      N >= 1. Commonly, N=1 and `predictions` has shape
1549      [batch size, num_classes]. The final dimension contains the logit values
1550      for each class. [D1, ... DN] must match `labels`.
1551    labels: `int64` `Tensor` or `SparseTensor` with shape
1552      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1553      target classes for the associated prediction. Commonly, N=1 and `labels`
1554      has shape [batch_size, num_labels]. [D1, ... DN] must match
1555      `predictions`. Values should be in range [0, num_classes), where
1556      num_classes is the last dimension of `predictions`. Values outside this
1557      range are ignored.
1558    k: Integer, k for @k metric. This will calculate an average precision for
1559      range `[1,k]`, as documented above.
1560
1561  Returns:
1562    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
1563    precision for that row.
1564
1565  Raises:
1566    ValueError: if k is invalid.
1567  """
1568  if k < 1:
1569    raise ValueError('Invalid k=%s.' % k)
1570  with ops.name_scope(
1571      None, 'average_precision', (predictions, labels, k)) as scope:
1572    # Calculate top k indices to produce [D1, ... DN, k] tensor.
1573    _, predictions_idx = nn.top_k(predictions, k)
1574    predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
1575
1576    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
1577    # prediction for each k, so we can calculate separate true positive values
1578    # for each k.
1579    predictions_idx_per_k = array_ops.expand_dims(
1580        predictions_idx, -1, name='predictions_idx_per_k')
1581
1582    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
1583    labels_per_k = expand_and_tile(
1584        labels, multiple=k, dim=-1, name='labels_per_k')
1585
1586    # The following tensors are all of shape [D1, ... DN, k], containing values
1587    # per row, per k value.
1588    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
1589    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
1590    #     the formula above.
1591    # `tp_per_k` (int32) - True positive counts.
1592    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
1593    #     the precision denominator.
1594    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
1595    #     term from the formula above.
1596    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
1597    #     precisions at all k for which relevance indicator is true.
1598    relevant_per_k = _sparse_true_positive_at_k(
1599        predictions_idx_per_k, labels_per_k, name='relevant_per_k')
1600    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
1601    retrieved_per_k = math_ops.cumsum(
1602        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
1603    precision_per_k = math_ops.div(
1604        math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
1605        name='precision_per_k')
1606    relevant_precision_per_k = math_ops.multiply(
1607        precision_per_k, math_ops.to_double(relevant_per_k),
1608        name='relevant_precision_per_k')
1609
1610    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
1611    precision_sum = math_ops.reduce_sum(
1612        relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
1613
1614    # Divide by number of relevant items to get average precision. These are
1615    # the "num_relevant_items" and "AveP" terms from the formula above.
1616    num_relevant_items = math_ops.to_double(num_relevant(labels, k))
1617    return math_ops.div(precision_sum, num_relevant_items, name=scope)
1618
1619
1620def streaming_sparse_average_precision_at_k(predictions,
1621                                            labels,
1622                                            k,
1623                                            weights=None,
1624                                            metrics_collections=None,
1625                                            updates_collections=None,
1626                                            name=None):
1627  """Computes average precision@k of predictions with respect to sparse labels.
1628
1629  See `sparse_average_precision_at_k` for details on formula. `weights` are
1630  applied to the result of `sparse_average_precision_at_k`
1631
1632  `streaming_sparse_average_precision_at_k` creates two local variables,
1633  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
1634  are used to compute the frequency. This frequency is ultimately returned as
1635  `average_precision_at_<k>`: an idempotent operation that simply divides
1636  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
1637
1638  For estimation of the metric over a stream of data, the function creates an
1639  `update_op` operation that updates these variables and returns the
1640  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1641  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1642  `labels` calculate the true positives and false positives weighted by
1643  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1644  `false_positive_at_<k>` using these values.
1645
1646  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1647
1648  Args:
1649    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1650      N >= 1. Commonly, N=1 and `predictions` has shape
1651      [batch size, num_classes]. The final dimension contains the logit values
1652      for each class. [D1, ... DN] must match `labels`.
1653    labels: `int64` `Tensor` or `SparseTensor` with shape
1654      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1655      target classes for the associated prediction. Commonly, N=1 and `labels`
1656      has shape [batch_size, num_labels]. [D1, ... DN] must match
1657      `predictions_`. Values should be in range [0, num_classes), where
1658      num_classes is the last dimension of `predictions`. Values outside this
1659      range are ignored.
1660    k: Integer, k for @k metric. This will calculate an average precision for
1661      range `[1,k]`, as documented above.
1662    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1663      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1664      dimensions must be either `1`, or the same as the corresponding `labels`
1665      dimension).
1666    metrics_collections: An optional list of collections that values should
1667      be added to.
1668    updates_collections: An optional list of collections that updates should
1669      be added to.
1670    name: Name of new update operation, and namespace for other dependent ops.
1671
1672  Returns:
1673    mean_average_precision: Scalar `float64` `Tensor` with the mean average
1674      precision values.
1675    update: `Operation` that increments  variables appropriately, and whose
1676      value matches `metric`.
1677  """
1678  return metrics.sparse_average_precision_at_k(
1679      k=k, predictions=predictions, labels=labels, weights=weights,
1680      metrics_collections=metrics_collections,
1681      updates_collections=updates_collections, name=name)
1682
1683
1684def _select_class_id(ids, selected_id):
1685  """Filter all but `selected_id` out of `ids`.
1686
1687  Args:
1688    ids: `int64` `Tensor` or `SparseTensor` of IDs.
1689    selected_id: Int id to select.
1690
1691  Returns:
1692    `SparseTensor` of same dimensions as `ids`. This contains only the entries
1693    equal to `selected_id`.
1694  """
1695  if isinstance(
1696      ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
1697    return sparse_ops.sparse_retain(
1698        ids, math_ops.equal(ids.values, selected_id))
1699
1700  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
1701  # tf.equal and tf.reduce_any?
1702
1703  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
1704  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
1705  ids_last_dim = array_ops.size(ids_shape) - 1
1706  filled_selected_id_shape = math_ops.reduced_shape(
1707      ids_shape, array_ops.reshape(ids_last_dim, [1]))
1708
1709  # Intersect `ids` with the selected ID.
1710  filled_selected_id = array_ops.fill(
1711      filled_selected_id_shape, math_ops.to_int64(selected_id))
1712  result = set_ops.set_intersection(filled_selected_id, ids)
1713  return sparse_tensor.SparseTensor(
1714      indices=result.indices, values=result.values, dense_shape=ids_shape)
1715
1716
1717def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
1718  """If class ID is specified, filter all other classes.
1719
1720  Args:
1721    labels: `int64` `Tensor` or `SparseTensor` with shape
1722      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1723      target classes for the associated prediction. Commonly, N=1 and `labels`
1724      has shape [batch_size, num_labels]. [D1, ... DN] must match
1725      `predictions_idx`.
1726    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
1727      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
1728      [batch size, k].
1729    selected_id: Int id to select.
1730
1731  Returns:
1732    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
1733  """
1734  if selected_id is None:
1735    return labels, predictions_idx
1736  return (_select_class_id(labels, selected_id),
1737          _select_class_id(predictions_idx, selected_id))
1738
1739
1740def _sparse_true_positive_at_k(predictions_idx,
1741                               labels,
1742                               class_id=None,
1743                               weights=None,
1744                               name=None):
1745  """Calculates true positives for recall@k and precision@k.
1746
1747  If `class_id` is specified, calculate binary true positives for `class_id`
1748      only.
1749  If `class_id` is not specified, calculate metrics for `k` predicted vs
1750      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1751
1752  Args:
1753    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1754      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1755      match `labels`.
1756    labels: `int64` `Tensor` or `SparseTensor` with shape
1757      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1758      target classes for the associated prediction. Commonly, N=1 and `labels`
1759      has shape [batch_size, num_labels]. [D1, ... DN] must match
1760      `predictions_idx`.
1761    class_id: Class for which we want binary metrics.
1762    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1763      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1764      dimensions must be either `1`, or the same as the corresponding `labels`
1765      dimension).
1766    name: Name of operation.
1767
1768  Returns:
1769    A [D1, ... DN] `Tensor` of true positive counts.
1770  """
1771  with ops.name_scope(
1772      name, 'true_positives', (predictions_idx, labels, weights)):
1773    labels, predictions_idx = _maybe_select_class_id(
1774        labels, predictions_idx, class_id)
1775    tp = set_ops.set_size(set_ops.set_intersection(predictions_idx, labels))
1776    tp = math_ops.to_double(tp)
1777    if weights is not None:
1778      weights = math_ops.to_double(weights)
1779      with ops.control_dependencies((_assert_weights_rank(weights, tp),)):
1780        tp = math_ops.multiply(tp, weights)
1781    return tp
1782
1783
1784def _streaming_sparse_true_positive_at_k(predictions_idx,
1785                                         labels,
1786                                         k=None,
1787                                         class_id=None,
1788                                         weights=None,
1789                                         name=None):
1790  """Calculates weighted per step true positives for recall@k and precision@k.
1791
1792  If `class_id` is specified, calculate binary true positives for `class_id`
1793      only.
1794  If `class_id` is not specified, calculate metrics for `k` predicted vs
1795      `n` label classes, where `n` is the 2nd dimension of `labels`.
1796
1797  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1798
1799  Args:
1800    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1801      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1802      match `labels`.
1803    labels: `int64` `Tensor` or `SparseTensor` with shape
1804      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1805      target classes for the associated prediction. Commonly, N=1 and `labels`
1806      has shape [batch_size, num_labels]. [D1, ... DN] must match
1807      `predictions_idx`.
1808    k: Integer, k for @k metric. This is only used for default op name.
1809    class_id: Class for which we want binary metrics.
1810    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1811      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1812      dimensions must be either `1`, or the same as the corresponding `labels`
1813      dimension).
1814    name: Name of new variable, and namespace for other dependent ops.
1815
1816  Returns:
1817    A tuple of `Variable` and update `Operation`.
1818
1819  Raises:
1820    ValueError: If `weights` is not `None` and has an incomptable shape.
1821  """
1822  default_name = _at_k_name('true_positive', k, class_id=class_id)
1823  with ops.name_scope(
1824      name, default_name, (predictions_idx, labels, weights)) as scope:
1825    tp = _sparse_true_positive_at_k(
1826        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
1827        weights=weights)
1828    batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
1829
1830    var = contrib_variables.local_variable(
1831        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1832    return var, state_ops.assign_add(var, batch_total_tp, name='update')
1833
1834
1835def _sparse_false_positive_at_k(predictions_idx,
1836                                labels,
1837                                class_id=None,
1838                                weights=None):
1839  """Calculates false positives for precision@k.
1840
1841  If `class_id` is specified, calculate binary true positives for `class_id`
1842      only.
1843  If `class_id` is not specified, calculate metrics for `k` predicted vs
1844      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1845
1846  Args:
1847    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1848      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1849      match `labels`.
1850    labels: `int64` `Tensor` or `SparseTensor` with shape
1851      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1852      target classes for the associated prediction. Commonly, N=1 and `labels`
1853      has shape [batch_size, num_labels]. [D1, ... DN] must match
1854      `predictions_idx`.
1855    class_id: Class for which we want binary metrics.
1856    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1857      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1858      dimensions must be either `1`, or the same as the corresponding `labels`
1859      dimension).
1860
1861  Returns:
1862    A [D1, ... DN] `Tensor` of false positive counts.
1863  """
1864  with ops.name_scope(
1865      None, 'false_positives', (predictions_idx, labels, weights)):
1866    labels, predictions_idx = _maybe_select_class_id(labels,
1867                                                     predictions_idx,
1868                                                     class_id)
1869    fp = set_ops.set_size(set_ops.set_difference(
1870        predictions_idx, labels, aminusb=True))
1871    fp = math_ops.to_double(fp)
1872    if weights is not None:
1873      weights = math_ops.to_double(weights)
1874      with ops.control_dependencies((_assert_weights_rank(weights, fp),)):
1875        fp = math_ops.multiply(fp, weights)
1876    return fp
1877
1878
1879def _streaming_sparse_false_positive_at_k(predictions_idx,
1880                                          labels,
1881                                          k=None,
1882                                          class_id=None,
1883                                          weights=None,
1884                                          name=None):
1885  """Calculates weighted per step false positives for precision@k.
1886
1887  If `class_id` is specified, calculate binary true positives for `class_id`
1888      only.
1889  If `class_id` is not specified, calculate metrics for `k` predicted vs
1890      `n` label classes, where `n` is the 2nd dimension of `labels`.
1891
1892  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1893
1894  Args:
1895    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1896      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1897      match `labels`.
1898    labels: `int64` `Tensor` or `SparseTensor` with shape
1899      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1900      target classes for the associated prediction. Commonly, N=1 and `labels`
1901      has shape [batch_size, num_labels]. [D1, ... DN] must match
1902      `predictions_idx`.
1903    k: Integer, k for @k metric. This is only used for default op name.
1904    class_id: Class for which we want binary metrics.
1905    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1906      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1907      dimensions must be either `1`, or the same as the corresponding `labels`
1908      dimension).
1909    name: Name of new variable, and namespace for other dependent ops.
1910
1911  Returns:
1912    A tuple of `Variable` and update `Operation`.
1913
1914  Raises:
1915    ValueError: If `weights` is not `None` and has an incomptable shape.
1916  """
1917  with ops.name_scope(
1918      name, _at_k_name('false_positive', k, class_id=class_id),
1919      (predictions_idx, labels, weights)) as scope:
1920    fp = _sparse_false_positive_at_k(
1921        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
1922        weights=weights)
1923    batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
1924
1925    var = contrib_variables.local_variable(
1926        array_ops.zeros([], dtype=dtypes.float64), name=scope)
1927    return var, state_ops.assign_add(var, batch_total_fp, name='update')
1928
1929
1930def _sparse_false_negative_at_k(predictions_idx,
1931                                labels,
1932                                class_id=None,
1933                                weights=None):
1934  """Calculates false negatives for recall@k.
1935
1936  If `class_id` is specified, calculate binary true positives for `class_id`
1937      only.
1938  If `class_id` is not specified, calculate metrics for `k` predicted vs
1939      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
1940
1941  Args:
1942    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1943      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1944      match `labels`.
1945    labels: `int64` `Tensor` or `SparseTensor` with shape
1946      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1947      target classes for the associated prediction. Commonly, N=1 and `labels`
1948      has shape [batch_size, num_labels]. [D1, ... DN] must match
1949      `predictions_idx`.
1950    class_id: Class for which we want binary metrics.
1951    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1952      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1953      dimensions must be either `1`, or the same as the corresponding `labels`
1954      dimension).
1955
1956  Returns:
1957    A [D1, ... DN] `Tensor` of false negative counts.
1958  """
1959  with ops.name_scope(
1960      None, 'false_negatives', (predictions_idx, labels, weights)):
1961    labels, predictions_idx = _maybe_select_class_id(labels,
1962                                                     predictions_idx,
1963                                                     class_id)
1964    fn = set_ops.set_size(set_ops.set_difference(predictions_idx,
1965                                                 labels,
1966                                                 aminusb=False))
1967    fn = math_ops.to_double(fn)
1968    if weights is not None:
1969      weights = math_ops.to_double(weights)
1970      with ops.control_dependencies((_assert_weights_rank(weights, fn),)):
1971        fn = math_ops.multiply(fn, weights)
1972    return fn
1973
1974
1975def _streaming_sparse_false_negative_at_k(predictions_idx,
1976                                          labels,
1977                                          k,
1978                                          class_id=None,
1979                                          weights=None,
1980                                          name=None):
1981  """Calculates weighted per step false negatives for recall@k.
1982
1983  If `class_id` is specified, calculate binary true positives for `class_id`
1984      only.
1985  If `class_id` is not specified, calculate metrics for `k` predicted vs
1986      `n` label classes, where `n` is the 2nd dimension of `labels`.
1987
1988  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1989
1990  Args:
1991    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
1992      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
1993      match `labels`.
1994    labels: `int64` `Tensor` or `SparseTensor` with shape
1995      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1996      target classes for the associated prediction. Commonly, N=1 and `labels`
1997      has shape [batch_size, num_labels]. [D1, ... DN] must match
1998      `predictions_idx`.
1999    k: Integer, k for @k metric. This is only used for default op name.
2000    class_id: Class for which we want binary metrics.
2001    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2002      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2003      dimensions must be either `1`, or the same as the corresponding `labels`
2004      dimension).
2005    name: Name of new variable, and namespace for other dependent ops.
2006
2007  Returns:
2008    A tuple of `Variable` and update `Operation`.
2009
2010  Raises:
2011    ValueError: If `weights` is not `None` and has an incomptable shape.
2012  """
2013  with ops.name_scope(
2014      name, _at_k_name('false_negative', k, class_id=class_id),
2015      (predictions_idx, labels, weights)) as scope:
2016    fn = _sparse_false_negative_at_k(
2017        predictions_idx=predictions_idx, labels=labels, class_id=class_id,
2018        weights=weights)
2019    batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
2020
2021    var = contrib_variables.local_variable(
2022        array_ops.zeros([], dtype=dtypes.float64), name=scope)
2023    return var, state_ops.assign_add(var, batch_total_fn, name='update')
2024
2025
2026def streaming_mean_absolute_error(predictions, labels, weights=None,
2027                                  metrics_collections=None,
2028                                  updates_collections=None,
2029                                  name=None):
2030  """Computes the mean absolute error between the labels and predictions.
2031
2032  The `streaming_mean_absolute_error` function creates two local variables,
2033  `total` and `count` that are used to compute the mean absolute error. This
2034  average is weighted by `weights`, and it is ultimately returned as
2035  `mean_absolute_error`: an idempotent operation that simply divides `total` by
2036  `count`.
2037
2038  For estimation of the metric over a stream of data, the function creates an
2039  `update_op` operation that updates these variables and returns the
2040  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
2041  absolute value of the differences between `predictions` and `labels`. Then
2042  `update_op` increments `total` with the reduced sum of the product of
2043  `weights` and `absolute_errors`, and it increments `count` with the reduced
2044  sum of `weights`
2045
2046  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2047
2048  Args:
2049    predictions: A `Tensor` of arbitrary shape.
2050    labels: A `Tensor` of the same shape as `predictions`.
2051    weights: Optional `Tensor` indicating the frequency with which an example is
2052      sampled. Rank must be 0, or the same rank as `labels`, and must be
2053      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2054      the same as the corresponding `labels` dimension).
2055    metrics_collections: An optional list of collections that
2056      `mean_absolute_error` should be added to.
2057    updates_collections: An optional list of collections that `update_op` should
2058      be added to.
2059    name: An optional variable_scope name.
2060
2061  Returns:
2062    mean_absolute_error: A `Tensor` representing the current mean, the value of
2063      `total` divided by `count`.
2064    update_op: An operation that increments the `total` and `count` variables
2065      appropriately and whose value matches `mean_absolute_error`.
2066
2067  Raises:
2068    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2069      `weights` is not `None` and its shape doesn't match `predictions`, or if
2070      either `metrics_collections` or `updates_collections` are not a list or
2071      tuple.
2072  """
2073  return metrics.mean_absolute_error(
2074      predictions=predictions, labels=labels, weights=weights,
2075      metrics_collections=metrics_collections,
2076      updates_collections=updates_collections, name=name)
2077
2078
2079def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
2080                                  metrics_collections=None,
2081                                  updates_collections=None,
2082                                  name=None):
2083  """Computes the mean relative error by normalizing with the given values.
2084
2085  The `streaming_mean_relative_error` function creates two local variables,
2086  `total` and `count` that are used to compute the mean relative absolute error.
2087  This average is weighted by `weights`, and it is ultimately returned as
2088  `mean_relative_error`: an idempotent operation that simply divides `total` by
2089  `count`.
2090
2091  For estimation of the metric over a stream of data, the function creates an
2092  `update_op` operation that updates these variables and returns the
2093  `mean_reative_error`. Internally, a `relative_errors` operation divides the
2094  absolute value of the differences between `predictions` and `labels` by the
2095  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
2096  product of `weights` and `relative_errors`, and it increments `count` with the
2097  reduced sum of `weights`.
2098
2099  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2100
2101  Args:
2102    predictions: A `Tensor` of arbitrary shape.
2103    labels: A `Tensor` of the same shape as `predictions`.
2104    normalizer: A `Tensor` of the same shape as `predictions`.
2105    weights: Optional `Tensor` indicating the frequency with which an example is
2106      sampled. Rank must be 0, or the same rank as `labels`, and must be
2107      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2108      the same as the corresponding `labels` dimension).
2109    metrics_collections: An optional list of collections that
2110      `mean_relative_error` should be added to.
2111    updates_collections: An optional list of collections that `update_op` should
2112      be added to.
2113    name: An optional variable_scope name.
2114
2115  Returns:
2116    mean_relative_error: A `Tensor` representing the current mean, the value of
2117      `total` divided by `count`.
2118    update_op: An operation that increments the `total` and `count` variables
2119      appropriately and whose value matches `mean_relative_error`.
2120
2121  Raises:
2122    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2123      `weights` is not `None` and its shape doesn't match `predictions`, or if
2124      either `metrics_collections` or `updates_collections` are not a list or
2125      tuple.
2126  """
2127  return metrics.mean_relative_error(
2128      normalizer=normalizer, predictions=predictions, labels=labels,
2129      weights=weights, metrics_collections=metrics_collections,
2130      updates_collections=updates_collections, name=name)
2131
2132
2133def streaming_mean_squared_error(predictions, labels, weights=None,
2134                                 metrics_collections=None,
2135                                 updates_collections=None,
2136                                 name=None):
2137  """Computes the mean squared error between the labels and predictions.
2138
2139  The `streaming_mean_squared_error` function creates two local variables,
2140  `total` and `count` that are used to compute the mean squared error.
2141  This average is weighted by `weights`, and it is ultimately returned as
2142  `mean_squared_error`: an idempotent operation that simply divides `total` by
2143  `count`.
2144
2145  For estimation of the metric over a stream of data, the function creates an
2146  `update_op` operation that updates these variables and returns the
2147  `mean_squared_error`. Internally, a `squared_error` operation computes the
2148  element-wise square of the difference between `predictions` and `labels`. Then
2149  `update_op` increments `total` with the reduced sum of the product of
2150  `weights` and `squared_error`, and it increments `count` with the reduced sum
2151  of `weights`.
2152
2153  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2154
2155  Args:
2156    predictions: A `Tensor` of arbitrary shape.
2157    labels: A `Tensor` of the same shape as `predictions`.
2158    weights: Optional `Tensor` indicating the frequency with which an example is
2159      sampled. Rank must be 0, or the same rank as `labels`, and must be
2160      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2161      the same as the corresponding `labels` dimension).
2162    metrics_collections: An optional list of collections that
2163      `mean_squared_error` should be added to.
2164    updates_collections: An optional list of collections that `update_op` should
2165      be added to.
2166    name: An optional variable_scope name.
2167
2168  Returns:
2169    mean_squared_error: A `Tensor` representing the current mean, the value of
2170      `total` divided by `count`.
2171    update_op: An operation that increments the `total` and `count` variables
2172      appropriately and whose value matches `mean_squared_error`.
2173
2174  Raises:
2175    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2176      `weights` is not `None` and its shape doesn't match `predictions`, or if
2177      either `metrics_collections` or `updates_collections` are not a list or
2178      tuple.
2179  """
2180  return metrics.mean_squared_error(
2181      predictions=predictions, labels=labels, weights=weights,
2182      metrics_collections=metrics_collections,
2183      updates_collections=updates_collections, name=name)
2184
2185
2186def streaming_root_mean_squared_error(predictions, labels, weights=None,
2187                                      metrics_collections=None,
2188                                      updates_collections=None,
2189                                      name=None):
2190  """Computes the root mean squared error between the labels and predictions.
2191
2192  The `streaming_root_mean_squared_error` function creates two local variables,
2193  `total` and `count` that are used to compute the root mean squared error.
2194  This average is weighted by `weights`, and it is ultimately returned as
2195  `root_mean_squared_error`: an idempotent operation that takes the square root
2196  of the division of `total` by `count`.
2197
2198  For estimation of the metric over a stream of data, the function creates an
2199  `update_op` operation that updates these variables and returns the
2200  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2201  the element-wise square of the difference between `predictions` and `labels`.
2202  Then `update_op` increments `total` with the reduced sum of the product of
2203  `weights` and `squared_error`, and it increments `count` with the reduced sum
2204  of `weights`.
2205
2206  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2207
2208  Args:
2209    predictions: A `Tensor` of arbitrary shape.
2210    labels: A `Tensor` of the same shape as `predictions`.
2211    weights: Optional `Tensor` indicating the frequency with which an example is
2212      sampled. Rank must be 0, or the same rank as `labels`, and must be
2213      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2214      the same as the corresponding `labels` dimension).
2215    metrics_collections: An optional list of collections that
2216      `root_mean_squared_error` should be added to.
2217    updates_collections: An optional list of collections that `update_op` should
2218      be added to.
2219    name: An optional variable_scope name.
2220
2221  Returns:
2222    root_mean_squared_error: A `Tensor` representing the current mean, the value
2223      of `total` divided by `count`.
2224    update_op: An operation that increments the `total` and `count` variables
2225      appropriately and whose value matches `root_mean_squared_error`.
2226
2227  Raises:
2228    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2229      `weights` is not `None` and its shape doesn't match `predictions`, or if
2230      either `metrics_collections` or `updates_collections` are not a list or
2231      tuple.
2232  """
2233  return metrics.root_mean_squared_error(
2234      predictions=predictions, labels=labels, weights=weights,
2235      metrics_collections=metrics_collections,
2236      updates_collections=updates_collections, name=name)
2237
2238
2239def streaming_covariance(predictions,
2240                         labels,
2241                         weights=None,
2242                         metrics_collections=None,
2243                         updates_collections=None,
2244                         name=None):
2245  """Computes the unbiased sample covariance between `predictions` and `labels`.
2246
2247  The `streaming_covariance` function creates four local variables,
2248  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
2249  compute the sample covariance between predictions and labels across multiple
2250  batches of data. The covariance is ultimately returned as an idempotent
2251  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
2252  in order to get an unbiased estimate.
2253
2254  The algorithm used for this online computation is described in
2255  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
2256  Specifically, the formula used to combine two sample comoments is
2257  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
2258  The comoment for a single batch of data is simply
2259  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
2260
2261  If `weights` is not None, then it is used to compute weighted comoments,
2262  means, and count. NOTE: these weights are treated as "frequency weights", as
2263  opposed to "reliability weights". See discussion of the difference on
2264  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2265
2266  To facilitate the computation of covariance across multiple batches of data,
2267  the function creates an `update_op` operation, which updates underlying
2268  variables and returns the updated covariance.
2269
2270  Args:
2271    predictions: A `Tensor` of arbitrary size.
2272    labels: A `Tensor` of the same size as `predictions`.
2273    weights: Optional `Tensor` indicating the frequency with which an example is
2274      sampled. Rank must be 0, or the same rank as `labels`, and must be
2275      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2276      the same as the corresponding `labels` dimension).
2277    metrics_collections: An optional list of collections that the metric
2278      value variable should be added to.
2279    updates_collections: An optional list of collections that the metric update
2280      ops should be added to.
2281    name: An optional variable_scope name.
2282
2283  Returns:
2284    covariance: A `Tensor` representing the current unbiased sample covariance,
2285      `comoment` / (`count` - 1).
2286    update_op: An operation that updates the local variables appropriately.
2287
2288  Raises:
2289    ValueError: If labels and predictions are of different sizes or if either
2290      `metrics_collections` or `updates_collections` are not a list or tuple.
2291  """
2292  with variable_scope.variable_scope(
2293      name, 'covariance', (predictions, labels, weights)):
2294    predictions, labels, weights = _remove_squeezable_dimensions(
2295        predictions, labels, weights)
2296    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2297    count = _create_local('count', [])
2298    mean_prediction = _create_local('mean_prediction', [])
2299    mean_label = _create_local('mean_label', [])
2300    comoment = _create_local('comoment', [])  # C_A in update equation
2301
2302    if weights is None:
2303      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
2304      weighted_predictions = predictions
2305      weighted_labels = labels
2306    else:
2307      weights = _broadcast_weights(weights, labels)
2308      batch_count = math_ops.reduce_sum(weights)  # n_B in eqn
2309      weighted_predictions = math_ops.multiply(predictions, weights)
2310      weighted_labels = math_ops.multiply(labels, weights)
2311
2312    update_count = state_ops.assign_add(count, batch_count)  # n_AB in eqn
2313    prev_count = update_count - batch_count  # n_A in update equation
2314
2315    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
2316    # batch_mean_prediction is E[x_B] in the update equation
2317    batch_mean_prediction = _safe_div(
2318        math_ops.reduce_sum(weighted_predictions), batch_count,
2319        'batch_mean_prediction')
2320    delta_mean_prediction = _safe_div(
2321        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
2322        'delta_mean_prediction')
2323    update_mean_prediction = state_ops.assign_add(mean_prediction,
2324                                                  delta_mean_prediction)
2325    # prev_mean_prediction is E[x_A] in the update equation
2326    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
2327
2328    # batch_mean_label is E[y_B] in the update equation
2329    batch_mean_label = _safe_div(
2330        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
2331    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
2332                                 update_count, 'delta_mean_label')
2333    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
2334    # prev_mean_label is E[y_A] in the update equation
2335    prev_mean_label = update_mean_label - delta_mean_label
2336
2337    unweighted_batch_coresiduals = (
2338        (predictions - batch_mean_prediction) * (labels - batch_mean_label))
2339    # batch_comoment is C_B in the update equation
2340    if weights is None:
2341      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
2342    else:
2343      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals *
2344                                           weights)
2345
2346    # View delta_comoment as = C_AB - C_A in the update equation above.
2347    # Since C_A is stored in a var, by how much do we need to increment that var
2348    # to make the var = C_AB?
2349    delta_comoment = (batch_comoment +
2350                      (prev_mean_prediction - batch_mean_prediction) *
2351                      (prev_mean_label - batch_mean_label) *
2352                      (prev_count * batch_count / update_count))
2353    update_comoment = state_ops.assign_add(comoment, delta_comoment)
2354
2355    covariance = _safe_div(comoment, count - 1, 'covariance')
2356    with ops.control_dependencies([update_comoment]):
2357      update_op = _safe_div(comoment, count - 1, 'update_op')
2358
2359  if metrics_collections:
2360    ops.add_to_collections(metrics_collections, covariance)
2361
2362  if updates_collections:
2363    ops.add_to_collections(updates_collections, update_op)
2364
2365  return covariance, update_op
2366
2367
2368def streaming_pearson_correlation(predictions,
2369                                  labels,
2370                                  weights=None,
2371                                  metrics_collections=None,
2372                                  updates_collections=None,
2373                                  name=None):
2374  """Computes Pearson correlation coefficient between `predictions`, `labels`.
2375
2376  The `streaming_pearson_correlation` function delegates to
2377  `streaming_covariance` the tracking of three [co]variances:
2378
2379  - `streaming_covariance(predictions, labels)`, i.e. covariance
2380  - `streaming_covariance(predictions, predictions)`, i.e. variance
2381  - `streaming_covariance(labels, labels)`, i.e. variance
2382
2383  The product-moment correlation ultimately returned is an idempotent operation
2384  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
2385  facilitate correlation computation across multiple batches, the function
2386  groups the `update_op`s of the underlying streaming_covariance and returns an
2387  `update_op`.
2388
2389  If `weights` is not None, then it is used to compute a weighted correlation.
2390  NOTE: these weights are treated as "frequency weights", as opposed to
2391  "reliability weights". See discussion of the difference on
2392  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2393
2394  Args:
2395    predictions: A `Tensor` of arbitrary size.
2396    labels: A `Tensor` of the same size as predictions.
2397    weights: Optional `Tensor` indicating the frequency with which an example is
2398      sampled. Rank must be 0, or the same rank as `labels`, and must be
2399      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2400      the same as the corresponding `labels` dimension).
2401    metrics_collections: An optional list of collections that the metric
2402      value variable should be added to.
2403    updates_collections: An optional list of collections that the metric update
2404      ops should be added to.
2405    name: An optional variable_scope name.
2406
2407  Returns:
2408    pearson_r: A `Tensor` representing the current Pearson product-moment
2409      correlation coefficient, the value of
2410      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
2411    update_op: An operation that updates the underlying variables appropriately.
2412
2413  Raises:
2414    ValueError: If `labels` and `predictions` are of different sizes, or if
2415      `weights` is the wrong size, or if either `metrics_collections` or
2416      `updates_collections` are not a `list` or `tuple`.
2417  """
2418  with variable_scope.variable_scope(
2419      name, 'pearson_r', (predictions, labels, weights)):
2420    predictions, labels, weights = _remove_squeezable_dimensions(
2421        predictions, labels, weights)
2422    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2423    # Broadcast weights here to avoid duplicate broadcasting in each call to
2424    # `streaming_covariance`.
2425    if weights is not None:
2426      weights = _broadcast_weights(weights, labels)
2427    cov, update_cov = streaming_covariance(
2428        predictions, labels, weights=weights, name='covariance')
2429    var_predictions, update_var_predictions = streaming_covariance(
2430        predictions, predictions, weights=weights, name='variance_predictions')
2431    var_labels, update_var_labels = streaming_covariance(
2432        labels, labels, weights=weights, name='variance_labels')
2433
2434    pearson_r = _safe_div(
2435        cov,
2436        math_ops.multiply(math_ops.sqrt(var_predictions),
2437                          math_ops.sqrt(var_labels)),
2438        'pearson_r')
2439    with ops.control_dependencies(
2440        [update_cov, update_var_predictions, update_var_labels]):
2441      update_op = _safe_div(update_cov, math_ops.multiply(
2442          math_ops.sqrt(update_var_predictions),
2443          math_ops.sqrt(update_var_labels)), 'update_op')
2444
2445  if metrics_collections:
2446    ops.add_to_collections(metrics_collections, pearson_r)
2447
2448  if updates_collections:
2449    ops.add_to_collections(updates_collections, update_op)
2450
2451  return pearson_r, update_op
2452
2453
2454# TODO(nsilberman): add a 'normalized' flag so that the user can request
2455# normalization if the inputs are not normalized.
2456def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
2457                                   metrics_collections=None,
2458                                   updates_collections=None,
2459                                   name=None):
2460  """Computes the cosine distance between the labels and predictions.
2461
2462  The `streaming_mean_cosine_distance` function creates two local variables,
2463  `total` and `count` that are used to compute the average cosine distance
2464  between `predictions` and `labels`. This average is weighted by `weights`,
2465  and it is ultimately returned as `mean_distance`, which is an idempotent
2466  operation that simply divides `total` by `count`.
2467
2468  For estimation of the metric over a stream of data, the function creates an
2469  `update_op` operation that updates these variables and returns the
2470  `mean_distance`.
2471
2472  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2473
2474  Args:
2475    predictions: A `Tensor` of the same shape as `labels`.
2476    labels: A `Tensor` of arbitrary shape.
2477    dim: The dimension along which the cosine distance is computed.
2478    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
2479      and whose dimension `dim` is 1.
2480    metrics_collections: An optional list of collections that the metric
2481      value variable should be added to.
2482    updates_collections: An optional list of collections that the metric update
2483      ops should be added to.
2484    name: An optional variable_scope name.
2485
2486  Returns:
2487    mean_distance: A `Tensor` representing the current mean, the value of
2488      `total` divided by `count`.
2489    update_op: An operation that increments the `total` and `count` variables
2490      appropriately.
2491
2492  Raises:
2493    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2494      `weights` is not `None` and its shape doesn't match `predictions`, or if
2495      either `metrics_collections` or `updates_collections` are not a list or
2496      tuple.
2497  """
2498  predictions, labels, weights = _remove_squeezable_dimensions(
2499      predictions, labels, weights)
2500  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2501  radial_diffs = math_ops.multiply(predictions, labels)
2502  radial_diffs = math_ops.reduce_sum(radial_diffs,
2503                                     reduction_indices=[dim,],
2504                                     keep_dims=True)
2505  mean_distance, update_op = streaming_mean(radial_diffs, weights,
2506                                            None,
2507                                            None,
2508                                            name or 'mean_cosine_distance')
2509  mean_distance = math_ops.subtract(1.0, mean_distance)
2510  update_op = math_ops.subtract(1.0, update_op)
2511
2512  if metrics_collections:
2513    ops.add_to_collections(metrics_collections, mean_distance)
2514
2515  if updates_collections:
2516    ops.add_to_collections(updates_collections, update_op)
2517
2518  return mean_distance, update_op
2519
2520
2521def streaming_percentage_less(values, threshold, weights=None,
2522                              metrics_collections=None,
2523                              updates_collections=None,
2524                              name=None):
2525  """Computes the percentage of values less than the given threshold.
2526
2527  The `streaming_percentage_less` function creates two local variables,
2528  `total` and `count` that are used to compute the percentage of `values` that
2529  fall below `threshold`. This rate is weighted by `weights`, and it is
2530  ultimately returned as `percentage` which is an idempotent operation that
2531  simply divides `total` by `count`.
2532
2533  For estimation of the metric over a stream of data, the function creates an
2534  `update_op` operation that updates these variables and returns the
2535  `percentage`.
2536
2537  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2538
2539  Args:
2540    values: A numeric `Tensor` of arbitrary size.
2541    threshold: A scalar threshold.
2542    weights: An optional `Tensor` whose shape is broadcastable to `values`.
2543    metrics_collections: An optional list of collections that the metric
2544      value variable should be added to.
2545    updates_collections: An optional list of collections that the metric update
2546      ops should be added to.
2547    name: An optional variable_scope name.
2548
2549  Returns:
2550    percentage: A `Tensor` representing the current mean, the value of `total`
2551      divided by `count`.
2552    update_op: An operation that increments the `total` and `count` variables
2553      appropriately.
2554
2555  Raises:
2556    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
2557      or if either `metrics_collections` or `updates_collections` are not a list
2558      or tuple.
2559  """
2560  return metrics.percentage_below(
2561      values=values, threshold=threshold, weights=weights,
2562      metrics_collections=metrics_collections,
2563      updates_collections=updates_collections, name=name)
2564
2565
2566def streaming_mean_iou(predictions,
2567                       labels,
2568                       num_classes,
2569                       weights=None,
2570                       metrics_collections=None,
2571                       updates_collections=None,
2572                       name=None):
2573  """Calculate per-step mean Intersection-Over-Union (mIOU).
2574
2575  Mean Intersection-Over-Union is a common evaluation metric for
2576  semantic image segmentation, which first computes the IOU for each
2577  semantic class and then computes the average over classes.
2578  IOU is defined as follows:
2579    IOU = true_positive / (true_positive + false_positive + false_negative).
2580  The predictions are accumulated in a confusion matrix, weighted by `weights`,
2581  and mIOU is then calculated from it.
2582
2583  For estimation of the metric over a stream of data, the function creates an
2584  `update_op` operation that updates these variables and returns the `mean_iou`.
2585
2586  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2587
2588  Args:
2589    predictions: A `Tensor` of prediction results for semantic labels, whose
2590      shape is [batch size] and type `int32` or `int64`. The tensor will be
2591      flattened, if its rank > 1.
2592    labels: A `Tensor` of ground truth labels with shape [batch size] and of
2593      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
2594    num_classes: The possible number of labels the prediction task can
2595      have. This value must be provided, since a confusion matrix of
2596      dimension = [num_classes, num_classes] will be allocated.
2597    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2598    metrics_collections: An optional list of collections that `mean_iou`
2599      should be added to.
2600    updates_collections: An optional list of collections `update_op` should be
2601      added to.
2602    name: An optional variable_scope name.
2603
2604  Returns:
2605    mean_iou: A `Tensor` representing the mean intersection-over-union.
2606    update_op: An operation that increments the confusion matrix.
2607
2608  Raises:
2609    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2610      `weights` is not `None` and its shape doesn't match `predictions`, or if
2611      either `metrics_collections` or `updates_collections` are not a list or
2612      tuple.
2613  """
2614  return metrics.mean_iou(
2615      num_classes=num_classes, predictions=predictions, labels=labels,
2616      weights=weights, metrics_collections=metrics_collections,
2617      updates_collections=updates_collections, name=name)
2618
2619
2620def _next_array_size(required_size, growth_factor=1.5):
2621  """Calculate the next size for reallocating a dynamic array.
2622
2623  Args:
2624    required_size: number or tf.Tensor specifying required array capacity.
2625    growth_factor: optional number or tf.Tensor specifying the growth factor
2626      between subsequent allocations.
2627
2628  Returns:
2629    tf.Tensor with dtype=int32 giving the next array size.
2630  """
2631  exponent = math_ops.ceil(
2632      math_ops.log(math_ops.cast(required_size, dtypes.float32))
2633      / math_ops.log(math_ops.cast(growth_factor, dtypes.float32)))
2634  return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32)
2635
2636
2637def streaming_concat(values,
2638                     axis=0,
2639                     max_size=None,
2640                     metrics_collections=None,
2641                     updates_collections=None,
2642                     name=None):
2643  """Concatenate values along an axis across batches.
2644
2645  The function `streaming_concat` creates two local variables, `array` and
2646  `size`, that are used to store concatenated values. Internally, `array` is
2647  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
2648  that updates can be run in amortized constant time.
2649
2650  For estimation of the metric over a stream of data, the function creates an
2651  `update_op` operation that appends the values of a tensor and returns the
2652  `value` of the concatenated tensors.
2653
2654  This op allows for evaluating metrics that cannot be updated incrementally
2655  using the same framework as other streaming metrics.
2656
2657  Args:
2658    values: `Tensor` to concatenate. Rank and the shape along all axes other
2659      than the axis to concatenate along must be statically known.
2660    axis: optional integer axis to concatenate along.
2661    max_size: optional integer maximum size of `value` along the given axis.
2662      Once the maximum size is reached, further updates are no-ops. By default,
2663      there is no maximum size: the array is resized as necessary.
2664    metrics_collections: An optional list of collections that `value`
2665      should be added to.
2666    updates_collections: An optional list of collections `update_op` should be
2667      added to.
2668    name: An optional variable_scope name.
2669
2670  Returns:
2671    value: A `Tensor` representing the concatenated values.
2672    update_op: An operation that concatenates the next values.
2673
2674  Raises:
2675    ValueError: if `values` does not have a statically known rank, `axis` is
2676      not in the valid range or the size of `values` is not statically known
2677      along any axis other than `axis`.
2678  """
2679  with variable_scope.variable_scope(name, 'streaming_concat', (values,)):
2680    # pylint: disable=invalid-slice-index
2681    values_shape = values.get_shape()
2682    if values_shape.dims is None:
2683      raise ValueError('`values` must have known statically known rank')
2684
2685    ndim = len(values_shape)
2686    if axis < 0:
2687      axis += ndim
2688    if not 0 <= axis < ndim:
2689      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
2690
2691    fixed_shape = [dim.value for n, dim in enumerate(values_shape)
2692                   if n != axis]
2693    if any(value is None for value in fixed_shape):
2694      raise ValueError('all dimensions of `values` other than the dimension to '
2695                       'concatenate along must have statically known size')
2696
2697    # We move `axis` to the front of the internal array so assign ops can be
2698    # applied to contiguous slices
2699    init_size = 0 if max_size is None else max_size
2700    init_shape = [init_size] + fixed_shape
2701    array = _create_local(
2702        'array', shape=init_shape, validate_shape=False, dtype=values.dtype)
2703    size = _create_local('size', shape=[], dtype=dtypes.int32)
2704
2705    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
2706    valid_array = array[:size]
2707    valid_array.set_shape([None] + fixed_shape)
2708    value = array_ops.transpose(valid_array, perm, name='concat')
2709
2710    values_size = array_ops.shape(values)[axis]
2711    if max_size is None:
2712      batch_size = values_size
2713    else:
2714      batch_size = math_ops.minimum(values_size, max_size - size)
2715
2716    perm = [axis] + [n for n in range(ndim) if n != axis]
2717    batch_values = array_ops.transpose(values, perm)[:batch_size]
2718
2719    def reallocate():
2720      next_size = _next_array_size(new_size)
2721      next_shape = array_ops.stack([next_size] + fixed_shape)
2722      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
2723      old_value = array.value()
2724      assign_op = state_ops.assign(array, new_value, validate_shape=False)
2725      with ops.control_dependencies([assign_op]):
2726        copy_op = array[:size].assign(old_value[:size])
2727      # return value needs to be the same dtype as no_op() for cond
2728      with ops.control_dependencies([copy_op]):
2729        return control_flow_ops.no_op()
2730
2731    new_size = size + batch_size
2732    array_size = array_ops.shape_internal(array, optimize=False)[0]
2733    maybe_reallocate_op = control_flow_ops.cond(
2734        new_size > array_size, reallocate, control_flow_ops.no_op)
2735    with ops.control_dependencies([maybe_reallocate_op]):
2736      append_values_op = array[size:new_size].assign(batch_values)
2737    with ops.control_dependencies([append_values_op]):
2738      update_op = size.assign(new_size)
2739
2740    if metrics_collections:
2741      ops.add_to_collections(metrics_collections, value)
2742
2743    if updates_collections:
2744      ops.add_to_collections(updates_collections, update_op)
2745
2746    return value, update_op
2747    # pylint: enable=invalid-slice-index
2748
2749
2750def aggregate_metrics(*value_update_tuples):
2751  """Aggregates the metric value tensors and update ops into two lists.
2752
2753  Args:
2754    *value_update_tuples: a variable number of tuples, each of which contain the
2755      pair of (value_tensor, update_op) from a streaming metric.
2756
2757  Returns:
2758    A list of value `Tensor` objects and a list of update ops.
2759
2760  Raises:
2761    ValueError: if `value_update_tuples` is empty.
2762  """
2763  if not value_update_tuples:
2764    raise ValueError('Expected at least one value_tensor/update_op pair')
2765  value_ops, update_ops = zip(*value_update_tuples)
2766  return list(value_ops), list(update_ops)
2767
2768
2769def aggregate_metric_map(names_to_tuples):
2770  """Aggregates the metric names to tuple dictionary.
2771
2772  This function is useful for pairing metric names with their associated value
2773  and update ops when the list of metrics is long. For example:
2774
2775  ```python
2776    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
2777        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
2778            predictions, labels, weights),
2779        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
2780            predictions, labels, labels, weights),
2781        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
2782            predictions, labels, weights),
2783        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
2784            predictions, labels, weights),
2785    })
2786  ```
2787
2788  Args:
2789    names_to_tuples: a map of metric names to tuples, each of which contain the
2790      pair of (value_tensor, update_op) from a streaming metric.
2791
2792  Returns:
2793    A dictionary from metric names to value ops and a dictionary from metric
2794    names to update ops.
2795  """
2796  metric_names = names_to_tuples.keys()
2797  value_ops, update_ops = zip(*names_to_tuples.values())
2798  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
2799
2800
2801def _remove_squeezable_dimensions(predictions, labels, weights):
2802  """Squeeze last dim if needed.
2803
2804  Squeezes `predictions` and `labels` if their rank differs by 1.
2805  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
2806
2807  This will use static shape if available. Otherwise, it will add graph
2808  operations, which could result in a performance hit.
2809
2810  Args:
2811    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
2812    labels: Label values, a `Tensor` whose dimensions match `predictions`.
2813    weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
2814      more than the new rank of `predictions`
2815
2816  Returns:
2817    Tuple of `predictions`, `labels` and `weights`, possibly with the last
2818    dimension squeezed.
2819  """
2820  predictions, labels = tensor_util.remove_squeezable_dimensions(
2821      predictions, labels)
2822  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2823
2824  if weights is not None:
2825    weights = ops.convert_to_tensor(weights)
2826    predictions_shape = predictions.get_shape()
2827    predictions_rank = predictions_shape.ndims
2828    weights_shape = weights.get_shape()
2829    weights_rank = weights_shape.ndims
2830
2831    if (predictions_rank is not None) and (weights_rank is not None):
2832      # Use static rank.
2833      if weights_rank - predictions_rank == 1:
2834        weights = array_ops.squeeze(weights, [-1])
2835    elif (weights_rank is None) or (
2836        weights_shape.dims[-1].is_compatible_with(1)):
2837      # Use dynamic rank
2838      weights = control_flow_ops.cond(
2839          math_ops.equal(array_ops.rank(weights),
2840                         math_ops.add(array_ops.rank(predictions), 1)),
2841          lambda: array_ops.squeeze(weights, [-1]),
2842          lambda: weights)
2843  return predictions, labels, weights
2844
2845
2846__all__ = [
2847    'aggregate_metric_map',
2848    'aggregate_metrics',
2849    'streaming_accuracy',
2850    'streaming_auc',
2851    'streaming_false_negatives',
2852    'streaming_false_negatives_at_thresholds',
2853    'streaming_false_positives',
2854    'streaming_false_positives_at_thresholds',
2855    'streaming_mean',
2856    'streaming_mean_absolute_error',
2857    'streaming_mean_cosine_distance',
2858    'streaming_mean_iou',
2859    'streaming_mean_relative_error',
2860    'streaming_mean_squared_error',
2861    'streaming_mean_tensor',
2862    'streaming_percentage_less',
2863    'streaming_precision',
2864    'streaming_precision_at_thresholds',
2865    'streaming_recall',
2866    'streaming_recall_at_k',
2867    'streaming_recall_at_thresholds',
2868    'streaming_root_mean_squared_error',
2869    'streaming_sensitivity_at_specificity',
2870    'streaming_sparse_average_precision_at_k',
2871    'streaming_sparse_precision_at_k',
2872    'streaming_sparse_recall_at_k',
2873    'streaming_specificity_at_sensitivity',
2874    'streaming_true_negatives',
2875    'streaming_true_negatives_at_thresholds',
2876    'streaming_true_positives',
2877    'streaming_true_positives_at_thresholds',
2878]
2879