1b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# 3b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# Licensed under the Apache License, Version 2.0 (the "License"); 4b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# you may not use this file except in compliance with the License. 5b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# You may obtain a copy of the License at 6b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# 7b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# http://www.apache.org/licenses/LICENSE-2.0 8b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# 9b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# Unless required by applicable law or agreed to in writing, software 10b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# distributed under the License is distributed on an "AS IS" BASIS, 11b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# See the License for the specific language governing permissions and 13b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# limitations under the License. 14b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon# ============================================================================== 15b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon"""Multivariate Normal distribution classes.""" 16b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 17b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom __future__ import absolute_import 18b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom __future__ import division 19b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom __future__ import print_function 20b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 21b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.contrib import linalg 22b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.contrib.distributions.python.ops import distribution_util 23b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop 24b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonfrom tensorflow.python.framework import ops 25b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 26b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 27b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon__all__ = [ 28b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon "MultivariateNormalDiagPlusLowRank", 29b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon] 30b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 31b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 32b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillonclass MultivariateNormalDiagPlusLowRank( 33b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon mvn_linop.MultivariateNormalLinearOperator): 34b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon """The multivariate normal distribution on `R^k`. 35b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 36b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon The Multivariate Normal distribution is defined over `R^k` and parameterized 37b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 38b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `scale` matrix; `covariance = scale @ scale.T` where `@` denotes 39b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon matrix-multiplication. 40b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 41b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon #### Mathematical Details 42b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 43b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon The probability density function (pdf) is, 44b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 45b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ```none 46b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, 47b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon y = inv(scale) @ (x - loc), 48b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Z = (2 pi)**(0.5 k) |det(scale)|, 49b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ``` 50b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 51b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon where: 52b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 53b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `loc` is a vector in `R^k`, 54b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, 55b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `Z` denotes the normalization constant, and, 56b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `||y||**2` denotes the squared Euclidean norm of `y`. 57b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 58b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon A (non-batch) `scale` matrix is: 59b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 60b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ```none 61b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale = diag(scale_diag + scale_identity_multiplier ones(k)) + 62b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T 63b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ``` 64b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 65b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon where: 66b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 67b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_diag.shape = [k]`, 68b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_identity_multiplier.shape = []`, 69b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, 70b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_perturb_diag.shape = [r]`. 71b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 72b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Additional leading dimensions (if any) will index batches. 73b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 74b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon If both `scale_diag` and `scale_identity_multiplier` are `None`, then 75b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `scale` is the Identity matrix. 76b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 77b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon The MultivariateNormal distribution is a member of the [location-scale 78b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 79b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon constructed as, 80b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 81b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ```none 82b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 83b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Y = scale @ X + loc 84b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ``` 85b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 86b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon #### Examples 87b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 88b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ```python 894ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0Joshua V. Dillon tfd = tf.contrib.distributions 90b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 91b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, 92b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is 93b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # a rank-2 update. 94b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon mu = [-0.5., 0, 0.5] # shape: [3] 95b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon d = [1.5, 0.5, 2] # shape: [3] 96b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon U = [[1., 2], 97b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [-1, 1], 98b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [2, -0.5]] # shape: [3, 2] 99b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon m = [4., 5] # shape: [2] 1004ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0Joshua V. Dillon mvn = tfd.MultivariateNormalDiagPlusLowRank( 101b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc=mu 102b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_diag=d 103b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor=U, 104b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag=m) 105b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 106b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # Evaluate this on an observation in `R^3`, returning a scalar. 107b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon mvn.prob([-1, 0, 1]).eval() # shape: [] 108b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 109b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`. 110b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon mu = [[1., 2, 3], 111b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [11, 22, 33]] # shape: [b, k] = [2, 3] 112b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon U = [[[1., 2], 113b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [3, 4], 114b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [5, 6]], 115b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [[0.5, 0.75], 116b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [1,0, 0.25], 117b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [1.5, 1.25]]] # shape: [b, k, r] = [2, 3, 2] 118b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon m = [[0.1, 0.2], 119b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [0.4, 0.5]] # shape: [b, r] = [2, 2] 120b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 1214ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0Joshua V. Dillon mvn = tfd.MultivariateNormalDiagPlusLowRank( 122b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc=mu, 123b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor=U, 124b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag=m) 125b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 126b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon mvn.covariance().eval() # shape: [2, 3, 3] 127b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # ==> [[[ 15.63 31.57 48.51] 128b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # [ 31.57 69.31 105.05] 129b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # [ 48.51 105.05 162.59]] 130b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # 131b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # [[ 2.59 1.41 3.35] 132b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # [ 1.41 2.71 3.34] 133b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # [ 3.35 3.34 8.35]]] 134b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 135b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # Compute the pdf of two `R^3` observations (one from each batch); 136b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon # return a length-2 vector. 137b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon x = [[-0.9, 0, 0.1], 138b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon [-10, 0, 9]] # shape: [2, 3] 139b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon mvn.prob(x).eval() # shape: [2] 140b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ``` 141b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 142b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon """ 143b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 144b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon def __init__(self, 145b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc=None, 146b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_diag=None, 147b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_identity_multiplier=None, 148b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor=None, 149b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag=None, 150b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon validate_args=False, 151b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon allow_nan_stats=True, 152b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon name="MultivariateNormalDiagPlusLowRank"): 153b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon """Construct Multivariate Normal distribution on `R^k`. 154b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 155b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon The `batch_shape` is the broadcast shape between `loc` and `scale` 156b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon arguments. 157b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 1583be35ada0ee4984d57aa3a1a887f778227ffc4c9Ian Langmore The `event_shape` is given by last dimension of the matrix implied by 1593be35ada0ee4984d57aa3a1a887f778227ffc4c9Ian Langmore `scale`. The last dimension of `loc` (if provided) must broadcast with this. 160b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 161b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: 162b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 163b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ```none 164b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale = diag(scale_diag + scale_identity_multiplier ones(k)) + 165b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T 166b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ``` 167b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 168b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon where: 169b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 170b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_diag.shape = [k]`, 171b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_identity_multiplier.shape = []`, 172b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, 173b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon * `scale_perturb_diag.shape = [r]`. 174b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 175b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Additional leading dimensions (if any) will index batches. 176b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 177b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon If both `scale_diag` and `scale_identity_multiplier` are `None`, then 178b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `scale` is the Identity matrix. 179b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 180b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Args: 181b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 182b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 183b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `b >= 0` and `k` is the event size. 184b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_diag: Non-zero, floating-point `Tensor` representing a diagonal 185b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, 186b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon and characterizes `b`-batches of `k x k` diagonal matrices added to 187b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `scale`. When both `scale_identity_multiplier` and `scale_diag` are 188b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `None` then `scale` is the `Identity`. 189b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_identity_multiplier: Non-zero, floating-point `Tensor` representing 190b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon a scaled-identity-matrix added to `scale`. May have shape 191b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled 192b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `k x k` identity matrices added to `scale`. When both 193b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is 194b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon the `Identity`. 195b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor: Floating-point `Tensor` representing a rank-`r` 196b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`, 197b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`. 198b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon When `None`, no rank-`r` update is added to `scale`. 199b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix 200b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon inside the rank-`r` perturbation added to `scale`. May have shape 201b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r` 202b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon diagonal matrices inside the perturbation added to `scale`. When 203b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon `None`, an identity matrix is used inside the perturbation. Can only be 204b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon specified if `scale_perturb_factor` is also specified. 205b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon validate_args: Python `bool`, default `False`. When `True` distribution 206b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon parameters are checked for validity despite possibly degrading runtime 207b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon performance. When `False` invalid inputs may silently render incorrect 208b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon outputs. 209b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon allow_nan_stats: Python `bool`, default `True`. When `True`, 210b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon statistics (e.g., mean, mode, variance) use the value "`NaN`" to 211b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon indicate the result is undefined. When `False`, an exception is raised 212b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon if one or more of the statistic's batch members are undefined. 213b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon name: Python `str` name prefixed to Ops created by this class. 214b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon 215b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon Raises: 216b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon ValueError: if at most `scale_identity_multiplier` is specified. 217b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon """ 218b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon parameters = locals() 219b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon def _convert_to_tensor(x, name): 220b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon return None if x is None else ops.convert_to_tensor(x, name=name) 221f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore with ops.name_scope(name): 222b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon with ops.name_scope("init", values=[ 223b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc, scale_diag, scale_identity_multiplier, scale_perturb_factor, 224b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag]): 225b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon has_low_rank = (scale_perturb_factor is not None or 226b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag is not None) 227b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale = distribution_util.make_diag_scale( 228b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc=loc, 229b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_diag=scale_diag, 230b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_identity_multiplier=scale_identity_multiplier, 231b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon validate_args=validate_args, 232b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon assert_positive=has_low_rank) 233b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor = _convert_to_tensor( 234b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_factor, 235b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon name="scale_perturb_factor") 236b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag = _convert_to_tensor( 237b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale_perturb_diag, 238b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon name="scale_perturb_diag") 239b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon if has_low_rank: 24023418e4317b9e2c4a5148368daec873592a0de9eEugene Brevdo scale = linalg.LinearOperatorLowRankUpdate( 241b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale, 242b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon u=scale_perturb_factor, 243c0b379385bb23ad86c7233458f42c62aa7538788Ian Langmore diag_update=scale_perturb_diag, 244c0b379385bb23ad86c7233458f42c62aa7538788Ian Langmore is_diag_update_positive=scale_perturb_diag is None, 245b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon is_non_singular=True, # Implied by is_positive_definite=True. 246b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon is_self_adjoint=True, 247b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon is_positive_definite=True, 248b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon is_square=True) 249b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon super(MultivariateNormalDiagPlusLowRank, self).__init__( 250b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon loc=loc, 251b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon scale=scale, 252b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon validate_args=validate_args, 253b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon allow_nan_stats=allow_nan_stats, 254f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore name=name) 255b0f76d112be9190ac03f5b6083afb12aa6bdbc35Joshua V. Dillon self._parameters = parameters 256