metric_ops.py revision ebcae4a5e3bf5c840d73a0d90f1b5bf01a68f82c
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains metric-computing operations on streamed tensors.
16
17Module documentation, including "@@" callouts, should be put in
18third_party/tensorflow/contrib/metrics/__init__.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections as collections_lib
26
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import check_ops
31from tensorflow.python.ops import confusion_matrix
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import metrics
35from tensorflow.python.ops import metrics_impl
36from tensorflow.python.ops import nn
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.ops import weights_broadcast_ops
40from tensorflow.python.util.deprecation import deprecated
41
42
43def _safe_div(numerator, denominator, name):
44  """Divides two values, returning 0 if the denominator is <= 0.
45
46  Args:
47    numerator: A real `Tensor`.
48    denominator: A real `Tensor`, with dtype matching `numerator`.
49    name: Name for the returned op.
50
51  Returns:
52    0 if `denominator` <= 0, else `numerator` / `denominator`
53  """
54  return array_ops.where(
55      math_ops.greater(denominator, 0),
56      math_ops.truediv(numerator, denominator),
57      0,
58      name=name)
59
60
61def _create_local(name,
62                  shape,
63                  collections=None,
64                  validate_shape=True,
65                  dtype=dtypes.float32):
66  """Creates a new local variable.
67
68  Args:
69    name: The name of the new or existing variable.
70    shape: Shape of the new or existing variable.
71    collections: A list of collection names to which the Variable will be added.
72    validate_shape: Whether to validate the shape of the variable.
73    dtype: Data type of the variables.
74
75  Returns:
76    The created variable.
77  """
78  # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
79  collections = list(collections or [])
80  collections += [ops.GraphKeys.LOCAL_VARIABLES]
81  return variable_scope.variable(
82      initial_value=array_ops.zeros(shape, dtype=dtype),
83      name=name,
84      trainable=False,
85      collections=collections,
86      validate_shape=validate_shape)
87
88
89# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py.
90def _assert_weights_rank(weights, values):
91  """`weights` rank must be either `0`, or the same as 'values'."""
92  return check_ops.assert_rank_in(weights, (0, array_ops.rank(values)))
93
94
95def _count_condition(values,
96                     weights=None,
97                     metrics_collections=None,
98                     updates_collections=None):
99  """Sums the weights of cases where the given values are True.
100
101  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
102
103  Args:
104    values: A `bool` `Tensor` of arbitrary size.
105    weights: Optional `Tensor` whose rank is either 0, or the same rank as
106      `values`, and must be broadcastable to `values` (i.e., all dimensions
107      must be either `1`, or the same as the corresponding `values`
108      dimension).
109    metrics_collections: An optional list of collections that the metric
110      value variable should be added to.
111    updates_collections: An optional list of collections that the metric update
112      ops should be added to.
113
114  Returns:
115    value_tensor: A `Tensor` representing the current value of the metric.
116    update_op: An operation that accumulates the error from a batch of data.
117
118  Raises:
119    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
120      or if either `metrics_collections` or `updates_collections` are not a list
121      or tuple.
122  """
123  check_ops.assert_type(values, dtypes.bool)
124  count = _create_local('count', shape=[])
125
126  values = math_ops.to_float(values)
127  if weights is not None:
128    weights = math_ops.to_float(weights)
129    with ops.control_dependencies((_assert_weights_rank(weights, values),)):
130      values = math_ops.multiply(values, weights)
131
132  value_tensor = array_ops.identity(count)
133  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
134
135  if metrics_collections:
136    ops.add_to_collections(metrics_collections, value_tensor)
137
138  if updates_collections:
139    ops.add_to_collections(updates_collections, update_op)
140
141  return value_tensor, update_op
142
143
144def streaming_true_positives(predictions,
145                             labels,
146                             weights=None,
147                             metrics_collections=None,
148                             updates_collections=None,
149                             name=None):
150  """Sum the weights of true_positives.
151
152  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
153
154  Args:
155    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
156      be cast to `bool`.
157    labels: The ground truth values, a `Tensor` whose dimensions must match
158      `predictions`. Will be cast to `bool`.
159    weights: Optional `Tensor` whose rank is either 0, or the same rank as
160      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
161      must be either `1`, or the same as the corresponding `labels`
162      dimension).
163    metrics_collections: An optional list of collections that the metric
164      value variable should be added to.
165    updates_collections: An optional list of collections that the metric update
166      ops should be added to.
167    name: An optional variable_scope name.
168
169  Returns:
170    value_tensor: A `Tensor` representing the current value of the metric.
171    update_op: An operation that accumulates the error from a batch of data.
172
173  Raises:
174    ValueError: If `predictions` and `labels` have mismatched shapes, or if
175      `weights` is not `None` and its shape doesn't match `predictions`, or if
176      either `metrics_collections` or `updates_collections` are not a list or
177      tuple.
178  """
179  return metrics.true_positives(
180      predictions=predictions,
181      labels=labels,
182      weights=weights,
183      metrics_collections=metrics_collections,
184      updates_collections=updates_collections,
185      name=name)
186
187
188def streaming_true_negatives(predictions,
189                             labels,
190                             weights=None,
191                             metrics_collections=None,
192                             updates_collections=None,
193                             name=None):
194  """Sum the weights of true_negatives.
195
196  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
197
198  Args:
199    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
200      be cast to `bool`.
201    labels: The ground truth values, a `Tensor` whose dimensions must match
202      `predictions`. Will be cast to `bool`.
203    weights: Optional `Tensor` whose rank is either 0, or the same rank as
204      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
205      must be either `1`, or the same as the corresponding `labels`
206      dimension).
207    metrics_collections: An optional list of collections that the metric
208      value variable should be added to.
209    updates_collections: An optional list of collections that the metric update
210      ops should be added to.
211    name: An optional variable_scope name.
212
213  Returns:
214    value_tensor: A `Tensor` representing the current value of the metric.
215    update_op: An operation that accumulates the error from a batch of data.
216
217  Raises:
218    ValueError: If `predictions` and `labels` have mismatched shapes, or if
219      `weights` is not `None` and its shape doesn't match `predictions`, or if
220      either `metrics_collections` or `updates_collections` are not a list or
221      tuple.
222  """
223  with variable_scope.variable_scope(name, 'true_negatives',
224                                     (predictions, labels, weights)):
225
226    predictions, labels, weights = _remove_squeezable_dimensions(
227        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
228        labels=math_ops.cast(labels, dtype=dtypes.bool),
229        weights=weights)
230    is_true_negative = math_ops.logical_and(
231        math_ops.equal(labels, False), math_ops.equal(predictions, False))
232    return _count_condition(is_true_negative, weights, metrics_collections,
233                            updates_collections)
234
235
236def streaming_false_positives(predictions,
237                              labels,
238                              weights=None,
239                              metrics_collections=None,
240                              updates_collections=None,
241                              name=None):
242  """Sum the weights of false positives.
243
244  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
245
246  Args:
247    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
248      be cast to `bool`.
249    labels: The ground truth values, a `Tensor` whose dimensions must match
250      `predictions`. Will be cast to `bool`.
251    weights: Optional `Tensor` whose rank is either 0, or the same rank as
252      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
253      must be either `1`, or the same as the corresponding `labels`
254      dimension).
255    metrics_collections: An optional list of collections that the metric
256      value variable should be added to.
257    updates_collections: An optional list of collections that the metric update
258      ops should be added to.
259    name: An optional variable_scope name.
260
261  Returns:
262    value_tensor: A `Tensor` representing the current value of the metric.
263    update_op: An operation that accumulates the error from a batch of data.
264
265  Raises:
266    ValueError: If `predictions` and `labels` have mismatched shapes, or if
267      `weights` is not `None` and its shape doesn't match `predictions`, or if
268      either `metrics_collections` or `updates_collections` are not a list or
269      tuple.
270  """
271  return metrics.false_positives(
272      predictions=predictions,
273      labels=labels,
274      weights=weights,
275      metrics_collections=metrics_collections,
276      updates_collections=updates_collections,
277      name=name)
278
279
280def streaming_false_negatives(predictions,
281                              labels,
282                              weights=None,
283                              metrics_collections=None,
284                              updates_collections=None,
285                              name=None):
286  """Computes the total number of false negatives.
287
288  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
289
290  Args:
291    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
292      be cast to `bool`.
293    labels: The ground truth values, a `Tensor` whose dimensions must match
294      `predictions`. Will be cast to `bool`.
295    weights: Optional `Tensor` whose rank is either 0, or the same rank as
296      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
297      must be either `1`, or the same as the corresponding `labels`
298      dimension).
299    metrics_collections: An optional list of collections that the metric
300      value variable should be added to.
301    updates_collections: An optional list of collections that the metric update
302      ops should be added to.
303    name: An optional variable_scope name.
304
305  Returns:
306    value_tensor: A `Tensor` representing the current value of the metric.
307    update_op: An operation that accumulates the error from a batch of data.
308
309  Raises:
310    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
311      or if either `metrics_collections` or `updates_collections` are not a list
312      or tuple.
313  """
314  return metrics.false_negatives(
315      predictions=predictions,
316      labels=labels,
317      weights=weights,
318      metrics_collections=metrics_collections,
319      updates_collections=updates_collections,
320      name=name)
321
322
323# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py.
324def _broadcast_weights(weights, values):
325  """Broadcast `weights` to the same shape as `values`.
326
327  This returns a version of `weights` following the same broadcast rules as
328  `mul(weights, values)`. When computing a weighted average, use this function
329  to broadcast `weights` before summing them; e.g.,
330  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
331
332  Args:
333    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
334      must be broadcastable to `values` (i.e., all dimensions must be either
335      `1`, or the same as the corresponding `values` dimension).
336    values: `Tensor` of any shape.
337
338  Returns:
339    `weights` broadcast to `values` shape.
340  """
341  with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope:
342    weights_shape = weights.get_shape()
343    values_shape = values.get_shape()
344    if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and
345        weights_shape.is_compatible_with(values_shape)):
346      return weights
347    with ops.control_dependencies((_assert_weights_rank(weights, values),)):
348      return math_ops.multiply(weights, array_ops.ones_like(values), name=scope)
349
350
351def streaming_mean(values,
352                   weights=None,
353                   metrics_collections=None,
354                   updates_collections=None,
355                   name=None):
356  """Computes the (weighted) mean of the given values.
357
358  The `streaming_mean` function creates two local variables, `total` and `count`
359  that are used to compute the average of `values`. This average is ultimately
360  returned as `mean` which is an idempotent operation that simply divides
361  `total` by `count`.
362
363  For estimation of the metric over a stream of data, the function creates an
364  `update_op` operation that updates these variables and returns the `mean`.
365  `update_op` increments `total` with the reduced sum of the product of `values`
366  and `weights`, and it increments `count` with the reduced sum of `weights`.
367
368  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
369
370  Args:
371    values: A `Tensor` of arbitrary dimensions.
372    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
373      must be broadcastable to `values` (i.e., all dimensions must be either
374      `1`, or the same as the corresponding `values` dimension).
375    metrics_collections: An optional list of collections that `mean`
376      should be added to.
377    updates_collections: An optional list of collections that `update_op`
378      should be added to.
379    name: An optional variable_scope name.
380
381  Returns:
382    mean: A `Tensor` representing the current mean, the value of `total` divided
383      by `count`.
384    update_op: An operation that increments the `total` and `count` variables
385      appropriately and whose value matches `mean`.
386
387  Raises:
388    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
389      or if either `metrics_collections` or `updates_collections` are not a list
390      or tuple.
391  """
392  return metrics.mean(
393      values=values,
394      weights=weights,
395      metrics_collections=metrics_collections,
396      updates_collections=updates_collections,
397      name=name)
398
399
400def streaming_mean_tensor(values,
401                          weights=None,
402                          metrics_collections=None,
403                          updates_collections=None,
404                          name=None):
405  """Computes the element-wise (weighted) mean of the given tensors.
406
407  In contrast to the `streaming_mean` function which returns a scalar with the
408  mean,  this function returns an average tensor with the same shape as the
409  input tensors.
410
411  The `streaming_mean_tensor` function creates two local variables,
412  `total_tensor` and `count_tensor` that are used to compute the average of
413  `values`. This average is ultimately returned as `mean` which is an idempotent
414  operation that simply divides `total` by `count`.
415
416  For estimation of the metric over a stream of data, the function creates an
417  `update_op` operation that updates these variables and returns the `mean`.
418  `update_op` increments `total` with the reduced sum of the product of `values`
419  and `weights`, and it increments `count` with the reduced sum of `weights`.
420
421  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
422
423  Args:
424    values: A `Tensor` of arbitrary dimensions.
425    weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
426      must be broadcastable to `values` (i.e., all dimensions must be either
427      `1`, or the same as the corresponding `values` dimension).
428    metrics_collections: An optional list of collections that `mean`
429      should be added to.
430    updates_collections: An optional list of collections that `update_op`
431      should be added to.
432    name: An optional variable_scope name.
433
434  Returns:
435    mean: A float `Tensor` representing the current mean, the value of `total`
436      divided by `count`.
437    update_op: An operation that increments the `total` and `count` variables
438      appropriately and whose value matches `mean`.
439
440  Raises:
441    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
442      or if either `metrics_collections` or `updates_collections` are not a list
443      or tuple.
444  """
445  return metrics.mean_tensor(
446      values=values,
447      weights=weights,
448      metrics_collections=metrics_collections,
449      updates_collections=updates_collections,
450      name=name)
451
452
453def streaming_accuracy(predictions,
454                       labels,
455                       weights=None,
456                       metrics_collections=None,
457                       updates_collections=None,
458                       name=None):
459  """Calculates how often `predictions` matches `labels`.
460
461  The `streaming_accuracy` function creates two local variables, `total` and
462  `count` that are used to compute the frequency with which `predictions`
463  matches `labels`. This frequency is ultimately returned as `accuracy`: an
464  idempotent operation that simply divides `total` by `count`.
465
466  For estimation of the metric over a stream of data, the function creates an
467  `update_op` operation that updates these variables and returns the `accuracy`.
468  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
469  where the corresponding elements of `predictions` and `labels` match and 0.0
470  otherwise. Then `update_op` increments `total` with the reduced sum of the
471  product of `weights` and `is_correct`, and it increments `count` with the
472  reduced sum of `weights`.
473
474  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
475
476  Args:
477    predictions: The predicted values, a `Tensor` of any shape.
478    labels: The ground truth values, a `Tensor` whose shape matches
479      `predictions`.
480    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
481      must be broadcastable to `labels` (i.e., all dimensions must be either
482      `1`, or the same as the corresponding `labels` dimension).
483    metrics_collections: An optional list of collections that `accuracy` should
484      be added to.
485    updates_collections: An optional list of collections that `update_op` should
486      be added to.
487    name: An optional variable_scope name.
488
489  Returns:
490    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
491      by `count`.
492    update_op: An operation that increments the `total` and `count` variables
493      appropriately and whose value matches `accuracy`.
494
495  Raises:
496    ValueError: If `predictions` and `labels` have mismatched shapes, or if
497      `weights` is not `None` and its shape doesn't match `predictions`, or if
498      either `metrics_collections` or `updates_collections` are not a list or
499      tuple.
500  """
501  return metrics.accuracy(
502      predictions=predictions,
503      labels=labels,
504      weights=weights,
505      metrics_collections=metrics_collections,
506      updates_collections=updates_collections,
507      name=name)
508
509
510def streaming_precision(predictions,
511                        labels,
512                        weights=None,
513                        metrics_collections=None,
514                        updates_collections=None,
515                        name=None):
516  """Computes the precision of the predictions with respect to the labels.
517
518  The `streaming_precision` function creates two local variables,
519  `true_positives` and `false_positives`, that are used to compute the
520  precision. This value is ultimately returned as `precision`, an idempotent
521  operation that simply divides `true_positives` by the sum of `true_positives`
522  and `false_positives`.
523
524  For estimation of the metric over a stream of data, the function creates an
525  `update_op` operation that updates these variables and returns the
526  `precision`. `update_op` weights each prediction by the corresponding value in
527  `weights`.
528
529  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
530
531  Args:
532    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
533    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
534      match `predictions`.
535    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
536      must be broadcastable to `labels` (i.e., all dimensions must be either
537      `1`, or the same as the corresponding `labels` dimension).
538    metrics_collections: An optional list of collections that `precision` should
539      be added to.
540    updates_collections: An optional list of collections that `update_op` should
541      be added to.
542    name: An optional variable_scope name.
543
544  Returns:
545    precision: Scalar float `Tensor` with the value of `true_positives`
546      divided by the sum of `true_positives` and `false_positives`.
547    update_op: `Operation` that increments `true_positives` and
548      `false_positives` variables appropriately and whose value matches
549      `precision`.
550
551  Raises:
552    ValueError: If `predictions` and `labels` have mismatched shapes, or if
553      `weights` is not `None` and its shape doesn't match `predictions`, or if
554      either `metrics_collections` or `updates_collections` are not a list or
555      tuple.
556  """
557  return metrics.precision(
558      predictions=predictions,
559      labels=labels,
560      weights=weights,
561      metrics_collections=metrics_collections,
562      updates_collections=updates_collections,
563      name=name)
564
565
566def streaming_recall(predictions,
567                     labels,
568                     weights=None,
569                     metrics_collections=None,
570                     updates_collections=None,
571                     name=None):
572  """Computes the recall of the predictions with respect to the labels.
573
574  The `streaming_recall` function creates two local variables, `true_positives`
575  and `false_negatives`, that are used to compute the recall. This value is
576  ultimately returned as `recall`, an idempotent operation that simply divides
577  `true_positives` by the sum of `true_positives`  and `false_negatives`.
578
579  For estimation of the metric over a stream of data, the function creates an
580  `update_op` that updates these variables and returns the `recall`. `update_op`
581  weights each prediction by the corresponding value in `weights`.
582
583  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
584
585  Args:
586    predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
587    labels: The ground truth values, a `bool` `Tensor` whose dimensions must
588      match `predictions`.
589    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
590      must be broadcastable to `labels` (i.e., all dimensions must be either
591      `1`, or the same as the corresponding `labels` dimension).
592    metrics_collections: An optional list of collections that `recall` should
593      be added to.
594    updates_collections: An optional list of collections that `update_op` should
595      be added to.
596    name: An optional variable_scope name.
597
598  Returns:
599    recall: Scalar float `Tensor` with the value of `true_positives` divided
600      by the sum of `true_positives` and `false_negatives`.
601    update_op: `Operation` that increments `true_positives` and
602      `false_negatives` variables appropriately and whose value matches
603      `recall`.
604
605  Raises:
606    ValueError: If `predictions` and `labels` have mismatched shapes, or if
607      `weights` is not `None` and its shape doesn't match `predictions`, or if
608      either `metrics_collections` or `updates_collections` are not a list or
609      tuple.
610  """
611  return metrics.recall(
612      predictions=predictions,
613      labels=labels,
614      weights=weights,
615      metrics_collections=metrics_collections,
616      updates_collections=updates_collections,
617      name=name)
618
619
620def _true_negatives(labels,
621                    predictions,
622                    weights=None,
623                    metrics_collections=None,
624                    updates_collections=None,
625                    name=None):
626  """Sum the weights of true negatives.
627
628  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
629
630  Args:
631    labels: The ground truth values, a `Tensor` whose dimensions must match
632      `predictions`. Will be cast to `bool`.
633    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
634      be cast to `bool`.
635    weights: Optional `Tensor` whose rank is either 0, or the same rank as
636      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
637      be either `1`, or the same as the corresponding `labels` dimension).
638    metrics_collections: An optional list of collections that the metric
639      value variable should be added to.
640    updates_collections: An optional list of collections that the metric update
641      ops should be added to.
642    name: An optional variable_scope name.
643
644  Returns:
645    value_tensor: A `Tensor` representing the current value of the metric.
646    update_op: An operation that accumulates the error from a batch of data.
647
648  Raises:
649    ValueError: If `predictions` and `labels` have mismatched shapes, or if
650      `weights` is not `None` and its shape doesn't match `predictions`, or if
651      either `metrics_collections` or `updates_collections` are not a list or
652      tuple.
653  """
654  with variable_scope.variable_scope(name, 'true_negatives',
655                                     (predictions, labels, weights)):
656
657    predictions, labels, weights = _remove_squeezable_dimensions(
658        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
659        labels=math_ops.cast(labels, dtype=dtypes.bool),
660        weights=weights)
661    is_true_negative = math_ops.logical_and(
662        math_ops.equal(labels, False), math_ops.equal(predictions, False))
663    return _count_condition(is_true_negative, weights, metrics_collections,
664                            updates_collections)
665
666
667def streaming_false_positive_rate(predictions,
668                                  labels,
669                                  weights=None,
670                                  metrics_collections=None,
671                                  updates_collections=None,
672                                  name=None):
673  """Computes the false positive rate of predictions with respect to labels.
674
675  The `false_positive_rate` function creates two local variables,
676  `false_positives` and `true_negatives`, that are used to compute the
677  false positive rate. This value is ultimately returned as
678  `false_positive_rate`, an idempotent operation that simply divides
679  `false_positives` by the sum of `false_positives` and `true_negatives`.
680
681  For estimation of the metric over a stream of data, the function creates an
682  `update_op` operation that updates these variables and returns the
683  `false_positive_rate`. `update_op` weights each prediction by the
684  corresponding value in `weights`.
685
686  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
687
688  Args:
689    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
690      be cast to `bool`.
691    labels: The ground truth values, a `Tensor` whose dimensions must match
692      `predictions`. Will be cast to `bool`.
693    weights: Optional `Tensor` whose rank is either 0, or the same rank as
694      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
695      be either `1`, or the same as the corresponding `labels` dimension).
696    metrics_collections: An optional list of collections that
697     `false_positive_rate` should be added to.
698    updates_collections: An optional list of collections that `update_op` should
699      be added to.
700    name: An optional variable_scope name.
701
702  Returns:
703    false_positive_rate: Scalar float `Tensor` with the value of
704      `false_positives` divided by the sum of `false_positives` and
705      `true_negatives`.
706    update_op: `Operation` that increments `false_positives` and
707      `true_negatives` variables appropriately and whose value matches
708      `false_positive_rate`.
709
710  Raises:
711    ValueError: If `predictions` and `labels` have mismatched shapes, or if
712      `weights` is not `None` and its shape doesn't match `predictions`, or if
713      either `metrics_collections` or `updates_collections` are not a list or
714      tuple.
715  """
716  with variable_scope.variable_scope(name, 'false_positive_rate',
717                                     (predictions, labels, weights)):
718    predictions, labels, weights = _remove_squeezable_dimensions(
719        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
720        labels=math_ops.cast(labels, dtype=dtypes.bool),
721        weights=weights)
722
723    false_p, false_positives_update_op = metrics.false_positives(
724        labels,
725        predictions,
726        weights,
727        metrics_collections=None,
728        updates_collections=None,
729        name=None)
730    true_n, true_negatives_update_op = _true_negatives(
731        labels,
732        predictions,
733        weights,
734        metrics_collections=None,
735        updates_collections=None,
736        name=None)
737
738    def compute_fpr(fp, tn, name):
739      return array_ops.where(
740          math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name)
741
742    fpr = compute_fpr(false_p, true_n, 'value')
743    update_op = compute_fpr(false_positives_update_op, true_negatives_update_op,
744                            'update_op')
745
746    if metrics_collections:
747      ops.add_to_collections(metrics_collections, fpr)
748
749    if updates_collections:
750      ops.add_to_collections(updates_collections, update_op)
751
752    return fpr, update_op
753
754
755def streaming_false_negative_rate(predictions,
756                                  labels,
757                                  weights=None,
758                                  metrics_collections=None,
759                                  updates_collections=None,
760                                  name=None):
761  """Computes the false negative rate of predictions with respect to labels.
762
763  The `false_negative_rate` function creates two local variables,
764  `false_negatives` and `true_positives`, that are used to compute the
765  false positive rate. This value is ultimately returned as
766  `false_negative_rate`, an idempotent operation that simply divides
767  `false_negatives` by the sum of `false_negatives` and `true_positives`.
768
769  For estimation of the metric over a stream of data, the function creates an
770  `update_op` operation that updates these variables and returns the
771  `false_negative_rate`. `update_op` weights each prediction by the
772  corresponding value in `weights`.
773
774  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
775
776  Args:
777    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
778      be cast to `bool`.
779    labels: The ground truth values, a `Tensor` whose dimensions must match
780      `predictions`. Will be cast to `bool`.
781    weights: Optional `Tensor` whose rank is either 0, or the same rank as
782      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
783      be either `1`, or the same as the corresponding `labels` dimension).
784    metrics_collections: An optional list of collections that
785      `false_negative_rate` should be added to.
786    updates_collections: An optional list of collections that `update_op` should
787      be added to.
788    name: An optional variable_scope name.
789
790  Returns:
791    false_negative_rate: Scalar float `Tensor` with the value of
792      `false_negatives` divided by the sum of `false_negatives` and
793      `true_positives`.
794    update_op: `Operation` that increments `false_negatives` and
795      `true_positives` variables appropriately and whose value matches
796      `false_negative_rate`.
797
798  Raises:
799    ValueError: If `predictions` and `labels` have mismatched shapes, or if
800      `weights` is not `None` and its shape doesn't match `predictions`, or if
801      either `metrics_collections` or `updates_collections` are not a list or
802      tuple.
803  """
804  with variable_scope.variable_scope(name, 'false_negative_rate',
805                                     (predictions, labels, weights)):
806    predictions, labels, weights = _remove_squeezable_dimensions(
807        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
808        labels=math_ops.cast(labels, dtype=dtypes.bool),
809        weights=weights)
810
811    false_n, false_negatives_update_op = metrics.false_negatives(
812        labels,
813        predictions,
814        weights,
815        metrics_collections=None,
816        updates_collections=None,
817        name=None)
818    true_p, true_positives_update_op = metrics.true_positives(
819        labels,
820        predictions,
821        weights,
822        metrics_collections=None,
823        updates_collections=None,
824        name=None)
825
826    def compute_fnr(fn, tp, name):
827      return array_ops.where(
828          math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name)
829
830    fnr = compute_fnr(false_n, true_p, 'value')
831    update_op = compute_fnr(false_negatives_update_op, true_positives_update_op,
832                            'update_op')
833
834    if metrics_collections:
835      ops.add_to_collections(metrics_collections, fnr)
836
837    if updates_collections:
838      ops.add_to_collections(updates_collections, update_op)
839
840    return fnr, update_op
841
842
843def _streaming_confusion_matrix_at_thresholds(predictions,
844                                              labels,
845                                              thresholds,
846                                              weights=None,
847                                              includes=None):
848  """Computes true_positives, false_negatives, true_negatives, false_positives.
849
850  This function creates up to four local variables, `true_positives`,
851  `true_negatives`, `false_positives` and `false_negatives`.
852  `true_positive[i]` is defined as the total weight of values in `predictions`
853  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
854  `false_negatives[i]` is defined as the total weight of values in `predictions`
855  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
856  `true_negatives[i]` is defined as the total weight of values in `predictions`
857  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
858  `false_positives[i]` is defined as the total weight of values in `predictions`
859  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
860
861  For estimation of these metrics over a stream of data, for each metric the
862  function respectively creates an `update_op` operation that updates the
863  variable and returns its value.
864
865  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
866
867  Args:
868    predictions: A floating point `Tensor` of arbitrary shape and whose values
869      are in the range `[0, 1]`.
870    labels: A `Tensor` whose shape matches `predictions`. `labels` will be cast
871      to `bool`.
872    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
873    weights: Optional `Tensor` whose rank is either 0, or the same rank as
874      `labels`, and must be broadcastable to `labels` (i.e., all dimensions
875      must be either `1`, or the same as the corresponding `labels`
876      dimension).
877    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
878      default to all four.
879
880  Returns:
881    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
882        `includes`.
883    update_ops: Dict of operations that increments the `values`. Keys are from
884        `includes`.
885
886  Raises:
887    ValueError: If `predictions` and `labels` have mismatched shapes, or if
888      `weights` is not `None` and its shape doesn't match `predictions`, or if
889      `includes` contains invalid keys.
890  """
891  all_includes = ('tp', 'fn', 'tn', 'fp')
892  if includes is None:
893    includes = all_includes
894  else:
895    for include in includes:
896      if include not in all_includes:
897        raise ValueError('Invaild key: %s.' % include)
898
899  predictions, labels, weights = _remove_squeezable_dimensions(
900      predictions, labels, weights)
901  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
902
903  num_thresholds = len(thresholds)
904
905  # Reshape predictions and labels.
906  predictions_2d = array_ops.reshape(predictions, [-1, 1])
907  labels_2d = array_ops.reshape(
908      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
909
910  # Use static shape if known.
911  num_predictions = predictions_2d.get_shape().as_list()[0]
912
913  # Otherwise use dynamic shape.
914  if num_predictions is None:
915    num_predictions = array_ops.shape(predictions_2d)[0]
916  thresh_tiled = array_ops.tile(
917      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
918      array_ops.stack([1, num_predictions]))
919
920  # Tile the predictions after thresholding them across different thresholds.
921  pred_is_pos = math_ops.greater(
922      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
923      thresh_tiled)
924  if ('fn' in includes) or ('tn' in includes):
925    pred_is_neg = math_ops.logical_not(pred_is_pos)
926
927  # Tile labels by number of thresholds
928  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
929  if ('fp' in includes) or ('tn' in includes):
930    label_is_neg = math_ops.logical_not(label_is_pos)
931
932  if weights is not None:
933    broadcast_weights = weights_broadcast_ops.broadcast_weights(
934        math_ops.to_float(weights), predictions)
935    weights_tiled = array_ops.tile(
936        array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1])
937    thresh_tiled.get_shape().assert_is_compatible_with(
938        weights_tiled.get_shape())
939  else:
940    weights_tiled = None
941
942  values = {}
943  update_ops = {}
944
945  if 'tp' in includes:
946    true_positives = _create_local('true_positives', shape=[num_thresholds])
947    is_true_positive = math_ops.to_float(
948        math_ops.logical_and(label_is_pos, pred_is_pos))
949    if weights_tiled is not None:
950      is_true_positive *= weights_tiled
951    update_ops['tp'] = state_ops.assign_add(true_positives,
952                                            math_ops.reduce_sum(
953                                                is_true_positive, 1))
954    values['tp'] = true_positives
955
956  if 'fn' in includes:
957    false_negatives = _create_local('false_negatives', shape=[num_thresholds])
958    is_false_negative = math_ops.to_float(
959        math_ops.logical_and(label_is_pos, pred_is_neg))
960    if weights_tiled is not None:
961      is_false_negative *= weights_tiled
962    update_ops['fn'] = state_ops.assign_add(false_negatives,
963                                            math_ops.reduce_sum(
964                                                is_false_negative, 1))
965    values['fn'] = false_negatives
966
967  if 'tn' in includes:
968    true_negatives = _create_local('true_negatives', shape=[num_thresholds])
969    is_true_negative = math_ops.to_float(
970        math_ops.logical_and(label_is_neg, pred_is_neg))
971    if weights_tiled is not None:
972      is_true_negative *= weights_tiled
973    update_ops['tn'] = state_ops.assign_add(true_negatives,
974                                            math_ops.reduce_sum(
975                                                is_true_negative, 1))
976    values['tn'] = true_negatives
977
978  if 'fp' in includes:
979    false_positives = _create_local('false_positives', shape=[num_thresholds])
980    is_false_positive = math_ops.to_float(
981        math_ops.logical_and(label_is_neg, pred_is_pos))
982    if weights_tiled is not None:
983      is_false_positive *= weights_tiled
984    update_ops['fp'] = state_ops.assign_add(false_positives,
985                                            math_ops.reduce_sum(
986                                                is_false_positive, 1))
987    values['fp'] = false_positives
988
989  return values, update_ops
990
991
992def streaming_true_positives_at_thresholds(predictions,
993                                           labels,
994                                           thresholds,
995                                           weights=None):
996  values, update_ops = _streaming_confusion_matrix_at_thresholds(
997      predictions, labels, thresholds, weights=weights, includes=('tp',))
998  return values['tp'], update_ops['tp']
999
1000
1001def streaming_false_negatives_at_thresholds(predictions,
1002                                            labels,
1003                                            thresholds,
1004                                            weights=None):
1005  values, update_ops = _streaming_confusion_matrix_at_thresholds(
1006      predictions, labels, thresholds, weights=weights, includes=('fn',))
1007  return values['fn'], update_ops['fn']
1008
1009
1010def streaming_false_positives_at_thresholds(predictions,
1011                                            labels,
1012                                            thresholds,
1013                                            weights=None):
1014  values, update_ops = _streaming_confusion_matrix_at_thresholds(
1015      predictions, labels, thresholds, weights=weights, includes=('fp',))
1016  return values['fp'], update_ops['fp']
1017
1018
1019def streaming_true_negatives_at_thresholds(predictions,
1020                                           labels,
1021                                           thresholds,
1022                                           weights=None):
1023  values, update_ops = _streaming_confusion_matrix_at_thresholds(
1024      predictions, labels, thresholds, weights=weights, includes=('tn',))
1025  return values['tn'], update_ops['tn']
1026
1027
1028def streaming_curve_points(labels=None,
1029                           predictions=None,
1030                           weights=None,
1031                           num_thresholds=200,
1032                           metrics_collections=None,
1033                           updates_collections=None,
1034                           curve='ROC',
1035                           name=None):
1036  """Computes curve (ROC or PR) values for a prespecified number of points.
1037
1038  The `streaming_curve_points` function creates four local variables,
1039  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1040  that are used to compute the curve values. To discretize the curve, a linearly
1041  spaced set of thresholds is used to compute pairs of recall and precision
1042  values.
1043
1044  For best results, `predictions` should be distributed approximately uniformly
1045  in the range [0, 1] and not peaked around 0 or 1.
1046
1047  For estimation of the metric over a stream of data, the function creates an
1048  `update_op` operation that updates these variables.
1049
1050  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1051
1052  Args:
1053    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1054      `bool`.
1055    predictions: A floating point `Tensor` of arbitrary shape and whose values
1056      are in the range `[0, 1]`.
1057    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1058      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1059      be either `1`, or the same as the corresponding `labels` dimension).
1060    num_thresholds: The number of thresholds to use when discretizing the roc
1061      curve.
1062    metrics_collections: An optional list of collections that `auc` should be
1063      added to.
1064    updates_collections: An optional list of collections that `update_op` should
1065      be added to.
1066    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
1067      'PR' for the Precision-Recall-curve.
1068    name: An optional variable_scope name.
1069
1070  Returns:
1071    points: A `Tensor` with shape [num_thresholds, 2] that contains points of
1072      the curve.
1073    update_op: An operation that increments the `true_positives`,
1074      `true_negatives`, `false_positives` and `false_negatives` variables.
1075
1076  Raises:
1077    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1078      `weights` is not `None` and its shape doesn't match `predictions`, or if
1079      either `metrics_collections` or `updates_collections` are not a list or
1080      tuple.
1081
1082  TODO(chizeng): Consider rewriting this method to make use of logic within the
1083  streaming_precision_recall_at_equal_thresholds method (to improve run time).
1084  """
1085  with variable_scope.variable_scope(name, 'curve_points',
1086                                     (labels, predictions, weights)):
1087    if curve != 'ROC' and curve != 'PR':
1088      raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
1089    kepsilon = 1e-7  # to account for floating point imprecisions
1090    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1091                  for i in range(num_thresholds - 2)]
1092    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
1093
1094    values, update_ops = _streaming_confusion_matrix_at_thresholds(
1095        labels=labels,
1096        predictions=predictions,
1097        thresholds=thresholds,
1098        weights=weights)
1099
1100    # Add epsilons to avoid dividing by 0.
1101    epsilon = 1.0e-6
1102
1103    def compute_points(tp, fn, tn, fp):
1104      """Computes the roc-auc or pr-auc based on confusion counts."""
1105      rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
1106      if curve == 'ROC':
1107        fp_rate = math_ops.div(fp, fp + tn + epsilon)
1108        return fp_rate, rec
1109      else:  # curve == 'PR'.
1110        prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
1111        return rec, prec
1112
1113    xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
1114                            values['fp'])
1115    points = array_ops.stack([xs, ys], axis=1)
1116    update_op = control_flow_ops.group(*update_ops.values())
1117
1118    if metrics_collections:
1119      ops.add_to_collections(metrics_collections, points)
1120
1121    if updates_collections:
1122      ops.add_to_collections(updates_collections, update_op)
1123
1124    return points, update_op
1125
1126
1127def streaming_auc(predictions,
1128                  labels,
1129                  weights=None,
1130                  num_thresholds=200,
1131                  metrics_collections=None,
1132                  updates_collections=None,
1133                  curve='ROC',
1134                  name=None):
1135  """Computes the approximate AUC via a Riemann sum.
1136
1137  The `streaming_auc` function creates four local variables, `true_positives`,
1138  `true_negatives`, `false_positives` and `false_negatives` that are used to
1139  compute the AUC. To discretize the AUC curve, a linearly spaced set of
1140  thresholds is used to compute pairs of recall and precision values. The area
1141  under the ROC-curve is therefore computed using the height of the recall
1142  values by the false positive rate, while the area under the PR-curve is the
1143  computed using the height of the precision values by the recall.
1144
1145  This value is ultimately returned as `auc`, an idempotent operation that
1146  computes the area under a discretized curve of precision versus recall values
1147  (computed using the aforementioned variables). The `num_thresholds` variable
1148  controls the degree of discretization with larger numbers of thresholds more
1149  closely approximating the true AUC. The quality of the approximation may vary
1150  dramatically depending on `num_thresholds`.
1151
1152  For best results, `predictions` should be distributed approximately uniformly
1153  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
1154  approximation may be poor if this is not the case.
1155
1156  For estimation of the metric over a stream of data, the function creates an
1157  `update_op` operation that updates these variables and returns the `auc`.
1158
1159  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1160
1161  Args:
1162    predictions: A floating point `Tensor` of arbitrary shape and whose values
1163      are in the range `[0, 1]`.
1164    labels: A `bool` `Tensor` whose shape matches `predictions`.
1165    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1166      must be broadcastable to `labels` (i.e., all dimensions must be either
1167      `1`, or the same as the corresponding `labels` dimension).
1168    num_thresholds: The number of thresholds to use when discretizing the roc
1169      curve.
1170    metrics_collections: An optional list of collections that `auc` should be
1171      added to.
1172    updates_collections: An optional list of collections that `update_op` should
1173      be added to.
1174    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
1175    'PR' for the Precision-Recall-curve.
1176    name: An optional variable_scope name.
1177
1178  Returns:
1179    auc: A scalar `Tensor` representing the current area-under-curve.
1180    update_op: An operation that increments the `true_positives`,
1181      `true_negatives`, `false_positives` and `false_negatives` variables
1182      appropriately and whose value matches `auc`.
1183
1184  Raises:
1185    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1186      `weights` is not `None` and its shape doesn't match `predictions`, or if
1187      either `metrics_collections` or `updates_collections` are not a list or
1188      tuple.
1189  """
1190  return metrics.auc(
1191      predictions=predictions,
1192      labels=labels,
1193      weights=weights,
1194      metrics_collections=metrics_collections,
1195      num_thresholds=num_thresholds,
1196      curve=curve,
1197      updates_collections=updates_collections,
1198      name=name)
1199
1200
1201def streaming_precision_recall_at_equal_thresholds(predictions,
1202                                                   labels,
1203                                                   num_thresholds=None,
1204                                                   weights=None,
1205                                                   name=None,
1206                                                   use_locking=None):
1207  """A helper method for creating metrics related to precision-recall curves.
1208
1209  These values are true positives, false negatives, true negatives, false
1210  positives, precision, and recall. This function returns a data structure that
1211  contains ops within it.
1212
1213  Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N)
1214  space and run time), this op exhibits O(T + N) space and run time, where T is
1215  the number of thresholds and N is the size of the predictions tensor. Hence,
1216  it may be advantageous to use this function when `predictions` is big.
1217
1218  For instance, prefer this method for per-pixel classification tasks, for which
1219  the predictions tensor may be very large.
1220
1221  Each number in `predictions`, a float in `[0, 1]`, is compared with its
1222  corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at
1223  each threshold. This is then multiplied with `weights` which can be used to
1224  reweight certain values, or more commonly used for masking values.
1225
1226  Args:
1227    predictions: A floating point `Tensor` of arbitrary shape and whose values
1228      are in the range `[0, 1]`.
1229    labels: A bool `Tensor` whose shape matches `predictions`.
1230    num_thresholds: Optional; Number of thresholds, evenly distributed in
1231      `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins
1232      is 1 less than `num_thresholds`. Using an even `num_thresholds` value
1233      instead of an odd one may yield unfriendly edges for bins.
1234    weights: Optional; If provided, a `Tensor` that has the same dtype as,
1235      and broadcastable to, `predictions`. This tensor is multplied by counts.
1236    name: Optional; variable_scope name. If not provided, the string
1237      'precision_recall_at_equal_threshold' is used.
1238    use_locking: Optional; If True, the op will be protected by a lock.
1239      Otherwise, the behavior is undefined, but may exhibit less contention.
1240      Defaults to True.
1241
1242  Returns:
1243    result: A named tuple (See PrecisionRecallData within the implementation of
1244      this function) with properties that are variables of shape
1245      `[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
1246      precision, recall, thresholds.
1247    update_op: An op that accumulates values.
1248
1249  Raises:
1250    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1251      `weights` is not `None` and its shape doesn't match `predictions`, or if
1252      `includes` contains invalid keys.
1253  """
1254  # Disable the invalid-name checker so that we can capitalize the name.
1255  # pylint: disable=invalid-name
1256  PrecisionRecallData = collections_lib.namedtuple(
1257      'PrecisionRecallData',
1258      ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds'])
1259  # pylint: enable=invalid-name
1260
1261  if num_thresholds is None:
1262    num_thresholds = 201
1263
1264  if weights is None:
1265    weights = 1.0
1266
1267  if use_locking is None:
1268    use_locking = True
1269
1270  check_ops.assert_type(labels, dtypes.bool)
1271
1272  dtype = predictions.dtype
1273  with variable_scope.variable_scope(name,
1274                                     'precision_recall_at_equal_thresholds',
1275                                     (labels, predictions, weights)):
1276    # Make sure that predictions are within [0.0, 1.0].
1277    with ops.control_dependencies([
1278        check_ops.assert_greater_equal(
1279            predictions,
1280            math_ops.cast(0.0, dtype=predictions.dtype),
1281            message='predictions must be in [0, 1]'),
1282        check_ops.assert_less_equal(
1283            predictions,
1284            math_ops.cast(1.0, dtype=predictions.dtype),
1285            message='predictions must be in [0, 1]')
1286    ]):
1287      predictions, labels, weights = _remove_squeezable_dimensions(
1288          predictions=predictions, labels=labels, weights=weights)
1289
1290    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1291
1292    # We cast to float to ensure we have 0.0 or 1.0.
1293    f_labels = math_ops.cast(labels, dtype)
1294
1295    # Get weighted true/false labels.
1296    true_labels = f_labels * weights
1297    false_labels = (1.0 - f_labels) * weights
1298
1299    # Flatten predictions and labels.
1300    predictions = array_ops.reshape(predictions, [-1])
1301    true_labels = array_ops.reshape(true_labels, [-1])
1302    false_labels = array_ops.reshape(false_labels, [-1])
1303
1304    # To compute TP/FP/TN/FN, we are measuring a binary classifier
1305    #   C(t) = (predictions >= t)
1306    # at each threshold 't'. So we have
1307    #   TP(t) = sum( C(t) * true_labels )
1308    #   FP(t) = sum( C(t) * false_labels )
1309    #
1310    # But, computing C(t) requires computation for each t. To make it fast,
1311    # observe that C(t) is a cumulative integral, and so if we have
1312    #   thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
1313    # where n = num_thresholds, and if we can compute the bucket function
1314    #   B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
1315    # then we get
1316    #   C(t_i) = sum( B(j), j >= i )
1317    # which is the reversed cumulative sum in tf.cumsum().
1318    #
1319    # We can compute B(i) efficiently by taking advantage of the fact that
1320    # our thresholds are evenly distributed, in that
1321    #   width = 1.0 / (num_thresholds - 1)
1322    #   thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
1323    # Given a prediction value p, we can map it to its bucket by
1324    #   bucket_index(p) = floor( p * (num_thresholds - 1) )
1325    # so we can use tf.scatter_add() to update the buckets in one pass.
1326    #
1327    # This implementation exhibits a run time and space complexity of O(T + N),
1328    # where T is the number of thresholds and N is the size of predictions.
1329    # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead
1330    # exhibit a complexity of O(T * N).
1331
1332    # Compute the bucket indices for each prediction value.
1333    bucket_indices = math_ops.cast(
1334        math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32)
1335
1336    with ops.name_scope('variables'):
1337      tp_buckets_v = _create_local(
1338          'tp_buckets', shape=[num_thresholds], dtype=dtype)
1339      fp_buckets_v = _create_local(
1340          'fp_buckets', shape=[num_thresholds], dtype=dtype)
1341
1342    with ops.name_scope('update_op'):
1343      update_tp = state_ops.scatter_add(
1344          tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking)
1345      update_fp = state_ops.scatter_add(
1346          fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking)
1347
1348    # Set up the cumulative sums to compute the actual metrics.
1349    tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp')
1350    fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp')
1351    # fn = sum(true_labels) - tp
1352    #    = sum(tp_buckets) - tp
1353    #    = tp[0] - tp
1354    # Similarly,
1355    # tn = fp[0] - fp
1356    tn = fp[0] - fp
1357    fn = tp[0] - tp
1358
1359    # We use a minimum to prevent division by 0.
1360    epsilon = 1e-7
1361    precision = tp / math_ops.maximum(epsilon, tp + fp)
1362    recall = tp / math_ops.maximum(epsilon, tp + fn)
1363
1364    result = PrecisionRecallData(
1365        tp=tp,
1366        fp=fp,
1367        tn=tn,
1368        fn=fn,
1369        precision=precision,
1370        recall=recall,
1371        thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds))
1372    update_op = control_flow_ops.group(update_tp, update_fp)
1373    return result, update_op
1374
1375
1376def streaming_specificity_at_sensitivity(predictions,
1377                                         labels,
1378                                         sensitivity,
1379                                         weights=None,
1380                                         num_thresholds=200,
1381                                         metrics_collections=None,
1382                                         updates_collections=None,
1383                                         name=None):
1384  """Computes the specificity at a given sensitivity.
1385
1386  The `streaming_specificity_at_sensitivity` function creates four local
1387  variables, `true_positives`, `true_negatives`, `false_positives` and
1388  `false_negatives` that are used to compute the specificity at the given
1389  sensitivity value. The threshold for the given sensitivity value is computed
1390  and used to evaluate the corresponding specificity.
1391
1392  For estimation of the metric over a stream of data, the function creates an
1393  `update_op` operation that updates these variables and returns the
1394  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
1395  `false_positives` and `false_negatives` counts with the weight of each case
1396  found in the `predictions` and `labels`.
1397
1398  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1399
1400  For additional information about specificity and sensitivity, see the
1401  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1402
1403  Args:
1404    predictions: A floating point `Tensor` of arbitrary shape and whose values
1405      are in the range `[0, 1]`.
1406    labels: A `bool` `Tensor` whose shape matches `predictions`.
1407    sensitivity: A scalar value in range `[0, 1]`.
1408    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1409      must be broadcastable to `labels` (i.e., all dimensions must be either
1410      `1`, or the same as the corresponding `labels` dimension).
1411    num_thresholds: The number of thresholds to use for matching the given
1412      sensitivity.
1413    metrics_collections: An optional list of collections that `specificity`
1414      should be added to.
1415    updates_collections: An optional list of collections that `update_op` should
1416      be added to.
1417    name: An optional variable_scope name.
1418
1419  Returns:
1420    specificity: A scalar `Tensor` representing the specificity at the given
1421      `specificity` value.
1422    update_op: An operation that increments the `true_positives`,
1423      `true_negatives`, `false_positives` and `false_negatives` variables
1424      appropriately and whose value matches `specificity`.
1425
1426  Raises:
1427    ValueError: If `predictions` and `labels` have mismatched shapes, if
1428      `weights` is not `None` and its shape doesn't match `predictions`, or if
1429      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
1430      or `updates_collections` are not a list or tuple.
1431  """
1432  return metrics.specificity_at_sensitivity(
1433      sensitivity=sensitivity,
1434      num_thresholds=num_thresholds,
1435      predictions=predictions,
1436      labels=labels,
1437      weights=weights,
1438      metrics_collections=metrics_collections,
1439      updates_collections=updates_collections,
1440      name=name)
1441
1442
1443def streaming_sensitivity_at_specificity(predictions,
1444                                         labels,
1445                                         specificity,
1446                                         weights=None,
1447                                         num_thresholds=200,
1448                                         metrics_collections=None,
1449                                         updates_collections=None,
1450                                         name=None):
1451  """Computes the sensitivity at a given specificity.
1452
1453  The `streaming_sensitivity_at_specificity` function creates four local
1454  variables, `true_positives`, `true_negatives`, `false_positives` and
1455  `false_negatives` that are used to compute the sensitivity at the given
1456  specificity value. The threshold for the given specificity value is computed
1457  and used to evaluate the corresponding sensitivity.
1458
1459  For estimation of the metric over a stream of data, the function creates an
1460  `update_op` operation that updates these variables and returns the
1461  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
1462  `false_positives` and `false_negatives` counts with the weight of each case
1463  found in the `predictions` and `labels`.
1464
1465  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1466
1467  For additional information about specificity and sensitivity, see the
1468  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
1469
1470  Args:
1471    predictions: A floating point `Tensor` of arbitrary shape and whose values
1472      are in the range `[0, 1]`.
1473    labels: A `bool` `Tensor` whose shape matches `predictions`.
1474    specificity: A scalar value in range `[0, 1]`.
1475    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1476      must be broadcastable to `labels` (i.e., all dimensions must be either
1477      `1`, or the same as the corresponding `labels` dimension).
1478    num_thresholds: The number of thresholds to use for matching the given
1479      specificity.
1480    metrics_collections: An optional list of collections that `sensitivity`
1481      should be added to.
1482    updates_collections: An optional list of collections that `update_op` should
1483      be added to.
1484    name: An optional variable_scope name.
1485
1486  Returns:
1487    sensitivity: A scalar `Tensor` representing the sensitivity at the given
1488      `specificity` value.
1489    update_op: An operation that increments the `true_positives`,
1490      `true_negatives`, `false_positives` and `false_negatives` variables
1491      appropriately and whose value matches `sensitivity`.
1492
1493  Raises:
1494    ValueError: If `predictions` and `labels` have mismatched shapes, if
1495      `weights` is not `None` and its shape doesn't match `predictions`, or if
1496      `specificity` is not between 0 and 1, or if either `metrics_collections`
1497      or `updates_collections` are not a list or tuple.
1498  """
1499  return metrics.sensitivity_at_specificity(
1500      specificity=specificity,
1501      num_thresholds=num_thresholds,
1502      predictions=predictions,
1503      labels=labels,
1504      weights=weights,
1505      metrics_collections=metrics_collections,
1506      updates_collections=updates_collections,
1507      name=name)
1508
1509
1510def streaming_precision_at_thresholds(predictions,
1511                                      labels,
1512                                      thresholds,
1513                                      weights=None,
1514                                      metrics_collections=None,
1515                                      updates_collections=None,
1516                                      name=None):
1517  """Computes precision values for different `thresholds` on `predictions`.
1518
1519  The `streaming_precision_at_thresholds` function creates four local variables,
1520  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1521  for various values of thresholds. `precision[i]` is defined as the total
1522  weight of values in `predictions` above `thresholds[i]` whose corresponding
1523  entry in `labels` is `True`, divided by the total weight of values in
1524  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
1525  false_positives[i])`).
1526
1527  For estimation of the metric over a stream of data, the function creates an
1528  `update_op` operation that updates these variables and returns the
1529  `precision`.
1530
1531  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1532
1533  Args:
1534    predictions: A floating point `Tensor` of arbitrary shape and whose values
1535      are in the range `[0, 1]`.
1536    labels: A `bool` `Tensor` whose shape matches `predictions`.
1537    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1538    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1539      must be broadcastable to `labels` (i.e., all dimensions must be either
1540      `1`, or the same as the corresponding `labels` dimension).
1541    metrics_collections: An optional list of collections that `precision` should
1542      be added to.
1543    updates_collections: An optional list of collections that `update_op` should
1544      be added to.
1545    name: An optional variable_scope name.
1546
1547  Returns:
1548    precision: A float `Tensor` of shape `[len(thresholds)]`.
1549    update_op: An operation that increments the `true_positives`,
1550      `true_negatives`, `false_positives` and `false_negatives` variables that
1551      are used in the computation of `precision`.
1552
1553  Raises:
1554    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1555      `weights` is not `None` and its shape doesn't match `predictions`, or if
1556      either `metrics_collections` or `updates_collections` are not a list or
1557      tuple.
1558  """
1559  return metrics.precision_at_thresholds(
1560      thresholds=thresholds,
1561      predictions=predictions,
1562      labels=labels,
1563      weights=weights,
1564      metrics_collections=metrics_collections,
1565      updates_collections=updates_collections,
1566      name=name)
1567
1568
1569def streaming_recall_at_thresholds(predictions,
1570                                   labels,
1571                                   thresholds,
1572                                   weights=None,
1573                                   metrics_collections=None,
1574                                   updates_collections=None,
1575                                   name=None):
1576  """Computes various recall values for different `thresholds` on `predictions`.
1577
1578  The `streaming_recall_at_thresholds` function creates four local variables,
1579  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
1580  for various values of thresholds. `recall[i]` is defined as the total weight
1581  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1582  `labels` is `True`, divided by the total weight of `True` values in `labels`
1583  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
1584
1585  For estimation of the metric over a stream of data, the function creates an
1586  `update_op` operation that updates these variables and returns the `recall`.
1587
1588  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1589
1590  Args:
1591    predictions: A floating point `Tensor` of arbitrary shape and whose values
1592      are in the range `[0, 1]`.
1593    labels: A `bool` `Tensor` whose shape matches `predictions`.
1594    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1595    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1596      must be broadcastable to `labels` (i.e., all dimensions must be either
1597      `1`, or the same as the corresponding `labels` dimension).
1598    metrics_collections: An optional list of collections that `recall` should be
1599      added to.
1600    updates_collections: An optional list of collections that `update_op` should
1601      be added to.
1602    name: An optional variable_scope name.
1603
1604  Returns:
1605    recall: A float `Tensor` of shape `[len(thresholds)]`.
1606    update_op: An operation that increments the `true_positives`,
1607      `true_negatives`, `false_positives` and `false_negatives` variables that
1608      are used in the computation of `recall`.
1609
1610  Raises:
1611    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1612      `weights` is not `None` and its shape doesn't match `predictions`, or if
1613      either `metrics_collections` or `updates_collections` are not a list or
1614      tuple.
1615  """
1616  return metrics.recall_at_thresholds(
1617      thresholds=thresholds,
1618      predictions=predictions,
1619      labels=labels,
1620      weights=weights,
1621      metrics_collections=metrics_collections,
1622      updates_collections=updates_collections,
1623      name=name)
1624
1625
1626def streaming_false_positive_rate_at_thresholds(predictions,
1627                                                labels,
1628                                                thresholds,
1629                                                weights=None,
1630                                                metrics_collections=None,
1631                                                updates_collections=None,
1632                                                name=None):
1633  """Computes various fpr values for different `thresholds` on `predictions`.
1634
1635  The `streaming_false_positive_rate_at_thresholds` function creates two
1636  local variables, `false_positives`, `true_negatives`, for various values of
1637  thresholds. `false_positive_rate[i]` is defined as the total weight
1638  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1639  `labels` is `False`, divided by the total weight of `False` values in `labels`
1640  (`false_positives[i] / (false_positives[i] + true_negatives[i])`).
1641
1642  For estimation of the metric over a stream of data, the function creates an
1643  `update_op` operation that updates these variables and returns the
1644  `false_positive_rate`.
1645
1646  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1647
1648  Args:
1649    predictions: A floating point `Tensor` of arbitrary shape and whose values
1650      are in the range `[0, 1]`.
1651    labels: A `bool` `Tensor` whose shape matches `predictions`.
1652    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1653    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1654      must be broadcastable to `labels` (i.e., all dimensions must be either
1655      `1`, or the same as the corresponding `labels` dimension).
1656    metrics_collections: An optional list of collections that
1657      `false_positive_rate` should be added to.
1658    updates_collections: An optional list of collections that `update_op` should
1659      be added to.
1660    name: An optional variable_scope name.
1661
1662  Returns:
1663    false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`.
1664    update_op: An operation that increments the `false_positives` and
1665      `true_negatives` variables that are used in the computation of
1666      `false_positive_rate`.
1667
1668  Raises:
1669    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1670      `weights` is not `None` and its shape doesn't match `predictions`, or if
1671      either `metrics_collections` or `updates_collections` are not a list or
1672      tuple.
1673  """
1674  with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds',
1675                                     (predictions, labels, weights)):
1676    values, update_ops = _streaming_confusion_matrix_at_thresholds(
1677        predictions, labels, thresholds, weights, includes=('fp', 'tn'))
1678
1679    # Avoid division by zero.
1680    epsilon = 1e-7
1681
1682    def compute_fpr(fp, tn, name):
1683      return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name)
1684
1685    fpr = compute_fpr(values['fp'], values['tn'], 'value')
1686    update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op')
1687
1688    if metrics_collections:
1689      ops.add_to_collections(metrics_collections, fpr)
1690
1691    if updates_collections:
1692      ops.add_to_collections(updates_collections, update_op)
1693
1694    return fpr, update_op
1695
1696
1697def streaming_false_negative_rate_at_thresholds(predictions,
1698                                                labels,
1699                                                thresholds,
1700                                                weights=None,
1701                                                metrics_collections=None,
1702                                                updates_collections=None,
1703                                                name=None):
1704  """Computes various fnr values for different `thresholds` on `predictions`.
1705
1706  The `streaming_false_negative_rate_at_thresholds` function creates two
1707  local variables, `false_negatives`, `true_positives`, for various values of
1708  thresholds. `false_negative_rate[i]` is defined as the total weight
1709  of values in `predictions` above `thresholds[i]` whose corresponding entry in
1710  `labels` is `False`, divided by the total weight of `True` values in `labels`
1711  (`false_negatives[i] / (false_negatives[i] + true_positives[i])`).
1712
1713  For estimation of the metric over a stream of data, the function creates an
1714  `update_op` operation that updates these variables and returns the
1715  `false_positive_rate`.
1716
1717  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1718
1719  Args:
1720    predictions: A floating point `Tensor` of arbitrary shape and whose values
1721      are in the range `[0, 1]`.
1722    labels: A `bool` `Tensor` whose shape matches `predictions`.
1723    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1724    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1725      must be broadcastable to `labels` (i.e., all dimensions must be either
1726      `1`, or the same as the corresponding `labels` dimension).
1727    metrics_collections: An optional list of collections that
1728      `false_negative_rate` should be added to.
1729    updates_collections: An optional list of collections that `update_op` should
1730      be added to.
1731    name: An optional variable_scope name.
1732
1733  Returns:
1734    false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`.
1735    update_op: An operation that increments the `false_negatives` and
1736      `true_positives` variables that are used in the computation of
1737      `false_negative_rate`.
1738
1739  Raises:
1740    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1741      `weights` is not `None` and its shape doesn't match `predictions`, or if
1742      either `metrics_collections` or `updates_collections` are not a list or
1743      tuple.
1744  """
1745  with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds',
1746                                     (predictions, labels, weights)):
1747    values, update_ops = _streaming_confusion_matrix_at_thresholds(
1748        predictions, labels, thresholds, weights, includes=('fn', 'tp'))
1749
1750    # Avoid division by zero.
1751    epsilon = 1e-7
1752
1753    def compute_fnr(fn, tp, name):
1754      return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name)
1755
1756    fnr = compute_fnr(values['fn'], values['tp'], 'value')
1757    update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op')
1758
1759    if metrics_collections:
1760      ops.add_to_collections(metrics_collections, fnr)
1761
1762    if updates_collections:
1763      ops.add_to_collections(updates_collections, update_op)
1764
1765    return fnr, update_op
1766
1767
1768def _at_k_name(name, k=None, class_id=None):
1769  if k is not None:
1770    name = '%s_at_%d' % (name, k)
1771  else:
1772    name = '%s_at_k' % (name)
1773  if class_id is not None:
1774    name = '%s_class%d' % (name, class_id)
1775  return name
1776
1777
1778@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
1779            'and reshape labels from [batch_size] to [batch_size, 1].')
1780def streaming_recall_at_k(predictions,
1781                          labels,
1782                          k,
1783                          weights=None,
1784                          metrics_collections=None,
1785                          updates_collections=None,
1786                          name=None):
1787  """Computes the recall@k of the predictions with respect to dense labels.
1788
1789  The `streaming_recall_at_k` function creates two local variables, `total` and
1790  `count`, that are used to compute the recall@k frequency. This frequency is
1791  ultimately returned as `recall_at_<k>`: an idempotent operation that simply
1792  divides `total` by `count`.
1793
1794  For estimation of the metric over a stream of data, the function creates an
1795  `update_op` operation that updates these variables and returns the
1796  `recall_at_<k>`. Internally, an `in_top_k` operation computes a `Tensor` with
1797  shape [batch_size] whose elements indicate whether or not the corresponding
1798  label is in the top `k` `predictions`. Then `update_op` increments `total`
1799  with the reduced sum of `weights` where `in_top_k` is `True`, and it
1800  increments `count` with the reduced sum of `weights`.
1801
1802  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1803
1804  Args:
1805    predictions: A float `Tensor` of dimension [batch_size, num_classes].
1806    labels: A `Tensor` of dimension [batch_size] whose type is in `int32`,
1807      `int64`.
1808    k: The number of top elements to look at for computing recall.
1809    weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
1810      must be broadcastable to `labels` (i.e., all dimensions must be either
1811      `1`, or the same as the corresponding `labels` dimension).
1812    metrics_collections: An optional list of collections that `recall_at_k`
1813      should be added to.
1814    updates_collections: An optional list of collections `update_op` should be
1815      added to.
1816    name: An optional variable_scope name.
1817
1818  Returns:
1819    recall_at_k: A `Tensor` representing the recall@k, the fraction of labels
1820      which fall into the top `k` predictions.
1821    update_op: An operation that increments the `total` and `count` variables
1822      appropriately and whose value matches `recall_at_k`.
1823
1824  Raises:
1825    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1826      `weights` is not `None` and its shape doesn't match `predictions`, or if
1827      either `metrics_collections` or `updates_collections` are not a list or
1828      tuple.
1829  """
1830  in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
1831  return streaming_mean(in_top_k, weights, metrics_collections,
1832                        updates_collections, name or _at_k_name('recall', k))
1833
1834
1835# TODO(ptucker): Validate range of values in labels?
1836def streaming_sparse_recall_at_k(predictions,
1837                                 labels,
1838                                 k,
1839                                 class_id=None,
1840                                 weights=None,
1841                                 metrics_collections=None,
1842                                 updates_collections=None,
1843                                 name=None):
1844  """Computes recall@k of the predictions with respect to sparse labels.
1845
1846  If `class_id` is not specified, we'll calculate recall as the ratio of true
1847      positives (i.e., correct predictions, items in the top `k` highest
1848      `predictions` that are found in the corresponding row in `labels`) to
1849      actual positives (the full `labels` row).
1850  If `class_id` is specified, we calculate recall by considering only the rows
1851      in the batch for which `class_id` is in `labels`, and computing the
1852      fraction of them for which `class_id` is in the corresponding row in
1853      `labels`.
1854
1855  `streaming_sparse_recall_at_k` creates two local variables,
1856  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
1857  the recall_at_k frequency. This frequency is ultimately returned as
1858  `recall_at_<k>`: an idempotent operation that simply divides
1859  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1860  `false_negative_at_<k>`).
1861
1862  For estimation of the metric over a stream of data, the function creates an
1863  `update_op` operation that updates these variables and returns the
1864  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1865  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1866  `labels` calculate the true positives and false negatives weighted by
1867  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1868  `false_negative_at_<k>` using these values.
1869
1870  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1871
1872  Args:
1873    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1874      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1875      The final dimension contains the logit values for each class. [D1, ... DN]
1876      must match `labels`.
1877    labels: `int64` `Tensor` or `SparseTensor` with shape
1878      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1879      target classes for the associated prediction. Commonly, N=1 and `labels`
1880      has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`.
1881      Values should be in range [0, num_classes), where num_classes is the last
1882      dimension of `predictions`. Values outside this range always count
1883      towards `false_negative_at_<k>`.
1884    k: Integer, k for @k metric.
1885    class_id: Integer class ID for which we want binary metrics. This should be
1886      in range [0, num_classes), where num_classes is the last dimension of
1887      `predictions`. If class_id is outside this range, the method returns NAN.
1888    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1889      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1890      dimensions must be either `1`, or the same as the corresponding `labels`
1891      dimension).
1892    metrics_collections: An optional list of collections that values should
1893      be added to.
1894    updates_collections: An optional list of collections that updates should
1895      be added to.
1896    name: Name of new update operation, and namespace for other dependent ops.
1897
1898  Returns:
1899    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
1900      by the sum of `true_positives` and `false_negatives`.
1901    update_op: `Operation` that increments `true_positives` and
1902      `false_negatives` variables appropriately, and whose value matches
1903      `recall`.
1904
1905  Raises:
1906    ValueError: If `weights` is not `None` and its shape doesn't match
1907    `predictions`, or if either `metrics_collections` or `updates_collections`
1908    are not a list or tuple.
1909  """
1910  return metrics.recall_at_k(
1911      k=k,
1912      class_id=class_id,
1913      predictions=predictions,
1914      labels=labels,
1915      weights=weights,
1916      metrics_collections=metrics_collections,
1917      updates_collections=updates_collections,
1918      name=name)
1919
1920
1921# TODO(ptucker): Validate range of values in labels?
1922def streaming_sparse_precision_at_k(predictions,
1923                                    labels,
1924                                    k,
1925                                    class_id=None,
1926                                    weights=None,
1927                                    metrics_collections=None,
1928                                    updates_collections=None,
1929                                    name=None):
1930  """Computes precision@k of the predictions with respect to sparse labels.
1931
1932  If `class_id` is not specified, we calculate precision as the ratio of true
1933      positives (i.e., correct predictions, items in the top `k` highest
1934      `predictions` that are found in the corresponding row in `labels`) to
1935      positives (all top `k` `predictions`).
1936  If `class_id` is specified, we calculate precision by considering only the
1937      rows in the batch for which `class_id` is in the top `k` highest
1938      `predictions`, and computing the fraction of them for which `class_id` is
1939      in the corresponding row in `labels`.
1940
1941  We expect precision to decrease as `k` increases.
1942
1943  `streaming_sparse_precision_at_k` creates two local variables,
1944  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
1945  the precision@k frequency. This frequency is ultimately returned as
1946  `precision_at_<k>`: an idempotent operation that simply divides
1947  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
1948  `false_positive_at_<k>`).
1949
1950  For estimation of the metric over a stream of data, the function creates an
1951  `update_op` operation that updates these variables and returns the
1952  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
1953  indicating the top `k` `predictions`. Set operations applied to `top_k` and
1954  `labels` calculate the true positives and false positives weighted by
1955  `weights`. Then `update_op` increments `true_positive_at_<k>` and
1956  `false_positive_at_<k>` using these values.
1957
1958  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1959
1960  Args:
1961    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
1962      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
1963      The final dimension contains the logit values for each class. [D1, ... DN]
1964      must match `labels`.
1965    labels: `int64` `Tensor` or `SparseTensor` with shape
1966      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
1967      target classes for the associated prediction. Commonly, N=1 and `labels`
1968      has shape [batch_size, num_labels]. [D1, ... DN] must match
1969      `predictions`. Values should be in range [0, num_classes), where
1970      num_classes is the last dimension of `predictions`. Values outside this
1971      range are ignored.
1972    k: Integer, k for @k metric.
1973    class_id: Integer class ID for which we want binary metrics. This should be
1974      in range [0, num_classes], where num_classes is the last dimension of
1975      `predictions`. If `class_id` is outside this range, the method returns
1976      NAN.
1977    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
1978      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
1979      dimensions must be either `1`, or the same as the corresponding `labels`
1980      dimension).
1981    metrics_collections: An optional list of collections that values should
1982      be added to.
1983    updates_collections: An optional list of collections that updates should
1984      be added to.
1985    name: Name of new update operation, and namespace for other dependent ops.
1986
1987  Returns:
1988    precision: Scalar `float64` `Tensor` with the value of `true_positives`
1989      divided by the sum of `true_positives` and `false_positives`.
1990    update_op: `Operation` that increments `true_positives` and
1991      `false_positives` variables appropriately, and whose value matches
1992      `precision`.
1993
1994  Raises:
1995    ValueError: If `weights` is not `None` and its shape doesn't match
1996      `predictions`, or if either `metrics_collections` or `updates_collections`
1997      are not a list or tuple.
1998  """
1999  return metrics.sparse_precision_at_k(
2000      k=k,
2001      class_id=class_id,
2002      predictions=predictions,
2003      labels=labels,
2004      weights=weights,
2005      metrics_collections=metrics_collections,
2006      updates_collections=updates_collections,
2007      name=name)
2008
2009
2010# TODO(ptucker): Validate range of values in labels?
2011def streaming_sparse_precision_at_top_k(top_k_predictions,
2012                                        labels,
2013                                        class_id=None,
2014                                        weights=None,
2015                                        metrics_collections=None,
2016                                        updates_collections=None,
2017                                        name=None):
2018  """Computes precision@k of top-k predictions with respect to sparse labels.
2019
2020  If `class_id` is not specified, we calculate precision as the ratio of
2021      true positives (i.e., correct predictions, items in `top_k_predictions`
2022      that are found in the corresponding row in `labels`) to positives (all
2023      `top_k_predictions`).
2024  If `class_id` is specified, we calculate precision by considering only the
2025      rows in the batch for which `class_id` is in the top `k` highest
2026      `predictions`, and computing the fraction of them for which `class_id` is
2027      in the corresponding row in `labels`.
2028
2029  We expect precision to decrease as `k` increases.
2030
2031  `streaming_sparse_precision_at_top_k` creates two local variables,
2032  `true_positive_at_k` and `false_positive_at_k`, that are used to compute
2033  the precision@k frequency. This frequency is ultimately returned as
2034  `precision_at_k`: an idempotent operation that simply divides
2035  `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
2036
2037  For estimation of the metric over a stream of data, the function creates an
2038  `update_op` operation that updates these variables and returns the
2039  `precision_at_k`. Internally, set operations applied to `top_k_predictions`
2040  and `labels` calculate the true positives and false positives weighted by
2041  `weights`. Then `update_op` increments `true_positive_at_k` and
2042  `false_positive_at_k` using these values.
2043
2044  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2045
2046  Args:
2047    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
2048      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
2049      The final dimension contains the indices of top-k labels. [D1, ... DN]
2050      must match `labels`.
2051    labels: `int64` `Tensor` or `SparseTensor` with shape
2052      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2053      target classes for the associated prediction. Commonly, N=1 and `labels`
2054      has shape [batch_size, num_labels]. [D1, ... DN] must match
2055      `top_k_predictions`. Values should be in range [0, num_classes), where
2056      num_classes is the last dimension of `predictions`. Values outside this
2057      range are ignored.
2058    class_id: Integer class ID for which we want binary metrics. This should be
2059      in range [0, num_classes), where num_classes is the last dimension of
2060      `predictions`. If `class_id` is outside this range, the method returns
2061      NAN.
2062    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2063      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2064      dimensions must be either `1`, or the same as the corresponding `labels`
2065      dimension).
2066    metrics_collections: An optional list of collections that values should
2067      be added to.
2068    updates_collections: An optional list of collections that updates should
2069      be added to.
2070    name: Name of new update operation, and namespace for other dependent ops.
2071
2072  Returns:
2073    precision: Scalar `float64` `Tensor` with the value of `true_positives`
2074      divided by the sum of `true_positives` and `false_positives`.
2075    update_op: `Operation` that increments `true_positives` and
2076      `false_positives` variables appropriately, and whose value matches
2077      `precision`.
2078
2079  Raises:
2080    ValueError: If `weights` is not `None` and its shape doesn't match
2081      `predictions`, or if either `metrics_collections` or `updates_collections`
2082      are not a list or tuple.
2083    ValueError: If `top_k_predictions` has rank < 2.
2084  """
2085  default_name = _at_k_name('precision', class_id=class_id)
2086  with ops.name_scope(name, default_name,
2087                      (top_k_predictions, labels, weights)) as name_scope:
2088    return metrics_impl._sparse_precision_at_top_k(  # pylint: disable=protected-access
2089        labels=labels,
2090        predictions_idx=top_k_predictions,
2091        class_id=class_id,
2092        weights=weights,
2093        metrics_collections=metrics_collections,
2094        updates_collections=updates_collections,
2095        name=name_scope)
2096
2097
2098def sparse_recall_at_top_k(labels,
2099                           top_k_predictions,
2100                           class_id=None,
2101                           weights=None,
2102                           metrics_collections=None,
2103                           updates_collections=None,
2104                           name=None):
2105  """Computes recall@k of top-k predictions with respect to sparse labels.
2106
2107  If `class_id` is specified, we calculate recall by considering only the
2108      entries in the batch for which `class_id` is in the label, and computing
2109      the fraction of them for which `class_id` is in the top-k `predictions`.
2110  If `class_id` is not specified, we'll calculate recall as how often on
2111      average a class among the labels of a batch entry is in the top-k
2112      `predictions`.
2113
2114  `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
2115  and `false_negative_at_<k>`, that are used to compute the recall_at_k
2116  frequency. This frequency is ultimately returned as `recall_at_<k>`: an
2117  idempotent operation that simply divides `true_positive_at_<k>` by total
2118  (`true_positive_at_<k>` + `false_negative_at_<k>`).
2119
2120  For estimation of the metric over a stream of data, the function creates an
2121  `update_op` operation that updates these variables and returns the
2122  `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
2123  true positives and false negatives weighted by `weights`. Then `update_op`
2124  increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
2125  values.
2126
2127  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2128
2129  Args:
2130    labels: `int64` `Tensor` or `SparseTensor` with shape
2131      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2132      target classes for the associated prediction. Commonly, N=1 and `labels`
2133      has shape [batch_size, num_labels]. [D1, ... DN] must match
2134      `top_k_predictions`. Values should be in range [0, num_classes), where
2135      num_classes is the last dimension of `predictions`. Values outside this
2136      range always count towards `false_negative_at_<k>`.
2137    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
2138      N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
2139      The final dimension contains the indices of top-k labels. [D1, ... DN]
2140      must match `labels`.
2141    class_id: Integer class ID for which we want binary metrics. This should be
2142      in range [0, num_classes), where num_classes is the last dimension of
2143      `predictions`. If class_id is outside this range, the method returns NAN.
2144    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2145      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2146      dimensions must be either `1`, or the same as the corresponding `labels`
2147      dimension).
2148    metrics_collections: An optional list of collections that values should
2149      be added to.
2150    updates_collections: An optional list of collections that updates should
2151      be added to.
2152    name: Name of new update operation, and namespace for other dependent ops.
2153
2154  Returns:
2155    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2156      by the sum of `true_positives` and `false_negatives`.
2157    update_op: `Operation` that increments `true_positives` and
2158      `false_negatives` variables appropriately, and whose value matches
2159      `recall`.
2160
2161  Raises:
2162    ValueError: If `weights` is not `None` and its shape doesn't match
2163    `predictions`, or if either `metrics_collections` or `updates_collections`
2164    are not a list or tuple.
2165  """
2166  default_name = _at_k_name('recall', class_id=class_id)
2167  with ops.name_scope(name, default_name,
2168                      (top_k_predictions, labels, weights)) as name_scope:
2169    return metrics_impl._sparse_recall_at_top_k(  # pylint: disable=protected-access
2170        labels=labels,
2171        predictions_idx=top_k_predictions,
2172        class_id=class_id,
2173        weights=weights,
2174        metrics_collections=metrics_collections,
2175        updates_collections=updates_collections,
2176        name=name_scope)
2177
2178
2179def streaming_sparse_average_precision_at_k(predictions,
2180                                            labels,
2181                                            k,
2182                                            weights=None,
2183                                            metrics_collections=None,
2184                                            updates_collections=None,
2185                                            name=None):
2186  """Computes average precision@k of predictions with respect to sparse labels.
2187
2188  See `sparse_average_precision_at_k` for details on formula. `weights` are
2189  applied to the result of `sparse_average_precision_at_k`
2190
2191  `streaming_sparse_average_precision_at_k` creates two local variables,
2192  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
2193  are used to compute the frequency. This frequency is ultimately returned as
2194  `average_precision_at_<k>`: an idempotent operation that simply divides
2195  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
2196
2197  For estimation of the metric over a stream of data, the function creates an
2198  `update_op` operation that updates these variables and returns the
2199  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2200  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2201  `labels` calculate the true positives and false positives weighted by
2202  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2203  `false_positive_at_<k>` using these values.
2204
2205  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2206
2207  Args:
2208    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2209      N >= 1. Commonly, N=1 and `predictions` has shape
2210      [batch size, num_classes]. The final dimension contains the logit values
2211      for each class. [D1, ... DN] must match `labels`.
2212    labels: `int64` `Tensor` or `SparseTensor` with shape
2213      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2214      target classes for the associated prediction. Commonly, N=1 and `labels`
2215      has shape [batch_size, num_labels]. [D1, ... DN] must match
2216      `predictions_`. Values should be in range [0, num_classes), where
2217      num_classes is the last dimension of `predictions`. Values outside this
2218      range are ignored.
2219    k: Integer, k for @k metric. This will calculate an average precision for
2220      range `[1,k]`, as documented above.
2221    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2222      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2223      dimensions must be either `1`, or the same as the corresponding `labels`
2224      dimension).
2225    metrics_collections: An optional list of collections that values should
2226      be added to.
2227    updates_collections: An optional list of collections that updates should
2228      be added to.
2229    name: Name of new update operation, and namespace for other dependent ops.
2230
2231  Returns:
2232    mean_average_precision: Scalar `float64` `Tensor` with the mean average
2233      precision values.
2234    update: `Operation` that increments variables appropriately, and whose
2235      value matches `metric`.
2236  """
2237  return metrics.sparse_average_precision_at_k(
2238      k=k,
2239      predictions=predictions,
2240      labels=labels,
2241      weights=weights,
2242      metrics_collections=metrics_collections,
2243      updates_collections=updates_collections,
2244      name=name)
2245
2246
2247def streaming_sparse_average_precision_at_top_k(top_k_predictions,
2248                                                labels,
2249                                                weights=None,
2250                                                metrics_collections=None,
2251                                                updates_collections=None,
2252                                                name=None):
2253  """Computes average precision@k of predictions with respect to sparse labels.
2254
2255  `streaming_sparse_average_precision_at_top_k` creates two local variables,
2256  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
2257  are used to compute the frequency. This frequency is ultimately returned as
2258  `average_precision_at_<k>`: an idempotent operation that simply divides
2259  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
2260
2261  For estimation of the metric over a stream of data, the function creates an
2262  `update_op` operation that updates these variables and returns the
2263  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
2264  the true positives and false positives weighted by `weights`. Then `update_op`
2265  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
2266  values.
2267
2268  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2269
2270  Args:
2271    top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2272      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
2273      dimension must be set and contains the top `k` predicted class indices.
2274      [D1, ... DN] must match `labels`. Values should be in range
2275      [0, num_classes).
2276    labels: `int64` `Tensor` or `SparseTensor` with shape
2277      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2278      num_labels=1. N >= 1 and num_labels is the number of target classes for
2279      the associated prediction. Commonly, N=1 and `labels` has shape
2280      [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`.
2281      Values should be in range [0, num_classes).
2282    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2283      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2284      dimensions must be either `1`, or the same as the corresponding `labels`
2285      dimension).
2286    metrics_collections: An optional list of collections that values should
2287      be added to.
2288    updates_collections: An optional list of collections that updates should
2289      be added to.
2290    name: Name of new update operation, and namespace for other dependent ops.
2291
2292  Returns:
2293    mean_average_precision: Scalar `float64` `Tensor` with the mean average
2294      precision values.
2295    update: `Operation` that increments variables appropriately, and whose
2296      value matches `metric`.
2297
2298  Raises:
2299    ValueError: if the last dimension of top_k_predictions is not set.
2300  """
2301  return metrics_impl._streaming_sparse_average_precision_at_top_k(  # pylint: disable=protected-access
2302      predictions_idx=top_k_predictions,
2303      labels=labels,
2304      weights=weights,
2305      metrics_collections=metrics_collections,
2306      updates_collections=updates_collections,
2307      name=name)
2308
2309
2310def streaming_mean_absolute_error(predictions,
2311                                  labels,
2312                                  weights=None,
2313                                  metrics_collections=None,
2314                                  updates_collections=None,
2315                                  name=None):
2316  """Computes the mean absolute error between the labels and predictions.
2317
2318  The `streaming_mean_absolute_error` function creates two local variables,
2319  `total` and `count` that are used to compute the mean absolute error. This
2320  average is weighted by `weights`, and it is ultimately returned as
2321  `mean_absolute_error`: an idempotent operation that simply divides `total` by
2322  `count`.
2323
2324  For estimation of the metric over a stream of data, the function creates an
2325  `update_op` operation that updates these variables and returns the
2326  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
2327  absolute value of the differences between `predictions` and `labels`. Then
2328  `update_op` increments `total` with the reduced sum of the product of
2329  `weights` and `absolute_errors`, and it increments `count` with the reduced
2330  sum of `weights`
2331
2332  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2333
2334  Args:
2335    predictions: A `Tensor` of arbitrary shape.
2336    labels: A `Tensor` of the same shape as `predictions`.
2337    weights: Optional `Tensor` indicating the frequency with which an example is
2338      sampled. Rank must be 0, or the same rank as `labels`, and must be
2339      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2340      the same as the corresponding `labels` dimension).
2341    metrics_collections: An optional list of collections that
2342      `mean_absolute_error` should be added to.
2343    updates_collections: An optional list of collections that `update_op` should
2344      be added to.
2345    name: An optional variable_scope name.
2346
2347  Returns:
2348    mean_absolute_error: A `Tensor` representing the current mean, the value of
2349      `total` divided by `count`.
2350    update_op: An operation that increments the `total` and `count` variables
2351      appropriately and whose value matches `mean_absolute_error`.
2352
2353  Raises:
2354    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2355      `weights` is not `None` and its shape doesn't match `predictions`, or if
2356      either `metrics_collections` or `updates_collections` are not a list or
2357      tuple.
2358  """
2359  return metrics.mean_absolute_error(
2360      predictions=predictions,
2361      labels=labels,
2362      weights=weights,
2363      metrics_collections=metrics_collections,
2364      updates_collections=updates_collections,
2365      name=name)
2366
2367
2368def streaming_mean_relative_error(predictions,
2369                                  labels,
2370                                  normalizer,
2371                                  weights=None,
2372                                  metrics_collections=None,
2373                                  updates_collections=None,
2374                                  name=None):
2375  """Computes the mean relative error by normalizing with the given values.
2376
2377  The `streaming_mean_relative_error` function creates two local variables,
2378  `total` and `count` that are used to compute the mean relative absolute error.
2379  This average is weighted by `weights`, and it is ultimately returned as
2380  `mean_relative_error`: an idempotent operation that simply divides `total` by
2381  `count`.
2382
2383  For estimation of the metric over a stream of data, the function creates an
2384  `update_op` operation that updates these variables and returns the
2385  `mean_reative_error`. Internally, a `relative_errors` operation divides the
2386  absolute value of the differences between `predictions` and `labels` by the
2387  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
2388  product of `weights` and `relative_errors`, and it increments `count` with the
2389  reduced sum of `weights`.
2390
2391  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2392
2393  Args:
2394    predictions: A `Tensor` of arbitrary shape.
2395    labels: A `Tensor` of the same shape as `predictions`.
2396    normalizer: A `Tensor` of the same shape as `predictions`.
2397    weights: Optional `Tensor` indicating the frequency with which an example is
2398      sampled. Rank must be 0, or the same rank as `labels`, and must be
2399      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2400      the same as the corresponding `labels` dimension).
2401    metrics_collections: An optional list of collections that
2402      `mean_relative_error` should be added to.
2403    updates_collections: An optional list of collections that `update_op` should
2404      be added to.
2405    name: An optional variable_scope name.
2406
2407  Returns:
2408    mean_relative_error: A `Tensor` representing the current mean, the value of
2409      `total` divided by `count`.
2410    update_op: An operation that increments the `total` and `count` variables
2411      appropriately and whose value matches `mean_relative_error`.
2412
2413  Raises:
2414    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2415      `weights` is not `None` and its shape doesn't match `predictions`, or if
2416      either `metrics_collections` or `updates_collections` are not a list or
2417      tuple.
2418  """
2419  return metrics.mean_relative_error(
2420      normalizer=normalizer,
2421      predictions=predictions,
2422      labels=labels,
2423      weights=weights,
2424      metrics_collections=metrics_collections,
2425      updates_collections=updates_collections,
2426      name=name)
2427
2428
2429def streaming_mean_squared_error(predictions,
2430                                 labels,
2431                                 weights=None,
2432                                 metrics_collections=None,
2433                                 updates_collections=None,
2434                                 name=None):
2435  """Computes the mean squared error between the labels and predictions.
2436
2437  The `streaming_mean_squared_error` function creates two local variables,
2438  `total` and `count` that are used to compute the mean squared error.
2439  This average is weighted by `weights`, and it is ultimately returned as
2440  `mean_squared_error`: an idempotent operation that simply divides `total` by
2441  `count`.
2442
2443  For estimation of the metric over a stream of data, the function creates an
2444  `update_op` operation that updates these variables and returns the
2445  `mean_squared_error`. Internally, a `squared_error` operation computes the
2446  element-wise square of the difference between `predictions` and `labels`. Then
2447  `update_op` increments `total` with the reduced sum of the product of
2448  `weights` and `squared_error`, and it increments `count` with the reduced sum
2449  of `weights`.
2450
2451  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2452
2453  Args:
2454    predictions: A `Tensor` of arbitrary shape.
2455    labels: A `Tensor` of the same shape as `predictions`.
2456    weights: Optional `Tensor` indicating the frequency with which an example is
2457      sampled. Rank must be 0, or the same rank as `labels`, and must be
2458      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2459      the same as the corresponding `labels` dimension).
2460    metrics_collections: An optional list of collections that
2461      `mean_squared_error` should be added to.
2462    updates_collections: An optional list of collections that `update_op` should
2463      be added to.
2464    name: An optional variable_scope name.
2465
2466  Returns:
2467    mean_squared_error: A `Tensor` representing the current mean, the value of
2468      `total` divided by `count`.
2469    update_op: An operation that increments the `total` and `count` variables
2470      appropriately and whose value matches `mean_squared_error`.
2471
2472  Raises:
2473    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2474      `weights` is not `None` and its shape doesn't match `predictions`, or if
2475      either `metrics_collections` or `updates_collections` are not a list or
2476      tuple.
2477  """
2478  return metrics.mean_squared_error(
2479      predictions=predictions,
2480      labels=labels,
2481      weights=weights,
2482      metrics_collections=metrics_collections,
2483      updates_collections=updates_collections,
2484      name=name)
2485
2486
2487def streaming_root_mean_squared_error(predictions,
2488                                      labels,
2489                                      weights=None,
2490                                      metrics_collections=None,
2491                                      updates_collections=None,
2492                                      name=None):
2493  """Computes the root mean squared error between the labels and predictions.
2494
2495  The `streaming_root_mean_squared_error` function creates two local variables,
2496  `total` and `count` that are used to compute the root mean squared error.
2497  This average is weighted by `weights`, and it is ultimately returned as
2498  `root_mean_squared_error`: an idempotent operation that takes the square root
2499  of the division of `total` by `count`.
2500
2501  For estimation of the metric over a stream of data, the function creates an
2502  `update_op` operation that updates these variables and returns the
2503  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2504  the element-wise square of the difference between `predictions` and `labels`.
2505  Then `update_op` increments `total` with the reduced sum of the product of
2506  `weights` and `squared_error`, and it increments `count` with the reduced sum
2507  of `weights`.
2508
2509  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2510
2511  Args:
2512    predictions: A `Tensor` of arbitrary shape.
2513    labels: A `Tensor` of the same shape as `predictions`.
2514    weights: Optional `Tensor` indicating the frequency with which an example is
2515      sampled. Rank must be 0, or the same rank as `labels`, and must be
2516      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2517      the same as the corresponding `labels` dimension).
2518    metrics_collections: An optional list of collections that
2519      `root_mean_squared_error` should be added to.
2520    updates_collections: An optional list of collections that `update_op` should
2521      be added to.
2522    name: An optional variable_scope name.
2523
2524  Returns:
2525    root_mean_squared_error: A `Tensor` representing the current mean, the value
2526      of `total` divided by `count`.
2527    update_op: An operation that increments the `total` and `count` variables
2528      appropriately and whose value matches `root_mean_squared_error`.
2529
2530  Raises:
2531    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2532      `weights` is not `None` and its shape doesn't match `predictions`, or if
2533      either `metrics_collections` or `updates_collections` are not a list or
2534      tuple.
2535  """
2536  return metrics.root_mean_squared_error(
2537      predictions=predictions,
2538      labels=labels,
2539      weights=weights,
2540      metrics_collections=metrics_collections,
2541      updates_collections=updates_collections,
2542      name=name)
2543
2544
2545def streaming_covariance(predictions,
2546                         labels,
2547                         weights=None,
2548                         metrics_collections=None,
2549                         updates_collections=None,
2550                         name=None):
2551  """Computes the unbiased sample covariance between `predictions` and `labels`.
2552
2553  The `streaming_covariance` function creates four local variables,
2554  `comoment`, `mean_prediction`, `mean_label`, and `count`, which are used to
2555  compute the sample covariance between predictions and labels across multiple
2556  batches of data. The covariance is ultimately returned as an idempotent
2557  operation that simply divides `comoment` by `count` - 1. We use `count` - 1
2558  in order to get an unbiased estimate.
2559
2560  The algorithm used for this online computation is described in
2561  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
2562  Specifically, the formula used to combine two sample comoments is
2563  `C_AB = C_A + C_B + (E[x_A] - E[x_B]) * (E[y_A] - E[y_B]) * n_A * n_B / n_AB`
2564  The comoment for a single batch of data is simply
2565  `sum((x - E[x]) * (y - E[y]))`, optionally weighted.
2566
2567  If `weights` is not None, then it is used to compute weighted comoments,
2568  means, and count. NOTE: these weights are treated as "frequency weights", as
2569  opposed to "reliability weights". See discussion of the difference on
2570  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2571
2572  To facilitate the computation of covariance across multiple batches of data,
2573  the function creates an `update_op` operation, which updates underlying
2574  variables and returns the updated covariance.
2575
2576  Args:
2577    predictions: A `Tensor` of arbitrary size.
2578    labels: A `Tensor` of the same size as `predictions`.
2579    weights: Optional `Tensor` indicating the frequency with which an example is
2580      sampled. Rank must be 0, or the same rank as `labels`, and must be
2581      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2582      the same as the corresponding `labels` dimension).
2583    metrics_collections: An optional list of collections that the metric
2584      value variable should be added to.
2585    updates_collections: An optional list of collections that the metric update
2586      ops should be added to.
2587    name: An optional variable_scope name.
2588
2589  Returns:
2590    covariance: A `Tensor` representing the current unbiased sample covariance,
2591      `comoment` / (`count` - 1).
2592    update_op: An operation that updates the local variables appropriately.
2593
2594  Raises:
2595    ValueError: If labels and predictions are of different sizes or if either
2596      `metrics_collections` or `updates_collections` are not a list or tuple.
2597  """
2598  with variable_scope.variable_scope(name, 'covariance',
2599                                     (predictions, labels, weights)):
2600    predictions, labels, weights = _remove_squeezable_dimensions(
2601        predictions, labels, weights)
2602    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2603    count = _create_local('count', [])
2604    mean_prediction = _create_local('mean_prediction', [])
2605    mean_label = _create_local('mean_label', [])
2606    comoment = _create_local('comoment', [])  # C_A in update equation
2607
2608    if weights is None:
2609      batch_count = math_ops.to_float(array_ops.size(labels))  # n_B in eqn
2610      weighted_predictions = predictions
2611      weighted_labels = labels
2612    else:
2613      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
2614      batch_count = math_ops.reduce_sum(weights)  # n_B in eqn
2615      weighted_predictions = math_ops.multiply(predictions, weights)
2616      weighted_labels = math_ops.multiply(labels, weights)
2617
2618    update_count = state_ops.assign_add(count, batch_count)  # n_AB in eqn
2619    prev_count = update_count - batch_count  # n_A in update equation
2620
2621    # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
2622    # batch_mean_prediction is E[x_B] in the update equation
2623    batch_mean_prediction = _safe_div(
2624        math_ops.reduce_sum(weighted_predictions), batch_count,
2625        'batch_mean_prediction')
2626    delta_mean_prediction = _safe_div(
2627        (batch_mean_prediction - mean_prediction) * batch_count, update_count,
2628        'delta_mean_prediction')
2629    update_mean_prediction = state_ops.assign_add(mean_prediction,
2630                                                  delta_mean_prediction)
2631    # prev_mean_prediction is E[x_A] in the update equation
2632    prev_mean_prediction = update_mean_prediction - delta_mean_prediction
2633
2634    # batch_mean_label is E[y_B] in the update equation
2635    batch_mean_label = _safe_div(
2636        math_ops.reduce_sum(weighted_labels), batch_count, 'batch_mean_label')
2637    delta_mean_label = _safe_div((batch_mean_label - mean_label) * batch_count,
2638                                 update_count, 'delta_mean_label')
2639    update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
2640    # prev_mean_label is E[y_A] in the update equation
2641    prev_mean_label = update_mean_label - delta_mean_label
2642
2643    unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) *
2644                                    (labels - batch_mean_label))
2645    # batch_comoment is C_B in the update equation
2646    if weights is None:
2647      batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals)
2648    else:
2649      batch_comoment = math_ops.reduce_sum(
2650          unweighted_batch_coresiduals * weights)
2651
2652    # View delta_comoment as = C_AB - C_A in the update equation above.
2653    # Since C_A is stored in a var, by how much do we need to increment that var
2654    # to make the var = C_AB?
2655    delta_comoment = (
2656        batch_comoment + (prev_mean_prediction - batch_mean_prediction) *
2657        (prev_mean_label - batch_mean_label) *
2658        (prev_count * batch_count / update_count))
2659    update_comoment = state_ops.assign_add(comoment, delta_comoment)
2660
2661    covariance = array_ops.where(
2662        math_ops.less_equal(count, 1.),
2663        float('nan'),
2664        math_ops.truediv(comoment, count - 1),
2665        name='covariance')
2666    with ops.control_dependencies([update_comoment]):
2667      update_op = array_ops.where(
2668          math_ops.less_equal(count, 1.),
2669          float('nan'),
2670          math_ops.truediv(comoment, count - 1),
2671          name='update_op')
2672
2673  if metrics_collections:
2674    ops.add_to_collections(metrics_collections, covariance)
2675
2676  if updates_collections:
2677    ops.add_to_collections(updates_collections, update_op)
2678
2679  return covariance, update_op
2680
2681
2682def streaming_pearson_correlation(predictions,
2683                                  labels,
2684                                  weights=None,
2685                                  metrics_collections=None,
2686                                  updates_collections=None,
2687                                  name=None):
2688  """Computes Pearson correlation coefficient between `predictions`, `labels`.
2689
2690  The `streaming_pearson_correlation` function delegates to
2691  `streaming_covariance` the tracking of three [co]variances:
2692
2693  - `streaming_covariance(predictions, labels)`, i.e. covariance
2694  - `streaming_covariance(predictions, predictions)`, i.e. variance
2695  - `streaming_covariance(labels, labels)`, i.e. variance
2696
2697  The product-moment correlation ultimately returned is an idempotent operation
2698  `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`. To
2699  facilitate correlation computation across multiple batches, the function
2700  groups the `update_op`s of the underlying streaming_covariance and returns an
2701  `update_op`.
2702
2703  If `weights` is not None, then it is used to compute a weighted correlation.
2704  NOTE: these weights are treated as "frequency weights", as opposed to
2705  "reliability weights". See discussion of the difference on
2706  https://wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_variance
2707
2708  Args:
2709    predictions: A `Tensor` of arbitrary size.
2710    labels: A `Tensor` of the same size as predictions.
2711    weights: Optional `Tensor` indicating the frequency with which an example is
2712      sampled. Rank must be 0, or the same rank as `labels`, and must be
2713      broadcastable to `labels` (i.e., all dimensions must be either `1`, or
2714      the same as the corresponding `labels` dimension).
2715    metrics_collections: An optional list of collections that the metric
2716      value variable should be added to.
2717    updates_collections: An optional list of collections that the metric update
2718      ops should be added to.
2719    name: An optional variable_scope name.
2720
2721  Returns:
2722    pearson_r: A `Tensor` representing the current Pearson product-moment
2723      correlation coefficient, the value of
2724      `cov(predictions, labels) / sqrt(var(predictions) * var(labels))`.
2725    update_op: An operation that updates the underlying variables appropriately.
2726
2727  Raises:
2728    ValueError: If `labels` and `predictions` are of different sizes, or if
2729      `weights` is the wrong size, or if either `metrics_collections` or
2730      `updates_collections` are not a `list` or `tuple`.
2731  """
2732  with variable_scope.variable_scope(name, 'pearson_r',
2733                                     (predictions, labels, weights)):
2734    predictions, labels, weights = _remove_squeezable_dimensions(
2735        predictions, labels, weights)
2736    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2737    # Broadcast weights here to avoid duplicate broadcasting in each call to
2738    # `streaming_covariance`.
2739    if weights is not None:
2740      weights = weights_broadcast_ops.broadcast_weights(weights, labels)
2741    cov, update_cov = streaming_covariance(
2742        predictions, labels, weights=weights, name='covariance')
2743    var_predictions, update_var_predictions = streaming_covariance(
2744        predictions, predictions, weights=weights, name='variance_predictions')
2745    var_labels, update_var_labels = streaming_covariance(
2746        labels, labels, weights=weights, name='variance_labels')
2747
2748    pearson_r = math_ops.truediv(
2749        cov,
2750        math_ops.multiply(
2751            math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)),
2752        name='pearson_r')
2753    update_op = math_ops.truediv(
2754        update_cov,
2755        math_ops.multiply(
2756            math_ops.sqrt(update_var_predictions),
2757            math_ops.sqrt(update_var_labels)),
2758        name='update_op')
2759
2760  if metrics_collections:
2761    ops.add_to_collections(metrics_collections, pearson_r)
2762
2763  if updates_collections:
2764    ops.add_to_collections(updates_collections, update_op)
2765
2766  return pearson_r, update_op
2767
2768
2769# TODO(nsilberman): add a 'normalized' flag so that the user can request
2770# normalization if the inputs are not normalized.
2771def streaming_mean_cosine_distance(predictions,
2772                                   labels,
2773                                   dim,
2774                                   weights=None,
2775                                   metrics_collections=None,
2776                                   updates_collections=None,
2777                                   name=None):
2778  """Computes the cosine distance between the labels and predictions.
2779
2780  The `streaming_mean_cosine_distance` function creates two local variables,
2781  `total` and `count` that are used to compute the average cosine distance
2782  between `predictions` and `labels`. This average is weighted by `weights`,
2783  and it is ultimately returned as `mean_distance`, which is an idempotent
2784  operation that simply divides `total` by `count`.
2785
2786  For estimation of the metric over a stream of data, the function creates an
2787  `update_op` operation that updates these variables and returns the
2788  `mean_distance`.
2789
2790  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2791
2792  Args:
2793    predictions: A `Tensor` of the same shape as `labels`.
2794    labels: A `Tensor` of arbitrary shape.
2795    dim: The dimension along which the cosine distance is computed.
2796    weights: An optional `Tensor` whose shape is broadcastable to `predictions`,
2797      and whose dimension `dim` is 1.
2798    metrics_collections: An optional list of collections that the metric
2799      value variable should be added to.
2800    updates_collections: An optional list of collections that the metric update
2801      ops should be added to.
2802    name: An optional variable_scope name.
2803
2804  Returns:
2805    mean_distance: A `Tensor` representing the current mean, the value of
2806      `total` divided by `count`.
2807    update_op: An operation that increments the `total` and `count` variables
2808      appropriately.
2809
2810  Raises:
2811    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2812      `weights` is not `None` and its shape doesn't match `predictions`, or if
2813      either `metrics_collections` or `updates_collections` are not a list or
2814      tuple.
2815  """
2816  predictions, labels, weights = _remove_squeezable_dimensions(
2817      predictions, labels, weights)
2818  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
2819  radial_diffs = math_ops.multiply(predictions, labels)
2820  radial_diffs = math_ops.reduce_sum(
2821      radial_diffs, reduction_indices=[
2822          dim,
2823      ], keep_dims=True)
2824  mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None,
2825                                            name or 'mean_cosine_distance')
2826  mean_distance = math_ops.subtract(1.0, mean_distance)
2827  update_op = math_ops.subtract(1.0, update_op)
2828
2829  if metrics_collections:
2830    ops.add_to_collections(metrics_collections, mean_distance)
2831
2832  if updates_collections:
2833    ops.add_to_collections(updates_collections, update_op)
2834
2835  return mean_distance, update_op
2836
2837
2838def streaming_percentage_less(values,
2839                              threshold,
2840                              weights=None,
2841                              metrics_collections=None,
2842                              updates_collections=None,
2843                              name=None):
2844  """Computes the percentage of values less than the given threshold.
2845
2846  The `streaming_percentage_less` function creates two local variables,
2847  `total` and `count` that are used to compute the percentage of `values` that
2848  fall below `threshold`. This rate is weighted by `weights`, and it is
2849  ultimately returned as `percentage` which is an idempotent operation that
2850  simply divides `total` by `count`.
2851
2852  For estimation of the metric over a stream of data, the function creates an
2853  `update_op` operation that updates these variables and returns the
2854  `percentage`.
2855
2856  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2857
2858  Args:
2859    values: A numeric `Tensor` of arbitrary size.
2860    threshold: A scalar threshold.
2861    weights: An optional `Tensor` whose shape is broadcastable to `values`.
2862    metrics_collections: An optional list of collections that the metric
2863      value variable should be added to.
2864    updates_collections: An optional list of collections that the metric update
2865      ops should be added to.
2866    name: An optional variable_scope name.
2867
2868  Returns:
2869    percentage: A `Tensor` representing the current mean, the value of `total`
2870      divided by `count`.
2871    update_op: An operation that increments the `total` and `count` variables
2872      appropriately.
2873
2874  Raises:
2875    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
2876      or if either `metrics_collections` or `updates_collections` are not a list
2877      or tuple.
2878  """
2879  return metrics.percentage_below(
2880      values=values,
2881      threshold=threshold,
2882      weights=weights,
2883      metrics_collections=metrics_collections,
2884      updates_collections=updates_collections,
2885      name=name)
2886
2887
2888def streaming_mean_iou(predictions,
2889                       labels,
2890                       num_classes,
2891                       weights=None,
2892                       metrics_collections=None,
2893                       updates_collections=None,
2894                       name=None):
2895  """Calculate per-step mean Intersection-Over-Union (mIOU).
2896
2897  Mean Intersection-Over-Union is a common evaluation metric for
2898  semantic image segmentation, which first computes the IOU for each
2899  semantic class and then computes the average over classes.
2900  IOU is defined as follows:
2901    IOU = true_positive / (true_positive + false_positive + false_negative).
2902  The predictions are accumulated in a confusion matrix, weighted by `weights`,
2903  and mIOU is then calculated from it.
2904
2905  For estimation of the metric over a stream of data, the function creates an
2906  `update_op` operation that updates these variables and returns the `mean_iou`.
2907
2908  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2909
2910  Args:
2911    predictions: A `Tensor` of prediction results for semantic labels, whose
2912      shape is [batch size] and type `int32` or `int64`. The tensor will be
2913      flattened, if its rank > 1.
2914    labels: A `Tensor` of ground truth labels with shape [batch size] and of
2915      type `int32` or `int64`. The tensor will be flattened, if its rank > 1.
2916    num_classes: The possible number of labels the prediction task can
2917      have. This value must be provided, since a confusion matrix of
2918      dimension = [num_classes, num_classes] will be allocated.
2919    weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
2920    metrics_collections: An optional list of collections that `mean_iou`
2921      should be added to.
2922    updates_collections: An optional list of collections `update_op` should be
2923      added to.
2924    name: An optional variable_scope name.
2925
2926  Returns:
2927    mean_iou: A `Tensor` representing the mean intersection-over-union.
2928    update_op: An operation that increments the confusion matrix.
2929
2930  Raises:
2931    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2932      `weights` is not `None` and its shape doesn't match `predictions`, or if
2933      either `metrics_collections` or `updates_collections` are not a list or
2934      tuple.
2935  """
2936  return metrics.mean_iou(
2937      num_classes=num_classes,
2938      predictions=predictions,
2939      labels=labels,
2940      weights=weights,
2941      metrics_collections=metrics_collections,
2942      updates_collections=updates_collections,
2943      name=name)
2944
2945
2946def _next_array_size(required_size, growth_factor=1.5):
2947  """Calculate the next size for reallocating a dynamic array.
2948
2949  Args:
2950    required_size: number or tf.Tensor specifying required array capacity.
2951    growth_factor: optional number or tf.Tensor specifying the growth factor
2952      between subsequent allocations.
2953
2954  Returns:
2955    tf.Tensor with dtype=int32 giving the next array size.
2956  """
2957  exponent = math_ops.ceil(
2958      math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log(
2959          math_ops.cast(growth_factor, dtypes.float32)))
2960  return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32)
2961
2962
2963def streaming_concat(values,
2964                     axis=0,
2965                     max_size=None,
2966                     metrics_collections=None,
2967                     updates_collections=None,
2968                     name=None):
2969  """Concatenate values along an axis across batches.
2970
2971  The function `streaming_concat` creates two local variables, `array` and
2972  `size`, that are used to store concatenated values. Internally, `array` is
2973  used as storage for a dynamic array (if `maxsize` is `None`), which ensures
2974  that updates can be run in amortized constant time.
2975
2976  For estimation of the metric over a stream of data, the function creates an
2977  `update_op` operation that appends the values of a tensor and returns the
2978  length of the concatenated axis.
2979
2980  This op allows for evaluating metrics that cannot be updated incrementally
2981  using the same framework as other streaming metrics.
2982
2983  Args:
2984    values: `Tensor` to concatenate. Rank and the shape along all axes other
2985      than the axis to concatenate along must be statically known.
2986    axis: optional integer axis to concatenate along.
2987    max_size: optional integer maximum size of `value` along the given axis.
2988      Once the maximum size is reached, further updates are no-ops. By default,
2989      there is no maximum size: the array is resized as necessary.
2990    metrics_collections: An optional list of collections that `value`
2991      should be added to.
2992    updates_collections: An optional list of collections `update_op` should be
2993      added to.
2994    name: An optional variable_scope name.
2995
2996  Returns:
2997    value: A `Tensor` representing the concatenated values.
2998    update_op: An operation that concatenates the next values.
2999
3000  Raises:
3001    ValueError: if `values` does not have a statically known rank, `axis` is
3002      not in the valid range or the size of `values` is not statically known
3003      along any axis other than `axis`.
3004  """
3005  with variable_scope.variable_scope(name, 'streaming_concat', (values,)):
3006    # pylint: disable=invalid-slice-index
3007    values_shape = values.get_shape()
3008    if values_shape.dims is None:
3009      raise ValueError('`values` must have known statically known rank')
3010
3011    ndim = len(values_shape)
3012    if axis < 0:
3013      axis += ndim
3014    if not 0 <= axis < ndim:
3015      raise ValueError('axis = %r not in [0, %r)' % (axis, ndim))
3016
3017    fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis]
3018    if any(value is None for value in fixed_shape):
3019      raise ValueError('all dimensions of `values` other than the dimension to '
3020                       'concatenate along must have statically known size')
3021
3022    # We move `axis` to the front of the internal array so assign ops can be
3023    # applied to contiguous slices
3024    init_size = 0 if max_size is None else max_size
3025    init_shape = [init_size] + fixed_shape
3026    array = _create_local(
3027        'array', shape=init_shape, validate_shape=False, dtype=values.dtype)
3028    size = _create_local('size', shape=[], dtype=dtypes.int32)
3029
3030    perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
3031    valid_array = array[:size]
3032    valid_array.set_shape([None] + fixed_shape)
3033    value = array_ops.transpose(valid_array, perm, name='concat')
3034
3035    values_size = array_ops.shape(values)[axis]
3036    if max_size is None:
3037      batch_size = values_size
3038    else:
3039      batch_size = math_ops.minimum(values_size, max_size - size)
3040
3041    perm = [axis] + [n for n in range(ndim) if n != axis]
3042    batch_values = array_ops.transpose(values, perm)[:batch_size]
3043
3044    def reallocate():
3045      next_size = _next_array_size(new_size)
3046      next_shape = array_ops.stack([next_size] + fixed_shape)
3047      new_value = array_ops.zeros(next_shape, dtype=values.dtype)
3048      old_value = array.value()
3049      assign_op = state_ops.assign(array, new_value, validate_shape=False)
3050      with ops.control_dependencies([assign_op]):
3051        copy_op = array[:size].assign(old_value[:size])
3052      # return value needs to be the same dtype as no_op() for cond
3053      with ops.control_dependencies([copy_op]):
3054        return control_flow_ops.no_op()
3055
3056    new_size = size + batch_size
3057    array_size = array_ops.shape_internal(array, optimize=False)[0]
3058    maybe_reallocate_op = control_flow_ops.cond(
3059        new_size > array_size, reallocate, control_flow_ops.no_op)
3060    with ops.control_dependencies([maybe_reallocate_op]):
3061      append_values_op = array[size:new_size].assign(batch_values)
3062    with ops.control_dependencies([append_values_op]):
3063      update_op = size.assign(new_size)
3064
3065    if metrics_collections:
3066      ops.add_to_collections(metrics_collections, value)
3067
3068    if updates_collections:
3069      ops.add_to_collections(updates_collections, update_op)
3070
3071    return value, update_op
3072    # pylint: enable=invalid-slice-index
3073
3074
3075def aggregate_metrics(*value_update_tuples):
3076  """Aggregates the metric value tensors and update ops into two lists.
3077
3078  Args:
3079    *value_update_tuples: a variable number of tuples, each of which contain the
3080      pair of (value_tensor, update_op) from a streaming metric.
3081
3082  Returns:
3083    A list of value `Tensor` objects and a list of update ops.
3084
3085  Raises:
3086    ValueError: if `value_update_tuples` is empty.
3087  """
3088  if not value_update_tuples:
3089    raise ValueError('Expected at least one value_tensor/update_op pair')
3090  value_ops, update_ops = zip(*value_update_tuples)
3091  return list(value_ops), list(update_ops)
3092
3093
3094def aggregate_metric_map(names_to_tuples):
3095  """Aggregates the metric names to tuple dictionary.
3096
3097  This function is useful for pairing metric names with their associated value
3098  and update ops when the list of metrics is long. For example:
3099
3100  ```python
3101    metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
3102        'Mean Absolute Error': new_slim.metrics.streaming_mean_absolute_error(
3103            predictions, labels, weights),
3104        'Mean Relative Error': new_slim.metrics.streaming_mean_relative_error(
3105            predictions, labels, labels, weights),
3106        'RMSE Linear': new_slim.metrics.streaming_root_mean_squared_error(
3107            predictions, labels, weights),
3108        'RMSE Log': new_slim.metrics.streaming_root_mean_squared_error(
3109            predictions, labels, weights),
3110    })
3111  ```
3112
3113  Args:
3114    names_to_tuples: a map of metric names to tuples, each of which contain the
3115      pair of (value_tensor, update_op) from a streaming metric.
3116
3117  Returns:
3118    A dictionary from metric names to value ops and a dictionary from metric
3119    names to update ops.
3120  """
3121  metric_names = names_to_tuples.keys()
3122  value_ops, update_ops = zip(*names_to_tuples.values())
3123  return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
3124
3125
3126def _remove_squeezable_dimensions(predictions, labels, weights):
3127  """Squeeze last dim if needed.
3128
3129  Squeezes `predictions` and `labels` if their rank differs by 1.
3130  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`
3131
3132  This will use static shape if available. Otherwise, it will add graph
3133  operations, which could result in a performance hit.
3134
3135  Args:
3136    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
3137    labels: Label values, a `Tensor` whose dimensions match `predictions`.
3138    weights: Optional weight `Tensor`. It will be squeezed if its rank is 1
3139      more than the new rank of `predictions`
3140
3141  Returns:
3142    Tuple of `predictions`, `labels` and `weights`, possibly with the last
3143    dimension squeezed.
3144  """
3145  labels, predictions = confusion_matrix.remove_squeezable_dimensions(
3146      labels, predictions)
3147  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
3148
3149  if weights is not None:
3150    weights = ops.convert_to_tensor(weights)
3151    predictions_shape = predictions.get_shape()
3152    predictions_rank = predictions_shape.ndims
3153    weights_shape = weights.get_shape()
3154    weights_rank = weights_shape.ndims
3155
3156    if (predictions_rank is not None) and (weights_rank is not None):
3157      # Use static rank.
3158      if weights_rank - predictions_rank == 1:
3159        weights = array_ops.squeeze(weights, [-1])
3160    elif (weights_rank is
3161          None) or (weights_shape.dims[-1].is_compatible_with(1)):
3162      # Use dynamic rank
3163      weights = control_flow_ops.cond(
3164          math_ops.equal(
3165              array_ops.rank(weights),
3166              math_ops.add(array_ops.rank(predictions), 1)),
3167          lambda: array_ops.squeeze(weights, [-1]), lambda: weights)
3168  return predictions, labels, weights
3169
3170
3171__all__ = [
3172    'aggregate_metric_map',
3173    'aggregate_metrics',
3174    'sparse_recall_at_top_k',
3175    'streaming_accuracy',
3176    'streaming_auc',
3177    'streaming_curve_points',
3178    'streaming_false_negative_rate',
3179    'streaming_false_negative_rate_at_thresholds',
3180    'streaming_false_negatives',
3181    'streaming_false_negatives_at_thresholds',
3182    'streaming_false_positive_rate',
3183    'streaming_false_positive_rate_at_thresholds',
3184    'streaming_false_positives',
3185    'streaming_false_positives_at_thresholds',
3186    'streaming_mean',
3187    'streaming_mean_absolute_error',
3188    'streaming_mean_cosine_distance',
3189    'streaming_mean_iou',
3190    'streaming_mean_relative_error',
3191    'streaming_mean_squared_error',
3192    'streaming_mean_tensor',
3193    'streaming_percentage_less',
3194    'streaming_precision',
3195    'streaming_precision_at_thresholds',
3196    'streaming_recall',
3197    'streaming_recall_at_k',
3198    'streaming_recall_at_thresholds',
3199    'streaming_root_mean_squared_error',
3200    'streaming_sensitivity_at_specificity',
3201    'streaming_sparse_average_precision_at_k',
3202    'streaming_sparse_average_precision_at_top_k',
3203    'streaming_sparse_precision_at_k',
3204    'streaming_sparse_precision_at_top_k',
3205    'streaming_sparse_recall_at_k',
3206    'streaming_specificity_at_sensitivity',
3207    'streaming_true_negatives',
3208    'streaming_true_negatives_at_thresholds',
3209    'streaming_true_positives',
3210    'streaming_true_positives_at_thresholds',
3211]
3212