1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Multivariate Normal distribution classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops import distribution_util
22from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.distributions import kullback_leibler
27from tensorflow.python.ops.distributions import normal
28from tensorflow.python.ops.distributions import transformed_distribution
29from tensorflow.python.ops.linalg import linalg
30
31
32__all__ = [
33    "MultivariateNormalLinearOperator",
34]
35
36
37_mvn_sample_note = """
38`value` is a batch vector with compatible shape if `value` is a `Tensor` whose
39shape can be broadcast up to either:
40
41```python
42self.batch_shape + self.event_shape
43```
44
45or
46
47```python
48[M1, ..., Mm] + self.batch_shape + self.event_shape
49```
50
51"""
52
53
54# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
55class MultivariateNormalLinearOperator(
56    transformed_distribution.TransformedDistribution):
57  """The multivariate normal distribution on `R^k`.
58
59  The Multivariate Normal distribution is defined over `R^k` and parameterized
60  by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
61  `scale` matrix; `covariance = scale @ scale.T`, where `@` denotes
62  matrix-multiplication.
63
64  #### Mathematical Details
65
66  The probability density function (pdf) is,
67
68  ```none
69  pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
70  y = inv(scale) @ (x - loc),
71  Z = (2 pi)**(0.5 k) |det(scale)|,
72  ```
73
74  where:
75
76  * `loc` is a vector in `R^k`,
77  * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
78  * `Z` denotes the normalization constant, and,
79  * `||y||**2` denotes the squared Euclidean norm of `y`.
80
81  The MultivariateNormal distribution is a member of the [location-scale
82  family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
83  constructed as,
84
85  ```none
86  X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
87  Y = scale @ X + loc
88  ```
89
90  #### Examples
91
92  ```python
93  tfd = tf.contrib.distributions
94
95  # Initialize a single 3-variate Gaussian.
96  mu = [1., 2, 3]
97  cov = [[ 0.36,  0.12,  0.06],
98         [ 0.12,  0.29, -0.13],
99         [ 0.06, -0.13,  0.26]]
100  scale = tf.cholesky(cov)
101  # ==> [[ 0.6,  0. ,  0. ],
102  #      [ 0.2,  0.5,  0. ],
103  #      [ 0.1, -0.3,  0.4]])
104
105  mvn = tfd.MultivariateNormalLinearOperator(
106      loc=mu,
107      scale=tf.linalg.LinearOperatorLowerTriangular(scale))
108
109  # Covariance agrees with cholesky(cov) parameterization.
110  mvn.covariance().eval()
111  # ==> [[ 0.36,  0.12,  0.06],
112  #      [ 0.12,  0.29, -0.13],
113  #      [ 0.06, -0.13,  0.26]]
114
115  # Compute the pdf of an`R^3` observation; return a scalar.
116  mvn.prob([-1., 0, 1]).eval()  # shape: []
117
118  # Initialize a 2-batch of 3-variate Gaussians.
119  mu = [[1., 2, 3],
120        [11, 22, 33]]              # shape: [2, 3]
121  scale_diag = [[1., 2, 3],
122                [0.5, 1, 1.5]]     # shape: [2, 3]
123
124  mvn = tfd.MultivariateNormalLinearOperator(
125      loc=mu,
126      scale=tf.linalg.LinearOperatorDiag(scale_diag))
127
128  # Compute the pdf of two `R^3` observations; return a length-2 vector.
129  x = [[-0.9, 0, 0.1],
130       [-10, 0, 9]]     # shape: [2, 3]
131  mvn.prob(x).eval()    # shape: [2]
132  ```
133
134  """
135
136  def __init__(self,
137               loc=None,
138               scale=None,
139               validate_args=False,
140               allow_nan_stats=True,
141               name="MultivariateNormalLinearOperator"):
142    """Construct Multivariate Normal distribution on `R^k`.
143
144    The `batch_shape` is the broadcast shape between `loc` and `scale`
145    arguments.
146
147    The `event_shape` is given by last dimension of the matrix implied by
148    `scale`. The last dimension of `loc` (if provided) must broadcast with this.
149
150    Recall that `covariance = scale @ scale.T`.
151
152    Additional leading dimensions (if any) will index batches.
153
154    Args:
155      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
156        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
157        `b >= 0` and `k` is the event size.
158      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
159        `[B1, ..., Bb, k, k]`.
160      validate_args: Python `bool`, default `False`. Whether to validate input
161        with asserts. If `validate_args` is `False`, and the inputs are
162        invalid, correct behavior is not guaranteed.
163      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
164        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
165        batch member If `True`, batch members with valid parameters leading to
166        undefined statistics will return NaN for this statistic.
167      name: The name to give Ops created by the initializer.
168
169    Raises:
170      ValueError: if `scale` is unspecified.
171      TypeError: if not `scale.dtype.is_floating`
172    """
173    parameters = locals()
174    if scale is None:
175      raise ValueError("Missing required `scale` parameter.")
176    if not scale.dtype.is_floating:
177      raise TypeError("`scale` parameter must have floating-point dtype.")
178
179    with ops.name_scope(name, values=[loc] + scale.graph_parents):
180      # Since expand_dims doesn't preserve constant-ness, we obtain the
181      # non-dynamic value if possible.
182      loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc
183      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
184          loc, scale)
185
186    super(MultivariateNormalLinearOperator, self).__init__(
187        distribution=normal.Normal(
188            loc=array_ops.zeros([], dtype=scale.dtype),
189            scale=array_ops.ones([], dtype=scale.dtype)),
190        bijector=AffineLinearOperator(
191            shift=loc, scale=scale, validate_args=validate_args),
192        batch_shape=batch_shape,
193        event_shape=event_shape,
194        validate_args=validate_args,
195        name=name)
196    self._parameters = parameters
197
198  @property
199  def loc(self):
200    """The `loc` `Tensor` in `Y = scale @ X + loc`."""
201    return self.bijector.shift
202
203  @property
204  def scale(self):
205    """The `scale` `LinearOperator` in `Y = scale @ X + loc`."""
206    return self.bijector.scale
207
208  @distribution_util.AppendDocstring(_mvn_sample_note)
209  def _log_prob(self, x):
210    return super(MultivariateNormalLinearOperator, self)._log_prob(x)
211
212  @distribution_util.AppendDocstring(_mvn_sample_note)
213  def _prob(self, x):
214    return super(MultivariateNormalLinearOperator, self)._prob(x)
215
216  def _mean(self):
217    shape = self.batch_shape.concatenate(self.event_shape)
218    has_static_shape = shape.is_fully_defined()
219    if not has_static_shape:
220      shape = array_ops.concat([
221          self.batch_shape_tensor(),
222          self.event_shape_tensor(),
223      ], 0)
224
225    if self.loc is None:
226      return array_ops.zeros(shape, self.dtype)
227
228    if has_static_shape and shape == self.loc.get_shape():
229      return array_ops.identity(self.loc)
230
231    # Add dummy tensor of zeros to broadcast.  This is only necessary if shape
232    # != self.loc.shape, but we could not determine if this is the case.
233    return array_ops.identity(self.loc) + array_ops.zeros(shape, self.dtype)
234
235  def _covariance(self):
236    if distribution_util.is_diagonal_scale(self.scale):
237      return array_ops.matrix_diag(math_ops.square(self.scale.diag_part()))
238    else:
239      return self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)
240
241  def _variance(self):
242    if distribution_util.is_diagonal_scale(self.scale):
243      return math_ops.square(self.scale.diag_part())
244    elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and
245          self.scale.is_self_adjoint):
246      return array_ops.matrix_diag_part(
247          self.scale.matmul(self.scale.to_dense()))
248    else:
249      return array_ops.matrix_diag_part(
250          self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))
251
252  def _stddev(self):
253    if distribution_util.is_diagonal_scale(self.scale):
254      return math_ops.abs(self.scale.diag_part())
255    elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and
256          self.scale.is_self_adjoint):
257      return math_ops.sqrt(array_ops.matrix_diag_part(
258          self.scale.matmul(self.scale.to_dense())))
259    else:
260      return math_ops.sqrt(array_ops.matrix_diag_part(
261          self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)))
262
263  def _mode(self):
264    return self._mean()
265
266
267@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator,
268                             MultivariateNormalLinearOperator)
269def _kl_brute_force(a, b, name=None):
270  """Batched KL divergence `KL(a || b)` for multivariate Normals.
271
272  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
273  covariance `C_a`, `C_b` respectively,
274
275  ```
276  KL(a || b) = 0.5 * ( L - k + T + Q ),
277  L := Log[Det(C_b)] - Log[Det(C_a)]
278  T := trace(C_b^{-1} C_a),
279  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
280  ```
281
282  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
283  methods for solving systems with `C_b` may be available, a dense version of
284  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
285  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
286  and `y`.
287
288  Args:
289    a: Instance of `MultivariateNormalLinearOperator`.
290    b: Instance of `MultivariateNormalLinearOperator`.
291    name: (optional) name to use for created ops. Default "kl_mvn".
292
293  Returns:
294    Batchwise `KL(a || b)`.
295  """
296
297  def squared_frobenius_norm(x):
298    """Helper to make KL calculation slightly more readable."""
299    # http://mathworld.wolfram.com/FrobeniusNorm.html
300    # The gradient of KL[p,q] is not defined when p==q. The culprit is
301    # linalg_ops.norm, i.e., we cannot use the commented out code.
302    # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1]))
303    return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1])
304
305  # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
306  # supports something like:
307  #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
308  def is_diagonal(x):
309    """Helper to identify if `LinearOperator` has only a diagonal component."""
310    return (isinstance(x, linalg.LinearOperatorIdentity) or
311            isinstance(x, linalg.LinearOperatorScaledIdentity) or
312            isinstance(x, linalg.LinearOperatorDiag))
313
314  with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] +
315                      a.scale.graph_parents + b.scale.graph_parents):
316    # Calculation is based on:
317    # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
318    # and,
319    # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
320    # i.e.,
321    #   If Ca = AA', Cb = BB', then
322    #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
323    #                  = tr[inv(B) A A' inv(B)']
324    #                  = tr[(inv(B) A) (inv(B) A)']
325    #                  = sum_{ij} (inv(B) A)_{ij}**2
326    #                  = ||inv(B) A||_F**2
327    # where ||.||_F is the Frobenius norm and the second equality follows from
328    # the cyclic permutation property.
329    if is_diagonal(a.scale) and is_diagonal(b.scale):
330      # Using `stddev` because it handles expansion of Identity cases.
331      b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis]
332    else:
333      b_inv_a = b.scale.solve(a.scale.to_dense())
334    kl_div = (b.scale.log_abs_determinant()
335              - a.scale.log_abs_determinant()
336              + 0.5 * (
337                  - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype)
338                  + squared_frobenius_norm(b_inv_a)
339                  + squared_frobenius_norm(b.scale.solve(
340                      (b.mean() - a.mean())[..., array_ops.newaxis]))))
341    kl_div.set_shape(array_ops.broadcast_static_shape(
342        a.batch_shape, b.batch_shape))
343    return kl_div
344