metric_ops.py revision 0845c9c976974ced632cb85e52808b22c1e7c669
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.framework import deprecated
26from tensorflow.contrib.framework.python.framework import tensor_util
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import check_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import metrics
34from tensorflow.python.ops import metrics_impl
35from tensorflow.python.ops import nn
36from tensorflow.python.ops import state_ops
37from tensorflow.python.ops import variable_scope
38
39
40def _safe_div(numerator, denominator, name):
41  """Divides two values, returning 0 if the denominator is <= 0.
42
43  Args:
44    numerator: A real `Tensor`.
45    denominator: A real `Tensor`, with dtype matching `numerator`.
46    name: Name for the returned op.
47
48  Returns:
49    0 if `denominator` <= 0, else `numerator` / `denominator`
50  """
51  return array_ops.where(
52      math_ops.greater(denominator, 0),
53      math_ops.truediv(numerator, denominator),
54      0,
55      name=name)
56
57
58def _create_local(name, shape, collections=None, validate_shape=True,
59                  dtype=dtypes.float32):
60  """Creates a new local variable.
61
62  Args:
63    name: The name of the new or existing variable.
64    shape: Shape of the new or existing variable.
65    collections: A list of collection names to which the Variable will be added.
66    validate_shape: Whether to validate the shape of the variable.
67    dtype: Data type of the variables.
68
69  Returns:
70    The created variable.
71  """
72  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
73  collections = list(collections or [])
74  collections += [ops.GraphKeys.LOCAL_VARIABLES]
75  return variable_scope.variable(
76      initial_value=array_ops.zeros(shape, dtype=dtype),
77      name=name,
78      trainable=False,
79      collections=collections,
80      validate_shape=validate_shape)
81
82
83# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py.
84def _assert_weights_rank(weights, values):
85  """`weights` rank must be either `0`, or the same as 'values'."""
86  return check_ops.assert_rank_in(weights, (0, array_ops.rank(values)))
87
88
89def _count_condition(values, weights=None, metrics_collections=None,
90                     updates_collections=None):
91  """Sums the weights of cases where the given values are True.
92
93  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
94
95  Args:
96    values: A `bool` `Tensor` of arbitrary size.
97    weights: Optional `Tensor` whose rank is either 0, or the same rank as
98      `values`, and must be broadcastable to `values` (i.e., all dimensions
99      must be either `1`, or the same as the corresponding `values`
100      dimension).
101    metrics_collections: An optional list of collections that the metric
102      value variable should be added to.
103    updates_collections: An optional list of collections that the metric update
104      ops should be added to.
105
106  Returns:
107    value_tensor: A `Tensor` representing the current value of the metric.
108    update_op: An operation that accumulates the error from a batch of data.
109
110  Raises:
111    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
112      or if either `metrics_collections` or `updates_collections` are not a list
113      or tuple.
114  """
115  check_ops.assert_type(values, dtypes.bool)
116  count = _create_local('count', shape=[])
117
118  values = math_ops.to_float(values)
119  if weights is not None:
120    weights = math_ops.to_float(weights)
121    with ops.control_dependencies((_assert_weights_rank(weights, values),)):
122      values = math_ops.multiply(values, weights)
123
124  value_tensor = array_ops.identity(count)
125  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
126
127  if metrics_collections:
128    ops.add_to_collections(metrics_collections, value_tensor)
129
130  if updates_collections:
131    ops.add_to_collections(updates_collections, update_op)
132
133  return value_tensor, update_op
134
135
136def streaming_true_positives(predictions, labels, weights=None,
137                             metrics_collections=None,
138                             updates_collections=None,
139                             name=None):
140  """Sum the weights of true_positives.
141
142  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
143
144  Args:
145    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
146      be cast to `bool`.
147    labels: The ground truth values, a `Tensor` whose dimensions must match
148      `predictions`. Will be cast to `bool`.
149    weights: Optional `Tensor` whose rank is either 0, or the same rank as
150      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
151      must be either `1`, or the same as the corresponding `labels`
152      dimension).
153    metrics_collections: An optional list of collections that the metric
154      value variable should be added to.
155    updates_collections: An optional list of collections that the metric update
156      ops should be added to.
157    name: An optional variable_scope name.
158
159  Returns:
160    value_tensor: A `Tensor` representing the current value of the metric.
161    update_op: An operation that accumulates the error from a batch of data.
162
163  Raises:
164    ValueError: If `predictions` and `labels` have mismatched shapes, or if
165      `weights` is not `None` and its shape doesn't match `predictions`, or if
166      either `metrics_collections` or `updates_collections` are not a list or
167      tuple.
168  """
169  return metrics.true_positives(
170      predictions=predictions, labels=labels, weights=weights,
171      metrics_collections=metrics_collections,
172      updates_collections=updates_collections, name=name)
173
174
175def streaming_true_negatives(predictions, labels, weights=None,
176                             metrics_collections=None,
177                             updates_collections=None,
178                             name=None):
179  """Sum the weights of true_negatives.
180
181  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
182
183  Args:
184    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
185      be cast to `bool`.
186    labels: The ground truth values, a `Tensor` whose dimensions must match
187      `predictions`. Will be cast to `bool`.
188    weights: Optional `Tensor` whose rank is either 0, or the same rank as
189      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
190      must be either `1`, or the same as the corresponding `labels`
191      dimension).
192    metrics_collections: An optional list of collections that the metric
193      value variable should be added to.
194    updates_collections: An optional list of collections that the metric update
195      ops should be added to.
196    name: An optional variable_scope name.
197
198  Returns:
199    value_tensor: A `Tensor` representing the current value of the metric.
200    update_op: An operation that accumulates the error from a batch of data.
201
202  Raises:
203    ValueError: If `predictions` and `labels` have mismatched shapes, or if
204      `weights` is not `None` and its shape doesn't match `predictions`, or if
205      either `metrics_collections` or `updates_collections` are not a list or
206      tuple.
207  """
208  with variable_scope.variable_scope(
209      name, 'true_negatives', (predictions, labels, weights)):
210
211    predictions = math_ops.cast(predictions, dtype=dtypes.bool)
212    labels = math_ops.cast(labels, dtype=dtypes.bool)
213    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
214    is_true_negative = math_ops.logical_and(math_ops.equal(labels, False),
215                                            math_ops.equal(predictions, False))
216    return _count_condition(is_true_negative, weights, metrics_collections,
217                            updates_collections)
218
219
220def streaming_false_positives(predictions, labels, weights=None,
221                              metrics_collections=None,
222                              updates_collections=None,
223                              name=None):
224  """Sum the weights of false positives.
225
226  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
227
228  Args:
229    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
230      be cast to `bool`.
231    labels: The ground truth values, a `Tensor` whose dimensions must match
232      `predictions`. Will be cast to `bool`.
233    weights: Optional `Tensor` whose rank is either 0, or the same rank as
234      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
235      must be either `1`, or the same as the corresponding `labels`
236      dimension).
237    metrics_collections: An optional list of collections that the metric
238      value variable should be added to.
239    updates_collections: An optional list of collections that the metric update
240      ops should be added to.
241    name: An optional variable_scope name.
242
243  Returns:
244    value_tensor: A `Tensor` representing the current value of the metric.
245    update_op: An operation that accumulates the error from a batch of data.
246
247  Raises:
248    ValueError: If `predictions` and `labels` have mismatched shapes, or if
249      `weights` is not `None` and its shape doesn't match `predictions`, or if
250      either `metrics_collections` or `updates_collections` are not a list or
251      tuple.
252  """
253  return metrics.false_positives(
254      predictions=predictions, labels=labels, weights=weights,
255      metrics_collections=metrics_collections,
256      updates_collections=updates_collections, name=name)
257
258
259def streaming_false_negatives(predictions, labels, weights=None,
260                              metrics_collections=None,
261                              updates_collections=None,
262                              name=None):
263  """Computes the total number of false negatives.
264
265  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
266
267  Args:
268    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
269      be cast to `bool`.
270    labels: The ground truth values, a `Tensor` whose dimensions must match
271      `predictions`. Will be cast to `bool`.
272    weights: Optional `Tensor` whose rank is either 0, or the same rank as
273      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
274      must be either `1`, or the same as the corresponding `labels`
275      dimension).
276    metrics_collections: An optional list of collections that the metric
277      value variable should be added to.
278    updates_collections: An optional list of collections that the metric update
279      ops should be added to.
280    name: An optional variable_scope name.
281
282  Returns:
283    value_tensor: A `Tensor` representing the current value of the metric.
284    update_op: An operation that accumulates the error from a batch of data.
285
286  Raises:
287    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
288      or if either `metrics_collections` or `updates_collections` are not a list
289      or tuple.
290  """
291  return metrics.false_negatives(
292      predictions=predictions, labels=labels, weights=weights,
293      metrics_collections=metrics_collections,
294      updates_collections=updates_collections, name=name)
295
296
297# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py.
298def _broadcast_weights(weights, values):
299  """Broadcast `weights` to the same shape as `values`.
300
301  This returns a version of `weights` following the same broadcast rules as
302  `mul(weights, values)`. When computing a weighted average, use this function
303  to broadcast `weights` before summing them; e.g.,
304  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
305
306  Args:
307    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
308      must be broadcastable to `values` (i.e., all dimensions must be either
309      `1`, or the same as the corresponding `values` dimension).
310    values: `Tensor` of any shape.
311
312  Returns:
313    `weights` broadcast to `values` shape.
314  """
315  with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope:
316    weights_shape = weights.get_shape()
317    values_shape = values.get_shape()
318    if (weights_shape.is_fully_defined() and
319        values_shape.is_fully_defined() and
320        weights_shape.is_compatible_with(values_shape)):
321      return weights
322    with ops.control_dependencies((_assert_weights_rank(weights, values),)):
323      return math_ops.multiply(
324          weights, array_ops.ones_like(values), name=scope)
325
326
327def streaming_mean(values, weights=None, metrics_collections=None,
328                   updates_collections=None, name=None):
329  """Computes the (weighted) mean of the given values.
330
331  The `streaming_mean` function creates two local variables, `total` and `count`
332  that are used to compute the average of `values`. This average is ultimately
333  returned as `mean` which is an idempotent operation that simply divides
334  `total` by `count`.
335
336  For estimation of the metric  over a stream of data, the function creates an
337  `update_op` operation that updates these variables and returns the `mean`.
338  `update_op` increments `total` with the reduced sum of the product of `values`
339  and `weights`, and it increments `count` with the reduced sum of `weights`.
340
341  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
342
343  Args:
344    values: A `Tensor` of arbitrary dimensions.
345    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
346      must be broadcastable to `values` (i.e., all dimensions must be either
347      `1`, or the same as the corresponding `values` dimension).
348    metrics_collections: An optional list of collections that `mean`
349      should be added to.
350    updates_collections: An optional list of collections that `update_op`
351      should be added to.
352    name: An optional variable_scope name.
353
354  Returns:
355    mean: A `Tensor` representing the current mean, the value of `total` divided
356      by `count`.
357    update_op: An operation that increments the `total` and `count` variables
358      appropriately and whose value matches `mean`.
359
360  Raises:
361    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
362      or if either `metrics_collections` or `updates_collections` are not a list
363      or tuple.
364  """
365  return metrics.mean(
366      values=values, weights=weights, metrics_collections=metrics_collections,
367      updates_collections=updates_collections, name=name)
368
369
370def streaming_mean_tensor(values, weights=None, metrics_collections=None,
371                          updates_collections=None, name=None):
372  """Computes the element-wise (weighted) mean of the given tensors.
373
374  In contrast to the `streaming_mean` function which returns a scalar with the
375  mean,  this function returns an average tensor with the same shape as the
376  input tensors.
377
378  The `streaming_mean_tensor` function creates two local variables,
379  `total_tensor` and `count_tensor` that are used to compute the average of
380  `values`. This average is ultimately returned as `mean` which is an idempotent
381  operation that simply divides `total` by `count`.
382
383  For estimation of the metric  over a stream of data, the function creates an
384  `update_op` operation that updates these variables and returns the `mean`.
385  `update_op` increments `total` with the reduced sum of the product of `values`
386  and `weights`, and it increments `count` with the reduced sum of `weights`.
387
388  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
389
390  Args:
391    values: A `Tensor` of arbitrary dimensions.
392    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
393      must be broadcastable to `values` (i.e., all dimensions must be either
394      `1`, or the same as the corresponding `values` dimension).
395    metrics_collections: An optional list of collections that `mean`
396      should be added to.
397    updates_collections: An optional list of collections that `update_op`
398      should be added to.
399    name: An optional variable_scope name.
400
401  Returns:
402    mean: A float `Tensor` representing the current mean, the value of `total`
403      divided by `count`.
404    update_op: An operation that increments the `total` and `count` variables
405      appropriately and whose value matches `mean`.
406
407  Raises:
408    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
409      or if either `metrics_collections` or `updates_collections` are not a list
410      or tuple.
411  """
412  return metrics.mean_tensor(
413      values=values, weights=weights, metrics_collections=metrics_collections,
414      updates_collections=updates_collections, name=name)
415
416
417def streaming_accuracy(predictions, labels, weights=None,
418                       metrics_collections=None, updates_collections=None,
419                       name=None):
420  """Calculates how often `predictions` matches `labels`.
421
422  The `streaming_accuracy` function creates two local variables, `total` and
423  `count` that are used to compute the frequency with which `predictions`
424  matches `labels`. This frequency is ultimately returned as `accuracy`: an
425  idempotent operation that simply divides `total` by `count`.
426
427  For estimation of the metric  over a stream of data, the function creates an
428  `update_op` operation that updates these variables and returns the `accuracy`.
429  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
430  where the corresponding elements of `predictions` and `labels` match and 0.0
431  otherwise. Then `update_op` increments `total` with the reduced sum of the
432  product of `weights` and `is_correct`, and it increments `count` with the
433  reduced sum of `weights`.
434
435  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
436
437  Args:
438    predictions: The predicted values, a `Tensor` of any shape.
439    labels: The ground truth values, a `Tensor` whose shape matches
440      `predictions`.
441    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
442      must be broadcastable to `labels` (i.e., all dimensions must be either
443      `1`, or the same as the corresponding `labels` dimension).
444    metrics_collections: An optional list of collections that `accuracy` should
445      be added to.
446    updates_collections: An optional list of collections that `update_op` should
447      be added to.
448    name: An optional variable_scope name.
449
450  Returns:
451    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
452      by `count`.
453    update_op: An operation that increments the `total` and `count` variables
454      appropriately and whose value matches `accuracy`.
455
456  Raises:
457    ValueError: If `predictions` and `labels` have mismatched shapes, or if
458      `weights` is not `None` and its shape doesn't match `predictions`, or if
459      either `metrics_collections` or `updates_collections` are not a list or
460      tuple.
461  """
462  return metrics.accuracy(
463      predictions=predictions, labels=labels, weights=weights,
464      metrics_collections=metrics_collections,
465      updates_collections=updates_collections, name=name)
466
467
468def streaming_precision(predictions, labels, weights=None,
469                        metrics_collections=None, updates_collections=None,
470                        name=None):
471  """Computes the precision of the predictions with respect to the labels.
472
473  The `streaming_precision` function creates two local variables,
474  `true_positives` and `false_positives`, that are used to compute the
475  precision. This value is ultimately returned as `precision`, an idempotent
476  operation that simply divides `true_positives` by the sum of `true_positives`
477  and `false_positives`.
478
479  For estimation of the metric  over a stream of data, the function creates an
480  `update_op` operation that updates these variables and returns the
481  `precision`. `update_op` weights each prediction by the corresponding value in
482  `weights`.
483
484  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
485
486  Args:
487    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
488    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
489      match `predictions`.
490    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
491      must be broadcastable to `labels` (i.e., all dimensions must be either
492      `1`, or the same as the corresponding `labels` dimension).
493    metrics_collections: An optional list of collections that `precision` should
494      be added to.
495    updates_collections: An optional list of collections that `update_op` should
496      be added to.
497    name: An optional variable_scope name.
498
499  Returns:
500    precision: Scalar float `Tensor` with the value of `true_positives`
501      divided by the sum of `true_positives` and `false_positives`.
502    update_op: `Operation` that increments `true_positives` and
503      `false_positives` variables appropriately and whose value matches
504      `precision`.
505
506  Raises:
507    ValueError: If `predictions` and `labels` have mismatched shapes, or if
508      `weights` is not `None` and its shape doesn't match `predictions`, or if
509      either `metrics_collections` or `updates_collections` are not a list or
510      tuple.
511  """
512  return metrics.precision(
513      predictions=predictions, labels=labels, weights=weights,
514      metrics_collections=metrics_collections,
515      updates_collections=updates_collections, name=name)
516
517
518def streaming_recall(predictions, labels, weights=None,
519                     metrics_collections=None, updates_collections=None,
520                     name=None):
521  """Computes the recall of the predictions with respect to the labels.
522
523  The `streaming_recall` function creates two local variables, `true_positives`
524  and `false_negatives`, that are used to compute the recall. This value is
525  ultimately returned as `recall`, an idempotent operation that simply divides
526  `true_positives` by the sum of `true_positives`  and `false_negatives`.
527
528  For estimation of the metric  over a stream of data, the function creates an
529  `update_op` that updates these variables and returns the `recall`. `update_op`
530  weights each prediction by the corresponding value in `weights`.
531
532  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
533
534  Args:
535    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
536    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
537      match `predictions`.
538    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
539      must be broadcastable to `labels` (i.e., all dimensions must be either
540      `1`, or the same as the corresponding `labels` dimension).
541    metrics_collections: An optional list of collections that `recall` should
542      be added to.
543    updates_collections: An optional list of collections that `update_op` should
544      be added to.
545    name: An optional variable_scope name.
546
547  Returns:
548    recall: Scalar float `Tensor` with the value of `true_positives` divided
549      by the sum of `true_positives` and `false_negatives`.
550    update_op: `Operation` that increments `true_positives` and
551      `false_negatives` variables appropriately and whose value matches
552      `recall`.
553
554  Raises:
555    ValueError: If `predictions` and `labels` have mismatched shapes, or if
556      `weights` is not `None` and its shape doesn't match `predictions`, or if
557      either `metrics_collections` or `updates_collections` are not a list or
558      tuple.
559  """
560  return metrics.recall(
561      predictions=predictions, labels=labels, weights=weights,
562      metrics_collections=metrics_collections,
563      updates_collections=updates_collections, name=name)
564
565
566def _streaming_confusion_matrix_at_thresholds(
567    predictions, labels, thresholds, weights=None, includes=None):
568  """Computes true_positives, false_negatives, true_negatives, false_positives.
569
570  This function creates up to four local variables, `true_positives`,
571  `true_negatives`, `false_positives` and `false_negatives`.
572  `true_positive[i]` is defined as the total weight of values in `predictions`
573  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
574  `false_negatives[i]` is defined as the total weight of values in `predictions`
575  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
576  `true_negatives[i]` is defined as the total weight of values in `predictions`
577  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
578  `false_positives[i]` is defined as the total weight of values in `predictions`
579  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
580
581  For estimation of these metrics over a stream of data, for each metric the
582  function respectively creates an `update_op` operation that updates the
583  variable and returns its value.
584
585  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
586
587  Args:
588    predictions: A floating point `Tensor` of arbitrary shape and whose values
589      are in the range `[0, 1]`.
590    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
591      to `bool`.
592    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
593    weights: Optional `Tensor` whose rank is either 0, or the same rank as
594      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
595      must be either `1`, or the same as the corresponding `labels`
596      dimension).
597    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
598      default to all four.
599
600  Returns:
601    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
602        `includes`.
603    update_ops: Dict of operations that increments the `values`. Keys are from
604        `includes`.
605
606  Raises:
607    ValueError: If `predictions` and `labels` have mismatched shapes, or if
608      `weights` is not `None` and its shape doesn't match `predictions`, or if
609      `includes` contains invalid keys.
610  """
611  all_includes = ('tp', 'fn', 'tn', 'fp')
612  if includes is None:
613    includes = all_includes
614  else:
615    for include in includes:
616      if include not in all_includes:
617        raise ValueError('Invaild key: %s.' % include)
618
619  predictions, labels, weights = _remove_squeezable_dimensions(
620      predictions, labels, weights)
621  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
622
623  num_thresholds = len(thresholds)
624
625  # Reshape predictions and labels.
626  predictions_2d = array_ops.reshape(predictions, [-1, 1])
627  labels_2d = array_ops.reshape(
628      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
629
630  # Use static shape if known.
631  num_predictions = predictions_2d.get_shape().as_list()[0]
632
633  # Otherwise use dynamic shape.
634  if num_predictions is None:
635    num_predictions = array_ops.shape(predictions_2d)[0]
636  thresh_tiled = array_ops.tile(
637      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
638      array_ops.stack([1, num_predictions]))
639
640  # Tile the predictions after thresholding them across different thresholds.
641  pred_is_pos = math_ops.greater(
642      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
643      thresh_tiled)
644  if ('fn' in includes) or ('tn' in includes):
645    pred_is_neg = math_ops.logical_not(pred_is_pos)
646
647  # Tile labels by number of thresholds
648  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
649  if ('fp' in includes) or ('tn' in includes):
650    label_is_neg = math_ops.logical_not(label_is_pos)
651
652  if weights is not None:
653    broadcast_weights = _broadcast_weights(
654        math_ops.to_float(weights), predictions)
655    weights_tiled = array_ops.tile(array_ops.reshape(
656        broadcast_weights, [1, -1]), [num_thresholds, 1])
657    thresh_tiled.get_shape().assert_is_compatible_with(
658        weights_tiled.get_shape())
659  else:
660    weights_tiled = None
661
662  values = {}
663  update_ops = {}
664
665  if 'tp' in includes:
666    true_positives = _create_local('true_positives', shape=[num_thresholds])
667    is_true_positive = math_ops.to_float(
668        math_ops.logical_and(label_is_pos, pred_is_pos))
669    if weights_tiled is not None:
670      is_true_positive *= weights_tiled
671    update_ops['tp'] = state_ops.assign_add(
672        true_positives, math_ops.reduce_sum(is_true_positive, 1))
673    values['tp'] = true_positives
674
675  if 'fn' in includes:
676    false_negatives = _create_local('false_negatives', shape=[num_thresholds])
677    is_false_negative = math_ops.to_float(
678        math_ops.logical_and(label_is_pos, pred_is_neg))
679    if weights_tiled is not None:
680      is_false_negative *= weights_tiled
681    update_ops['fn'] = state_ops.assign_add(
682        false_negatives, math_ops.reduce_sum(is_false_negative, 1))
683    values['fn'] = false_negatives
684
685  if 'tn' in includes:
686    true_negatives = _create_local('true_negatives', shape=[num_thresholds])
687    is_true_negative = math_ops.to_float(
688        math_ops.logical_and(label_is_neg, pred_is_neg))
689    if weights_tiled is not None:
690      is_true_negative *= weights_tiled
691    update_ops['tn'] = state_ops.assign_add(
692        true_negatives, math_ops.reduce_sum(is_true_negative, 1))
693    values['tn'] = true_negatives
694
695  if 'fp' in includes:
696    false_positives = _create_local('false_positives', shape=[num_thresholds])
697    is_false_positive = math_ops.to_float(
698        math_ops.logical_and(label_is_neg, pred_is_pos))
699    if weights_tiled is not None:
700      is_false_positive *= weights_tiled
701    update_ops['fp'] = state_ops.assign_add(
702        false_positives, math_ops.reduce_sum(is_false_positive, 1))
703    values['fp'] = false_positives
704
705  return values, update_ops
706
707
708def streaming_true_positives_at_thresholds(
709    predictions, labels, thresholds, weights=None):
710  values, update_ops = _streaming_confusion_matrix_at_thresholds(
711      predictions, labels, thresholds, weights=weights, includes=('tp',))
712  return values['tp'], update_ops['tp']
713
714
715def streaming_false_negatives_at_thresholds(
716    predictions, labels, thresholds, weights=None):
717  values, update_ops = _streaming_confusion_matrix_at_thresholds(
718      predictions, labels, thresholds, weights=weights, includes=('fn',))
719  return values['fn'], update_ops['fn']
720
721
722def streaming_false_positives_at_thresholds(
723    predictions, labels, thresholds, weights=None):
724  values, update_ops = _streaming_confusion_matrix_at_thresholds(
725      predictions, labels, thresholds, weights=weights, includes=('fp',))
726  return values['fp'], update_ops['fp']
727
728
729def streaming_true_negatives_at_thresholds(
730    predictions, labels, thresholds, weights=None):
731  values, update_ops = _streaming_confusion_matrix_at_thresholds(
732      predictions, labels, thresholds, weights=weights, includes=('tn',))
733  return values['tn'], update_ops['tn']
734
735
736def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
737                  metrics_collections=None, updates_collections=None,
738                  curve='ROC', name=None):
739  """Computes the approximate AUC via a Riemann sum.
740
741  The `streaming_auc` function creates four local variables, `true_positives`,
742  `true_negatives`, `false_positives` and `false_negatives` that are used to
743  compute the AUC. To discretize the AUC curve, a linearly spaced set of
744  thresholds is used to compute pairs of recall and precision values. The area
745  under the ROC-curve is therefore computed using the height of the recall
746  values by the false positive rate, while the area under the PR-curve is the
747  computed using the height of the precision values by the recall.
748
749  This value is ultimately returned as `auc`, an idempotent operation that
750  computes the area under a discretized curve of precision versus recall values
751  (computed using the aforementioned variables). The `num_thresholds` variable
752  controls the degree of discretization with larger numbers of thresholds more
753  closely approximating the true AUC. The quality of the approximation may vary
754  dramatically depending on `num_thresholds`.
755
756  For best results, `predictions` should be distributed approximately uniformly
757  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
758  approximation may be poor if this is not the case.
759
760  For estimation of the metric over a stream of data, the function creates an
761  `update_op` operation that updates these variables and returns the `auc`.
762
763  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
764
765  Args:
766    predictions: A floating point `Tensor` of arbitrary shape and whose values
767      are in the range `[0, 1]`.
768    labels: A `bool` `Tensor` whose shape matches `predictions`.
769    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
770      must be broadcastable to `labels` (i.e., all dimensions must be either
771      `1`, or the same as the corresponding `labels` dimension).
772    num_thresholds: The number of thresholds to use when discretizing the roc
773      curve.
774    metrics_collections: An optional list of collections that `auc` should be
775      added to.
776    updates_collections: An optional list of collections that `update_op` should
777      be added to.
778    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
779    'PR' for the Precision-Recall-curve.
780    name: An optional variable_scope name.
781
782  Returns:
783    auc: A scalar `Tensor` representing the current area-under-curve.
784    update_op: An operation that increments the `true_positives`,
785      `true_negatives`, `false_positives` and `false_negatives` variables
786      appropriately and whose value matches `auc`.
787
788  Raises:
789    ValueError: If `predictions` and `labels` have mismatched shapes, or if
790      `weights` is not `None` and its shape doesn't match `predictions`, or if
791      either `metrics_collections` or `updates_collections` are not a list or
792      tuple.
793  """
794  return metrics.auc(
795      predictions=predictions, labels=labels, weights=weights,
796      metrics_collections=metrics_collections, num_thresholds=num_thresholds,
797      curve=curve, updates_collections=updates_collections, name=name)
798
799
800def streaming_specificity_at_sensitivity(
801    predictions, labels, sensitivity, weights=None, num_thresholds=200,
802    metrics_collections=None, updates_collections=None, name=None):
803  """Computes the specificity at a given sensitivity.
804
805  The `streaming_specificity_at_sensitivity` function creates four local
806  variables, `true_positives`, `true_negatives`, `false_positives` and
807  `false_negatives` that are used to compute the specificity at the given
808  sensitivity value. The threshold for the given sensitivity value is computed
809  and used to evaluate the corresponding specificity.
810
811  For estimation of the metric over a stream of data, the function creates an
812  `update_op` operation that updates these variables and returns the
813  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
814  `false_positives` and `false_negatives` counts with the weight of each case
815  found in the `predictions` and `labels`.
816
817  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
818
819  For additional information about specificity and sensitivity, see the
820  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
821
822  Args:
823    predictions: A floating point `Tensor` of arbitrary shape and whose values
824      are in the range `[0, 1]`.
825    labels: A `bool` `Tensor` whose shape matches `predictions`.
826    sensitivity: A scalar value in range `[0, 1]`.
827    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
828      must be broadcastable to `labels` (i.e., all dimensions must be either
829      `1`, or the same as the corresponding `labels` dimension).
830    num_thresholds: The number of thresholds to use for matching the given
831      sensitivity.
832    metrics_collections: An optional list of collections that `specificity`
833      should be added to.
834    updates_collections: An optional list of collections that `update_op` should
835      be added to.
836    name: An optional variable_scope name.
837
838  Returns:
839    specificity: A scalar `Tensor` representing the specificity at the given
840      `specificity` value.
841    update_op: An operation that increments the `true_positives`,
842      `true_negatives`, `false_positives` and `false_negatives` variables
843      appropriately and whose value matches `specificity`.
844
845  Raises:
846    ValueError: If `predictions` and `labels` have mismatched shapes, if
847      `weights` is not `None` and its shape doesn't match `predictions`, or if
848      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
849      or `updates_collections` are not a list or tuple.
850  """
851  return metrics.specificity_at_sensitivity(
852      sensitivity=sensitivity, num_thresholds=num_thresholds,
853      predictions=predictions, labels=labels, weights=weights,
854      metrics_collections=metrics_collections,
855      updates_collections=updates_collections, name=name)
856
857
858def streaming_sensitivity_at_specificity(
859    predictions, labels, specificity, weights=None, num_thresholds=200,
860    metrics_collections=None, updates_collections=None, name=None):
861  """Computes the specificity at a given sensitivity.
862
863  The `streaming_sensitivity_at_specificity` function creates four local
864  variables, `true_positives`, `true_negatives`, `false_positives` and
865  `false_negatives` that are used to compute the sensitivity at the given
866  specificity value. The threshold for the given specificity value is computed
867  and used to evaluate the corresponding sensitivity.
868
869  For estimation of the metric over a stream of data, the function creates an
870  `update_op` operation that updates these variables and returns the
871  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
872  `false_positives` and `false_negatives` counts with the weight of each case
873  found in the `predictions` and `labels`.
874
875  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
876
877  For additional information about specificity and sensitivity, see the
878  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
879
880  Args:
881    predictions: A floating point `Tensor` of arbitrary shape and whose values
882      are in the range `[0, 1]`.
883    labels: A `bool` `Tensor` whose shape matches `predictions`.
884    specificity: A scalar value in range `[0, 1]`.
885    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
886      must be broadcastable to `labels` (i.e., all dimensions must be either
887      `1`, or the same as the corresponding `labels` dimension).
888    num_thresholds: The number of thresholds to use for matching the given
889      specificity.
890    metrics_collections: An optional list of collections that `sensitivity`
891      should be added to.
892    updates_collections: An optional list of collections that `update_op` should
893      be added to.
894    name: An optional variable_scope name.
895
896  Returns:
897    sensitivity: A scalar `Tensor` representing the sensitivity at the given
898      `specificity` value.
899    update_op: An operation that increments the `true_positives`,
900      `true_negatives`, `false_positives` and `false_negatives` variables
901      appropriately and whose value matches `sensitivity`.
902
903  Raises:
904    ValueError: If `predictions` and `labels` have mismatched shapes, if
905      `weights` is not `None` and its shape doesn't match `predictions`, or if
906      `specificity` is not between 0 and 1, or if either `metrics_collections`
907      or `updates_collections` are not a list or tuple.
908  """
909  return metrics.sensitivity_at_specificity(
910      specificity=specificity, num_thresholds=num_thresholds,
911      predictions=predictions, labels=labels, weights=weights,
912      metrics_collections=metrics_collections,
913      updates_collections=updates_collections, name=name)
914
915
916def streaming_precision_at_thresholds(predictions, labels, thresholds,
917                                      weights=None,
918                                      metrics_collections=None,
919                                      updates_collections=None, name=None):
920  """Computes precision values for different `thresholds` on `predictions`.
921
922  The `streaming_precision_at_thresholds` function creates four local variables,
923  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
924  for various values of thresholds. `precision[i]` is defined as the total
925  weight of values in `predictions` above `thresholds[i]` whose corresponding
926  entry in `labels` is `True`, divided by the total weight of values in
927  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
928  false_positives[i])`).
929
930  For estimation of the metric over a stream of data, the function creates an
931  `update_op` operation that updates these variables and returns the
932  `precision`.
933
934  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
935
936  Args:
937    predictions: A floating point `Tensor` of arbitrary shape and whose values
938      are in the range `[0, 1]`.
939    labels: A `bool` `Tensor` whose shape matches `predictions`.
940    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
941    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
942      must be broadcastable to `labels` (i.e., all dimensions must be either
943      `1`, or the same as the corresponding `labels` dimension).
944    metrics_collections: An optional list of collections that `auc` should be
945      added to.
946    updates_collections: An optional list of collections that `update_op` should
947      be added to.
948    name: An optional variable_scope name.
949
950  Returns:
951    precision: A float `Tensor` of shape `[len(thresholds)]`.
952    update_op: An operation that increments the `true_positives`,
953      `true_negatives`, `false_positives` and `false_negatives` variables that
954      are used in the computation of `precision`.
955
956  Raises:
957    ValueError: If `predictions` and `labels` have mismatched shapes, or if
958      `weights` is not `None` and its shape doesn't match `predictions`, or if
959      either `metrics_collections` or `updates_collections` are not a list or
960      tuple.
961  """
962  return metrics.precision_at_thresholds(
963      thresholds=thresholds,
964      predictions=predictions, labels=labels, weights=weights,
965      metrics_collections=metrics_collections,
966      updates_collections=updates_collections, name=name)
967
968
969def streaming_recall_at_thresholds(predictions, labels, thresholds,
970                                   weights=None, metrics_collections=None,
971                                   updates_collections=None, name=None):
972  """Computes various recall values for different `thresholds` on `predictions`.
973
974  The `streaming_recall_at_thresholds` function creates four local variables,
975  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
976  for various values of thresholds. `recall[i]` is defined as the total weight
977  of values in `predictions` above `thresholds[i]` whose corresponding entry in
978  `labels` is `True`, divided by the total weight of `True` values in `labels`
979  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
980
981  For estimation of the metric over a stream of data, the function creates an
982  `update_op` operation that updates these variables and returns the `recall`.
983
984  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
985
986  Args:
987    predictions: A floating point `Tensor` of arbitrary shape and whose values
988      are in the range `[0, 1]`.
989    labels: A `bool` `Tensor` whose shape matches `predictions`.
990    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
991    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
992      must be broadcastable to `labels` (i.e., all dimensions must be either
993      `1`, or the same as the corresponding `labels` dimension).
994    metrics_collections: An optional list of collections that `recall` should be
995      added to.
996    updates_collections: An optional list of collections that `update_op` should
997      be added to.
998    name: An optional variable_scope name.
999
1000  Returns:
1001    recall: A float `Tensor` of shape `[len(thresholds)]`.
1002    update_op: An operation that increments the `true_positives`,
1003      `true_negatives`, `false_positives` and `false_negatives` variables that
1004      are used in the computation of `recall`.
1005
1006  Raises:
1007    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1008      `weights` is not `None` and its shape doesn't match `predictions`, or if
1009      either `metrics_collections` or `updates_collections` are not a list or
1010      tuple.
1011  """
1012  return metrics.recall_at_thresholds(
1013      thresholds=thresholds,
1014      predictions=predictions, labels=labels, weights=weights,
1015      metrics_collections=metrics_collections,
1016      updates_collections=updates_collections, name=name)
1017
1018
1019def _at_k_name(name, k=None, class_id=None):
1020  if k is not None:
1021    name = '%s_at_%d' % (name, k)
1022  else:
1023    name = '%s_at_k' % (name)
1024  if class_id is not None:
1025    name = '%s_class%d' % (name, class_id)
1026  return name
1027
1028
1029@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
1030            'and reshape labels from [batch_size] to [batch_size, 1].')
1031def streaming_recall_at_k(predictions, labels, k, weights=None,
1032                          metrics_collections=None, updates_collections=None,
1033                          name=None):
1034  """Computes the recall@k of the predictions with respect to dense labels.
1035
1036  The `streaming_recall_at_k` function creates two local variables, `total` and
1037  `count`, that are used to compute the recall@k frequency. This frequency is
1038  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1039  divides `total` by `count`.
1040
1041  For estimation of the metric over a stream of data, the function creates an
1042  `update_op` operation that updates these variables and returns the
1043  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
1044  shape [batch_size] whose elements indicate whether or not the corresponding
1045  label is in the top `k` `predictions`. Then `update_op` increments `total`
1046  with the reduced sum of `weights` where `in_top_k` is `True`, and it
1047  increments `count` with the reduced sum of `weights`.
1048
1049  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1050
1051  Args:
1052    predictions: A float `Tensor` of dimension [batch_size, num_classes].
1053    labels: A `Tensor` of dimension [batch_size] whose type is in `int32`,
1054      `int64`.
1055    k: The number of top elements to look at for computing recall.
1056    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1057      must be broadcastable to `labels` (i.e., all dimensions must be either
1058      `1`, or the same as the corresponding `labels` dimension).
1059    metrics_collections: An optional list of collections that `recall_at_k`
1060      should be added to.
1061    updates_collections: An optional list of collections `update_op` should be
1062      added to.
1063    name: An optional variable_scope name.
1064
1065  Returns:
1066    recall_at_k: A `Tensor` representing the recall@k, the fraction of labels
1067      which fall into the top `k` predictions.
1068    update_op: An operation that increments the `total` and `count` variables
1069      appropriately and whose value matches `recall_at_k`.
1070
1071  Raises:
1072    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1073      `weights` is not `None` and its shape doesn't match `predictions`, or if
1074      either `metrics_collections` or `updates_collections` are not a list or
1075      tuple.
1076  """
1077  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1078  return streaming_mean(in_top_k,
1079                        weights,
1080                        metrics_collections,
1081                        updates_collections,
1082                        name or _at_k_name('recall', k))
1083
1084
1085# TODO(ptucker): Validate range of values in labels?
1086def streaming_sparse_recall_at_k(predictions,
1087                                 labels,
1088                                 k,
1089                                 class_id=None,
1090                                 weights=None,
1091                                 metrics_collections=None,
1092                                 updates_collections=None,
1093                                 name=None):
1094  """Computes recall@k of the predictions with respect to sparse labels.
1095
1096  If `class_id` is not specified, we'll calculate recall as the ratio of true
1097      positives (i.e., correct predictions, items in the top `k` highest
1098      `predictions` that are found in the corresponding row in `labels`) to
1099      actual positives (the full `labels` row).
1100  If `class_id` is specified, we calculate recall by considering only the rows
1101      in the batch for which `class_id` is in `labels`, and computing the
1102      fraction of them for which `class_id` is in the corresponding row in
1103      `labels`.
1104
1105  `streaming_sparse_recall_at_k` creates two local variables,
1106  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1107  the recall_at_k frequency. This frequency is ultimately returned as
1108  `recall_at_<k>`: an idempotent operation that simply divides
1109  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1110  `false_negative_at_<k>`).
1111
1112  For estimation of the metric over a stream of data, the function creates an
1113  `update_op` operation that updates these variables and returns the
1114  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1115  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1116  `labels` calculate the true positives and false negatives weighted by
1117  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1118  `false_negative_at_<k>` using these values.
1119
1120  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1121
1122  Args:
1123    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1124      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1125      The final dimension contains the logit values for each class. [D1, ... DN]
1126      must match `labels`.
1127    labels: `int64` `Tensor` or `SparseTensor` with shape
1128      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1129      target classes for the associated prediction. Commonly, N=1 and `labels`
1130      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
1131      Values should be in range [0, num_classes), where num_classes is the last
1132      dimension of `predictions`. Values outside this range always count
1133      towards `false_negative_at_<k>`.
1134    k: Integer, k for @k metric.
1135    class_id: Integer class ID for which we want binary metrics. This should be
1136      in range [0, num_classes), where num_classes is the last dimension of
1137      `predictions`. If class_id is outside this range, the method returns NAN.
1138    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1139      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1140      dimensions must be either `1`, or the same as the corresponding `labels`
1141      dimension).
1142    metrics_collections: An optional list of collections that values should
1143      be added to.
1144    updates_collections: An optional list of collections that updates should
1145      be added to.
1146    name: Name of new update operation, and namespace for other dependent ops.
1147
1148  Returns:
1149    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1150      by the sum of `true_positives` and `false_negatives`.
1151    update_op: `Operation` that increments `true_positives` and
1152      `false_negatives` variables appropriately, and whose value matches
1153      `recall`.
1154
1155  Raises:
1156    ValueError: If `weights` is not `None` and its shape doesn't match
1157    `predictions`, or if either `metrics_collections` or `updates_collections`
1158    are not a list or tuple.
1159  """
1160  return metrics.recall_at_k(
1161      k=k, class_id=class_id,
1162      predictions=predictions, labels=labels, weights=weights,
1163      metrics_collections=metrics_collections,
1164      updates_collections=updates_collections, name=name)
1165
1166
1167# TODO(ptucker): Validate range of values in labels?
1168def streaming_sparse_precision_at_k(predictions,
1169                                    labels,
1170                                    k,
1171                                    class_id=None,
1172                                    weights=None,
1173                                    metrics_collections=None,
1174                                    updates_collections=None,
1175                                    name=None):
1176  """Computes precision@k of the predictions with respect to sparse labels.
1177
1178  If `class_id` is not specified, we calculate precision as the ratio of true
1179      positives (i.e., correct predictions, items in the top `k` highest
1180      `predictions` that are found in the corresponding row in `labels`) to
1181      positives (all top `k` `predictions`).
1182  If `class_id` is specified, we calculate precision by considering only the
1183      rows in the batch for which `class_id` is in the top `k` highest
1184      `predictions`, and computing the fraction of them for which `class_id` is
1185      in the corresponding row in `labels`.
1186
1187  We expect precision to decrease as `k` increases.
1188
1189  `streaming_sparse_precision_at_k` creates two local variables,
1190  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1191  the precision@k frequency. This frequency is ultimately returned as
1192  `precision_at_<k>`: an idempotent operation that simply divides
1193  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1194  `false_positive_at_<k>`).
1195
1196  For estimation of the metric over a stream of data, the function creates an
1197  `update_op` operation that updates these variables and returns the
1198  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1199  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1200  `labels` calculate the true positives and false positives weighted by
1201  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1202  `false_positive_at_<k>` using these values.
1203
1204  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1205
1206  Args:
1207    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1208      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1209      The final dimension contains the logit values for each class. [D1, ... DN]
1210      must match `labels`.
1211    labels: `int64` `Tensor` or `SparseTensor` with shape
1212      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1213      target classes for the associated prediction. Commonly, N=1 and `labels`
1214      has shape [batch_size, num_labels]. [D1, ... DN] must match
1215      `predictions`. Values should be in range [0, num_classes), where
1216      num_classes is the last dimension of `predictions`. Values outside this
1217      range are ignored.
1218    k: Integer, k for @k metric.
1219    class_id: Integer class ID for which we want binary metrics. This should be
1220      in range [0, num_classes], where num_classes is the last dimension of
1221      `predictions`. If `class_id` is outside this range, the method returns
1222      NAN.
1223    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1224      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1225      dimensions must be either `1`, or the same as the corresponding `labels`
1226      dimension).
1227    metrics_collections: An optional list of collections that values should
1228      be added to.
1229    updates_collections: An optional list of collections that updates should
1230      be added to.
1231    name: Name of new update operation, and namespace for other dependent ops.
1232
1233  Returns:
1234    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1235      divided by the sum of `true_positives` and `false_positives`.
1236    update_op: `Operation` that increments `true_positives` and
1237      `false_positives` variables appropriately, and whose value matches
1238      `precision`.
1239
1240  Raises:
1241    ValueError: If `weights` is not `None` and its shape doesn't match
1242      `predictions`, or if either `metrics_collections` or `updates_collections`
1243      are not a list or tuple.
1244  """
1245  return metrics.sparse_precision_at_k(
1246      k=k, class_id=class_id,
1247      predictions=predictions, labels=labels, weights=weights,
1248      metrics_collections=metrics_collections,
1249      updates_collections=updates_collections, name=name)
1250
1251
1252# TODO(ptucker): Validate range of values in labels?
1253def streaming_sparse_precision_at_top_k(top_k_predictions,
1254                                        labels,
1255                                        class_id=None,
1256                                        weights=None,
1257                                        metrics_collections=None,
1258                                        updates_collections=None,
1259                                        name=None):
1260  """Computes precision@k of top-k predictions with respect to sparse labels.
1261
1262  If `class_id` is not specified, we calculate precision as the ratio of
1263      true positives (i.e., correct predictions, items in `top_k_predictions`
1264      that are found in the corresponding row in `labels`) to positives (all
1265      `top_k_predictions`).
1266  If `class_id` is specified, we calculate precision by considering only the
1267      rows in the batch for which `class_id` is in the top `k` highest
1268      `predictions`, and computing the fraction of them for which `class_id` is
1269      in the corresponding row in `labels`.
1270
1271  We expect precision to decrease as `k` increases.
1272
1273  `streaming_sparse_precision_at_top_k` creates two local variables,
1274  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
1275  the precision@k frequency. This frequency is ultimately returned as
1276  `precision_at_k`: an idempotent operation that simply divides
1277  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
1278
1279  For estimation of the metric over a stream of data, the function creates an
1280  `update_op` operation that updates these variables and returns the
1281  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
1282  and `labels` calculate the true positives and false positives weighted by
1283  `weights`. Then `update_op` increments `true_positive_at_k` and
1284  `false_positive_at_k` using these values.
1285
1286  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1287
1288  Args:
1289    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
1290      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
1291      The final dimension contains the indices of top-k labels. [D1, ... DN]
1292      must match `labels`.
1293    labels: `int64` `Tensor` or `SparseTensor` with shape
1294      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1295      target classes for the associated prediction. Commonly, N=1 and `labels`
1296      has shape [batch_size, num_labels]. [D1, ... DN] must match
1297      `top_k_predictions`. Values should be in range [0, num_classes), where
1298      num_classes is the last dimension of `predictions`. Values outside this
1299      range are ignored.
1300    class_id: Integer class ID for which we want binary metrics. This should be
1301      in range [0, num_classes), where num_classes is the last dimension of
1302      `predictions`. If `class_id` is outside this range, the method returns
1303      NAN.
1304    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1305      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1306      dimensions must be either `1`, or the same as the corresponding `labels`
1307      dimension).
1308    metrics_collections: An optional list of collections that values should
1309      be added to.
1310    updates_collections: An optional list of collections that updates should
1311      be added to.
1312    name: Name of new update operation, and namespace for other dependent ops.
1313
1314  Returns:
1315    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1316      divided by the sum of `true_positives` and `false_positives`.
1317    update_op: `Operation` that increments `true_positives` and
1318      `false_positives` variables appropriately, and whose value matches
1319      `precision`.
1320
1321  Raises:
1322    ValueError: If `weights` is not `None` and its shape doesn't match
1323      `predictions`, or if either `metrics_collections` or `updates_collections`
1324      are not a list or tuple.
1325    ValueError: If `top_k_predictions` has rank < 2.
1326  """
1327  default_name = _at_k_name('precision', class_id=class_id)
1328  with ops.name_scope(
1329      name, default_name,
1330      (top_k_predictions, labels, weights)) as name_scope:
1331    return metrics_impl._sparse_precision_at_top_k(  # pylint: disable=protected-access
1332        labels=labels,
1333        predictions_idx=top_k_predictions,
1334        class_id=class_id,
1335        weights=weights,
1336        metrics_collections=metrics_collections,
1337        updates_collections=updates_collections,
1338        name=name_scope)
1339
1340
1341def streaming_sparse_average_precision_at_k(predictions,
1342                                            labels,
1343                                            k,
1344                                            weights=None,
1345                                            metrics_collections=None,
1346                                            updates_collections=None,
1347                                            name=None):
1348  """Computes average precision@k of predictions with respect to sparse labels.
1349
1350  See `sparse_average_precision_at_k` for details on formula. `weights` are
1351  applied to the result of `sparse_average_precision_at_k`
1352
1353  `streaming_sparse_average_precision_at_k` creates two local variables,
1354  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
1355  are used to compute the frequency. This frequency is ultimately returned as
1356  `average_precision_at_<k>`: an idempotent operation that simply divides
1357  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
1358
1359  For estimation of the metric over a stream of data, the function creates an
1360  `update_op` operation that updates these variables and returns the
1361  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1362  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1363  `labels` calculate the true positives and false positives weighted by
1364  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1365  `false_positive_at_<k>` using these values.
1366
1367  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1368
1369  Args:
1370    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1371      N >= 1. Commonly, N=1 and `predictions` has shape
1372      [batch size, num_classes]. The final dimension contains the logit values
1373      for each class. [D1, ... DN] must match `labels`.
1374    labels: `int64` `Tensor` or `SparseTensor` with shape
1375      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1376      target classes for the associated prediction. Commonly, N=1 and `labels`
1377      has shape [batch_size, num_labels]. [D1, ... DN] must match
1378      `predictions_`. Values should be in range [0, num_classes), where
1379      num_classes is the last dimension of `predictions`. Values outside this
1380      range are ignored.
1381    k: Integer, k for @k metric. This will calculate an average precision for
1382      range `[1,k]`, as documented above.
1383    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1384      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1385      dimensions must be either `1`, or the same as the corresponding `labels`
1386      dimension).
1387    metrics_collections: An optional list of collections that values should
1388      be added to.
1389    updates_collections: An optional list of collections that updates should
1390      be added to.
1391    name: Name of new update operation, and namespace for other dependent ops.
1392
1393  Returns:
1394    mean_average_precision: Scalar `float64` `Tensor` with the mean average
1395      precision values.
1396    update: `Operation` that increments  variables appropriately, and whose
1397      value matches `metric`.
1398  """
1399  return metrics.sparse_average_precision_at_k(
1400      k=k, predictions=predictions, labels=labels, weights=weights,
1401      metrics_collections=metrics_collections,
1402      updates_collections=updates_collections, name=name)
1403
1404
1405def streaming_sparse_average_precision_at_top_k(top_k_predictions,
1406                                                labels,
1407                                                weights=None,
1408                                                metrics_collections=None,
1409                                                updates_collections=None,
1410                                                name=None):
1411  """Computes average precision@k of predictions with respect to sparse labels.
1412
1413  `streaming_sparse_average_precision_at_top_k` creates two local variables,
1414  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
1415  are used to compute the frequency. This frequency is ultimately returned as
1416  `average_precision_at_<k>`: an idempotent operation that simply divides
1417  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
1418
1419  For estimation of the metric over a stream of data, the function creates an
1420  `update_op` operation that updates these variables and returns the
1421  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
1422  the true positives and false positives weighted by `weights`. Then `update_op`
1423  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
1424  values.
1425
1426  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1427
1428  Args:
1429    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
1430      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
1431      dimension must be set and contains the top `k` predicted class indices.
1432      [D1, ... DN] must match `labels`. Values should be in range
1433      [0, num_classes).
1434    labels: `int64` `Tensor` or `SparseTensor` with shape
1435      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
1436      num_labels=1. N >= 1 and num_labels is the number of target classes for
1437      the associated prediction. Commonly, N=1 and `labels` has shape
1438      [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`.
1439      Values should be in range [0, num_classes).
1440    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1441      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1442      dimensions must be either `1`, or the same as the corresponding `labels`
1443      dimension).
1444    metrics_collections: An optional list of collections that values should
1445      be added to.
1446    updates_collections: An optional list of collections that updates should
1447      be added to.
1448    name: Name of new update operation, and namespace for other dependent ops.
1449
1450  Returns:
1451    mean_average_precision: Scalar `float64` `Tensor` with the mean average
1452      precision values.
1453    update: `Operation` that increments  variables appropriately, and whose
1454      value matches `metric`.
1455
1456  Raises:
1457    ValueError: if the last dimension of top_k_predictions is not set.
1458  """
1459  return metrics_impl._streaming_sparse_average_precision_at_top_k(  # pylint: disable=protected-access
1460      predictions_idx=top_k_predictions,
1461      labels=labels,
1462      weights=weights,
1463      metrics_collections=metrics_collections,
1464      updates_collections=updates_collections,
1465      name=name)
1466
1467
1468def streaming_mean_absolute_error(predictions, labels, weights=None,
1469                                  metrics_collections=None,
1470                                  updates_collections=None,
1471                                  name=None):
1472  """Computes the mean absolute error between the labels and predictions.
1473
1474  The `streaming_mean_absolute_error` function creates two local variables,
1475  `total` and `count` that are used to compute the mean absolute error. This
1476  average is weighted by `weights`, and it is ultimately returned as
1477  `mean_absolute_error`: an idempotent operation that simply divides `total` by
1478  `count`.
1479
1480  For estimation of the metric over a stream of data, the function creates an
1481  `update_op` operation that updates these variables and returns the
1482  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
1483  absolute value of the differences between `predictions` and `labels`. Then
1484  `update_op` increments `total` with the reduced sum of the product of
1485  `weights` and `absolute_errors`, and it increments `count` with the reduced
1486  sum of `weights`
1487
1488  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1489
1490  Args:
1491    predictions: A `Tensor` of arbitrary shape.
1492    labels: A `Tensor` of the same shape as `predictions`.
1493    weights: Optional `Tensor` indicating the frequency with which an example is
1494      sampled. Rank must be 0, or the same rank as `labels`, and must be
1495      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
1496      the same as the corresponding `labels` dimension).
1497    metrics_collections: An optional list of collections that
1498      `mean_absolute_error` should be added to.
1499    updates_collections: An optional list of collections that `update_op` should
1500      be added to.
1501    name: An optional variable_scope name.
1502
1503  Returns:
1504    mean_absolute_error: A `Tensor` representing the current mean, the value of
1505      `total` divided by `count`.
1506    update_op: An operation that increments the `total` and `count` variables
1507      appropriately and whose value matches `mean_absolute_error`.
1508
1509  Raises:
1510    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1511      `weights` is not `None` and its shape doesn't match `predictions`, or if
1512      either `metrics_collections` or `updates_collections` are not a list or
1513      tuple.
1514  """
1515  return metrics.mean_absolute_error(
1516      predictions=predictions, labels=labels, weights=weights,
1517      metrics_collections=metrics_collections,
1518      updates_collections=updates_collections, name=name)
1519
1520
1521def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
1522                                  metrics_collections=None,
1523                                  updates_collections=None,
1524                                  name=None):
1525  """Computes the mean relative error by normalizing with the given values.
1526
1527  The `streaming_mean_relative_error` function creates two local variables,
1528  `total` and `count` that are used to compute the mean relative absolute error.
1529  This average is weighted by `weights`, and it is ultimately returned as
1530  `mean_relative_error`: an idempotent operation that simply divides `total` by
1531  `count`.
1532
1533  For estimation of the metric over a stream of data, the function creates an
1534  `update_op` operation that updates these variables and returns the
1535  `mean_reative_error`. Internally, a `relative_errors` operation divides the
1536  absolute value of the differences between `predictions` and `labels` by the
1537  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1538  product of `weights` and `relative_errors`, and it increments `count` with the
1539  reduced sum of `weights`.
1540
1541  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1542
1543  Args:
1544    predictions: A `Tensor` of arbitrary shape.
1545    labels: A `Tensor` of the same shape as `predictions`.
1546    normalizer: A `Tensor` of the same shape as `predictions`.
1547    weights: Optional `Tensor` indicating the frequency with which an example is
1548      sampled. Rank must be 0, or the same rank as `labels`, and must be
1549      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
1550      the same as the corresponding `labels` dimension).
1551    metrics_collections: An optional list of collections that
1552      `mean_relative_error` should be added to.
1553    updates_collections: An optional list of collections that `update_op` should
1554      be added to.
1555    name: An optional variable_scope name.
1556
1557  Returns:
1558    mean_relative_error: A `Tensor` representing the current mean, the value of
1559      `total` divided by `count`.
1560    update_op: An operation that increments the `total` and `count` variables
1561      appropriately and whose value matches `mean_relative_error`.
1562
1563  Raises:
1564    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1565      `weights` is not `None` and its shape doesn't match `predictions`, or if
1566      either `metrics_collections` or `updates_collections` are not a list or
1567      tuple.
1568  """
1569  return metrics.mean_relative_error(
1570      normalizer=normalizer, predictions=predictions, labels=labels,
1571      weights=weights, metrics_collections=metrics_collections,
1572      updates_collections=updates_collections, name=name)
1573
1574
1575def streaming_mean_squared_error(predictions, labels, weights=None,
1576                                 metrics_collections=None,
1577                                 updates_collections=None,
1578                                 name=None):
1579  """Computes the mean squared error between the labels and predictions.
1580
1581  The `streaming_mean_squared_error` function creates two local variables,
1582  `total` and `count` that are used to compute the mean squared error.
1583  This average is weighted by `weights`, and it is ultimately returned as
1584  `mean_squared_error`: an idempotent operation that simply divides `total` by
1585  `count`.
1586
1587  For estimation of the metric over a stream of data, the function creates an
1588  `update_op` operation that updates these variables and returns the
1589  `mean_squared_error`. Internally, a `squared_error` operation computes the
1590  element-wise square of the difference between `predictions` and `labels`. Then
1591  `update_op` increments `total` with the reduced sum of the product of
1592  `weights` and `squared_error`, and it increments `count` with the reduced sum
1593  of `weights`.
1594
1595  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1596
1597  Args:
1598    predictions: A `Tensor` of arbitrary shape.
1599    labels: A `Tensor` of the same shape as `predictions`.
1600    weights: Optional `Tensor` indicating the frequency with which an example is
1601      sampled. Rank must be 0, or the same rank as `labels`, and must be
1602      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
1603      the same as the corresponding `labels` dimension).
1604    metrics_collections: An optional list of collections that
1605      `mean_squared_error` should be added to.
1606    updates_collections: An optional list of collections that `update_op` should
1607      be added to.
1608    name: An optional variable_scope name.
1609
1610  Returns:
1611    mean_squared_error: A `Tensor` representing the current mean, the value of
1612      `total` divided by `count`.
1613    update_op: An operation that increments the `total` and `count` variables
1614      appropriately and whose value matches `mean_squared_error`.
1615
1616  Raises:
1617    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1618      `weights` is not `None` and its shape doesn't match `predictions`, or if
1619      either `metrics_collections` or `updates_collections` are not a list or
1620      tuple.
1621  """
1622  return metrics.mean_squared_error(
1623      predictions=predictions, labels=labels, weights=weights,
1624      metrics_collections=metrics_collections,
1625      updates_collections=updates_collections, name=name)
1626
1627
1628def streaming_root_mean_squared_error(predictions, labels, weights=None,
1629                                      metrics_collections=None,
1630                                      updates_collections=None,
1631                                      name=None):
1632  """Computes the root mean squared error between the labels and predictions.
1633
1634  The `streaming_root_mean_squared_error` function creates two local variables,
1635  `total` and `count` that are used to compute the root mean squared error.
1636  This average is weighted by `weights`, and it is ultimately returned as
1637  `root_mean_squared_error`: an idempotent operation that takes the square root
1638  of the division of `total` by `count`.
1639
1640  For estimation of the metric over a stream of data, the function creates an
1641  `update_op` operation that updates these variables and returns the
1642  `root_mean_squared_error`. Internally, a `squared_error` operation computes
1643  the element-wise square of the difference between `predictions` and `labels`.
1644  Then `update_op` increments `total` with the reduced sum of the product of
1645  `weights` and `squared_error`, and it increments `count` with the reduced sum
1646  of `weights`.
1647
1648  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1649
1650  Args:
1651    predictions: A `Tensor` of arbitrary shape.
1652    labels: A `Tensor` of the same shape as `predictions`.
1653    weights: Optional `Tensor` indicating the frequency with which an example is
1654      sampled. Rank must be 0, or the same rank as `labels`, and must be
1655      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
1656      the same as the corresponding `labels` dimension).
1657    metrics_collections: An optional list of collections that
1658      `root_mean_squared_error` should be added to.
1659    updates_collections: An optional list of collections that `update_op` should
1660      be added to.
1661    name: An optional variable_scope name.
1662
1663  Returns:
1664    root_mean_squared_error: A `Tensor` representing the current mean, the value
1665      of `total` divided by `count`.
1666    update_op: An operation that increments the `total` and `count` variables
1667      appropriately and whose value matches `root_mean_squared_error`.
1668
1669  Raises:
1670    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1671      `weights` is not `None` and its shape doesn't match `predictions`, or if
1672      either `metrics_collections` or `updates_collections` are not a list or
1673      tuple.
1674  """
1675  return metrics.root_mean_squared_error(
1676      predictions=predictions, labels=labels, weights=weights,
1677      metrics_collections=metrics_collections,
1678      updates_collections=updates_collections, name=name)
1679
1680
1681def streaming_covariance(predictions,
1682                         labels,
1683                         weights=None,
1684                         metrics_collections=None,
1685                         updates_collections=None,
1686                         name=None):
1687  """Computes the unbiased sample covariance between `predictions` and `labels`.
1688
1689  The `streaming_covariance` function creates four local variables,
1690  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
1691  compute the sample covariance between predictions and labels across multiple
1692  batches of data. The covariance is ultimately returned as an idempotent
1693  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
1694  in order to get an unbiased estimate.
1695
1696  The algorithm used for this online computation is described in
1697  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
1698  Specifically, the formula used to combine two sample comoments is
1699  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
1700  The comoment for a single batch of data is simply
1701  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
1702
1703  If `weights` is not None, then it is used to compute weighted comoments,
1704  means, and count. NOTE: these weights are treated as "frequency weights", as
1705  opposed to "reliability weights". See discussion of the difference on
1706  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
1707
1708  To facilitate the computation of covariance across multiple batches of data,
1709  the function creates an `update_op` operation, which updates underlying
1710  variables and returns the updated covariance.
1711
1712  Args:
1713    predictions: A `Tensor` of arbitrary size.
1714    labels: A `Tensor` of the same size as `predictions`.
1715    weights: Optional `Tensor` indicating the frequency with which an example is
1716      sampled. Rank must be 0, or the same rank as `labels`, and must be
1717      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
1718      the same as the corresponding `labels` dimension).
1719    metrics_collections: An optional list of collections that the metric
1720      value variable should be added to.
1721    updates_collections: An optional list of collections that the metric update
1722      ops should be added to.
1723    name: An optional variable_scope name.
1724
1725  Returns:
1726    covariance: A `Tensor` representing the current unbiased sample covariance,
1727      `comoment` / (`count` - 1).
1728    update_op: An operation that updates the local variables appropriately.
1729
1730  Raises:
1731    ValueError: If labels and predictions are of different sizes or if either
1732      `metrics_collections` or `updates_collections` are not a list or tuple.
1733  """
1734  with variable_scope.variable_scope(
1735      name, 'covariance', (predictions, labels, weights)):
1736    predictions, labels, weights = _remove_squeezable_dimensions(
1737        predictions, labels, weights)
1738    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1739    count = _create_local('count', [])
1740    mean_prediction = _create_local('mean_prediction', [])
1741    mean_label = _create_local('mean_label', [])
1742    comoment = _create_local('comoment', [])  # C_A in update equation
1743
1744    if weights is None:
1745      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
1746      weighted_predictions = predictions
1747      weighted_labels = labels
1748    else:
1749      weights = _broadcast_weights(weights, labels)
1750      batch_count = math_ops.reduce_sum(weights)  # n_B in eqn
1751      weighted_predictions = math_ops.multiply(predictions, weights)
1752      weighted_labels = math_ops.multiply(labels, weights)
1753
1754    update_count = state_ops.assign_add(count, batch_count)  # n_AB in eqn
1755    prev_count = update_count - batch_count  # n_A in update equation
1756
1757    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
1758    # batch_mean_prediction is E[x_B] in the update equation
1759    batch_mean_prediction = _safe_div(
1760        math_ops.reduce_sum(weighted_predictions), batch_count,
1761        'batch_mean_prediction')
1762    delta_mean_prediction = _safe_div(
1763        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
1764        'delta_mean_prediction')
1765    update_mean_prediction = state_ops.assign_add(mean_prediction,
1766                                                  delta_mean_prediction)
1767    # prev_mean_prediction is E[x_A] in the update equation
1768    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
1769
1770    # batch_mean_label is E[y_B] in the update equation
1771    batch_mean_label = _safe_div(
1772        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
1773    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
1774                                 update_count, 'delta_mean_label')
1775    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
1776    # prev_mean_label is E[y_A] in the update equation
1777    prev_mean_label = update_mean_label - delta_mean_label
1778
1779    unweighted_batch_coresiduals = (
1780        (predictions - batch_mean_prediction) * (labels - batch_mean_label))
1781    # batch_comoment is C_B in the update equation
1782    if weights is None:
1783      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
1784    else:
1785      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals *
1786                                           weights)
1787
1788    # View delta_comoment as = C_AB - C_A in the update equation above.
1789    # Since C_A is stored in a var, by how much do we need to increment that var
1790    # to make the var = C_AB?
1791    delta_comoment = (batch_comoment +
1792                      (prev_mean_prediction - batch_mean_prediction) *
1793                      (prev_mean_label - batch_mean_label) *
1794                      (prev_count * batch_count / update_count))
1795    update_comoment = state_ops.assign_add(comoment, delta_comoment)
1796
1797    covariance = _safe_div(comoment, count - 1, 'covariance')
1798    with ops.control_dependencies([update_comoment]):
1799      update_op = _safe_div(comoment, count - 1, 'update_op')
1800
1801  if metrics_collections:
1802    ops.add_to_collections(metrics_collections, covariance)
1803
1804  if updates_collections:
1805    ops.add_to_collections(updates_collections, update_op)
1806
1807  return covariance, update_op
1808
1809
1810def streaming_pearson_correlation(predictions,
1811                                  labels,
1812                                  weights=None,
1813                                  metrics_collections=None,
1814                                  updates_collections=None,
1815                                  name=None):
1816  """Computes Pearson correlation coefficient between `predictions`, `labels`.
1817
1818  The `streaming_pearson_correlation` function delegates to
1819  `streaming_covariance` the tracking of three [co]variances:
1820
1821  - `streaming_covariance(predictions, labels)`, i.e. covariance
1822  - `streaming_covariance(predictions, predictions)`, i.e. variance
1823  - `streaming_covariance(labels, labels)`, i.e. variance
1824
1825  The product-moment correlation ultimately returned is an idempotent operation
1826  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
1827  facilitate correlation computation across multiple batches, the function
1828  groups the `update_op`s of the underlying streaming_covariance and returns an
1829  `update_op`.
1830
1831  If `weights` is not None, then it is used to compute a weighted correlation.
1832  NOTE: these weights are treated as "frequency weights", as opposed to
1833  "reliability weights". See discussion of the difference on
1834  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
1835
1836  Args:
1837    predictions: A `Tensor` of arbitrary size.
1838    labels: A `Tensor` of the same size as predictions.
1839    weights: Optional `Tensor` indicating the frequency with which an example is
1840      sampled. Rank must be 0, or the same rank as `labels`, and must be
1841      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
1842      the same as the corresponding `labels` dimension).
1843    metrics_collections: An optional list of collections that the metric
1844      value variable should be added to.
1845    updates_collections: An optional list of collections that the metric update
1846      ops should be added to.
1847    name: An optional variable_scope name.
1848
1849  Returns:
1850    pearson_r: A `Tensor` representing the current Pearson product-moment
1851      correlation coefficient, the value of
1852      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
1853    update_op: An operation that updates the underlying variables appropriately.
1854
1855  Raises:
1856    ValueError: If `labels` and `predictions` are of different sizes, or if
1857      `weights` is the wrong size, or if either `metrics_collections` or
1858      `updates_collections` are not a `list` or `tuple`.
1859  """
1860  with variable_scope.variable_scope(
1861      name, 'pearson_r', (predictions, labels, weights)):
1862    predictions, labels, weights = _remove_squeezable_dimensions(
1863        predictions, labels, weights)
1864    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1865    # Broadcast weights here to avoid duplicate broadcasting in each call to
1866    # `streaming_covariance`.
1867    if weights is not None:
1868      weights = _broadcast_weights(weights, labels)
1869    cov, update_cov = streaming_covariance(
1870        predictions, labels, weights=weights, name='covariance')
1871    var_predictions, update_var_predictions = streaming_covariance(
1872        predictions, predictions, weights=weights, name='variance_predictions')
1873    var_labels, update_var_labels = streaming_covariance(
1874        labels, labels, weights=weights, name='variance_labels')
1875
1876    pearson_r = _safe_div(
1877        cov,
1878        math_ops.multiply(math_ops.sqrt(var_predictions),
1879                          math_ops.sqrt(var_labels)),
1880        'pearson_r')
1881    update_op = _safe_div(
1882        update_cov,
1883        math_ops.multiply(math_ops.sqrt(update_var_predictions),
1884                          math_ops.sqrt(update_var_labels)),
1885        'update_op')
1886
1887  if metrics_collections:
1888    ops.add_to_collections(metrics_collections, pearson_r)
1889
1890  if updates_collections:
1891    ops.add_to_collections(updates_collections, update_op)
1892
1893  return pearson_r, update_op
1894
1895
1896# TODO(nsilberman): add a 'normalized' flag so that the user can request
1897# normalization if the inputs are not normalized.
1898def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
1899                                   metrics_collections=None,
1900                                   updates_collections=None,
1901                                   name=None):
1902  """Computes the cosine distance between the labels and predictions.
1903
1904  The `streaming_mean_cosine_distance` function creates two local variables,
1905  `total` and `count` that are used to compute the average cosine distance
1906  between `predictions` and `labels`. This average is weighted by `weights`,
1907  and it is ultimately returned as `mean_distance`, which is an idempotent
1908  operation that simply divides `total` by `count`.
1909
1910  For estimation of the metric over a stream of data, the function creates an
1911  `update_op` operation that updates these variables and returns the
1912  `mean_distance`.
1913
1914  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1915
1916  Args:
1917    predictions: A `Tensor` of the same shape as `labels`.
1918    labels: A `Tensor` of arbitrary shape.
1919    dim: The dimension along which the cosine distance is computed.
1920    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
1921      and whose dimension `dim` is 1.
1922    metrics_collections: An optional list of collections that the metric
1923      value variable should be added to.
1924    updates_collections: An optional list of collections that the metric update
1925      ops should be added to.
1926    name: An optional variable_scope name.
1927
1928  Returns:
1929    mean_distance: A `Tensor` representing the current mean, the value of
1930      `total` divided by `count`.
1931    update_op: An operation that increments the `total` and `count` variables
1932      appropriately.
1933
1934  Raises:
1935    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1936      `weights` is not `None` and its shape doesn't match `predictions`, or if
1937      either `metrics_collections` or `updates_collections` are not a list or
1938      tuple.
1939  """
1940  predictions, labels, weights = _remove_squeezable_dimensions(
1941      predictions, labels, weights)
1942  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1943  radial_diffs = math_ops.multiply(predictions, labels)
1944  radial_diffs = math_ops.reduce_sum(radial_diffs,
1945                                     reduction_indices=[dim,],
1946                                     keep_dims=True)
1947  mean_distance, update_op = streaming_mean(radial_diffs, weights,
1948                                            None,
1949                                            None,
1950                                            name or 'mean_cosine_distance')
1951  mean_distance = math_ops.subtract(1.0, mean_distance)
1952  update_op = math_ops.subtract(1.0, update_op)
1953
1954  if metrics_collections:
1955    ops.add_to_collections(metrics_collections, mean_distance)
1956
1957  if updates_collections:
1958    ops.add_to_collections(updates_collections, update_op)
1959
1960  return mean_distance, update_op
1961
1962
1963def streaming_percentage_less(values, threshold, weights=None,
1964                              metrics_collections=None,
1965                              updates_collections=None,
1966                              name=None):
1967  """Computes the percentage of values less than the given threshold.
1968
1969  The `streaming_percentage_less` function creates two local variables,
1970  `total` and `count` that are used to compute the percentage of `values` that
1971  fall below `threshold`. This rate is weighted by `weights`, and it is
1972  ultimately returned as `percentage` which is an idempotent operation that
1973  simply divides `total` by `count`.
1974
1975  For estimation of the metric over a stream of data, the function creates an
1976  `update_op` operation that updates these variables and returns the
1977  `percentage`.
1978
1979  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1980
1981  Args:
1982    values: A numeric `Tensor` of arbitrary size.
1983    threshold: A scalar threshold.
1984    weights: An optional `Tensor` whose shape is broadcastable to `values`.
1985    metrics_collections: An optional list of collections that the metric
1986      value variable should be added to.
1987    updates_collections: An optional list of collections that the metric update
1988      ops should be added to.
1989    name: An optional variable_scope name.
1990
1991  Returns:
1992    percentage: A `Tensor` representing the current mean, the value of `total`
1993      divided by `count`.
1994    update_op: An operation that increments the `total` and `count` variables
1995      appropriately.
1996
1997  Raises:
1998    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1999      or if either `metrics_collections` or `updates_collections` are not a list
2000      or tuple.
2001  """
2002  return metrics.percentage_below(
2003      values=values, threshold=threshold, weights=weights,
2004      metrics_collections=metrics_collections,
2005      updates_collections=updates_collections, name=name)
2006
2007
2008def streaming_mean_iou(predictions,
2009                       labels,
2010                       num_classes,
2011                       weights=None,
2012                       metrics_collections=None,
2013                       updates_collections=None,
2014                       name=None):
2015  """Calculate per-step mean Intersection-Over-Union (mIOU).
2016
2017  Mean Intersection-Over-Union is a common evaluation metric for
2018  semantic image segmentation, which first computes the IOU for each
2019  semantic class and then computes the average over classes.
2020  IOU is defined as follows:
2021    IOU = true_positive / (true_positive + false_positive + false_negative).
2022  The predictions are accumulated in a confusion matrix, weighted by `weights`,
2023  and mIOU is then calculated from it.
2024
2025  For estimation of the metric over a stream of data, the function creates an
2026  `update_op` operation that updates these variables and returns the `mean_iou`.
2027
2028  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2029
2030  Args:
2031    predictions: A `Tensor` of prediction results for semantic labels, whose
2032      shape is [batch size] and type `int32` or `int64`. The tensor will be
2033      flattened, if its rank > 1.
2034    labels: A `Tensor` of ground truth labels with shape [batch size] and of
2035      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
2036    num_classes: The possible number of labels the prediction task can
2037      have. This value must be provided, since a confusion matrix of
2038      dimension = [num_classes, num_classes] will be allocated.
2039    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2040    metrics_collections: An optional list of collections that `mean_iou`
2041      should be added to.
2042    updates_collections: An optional list of collections `update_op` should be
2043      added to.
2044    name: An optional variable_scope name.
2045
2046  Returns:
2047    mean_iou: A `Tensor` representing the mean intersection-over-union.
2048    update_op: An operation that increments the confusion matrix.
2049
2050  Raises:
2051    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2052      `weights` is not `None` and its shape doesn't match `predictions`, or if
2053      either `metrics_collections` or `updates_collections` are not a list or
2054      tuple.
2055  """
2056  return metrics.mean_iou(
2057      num_classes=num_classes, predictions=predictions, labels=labels,
2058      weights=weights, metrics_collections=metrics_collections,
2059      updates_collections=updates_collections, name=name)
2060
2061
2062def _next_array_size(required_size, growth_factor=1.5):
2063  """Calculate the next size for reallocating a dynamic array.
2064
2065  Args:
2066    required_size: number or tf.Tensor specifying required array capacity.
2067    growth_factor: optional number or tf.Tensor specifying the growth factor
2068      between subsequent allocations.
2069
2070  Returns:
2071    tf.Tensor with dtype=int32 giving the next array size.
2072  """
2073  exponent = math_ops.ceil(
2074      math_ops.log(math_ops.cast(required_size, dtypes.float32))
2075      / math_ops.log(math_ops.cast(growth_factor, dtypes.float32)))
2076  return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32)
2077
2078
2079def streaming_concat(values,
2080                     axis=0,
2081                     max_size=None,
2082                     metrics_collections=None,
2083                     updates_collections=None,
2084                     name=None):
2085  """Concatenate values along an axis across batches.
2086
2087  The function `streaming_concat` creates two local variables, `array` and
2088  `size`, that are used to store concatenated values. Internally, `array` is
2089  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
2090  that updates can be run in amortized constant time.
2091
2092  For estimation of the metric over a stream of data, the function creates an
2093  `update_op` operation that appends the values of a tensor and returns the
2094  length of the concatenated axis.
2095
2096  This op allows for evaluating metrics that cannot be updated incrementally
2097  using the same framework as other streaming metrics.
2098
2099  Args:
2100    values: `Tensor` to concatenate. Rank and the shape along all axes other
2101      than the axis to concatenate along must be statically known.
2102    axis: optional integer axis to concatenate along.
2103    max_size: optional integer maximum size of `value` along the given axis.
2104      Once the maximum size is reached, further updates are no-ops. By default,
2105      there is no maximum size: the array is resized as necessary.
2106    metrics_collections: An optional list of collections that `value`
2107      should be added to.
2108    updates_collections: An optional list of collections `update_op` should be
2109      added to.
2110    name: An optional variable_scope name.
2111
2112  Returns:
2113    value: A `Tensor` representing the concatenated values.
2114    update_op: An operation that concatenates the next values.
2115
2116  Raises:
2117    ValueError: if `values` does not have a statically known rank, `axis` is
2118      not in the valid range or the size of `values` is not statically known
2119      along any axis other than `axis`.
2120  """
2121  with variable_scope.variable_scope(name, 'streaming_concat', (values,)):
2122    # pylint: disable=invalid-slice-index
2123    values_shape = values.get_shape()
2124    if values_shape.dims is None:
2125      raise ValueError('`values` must have known statically known rank')
2126
2127    ndim = len(values_shape)
2128    if axis < 0:
2129      axis += ndim
2130    if not 0 <= axis < ndim:
2131      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
2132
2133    fixed_shape = [dim.value for n, dim in enumerate(values_shape)
2134                   if n != axis]
2135    if any(value is None for value in fixed_shape):
2136      raise ValueError('all dimensions of `values` other than the dimension to '
2137                       'concatenate along must have statically known size')
2138
2139    # We move `axis` to the front of the internal array so assign ops can be
2140    # applied to contiguous slices
2141    init_size = 0 if max_size is None else max_size
2142    init_shape = [init_size] + fixed_shape
2143    array = _create_local(
2144        'array', shape=init_shape, validate_shape=False, dtype=values.dtype)
2145    size = _create_local('size', shape=[], dtype=dtypes.int32)
2146
2147    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
2148    valid_array = array[:size]
2149    valid_array.set_shape([None] + fixed_shape)
2150    value = array_ops.transpose(valid_array, perm, name='concat')
2151
2152    values_size = array_ops.shape(values)[axis]
2153    if max_size is None:
2154      batch_size = values_size
2155    else:
2156      batch_size = math_ops.minimum(values_size, max_size - size)
2157
2158    perm = [axis] + [n for n in range(ndim) if n != axis]
2159    batch_values = array_ops.transpose(values, perm)[:batch_size]
2160
2161    def reallocate():
2162      next_size = _next_array_size(new_size)
2163      next_shape = array_ops.stack([next_size] + fixed_shape)
2164      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
2165      old_value = array.value()
2166      assign_op = state_ops.assign(array, new_value, validate_shape=False)
2167      with ops.control_dependencies([assign_op]):
2168        copy_op = array[:size].assign(old_value[:size])
2169      # return value needs to be the same dtype as no_op() for cond
2170      with ops.control_dependencies([copy_op]):
2171        return control_flow_ops.no_op()
2172
2173    new_size = size + batch_size
2174    array_size = array_ops.shape_internal(array, optimize=False)[0]
2175    maybe_reallocate_op = control_flow_ops.cond(
2176        new_size > array_size, reallocate, control_flow_ops.no_op)
2177    with ops.control_dependencies([maybe_reallocate_op]):
2178      append_values_op = array[size:new_size].assign(batch_values)
2179    with ops.control_dependencies([append_values_op]):
2180      update_op = size.assign(new_size)
2181
2182    if metrics_collections:
2183      ops.add_to_collections(metrics_collections, value)
2184
2185    if updates_collections:
2186      ops.add_to_collections(updates_collections, update_op)
2187
2188    return value, update_op
2189    # pylint: enable=invalid-slice-index
2190
2191
2192def aggregate_metrics(*value_update_tuples):
2193  """Aggregates the metric value tensors and update ops into two lists.
2194
2195  Args:
2196    *value_update_tuples: a variable number of tuples, each of which contain the
2197      pair of (value_tensor, update_op) from a streaming metric.
2198
2199  Returns:
2200    A list of value `Tensor` objects and a list of update ops.
2201
2202  Raises:
2203    ValueError: if `value_update_tuples` is empty.
2204  """
2205  if not value_update_tuples:
2206    raise ValueError('Expected at least one value_tensor/update_op pair')
2207  value_ops, update_ops = zip(*value_update_tuples)
2208  return list(value_ops), list(update_ops)
2209
2210
2211def aggregate_metric_map(names_to_tuples):
2212  """Aggregates the metric names to tuple dictionary.
2213
2214  This function is useful for pairing metric names with their associated value
2215  and update ops when the list of metrics is long. For example:
2216
2217  ```python
2218    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
2219        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
2220            predictions, labels, weights),
2221        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
2222            predictions, labels, labels, weights),
2223        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
2224            predictions, labels, weights),
2225        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
2226            predictions, labels, weights),
2227    })
2228  ```
2229
2230  Args:
2231    names_to_tuples: a map of metric names to tuples, each of which contain the
2232      pair of (value_tensor, update_op) from a streaming metric.
2233
2234  Returns:
2235    A dictionary from metric names to value ops and a dictionary from metric
2236    names to update ops.
2237  """
2238  metric_names = names_to_tuples.keys()
2239  value_ops, update_ops = zip(*names_to_tuples.values())
2240  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
2241
2242
2243def _remove_squeezable_dimensions(predictions, labels, weights):
2244  """Squeeze last dim if needed.
2245
2246  Squeezes `predictions` and `labels` if their rank differs by 1.
2247  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
2248
2249  This will use static shape if available. Otherwise, it will add graph
2250  operations, which could result in a performance hit.
2251
2252  Args:
2253    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
2254    labels: Label values, a `Tensor` whose dimensions match `predictions`.
2255    weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
2256      more than the new rank of `predictions`
2257
2258  Returns:
2259    Tuple of `predictions`, `labels` and `weights`, possibly with the last
2260    dimension squeezed.
2261  """
2262  predictions, labels = tensor_util.remove_squeezable_dimensions(
2263      predictions, labels)
2264  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2265
2266  if weights is not None:
2267    weights = ops.convert_to_tensor(weights)
2268    predictions_shape = predictions.get_shape()
2269    predictions_rank = predictions_shape.ndims
2270    weights_shape = weights.get_shape()
2271    weights_rank = weights_shape.ndims
2272
2273    if (predictions_rank is not None) and (weights_rank is not None):
2274      # Use static rank.
2275      if weights_rank - predictions_rank == 1:
2276        weights = array_ops.squeeze(weights, [-1])
2277    elif (weights_rank is None) or (
2278        weights_shape.dims[-1].is_compatible_with(1)):
2279      # Use dynamic rank
2280      weights = control_flow_ops.cond(
2281          math_ops.equal(array_ops.rank(weights),
2282                         math_ops.add(array_ops.rank(predictions), 1)),
2283          lambda: array_ops.squeeze(weights, [-1]),
2284          lambda: weights)
2285  return predictions, labels, weights
2286
2287
2288__all__ = [
2289    'aggregate_metric_map',
2290    'aggregate_metrics',
2291    'streaming_accuracy',
2292    'streaming_auc',
2293    'streaming_false_negatives',
2294    'streaming_false_negatives_at_thresholds',
2295    'streaming_false_positives',
2296    'streaming_false_positives_at_thresholds',
2297    'streaming_mean',
2298    'streaming_mean_absolute_error',
2299    'streaming_mean_cosine_distance',
2300    'streaming_mean_iou',
2301    'streaming_mean_relative_error',
2302    'streaming_mean_squared_error',
2303    'streaming_mean_tensor',
2304    'streaming_percentage_less',
2305    'streaming_precision',
2306    'streaming_precision_at_thresholds',
2307    'streaming_recall',
2308    'streaming_recall_at_k',
2309    'streaming_recall_at_thresholds',
2310    'streaming_root_mean_squared_error',
2311    'streaming_sensitivity_at_specificity',
2312    'streaming_sparse_average_precision_at_k',
2313    'streaming_sparse_precision_at_k',
2314    'streaming_sparse_recall_at_k',
2315    'streaming_specificity_at_sensitivity',
2316    'streaming_true_negatives',
2317    'streaming_true_negatives_at_thresholds',
2318    'streaming_true_positives',
2319    'streaming_true_positives_at_thresholds',
2320]
2321