1b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon#
3b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# Licensed under the Apache License, Version 2.0 (the "License");
4b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# you may not use this file except in compliance with the License.
5b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# You may obtain a copy of the License at
6b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon#
7b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon#     http://www.apache.org/licenses/LICENSE-2.0
8b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon#
9b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# Unless required by applicable law or agreed to in writing, software
10b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# distributed under the License is distributed on an "AS IS" BASIS,
11b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# See the License for the specific language governing permissions and
13b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# limitations under the License.
14b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# ==============================================================================
15b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon"""Multivariate Normal distribution classes."""
16b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
17b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom __future__ import absolute_import
18b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom __future__ import division
19b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom __future__ import print_function
20b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
21b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.contrib import linalg
22b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.contrib.distributions.python.ops import distribution_util
23b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
24b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.python.framework import ops
25b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
26b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
27b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon__all__ = [
28b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    "MultivariateNormalDiagPlusLowRank",
29b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon]
30b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
31b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
32b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonclass MultivariateNormalDiagPlusLowRank(
33b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    mvn_linop.MultivariateNormalLinearOperator):
34b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  """The multivariate normal distribution on `R^k`.
35b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
36b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  The Multivariate Normal distribution is defined over `R^k` and parameterized
37b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
38b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
39b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  matrix-multiplication.
40b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
41b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #### Mathematical Details
42b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
43b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  The probability density function (pdf) is,
44b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
45b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```none
46b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
47b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  y = inv(scale) @ (x - loc),
48b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  Z = (2 pi)**(0.5 k) |det(scale)|,
49b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```
50b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
51b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  where:
52b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
53b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `loc` is a vector in `R^k`,
54b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
55b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `Z` denotes the normalization constant, and,
56b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `||y||**2` denotes the squared Euclidean norm of `y`.
57b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
58b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  A (non-batch) `scale` matrix is:
59b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
60b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```none
61b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
62b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
63b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```
64b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
65b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  where:
66b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
67b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `scale_diag.shape = [k]`,
68b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `scale_identity_multiplier.shape = []`,
69b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
70b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  * `scale_perturb_diag.shape = [r]`.
71b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
72b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  Additional leading dimensions (if any) will index batches.
73b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
74b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  If both `scale_diag` and `scale_identity_multiplier` are `None`, then
75b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  `scale` is the Identity matrix.
76b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
77b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  The MultivariateNormal distribution is a member of the [location-scale
78b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
79b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  constructed as,
80b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
81b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```none
82b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
83b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  Y = scale @ X + loc
84b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```
85b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
86b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #### Examples
87b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
88b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```python
894ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0Joshua V. Dillon  tfd = tf.contrib.distributions
90b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
91b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
92b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
93b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # a rank-2 update.
94b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  mu = [-0.5., 0, 0.5]   # shape: [3]
95b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  d = [1.5, 0.5, 2]      # shape: [3]
96b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  U = [[1., 2],
97b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon       [-1, 1],
98b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon       [2, -0.5]]        # shape: [3, 2]
99b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  m = [4., 5]            # shape: [2]
1004ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0Joshua V. Dillon  mvn = tfd.MultivariateNormalDiagPlusLowRank(
101b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      loc=mu
102b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_diag=d
103b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_perturb_factor=U,
104b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_perturb_diag=m)
105b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
106b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # Evaluate this on an observation in `R^3`, returning a scalar.
107b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  mvn.prob([-1, 0, 1]).eval()  # shape: []
108b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
109b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`.
110b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  mu = [[1.,  2,  3],
111b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        [11, 22, 33]]      # shape: [b, k] = [2, 3]
112b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  U = [[[1., 2],
113b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        [3,  4],
114b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        [5,  6]],
115b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon       [[0.5, 0.75],
116b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        [1,0, 0.25],
117b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        [1.5, 1.25]]]      # shape: [b, k, r] = [2, 3, 2]
118b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  m = [[0.1, 0.2],
119b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon       [0.4, 0.5]]         # shape: [b, r] = [2, 2]
120b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
1214ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0Joshua V. Dillon  mvn = tfd.MultivariateNormalDiagPlusLowRank(
122b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      loc=mu,
123b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_perturb_factor=U,
124b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_perturb_diag=m)
125b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
126b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  mvn.covariance().eval()   # shape: [2, 3, 3]
127b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # ==> [[[  15.63   31.57    48.51]
128b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #       [  31.57   69.31   105.05]
129b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #       [  48.51  105.05   162.59]]
130b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #
131b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #      [[   2.59    1.41    3.35]
132b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #       [   1.41    2.71    3.34]
133b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  #       [   3.35    3.34    8.35]]]
134b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
135b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # Compute the pdf of two `R^3` observations (one from each batch);
136b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  # return a length-2 vector.
137b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  x = [[-0.9, 0, 0.1],
138b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon       [-10, 0, 9]]     # shape: [2, 3]
139b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  mvn.prob(x).eval()    # shape: [2]
140b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  ```
141b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
142b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  """
143b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
144b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon  def __init__(self,
145b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               loc=None,
146b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               scale_diag=None,
147b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               scale_identity_multiplier=None,
148b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               scale_perturb_factor=None,
149b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               scale_perturb_diag=None,
150b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               validate_args=False,
151b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               allow_nan_stats=True,
152b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon               name="MultivariateNormalDiagPlusLowRank"):
153b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    """Construct Multivariate Normal distribution on `R^k`.
154b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
155b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    The `batch_shape` is the broadcast shape between `loc` and `scale`
156b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    arguments.
157b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
1583be35ada0ee4984d57aa3a1a887f778227ffc4c9Ian Langmore    The `event_shape` is given by last dimension of the matrix implied by
1593be35ada0ee4984d57aa3a1a887f778227ffc4c9Ian Langmore    `scale`. The last dimension of `loc` (if provided) must broadcast with this.
160b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
161b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
162b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
163b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    ```none
164b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
165b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
166b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    ```
167b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
168b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    where:
169b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
170b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    * `scale_diag.shape = [k]`,
171b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    * `scale_identity_multiplier.shape = []`,
172b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
173b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    * `scale_perturb_diag.shape = [r]`.
174b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
175b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    Additional leading dimensions (if any) will index batches.
176b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
177b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    If both `scale_diag` and `scale_identity_multiplier` are `None`, then
178b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    `scale` is the Identity matrix.
179b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
180b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    Args:
181b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
182b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
183b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `b >= 0` and `k` is the event size.
184b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
185b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
186b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        and characterizes `b`-batches of `k x k` diagonal matrices added to
187b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
188b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `None` then `scale` is the `Identity`.
189b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
190b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        a scaled-identity-matrix added to `scale`. May have shape
191b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
192b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `k x k` identity matrices added to `scale`. When both
193b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
194b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        the `Identity`.
195b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_perturb_factor: Floating-point `Tensor` representing a rank-`r`
196b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`,
197b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`.
198b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        When `None`, no rank-`r` update is added to `scale`.
199b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix
200b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        inside the rank-`r` perturbation added to `scale`. May have shape
201b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r`
202b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon        diagonal matrices inside the perturbation added to `scale`. When
203b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        `None`, an identity matrix is used inside the perturbation. Can only be
204b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        specified if `scale_perturb_factor` is also specified.
205b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon      validate_args: Python `bool`, default `False`. When `True` distribution
206b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        parameters are checked for validity despite possibly degrading runtime
207b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        performance. When `False` invalid inputs may silently render incorrect
208b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        outputs.
209b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon      allow_nan_stats: Python `bool`, default `True`. When `True`,
210b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
211b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        indicate the result is undefined. When `False`, an exception is raised
212b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        if one or more of the statistic's batch members are undefined.
213b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon      name: Python `str` name prefixed to Ops created by this class.
214b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon
215b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    Raises:
216b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      ValueError: if at most `scale_identity_multiplier` is specified.
217b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    """
218b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    parameters = locals()
219b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    def _convert_to_tensor(x, name):
220b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      return None if x is None else ops.convert_to_tensor(x, name=name)
221f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore    with ops.name_scope(name):
222b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon      with ops.name_scope("init", values=[
223b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon          loc, scale_diag, scale_identity_multiplier, scale_perturb_factor,
224b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon          scale_perturb_diag]):
225b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        has_low_rank = (scale_perturb_factor is not None or
226b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon                        scale_perturb_diag is not None)
227b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        scale = distribution_util.make_diag_scale(
228b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            loc=loc,
229b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            scale_diag=scale_diag,
230b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            scale_identity_multiplier=scale_identity_multiplier,
231b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            validate_args=validate_args,
232b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            assert_positive=has_low_rank)
233b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        scale_perturb_factor = _convert_to_tensor(
234b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            scale_perturb_factor,
235b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            name="scale_perturb_factor")
236b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        scale_perturb_diag = _convert_to_tensor(
237b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            scale_perturb_diag,
238b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon            name="scale_perturb_diag")
239b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        if has_low_rank:
24023418e4317b9e2c4a5148368daec873592a0de9eEugene Brevdo          scale = linalg.LinearOperatorLowRankUpdate(
241b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon              scale,
242b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon              u=scale_perturb_factor,
243c0b379385bb23ad86c7233458f42c62aa7538788Ian Langmore              diag_update=scale_perturb_diag,
244c0b379385bb23ad86c7233458f42c62aa7538788Ian Langmore              is_diag_update_positive=scale_perturb_diag is None,
245b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon              is_non_singular=True,  # Implied by is_positive_definite=True.
246b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon              is_self_adjoint=True,
247b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon              is_positive_definite=True,
248b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon              is_square=True)
249b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    super(MultivariateNormalDiagPlusLowRank, self).__init__(
250b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        loc=loc,
251b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        scale=scale,
252b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        validate_args=validate_args,
253b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon        allow_nan_stats=allow_nan_stats,
254f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore        name=name)
255b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon    self._parameters = parameters
256