1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for computing moving statistics."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import init_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import state_ops
25from tensorflow.python.ops import variable_scope
26
27
28__all__ = [
29    "assign_moving_mean_variance",
30    "assign_log_moving_mean_exp",
31    "moving_mean_variance",
32]
33
34
35def assign_moving_mean_variance(
36    mean_var, variance_var, value, decay, name=None):
37  """Compute exponentially weighted moving {mean,variance} of a streaming value.
38
39  The `value` updated exponentially weighted moving `mean_var` and
40  `variance_var` are given by the following recurrence relations:
41
42  ```python
43  variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2)
44  mean_var     = decay * mean_var + (1 - decay) * value
45  ```
46
47  Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses
48  the lag-1 mean.
49
50  For derivation justification, see equation 143 of:
51    T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance".
52    http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
53
54  Args:
55    mean_var: `float`-like `Variable` representing the exponentially weighted
56      moving mean. Same shape as `variance_var` and `value`.
57    variance_var: `float`-like `Variable` representing the
58      exponentially weighted moving variance. Same shape as `mean_var` and
59      `value`.
60    value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`.
61    decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
62      `1.`, e.g., `0.999`.
63    name: Optional name of the returned operation.
64
65  Returns:
66    mean_var: `Variable` representing the `value`-updated exponentially weighted
67      moving mean.
68    variance_var: `Variable` representing the `value`-updated
69      exponentially weighted moving variance.
70
71  Raises:
72    TypeError: if `mean_var` does not have float type `dtype`.
73    TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different
74      `base_dtype`.
75  """
76  with ops.name_scope(name, "assign_moving_mean_variance",
77                      [variance_var, mean_var, value, decay]):
78    with ops.colocate_with(variance_var):
79      with ops.colocate_with(mean_var):
80        base_dtype = mean_var.dtype.base_dtype
81        if not base_dtype.is_floating:
82          raise TypeError(
83              "mean_var.base_dtype({}) does not have float type "
84              "`dtype`.".format(base_dtype.name))
85        if base_dtype != variance_var.dtype.base_dtype:
86          raise TypeError(
87              "mean_var.base_dtype({}) != variance_var.base_dtype({})".format(
88                  base_dtype.name,
89                  variance_var.dtype.base_dtype.name))
90        value = ops.convert_to_tensor(value, dtype=base_dtype, name="value")
91        decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
92        delta = value - mean_var
93        with ops.control_dependencies([delta]):
94          mean_var = state_ops.assign_add(
95              mean_var,
96              (1. - decay) * delta)
97          variance_var = state_ops.assign_sub(
98              variance_var,
99              (1. - decay) * (variance_var - decay * math_ops.square(delta)))
100        return mean_var, variance_var
101
102
103def assign_log_moving_mean_exp(
104    log_mean_exp_var, log_value, decay, name=None):
105  """Compute the log of the exponentially weighted moving mean of the exp.
106
107  If `log_value` is a draw from a stationary random variable, this function
108  approximates `log(E[exp(log_value)])`, i.e., a weighted log-sum-exp. More
109  precisely, a `tf.Variable`, `log_mean_exp_var`, is updated by `log_value`
110  using the following identity:
111
112  ```none
113  log_mean_exp_var =
114  = log(decay exp(log_mean_exp_var) + (1 - decay) exp(log_value))
115  = log(exp(log_mean_exp_var + log(decay)) + exp(log_value + log1p(-decay)))
116  = log_mean_exp_var
117    + log(  exp(log_mean_exp_var   - log_mean_exp_var + log(decay))
118          + exp(log_value - log_mean_exp_var + log1p(-decay)))
119  = log_mean_exp_var
120    + log_sum_exp([log(decay), log_value - log_mean_exp_var + log1p(-decay)]).
121  ```
122
123  In addition to numerical stability, this formulation is advantageous because
124  `log_mean_exp_var` can be updated in a lock-free manner, i.e., using
125  `assign_add`. (Note: the updates are not thread-safe; it's just that the
126  update to the tf.Variable is presumed efficient due to being lock-free.)
127
128  Args:
129    log_mean_exp_var: `float`-like `Variable` representing the log of the
130      exponentially weighted moving mean of the exp. Same shape as `log_value`.
131    log_value: `float`-like `Tensor` representing a new (streaming) observation.
132      Same shape as `log_mean_exp_var`.
133    decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
134      `1.`, e.g., `0.999`.
135    name: Optional name of the returned operation.
136
137  Returns:
138    log_mean_exp_var: A reference to the input 'Variable' tensor with the
139      `log_value`-updated log of the exponentially weighted moving mean of exp.
140
141  Raises:
142    TypeError: if `log_mean_exp_var` does not have float type `dtype`.
143    TypeError: if `log_mean_exp_var`, `log_value`, `decay` have different
144      `base_dtype`.
145  """
146  with ops.name_scope(name, "assign_log_moving_mean_exp",
147                      [log_mean_exp_var, log_value, decay]):
148    # We want to update the variable in a numerically stable and lock-free way.
149    # To do this, observe that variable `x` updated by `v` is:
150    # x = log(w exp(x) + (1-w) exp(v))
151    #   = log(exp(x + log(w)) + exp(v + log1p(-w)))
152    #   = x + log(exp(x - x + log(w)) + exp(v - x + log1p(-w)))
153    #   = x + lse([log(w), v - x + log1p(-w)])
154    with ops.colocate_with(log_mean_exp_var):
155      base_dtype = log_mean_exp_var.dtype.base_dtype
156      if not base_dtype.is_floating:
157        raise TypeError(
158            "log_mean_exp_var.base_dtype({}) does not have float type "
159            "`dtype`.".format(base_dtype.name))
160      log_value = ops.convert_to_tensor(log_value, dtype=base_dtype,
161                                        name="log_value")
162      decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
163      delta = (log_value - log_mean_exp_var)[array_ops.newaxis, ...]
164      x = array_ops.concat([
165          math_ops.log(decay) * array_ops.ones_like(delta),
166          delta + math_ops.log1p(-decay)
167      ], axis=0)
168      x = math_ops.reduce_logsumexp(x, axis=0)
169      return log_mean_exp_var.assign_add(x)
170
171
172def moving_mean_variance(value, decay, collections=None, name=None):
173  """Compute exponentially weighted moving {mean,variance} of a streaming value.
174
175  The exponentially-weighting moving `mean_var` and `variance_var` are updated
176  by `value` according to the following recurrence:
177
178  ```python
179  variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2)
180  mean_var     = decay * mean_var + (1 - decay) * value
181  ```
182
183  Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses
184  the lag-`1` mean.
185
186  For derivation justification, see equation 143 of:
187    T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance".
188    http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
189
190  Unlike `assign_moving_mean_variance`, this function handles
191  variable creation.
192
193  Args:
194    value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`.
195    decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
196      `1.`, e.g., `0.999`.
197    collections: Python list of graph-collections keys to which the internal
198      variables `mean_var` and `variance_var` are added.
199      Default value is `[GraphKeys.GLOBAL_VARIABLES]`.
200    name: Optional name of the returned operation.
201
202  Returns:
203    mean_var: `Variable` representing the `value`-updated exponentially weighted
204      moving mean.
205    variance_var: `Variable` representing the `value`-updated
206      exponentially weighted moving variance.
207
208  Raises:
209    TypeError: if `value_var` does not have float type `dtype`.
210    TypeError: if `value`, `decay` have different `base_dtype`.
211  """
212  if collections is None:
213    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
214  with variable_scope.variable_scope(
215      name, "moving_mean_variance", [value, decay]):
216    value = ops.convert_to_tensor(value, name="value")
217    base_dtype = value.dtype.base_dtype
218    if not base_dtype.is_floating:
219      raise TypeError(
220          "value.base_dtype({}) does not have float type `dtype`.".format(
221              base_dtype.name))
222    decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
223    variance_var = variable_scope.get_variable(
224        "moving_variance",
225        shape=value.shape,
226        dtype=value.dtype,
227        initializer=init_ops.zeros_initializer(),
228        trainable=False,
229        collections=collections)
230    mean_var = variable_scope.get_variable(
231        "moving_mean",
232        shape=value.shape,
233        dtype=value.dtype,
234        initializer=init_ops.zeros_initializer(),
235        trainable=False,
236        collections=collections)
237    return assign_moving_mean_variance(
238        mean_var, variance_var, value, decay)
239