1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions for computing moving statistics.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import init_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import state_ops 25from tensorflow.python.ops import variable_scope 26 27 28__all__ = [ 29 "assign_moving_mean_variance", 30 "assign_log_moving_mean_exp", 31 "moving_mean_variance", 32] 33 34 35def assign_moving_mean_variance( 36 mean_var, variance_var, value, decay, name=None): 37 """Compute exponentially weighted moving {mean,variance} of a streaming value. 38 39 The `value` updated exponentially weighted moving `mean_var` and 40 `variance_var` are given by the following recurrence relations: 41 42 ```python 43 variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2) 44 mean_var = decay * mean_var + (1 - decay) * value 45 ``` 46 47 Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses 48 the lag-1 mean. 49 50 For derivation justification, see equation 143 of: 51 T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". 52 http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf 53 54 Args: 55 mean_var: `float`-like `Variable` representing the exponentially weighted 56 moving mean. Same shape as `variance_var` and `value`. 57 variance_var: `float`-like `Variable` representing the 58 exponentially weighted moving variance. Same shape as `mean_var` and 59 `value`. 60 value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`. 61 decay: A `float`-like `Tensor`. The moving mean decay. Typically close to 62 `1.`, e.g., `0.999`. 63 name: Optional name of the returned operation. 64 65 Returns: 66 mean_var: `Variable` representing the `value`-updated exponentially weighted 67 moving mean. 68 variance_var: `Variable` representing the `value`-updated 69 exponentially weighted moving variance. 70 71 Raises: 72 TypeError: if `mean_var` does not have float type `dtype`. 73 TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different 74 `base_dtype`. 75 """ 76 with ops.name_scope(name, "assign_moving_mean_variance", 77 [variance_var, mean_var, value, decay]): 78 with ops.colocate_with(variance_var): 79 with ops.colocate_with(mean_var): 80 base_dtype = mean_var.dtype.base_dtype 81 if not base_dtype.is_floating: 82 raise TypeError( 83 "mean_var.base_dtype({}) does not have float type " 84 "`dtype`.".format(base_dtype.name)) 85 if base_dtype != variance_var.dtype.base_dtype: 86 raise TypeError( 87 "mean_var.base_dtype({}) != variance_var.base_dtype({})".format( 88 base_dtype.name, 89 variance_var.dtype.base_dtype.name)) 90 value = ops.convert_to_tensor(value, dtype=base_dtype, name="value") 91 decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay") 92 delta = value - mean_var 93 with ops.control_dependencies([delta]): 94 mean_var = state_ops.assign_add( 95 mean_var, 96 (1. - decay) * delta) 97 variance_var = state_ops.assign_sub( 98 variance_var, 99 (1. - decay) * (variance_var - decay * math_ops.square(delta))) 100 return mean_var, variance_var 101 102 103def assign_log_moving_mean_exp( 104 log_mean_exp_var, log_value, decay, name=None): 105 """Compute the log of the exponentially weighted moving mean of the exp. 106 107 If `log_value` is a draw from a stationary random variable, this function 108 approximates `log(E[exp(log_value)])`, i.e., a weighted log-sum-exp. More 109 precisely, a `tf.Variable`, `log_mean_exp_var`, is updated by `log_value` 110 using the following identity: 111 112 ```none 113 log_mean_exp_var = 114 = log(decay exp(log_mean_exp_var) + (1 - decay) exp(log_value)) 115 = log(exp(log_mean_exp_var + log(decay)) + exp(log_value + log1p(-decay))) 116 = log_mean_exp_var 117 + log( exp(log_mean_exp_var - log_mean_exp_var + log(decay)) 118 + exp(log_value - log_mean_exp_var + log1p(-decay))) 119 = log_mean_exp_var 120 + log_sum_exp([log(decay), log_value - log_mean_exp_var + log1p(-decay)]). 121 ``` 122 123 In addition to numerical stability, this formulation is advantageous because 124 `log_mean_exp_var` can be updated in a lock-free manner, i.e., using 125 `assign_add`. (Note: the updates are not thread-safe; it's just that the 126 update to the tf.Variable is presumed efficient due to being lock-free.) 127 128 Args: 129 log_mean_exp_var: `float`-like `Variable` representing the log of the 130 exponentially weighted moving mean of the exp. Same shape as `log_value`. 131 log_value: `float`-like `Tensor` representing a new (streaming) observation. 132 Same shape as `log_mean_exp_var`. 133 decay: A `float`-like `Tensor`. The moving mean decay. Typically close to 134 `1.`, e.g., `0.999`. 135 name: Optional name of the returned operation. 136 137 Returns: 138 log_mean_exp_var: A reference to the input 'Variable' tensor with the 139 `log_value`-updated log of the exponentially weighted moving mean of exp. 140 141 Raises: 142 TypeError: if `log_mean_exp_var` does not have float type `dtype`. 143 TypeError: if `log_mean_exp_var`, `log_value`, `decay` have different 144 `base_dtype`. 145 """ 146 with ops.name_scope(name, "assign_log_moving_mean_exp", 147 [log_mean_exp_var, log_value, decay]): 148 # We want to update the variable in a numerically stable and lock-free way. 149 # To do this, observe that variable `x` updated by `v` is: 150 # x = log(w exp(x) + (1-w) exp(v)) 151 # = log(exp(x + log(w)) + exp(v + log1p(-w))) 152 # = x + log(exp(x - x + log(w)) + exp(v - x + log1p(-w))) 153 # = x + lse([log(w), v - x + log1p(-w)]) 154 with ops.colocate_with(log_mean_exp_var): 155 base_dtype = log_mean_exp_var.dtype.base_dtype 156 if not base_dtype.is_floating: 157 raise TypeError( 158 "log_mean_exp_var.base_dtype({}) does not have float type " 159 "`dtype`.".format(base_dtype.name)) 160 log_value = ops.convert_to_tensor(log_value, dtype=base_dtype, 161 name="log_value") 162 decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay") 163 delta = (log_value - log_mean_exp_var)[array_ops.newaxis, ...] 164 x = array_ops.concat([ 165 math_ops.log(decay) * array_ops.ones_like(delta), 166 delta + math_ops.log1p(-decay) 167 ], axis=0) 168 x = math_ops.reduce_logsumexp(x, axis=0) 169 return log_mean_exp_var.assign_add(x) 170 171 172def moving_mean_variance(value, decay, collections=None, name=None): 173 """Compute exponentially weighted moving {mean,variance} of a streaming value. 174 175 The exponentially-weighting moving `mean_var` and `variance_var` are updated 176 by `value` according to the following recurrence: 177 178 ```python 179 variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2) 180 mean_var = decay * mean_var + (1 - decay) * value 181 ``` 182 183 Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses 184 the lag-`1` mean. 185 186 For derivation justification, see equation 143 of: 187 T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". 188 http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf 189 190 Unlike `assign_moving_mean_variance`, this function handles 191 variable creation. 192 193 Args: 194 value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`. 195 decay: A `float`-like `Tensor`. The moving mean decay. Typically close to 196 `1.`, e.g., `0.999`. 197 collections: Python list of graph-collections keys to which the internal 198 variables `mean_var` and `variance_var` are added. 199 Default value is `[GraphKeys.GLOBAL_VARIABLES]`. 200 name: Optional name of the returned operation. 201 202 Returns: 203 mean_var: `Variable` representing the `value`-updated exponentially weighted 204 moving mean. 205 variance_var: `Variable` representing the `value`-updated 206 exponentially weighted moving variance. 207 208 Raises: 209 TypeError: if `value_var` does not have float type `dtype`. 210 TypeError: if `value`, `decay` have different `base_dtype`. 211 """ 212 if collections is None: 213 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 214 with variable_scope.variable_scope( 215 name, "moving_mean_variance", [value, decay]): 216 value = ops.convert_to_tensor(value, name="value") 217 base_dtype = value.dtype.base_dtype 218 if not base_dtype.is_floating: 219 raise TypeError( 220 "value.base_dtype({}) does not have float type `dtype`.".format( 221 base_dtype.name)) 222 decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay") 223 variance_var = variable_scope.get_variable( 224 "moving_variance", 225 shape=value.shape, 226 dtype=value.dtype, 227 initializer=init_ops.zeros_initializer(), 228 trainable=False, 229 collections=collections) 230 mean_var = variable_scope.get_variable( 231 "moving_mean", 232 shape=value.shape, 233 dtype=value.dtype, 234 initializer=init_ops.zeros_initializer(), 235 trainable=False, 236 collections=collections) 237 return assign_moving_mean_variance( 238 mean_var, variance_var, value, decay) 239