1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Multivariate Normal distribution classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops import distribution_util 22from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops.distributions import kullback_leibler 27from tensorflow.python.ops.distributions import normal 28from tensorflow.python.ops.distributions import transformed_distribution 29from tensorflow.python.ops.linalg import linalg 30 31 32__all__ = [ 33 "MultivariateNormalLinearOperator", 34] 35 36 37_mvn_sample_note = """ 38`value` is a batch vector with compatible shape if `value` is a `Tensor` whose 39shape can be broadcast up to either: 40 41```python 42self.batch_shape + self.event_shape 43``` 44 45or 46 47```python 48[M1, ..., Mm] + self.batch_shape + self.event_shape 49``` 50 51""" 52 53 54# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests. 55class MultivariateNormalLinearOperator( 56 transformed_distribution.TransformedDistribution): 57 """The multivariate normal distribution on `R^k`. 58 59 The Multivariate Normal distribution is defined over `R^k` and parameterized 60 by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 61 `scale` matrix; `covariance = scale @ scale.T`, where `@` denotes 62 matrix-multiplication. 63 64 #### Mathematical Details 65 66 The probability density function (pdf) is, 67 68 ```none 69 pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, 70 y = inv(scale) @ (x - loc), 71 Z = (2 pi)**(0.5 k) |det(scale)|, 72 ``` 73 74 where: 75 76 * `loc` is a vector in `R^k`, 77 * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, 78 * `Z` denotes the normalization constant, and, 79 * `||y||**2` denotes the squared Euclidean norm of `y`. 80 81 The MultivariateNormal distribution is a member of the [location-scale 82 family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 83 constructed as, 84 85 ```none 86 X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 87 Y = scale @ X + loc 88 ``` 89 90 #### Examples 91 92 ```python 93 tfd = tf.contrib.distributions 94 95 # Initialize a single 3-variate Gaussian. 96 mu = [1., 2, 3] 97 cov = [[ 0.36, 0.12, 0.06], 98 [ 0.12, 0.29, -0.13], 99 [ 0.06, -0.13, 0.26]] 100 scale = tf.cholesky(cov) 101 # ==> [[ 0.6, 0. , 0. ], 102 # [ 0.2, 0.5, 0. ], 103 # [ 0.1, -0.3, 0.4]]) 104 105 mvn = tfd.MultivariateNormalLinearOperator( 106 loc=mu, 107 scale=tf.linalg.LinearOperatorLowerTriangular(scale)) 108 109 # Covariance agrees with cholesky(cov) parameterization. 110 mvn.covariance().eval() 111 # ==> [[ 0.36, 0.12, 0.06], 112 # [ 0.12, 0.29, -0.13], 113 # [ 0.06, -0.13, 0.26]] 114 115 # Compute the pdf of an`R^3` observation; return a scalar. 116 mvn.prob([-1., 0, 1]).eval() # shape: [] 117 118 # Initialize a 2-batch of 3-variate Gaussians. 119 mu = [[1., 2, 3], 120 [11, 22, 33]] # shape: [2, 3] 121 scale_diag = [[1., 2, 3], 122 [0.5, 1, 1.5]] # shape: [2, 3] 123 124 mvn = tfd.MultivariateNormalLinearOperator( 125 loc=mu, 126 scale=tf.linalg.LinearOperatorDiag(scale_diag)) 127 128 # Compute the pdf of two `R^3` observations; return a length-2 vector. 129 x = [[-0.9, 0, 0.1], 130 [-10, 0, 9]] # shape: [2, 3] 131 mvn.prob(x).eval() # shape: [2] 132 ``` 133 134 """ 135 136 def __init__(self, 137 loc=None, 138 scale=None, 139 validate_args=False, 140 allow_nan_stats=True, 141 name="MultivariateNormalLinearOperator"): 142 """Construct Multivariate Normal distribution on `R^k`. 143 144 The `batch_shape` is the broadcast shape between `loc` and `scale` 145 arguments. 146 147 The `event_shape` is given by last dimension of the matrix implied by 148 `scale`. The last dimension of `loc` (if provided) must broadcast with this. 149 150 Recall that `covariance = scale @ scale.T`. 151 152 Additional leading dimensions (if any) will index batches. 153 154 Args: 155 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 156 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 157 `b >= 0` and `k` is the event size. 158 scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape 159 `[B1, ..., Bb, k, k]`. 160 validate_args: Python `bool`, default `False`. Whether to validate input 161 with asserts. If `validate_args` is `False`, and the inputs are 162 invalid, correct behavior is not guaranteed. 163 allow_nan_stats: Python `bool`, default `True`. If `False`, raise an 164 exception if a statistic (e.g. mean/mode/etc...) is undefined for any 165 batch member If `True`, batch members with valid parameters leading to 166 undefined statistics will return NaN for this statistic. 167 name: The name to give Ops created by the initializer. 168 169 Raises: 170 ValueError: if `scale` is unspecified. 171 TypeError: if not `scale.dtype.is_floating` 172 """ 173 parameters = locals() 174 if scale is None: 175 raise ValueError("Missing required `scale` parameter.") 176 if not scale.dtype.is_floating: 177 raise TypeError("`scale` parameter must have floating-point dtype.") 178 179 with ops.name_scope(name, values=[loc] + scale.graph_parents): 180 # Since expand_dims doesn't preserve constant-ness, we obtain the 181 # non-dynamic value if possible. 182 loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc 183 batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( 184 loc, scale) 185 186 super(MultivariateNormalLinearOperator, self).__init__( 187 distribution=normal.Normal( 188 loc=array_ops.zeros([], dtype=scale.dtype), 189 scale=array_ops.ones([], dtype=scale.dtype)), 190 bijector=AffineLinearOperator( 191 shift=loc, scale=scale, validate_args=validate_args), 192 batch_shape=batch_shape, 193 event_shape=event_shape, 194 validate_args=validate_args, 195 name=name) 196 self._parameters = parameters 197 198 @property 199 def loc(self): 200 """The `loc` `Tensor` in `Y = scale @ X + loc`.""" 201 return self.bijector.shift 202 203 @property 204 def scale(self): 205 """The `scale` `LinearOperator` in `Y = scale @ X + loc`.""" 206 return self.bijector.scale 207 208 @distribution_util.AppendDocstring(_mvn_sample_note) 209 def _log_prob(self, x): 210 return super(MultivariateNormalLinearOperator, self)._log_prob(x) 211 212 @distribution_util.AppendDocstring(_mvn_sample_note) 213 def _prob(self, x): 214 return super(MultivariateNormalLinearOperator, self)._prob(x) 215 216 def _mean(self): 217 shape = self.batch_shape.concatenate(self.event_shape) 218 has_static_shape = shape.is_fully_defined() 219 if not has_static_shape: 220 shape = array_ops.concat([ 221 self.batch_shape_tensor(), 222 self.event_shape_tensor(), 223 ], 0) 224 225 if self.loc is None: 226 return array_ops.zeros(shape, self.dtype) 227 228 if has_static_shape and shape == self.loc.get_shape(): 229 return array_ops.identity(self.loc) 230 231 # Add dummy tensor of zeros to broadcast. This is only necessary if shape 232 # != self.loc.shape, but we could not determine if this is the case. 233 return array_ops.identity(self.loc) + array_ops.zeros(shape, self.dtype) 234 235 def _covariance(self): 236 if distribution_util.is_diagonal_scale(self.scale): 237 return array_ops.matrix_diag(math_ops.square(self.scale.diag_part())) 238 else: 239 return self.scale.matmul(self.scale.to_dense(), adjoint_arg=True) 240 241 def _variance(self): 242 if distribution_util.is_diagonal_scale(self.scale): 243 return math_ops.square(self.scale.diag_part()) 244 elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and 245 self.scale.is_self_adjoint): 246 return array_ops.matrix_diag_part( 247 self.scale.matmul(self.scale.to_dense())) 248 else: 249 return array_ops.matrix_diag_part( 250 self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)) 251 252 def _stddev(self): 253 if distribution_util.is_diagonal_scale(self.scale): 254 return math_ops.abs(self.scale.diag_part()) 255 elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and 256 self.scale.is_self_adjoint): 257 return math_ops.sqrt(array_ops.matrix_diag_part( 258 self.scale.matmul(self.scale.to_dense()))) 259 else: 260 return math_ops.sqrt(array_ops.matrix_diag_part( 261 self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) 262 263 def _mode(self): 264 return self._mean() 265 266 267@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, 268 MultivariateNormalLinearOperator) 269def _kl_brute_force(a, b, name=None): 270 """Batched KL divergence `KL(a || b)` for multivariate Normals. 271 272 With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and 273 covariance `C_a`, `C_b` respectively, 274 275 ``` 276 KL(a || b) = 0.5 * ( L - k + T + Q ), 277 L := Log[Det(C_b)] - Log[Det(C_a)] 278 T := trace(C_b^{-1} C_a), 279 Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), 280 ``` 281 282 This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient 283 methods for solving systems with `C_b` may be available, a dense version of 284 (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` 285 is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` 286 and `y`. 287 288 Args: 289 a: Instance of `MultivariateNormalLinearOperator`. 290 b: Instance of `MultivariateNormalLinearOperator`. 291 name: (optional) name to use for created ops. Default "kl_mvn". 292 293 Returns: 294 Batchwise `KL(a || b)`. 295 """ 296 297 def squared_frobenius_norm(x): 298 """Helper to make KL calculation slightly more readable.""" 299 # http://mathworld.wolfram.com/FrobeniusNorm.html 300 # The gradient of KL[p,q] is not defined when p==q. The culprit is 301 # linalg_ops.norm, i.e., we cannot use the commented out code. 302 # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) 303 return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1]) 304 305 # TODO(b/35041439): See also b/35040945. Remove this function once LinOp 306 # supports something like: 307 # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) 308 def is_diagonal(x): 309 """Helper to identify if `LinearOperator` has only a diagonal component.""" 310 return (isinstance(x, linalg.LinearOperatorIdentity) or 311 isinstance(x, linalg.LinearOperatorScaledIdentity) or 312 isinstance(x, linalg.LinearOperatorDiag)) 313 314 with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + 315 a.scale.graph_parents + b.scale.graph_parents): 316 # Calculation is based on: 317 # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians 318 # and, 319 # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm 320 # i.e., 321 # If Ca = AA', Cb = BB', then 322 # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] 323 # = tr[inv(B) A A' inv(B)'] 324 # = tr[(inv(B) A) (inv(B) A)'] 325 # = sum_{ij} (inv(B) A)_{ij}**2 326 # = ||inv(B) A||_F**2 327 # where ||.||_F is the Frobenius norm and the second equality follows from 328 # the cyclic permutation property. 329 if is_diagonal(a.scale) and is_diagonal(b.scale): 330 # Using `stddev` because it handles expansion of Identity cases. 331 b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis] 332 else: 333 b_inv_a = b.scale.solve(a.scale.to_dense()) 334 kl_div = (b.scale.log_abs_determinant() 335 - a.scale.log_abs_determinant() 336 + 0.5 * ( 337 - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) 338 + squared_frobenius_norm(b_inv_a) 339 + squared_frobenius_norm(b.scale.solve( 340 (b.mean() - a.mean())[..., array_ops.newaxis])))) 341 kl_div.set_shape(array_ops.broadcast_static_shape( 342 a.batch_shape, b.batch_shape)) 343 return kl_div 344