1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Multi-dimensional (Vector) SinhArcsinh transformation of a distribution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops import bijectors 22from tensorflow.contrib.distributions.python.ops import distribution_util 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops.distributions import normal 27from tensorflow.python.ops.distributions import transformed_distribution 28 29__all__ = [ 30 "VectorSinhArcsinhDiag", 31] 32 33 34class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): 35 """The (diagonal) SinhArcsinh transformation of a distribution on `R^k`. 36 37 This distribution models a random vector `Y = (Y1,...,Yk)`, making use of 38 a `SinhArcsinh` transformation (which has adjustable tailweight and skew), 39 a rescaling, and a shift. 40 41 The `SinhArcsinh` transformation of the Normal is described in great depth in 42 [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865). 43 Here we use a slightly different parameterization, in terms of `tailweight` 44 and `skewness`. Additionally we allow for distributions other than Normal, 45 and control over `scale` as well as a "shift" parameter `loc`. 46 47 #### Mathematical Details 48 49 Given iid random vector `Z = (Z1,...,Zk)`, we define the VectorSinhArcsinhDiag 50 transformation of `Z`, `Y`, parameterized by 51 `(loc, scale, skewness, tailweight)`, via the relation (with `@` denoting 52 matrix multiplication): 53 54 ``` 55 Y := loc + scale @ F(Z) * (2 / F_0(2)) 56 F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) 57 F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) 58 ``` 59 60 This distribution is similar to the location-scale transformation 61 `L(Z) := loc + scale @ Z` in the following ways: 62 63 * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then 64 `Y = L(Z)` exactly. 65 * `loc` is used in both to shift the result by a constant factor. 66 * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` 67 `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. 68 Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond 69 `loc + 2 * scale` are the same. 70 71 This distribution is different than `loc + scale @ Z` due to the 72 reshaping done by `F`: 73 74 * Positive (negative) `skewness` leads to positive (negative) skew. 75 * positive skew means, the mode of `F(Z)` is "tilted" to the right. 76 * positive skew means positive values of `F(Z)` become more likely, and 77 negative values become less likely. 78 * Larger (smaller) `tailweight` leads to fatter (thinner) tails. 79 * Fatter tails mean larger values of `|F(Z)|` become more likely. 80 * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`, 81 and a very steep drop-off in the tails. 82 * `tailweight > 1` leads to a distribution more peaked at the mode with 83 heavier tails. 84 85 To see the argument about the tails, note that for `|Z| >> 1` and 86 `|Z| >> (|skewness| * tailweight)**tailweight`, we have 87 `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. 88 89 To see the argument regarding multiplying `scale` by `2 / F_0(2)`, 90 91 ``` 92 P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] 93 = P[F(Z) <= F_0(2)] 94 = P[Z <= 2] (if F = F_0). 95 ``` 96 """ 97 98 def __init__(self, 99 loc=None, 100 scale_diag=None, 101 scale_identity_multiplier=None, 102 skewness=None, 103 tailweight=None, 104 distribution=None, 105 validate_args=False, 106 allow_nan_stats=True, 107 name="MultivariateNormalLinearOperator"): 108 """Construct VectorSinhArcsinhDiag distribution on `R^k`. 109 110 The arguments `scale_diag` and `scale_identity_multiplier` combine to 111 define the diagonal `scale` referred to in this class docstring: 112 113 ```none 114 scale = diag(scale_diag + scale_identity_multiplier * ones(k)) 115 ``` 116 117 The `batch_shape` is the broadcast shape between `loc` and `scale` 118 arguments. 119 120 The `event_shape` is given by last dimension of the matrix implied by 121 `scale`. The last dimension of `loc` (if provided) must broadcast with this 122 123 Additional leading dimensions (if any) will index batches. 124 125 Args: 126 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 127 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 128 `b >= 0` and `k` is the event size. 129 scale_diag: Non-zero, floating-point `Tensor` representing a diagonal 130 matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, 131 and characterizes `b`-batches of `k x k` diagonal matrices added to 132 `scale`. When both `scale_identity_multiplier` and `scale_diag` are 133 `None` then `scale` is the `Identity`. 134 scale_identity_multiplier: Non-zero, floating-point `Tensor` representing 135 a scale-identity-matrix added to `scale`. May have shape 136 `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale 137 `k x k` identity matrices added to `scale`. When both 138 `scale_identity_multiplier` and `scale_diag` are `None` then `scale` 139 is the `Identity`. 140 skewness: Skewness parameter. floating-point `Tensor` with shape 141 broadcastable with `event_shape`. 142 tailweight: Tailweight parameter. floating-point `Tensor` with shape 143 broadcastable with `event_shape`. 144 distribution: `tf.Distribution`-like instance. Distribution from which `k` 145 iid samples are used as input to transformation `F`. Default is 146 `tf.distributions.Normal(loc=0., scale=1.)`. 147 Must be a scalar-batch, scalar-event distribution. Typically 148 `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is 149 a function of non-trainable parameters. WARNING: If you backprop through 150 a VectorSinhArcsinhDiag sample and `distribution` is not 151 `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then 152 the gradient will be incorrect! 153 validate_args: Python `bool`, default `False`. When `True` distribution 154 parameters are checked for validity despite possibly degrading runtime 155 performance. When `False` invalid inputs may silently render incorrect 156 outputs. 157 allow_nan_stats: Python `bool`, default `True`. When `True`, 158 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 159 indicate the result is undefined. When `False`, an exception is raised 160 if one or more of the statistic's batch members are undefined. 161 name: Python `str` name prefixed to Ops created by this class. 162 163 Raises: 164 ValueError: if at most `scale_identity_multiplier` is specified. 165 """ 166 parameters = locals() 167 168 with ops.name_scope( 169 name, 170 values=[ 171 loc, scale_diag, scale_identity_multiplier, skewness, tailweight 172 ]): 173 loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc 174 tailweight = 1. if tailweight is None else tailweight 175 has_default_skewness = skewness is None 176 skewness = 0. if skewness is None else skewness 177 178 # Recall, with Z a random variable, 179 # Y := loc + C * F(Z), 180 # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) 181 # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) 182 # C := 2 * scale / F_0(2) 183 184 # Construct shapes and 'scale' out of the scale_* and loc kwargs. 185 # scale_linop is only an intermediary to: 186 # 1. get shapes from looking at loc and the two scale args. 187 # 2. combine scale_diag with scale_identity_multiplier, which gives us 188 # 'scale', which in turn gives us 'C'. 189 scale_linop = distribution_util.make_diag_scale( 190 loc=loc, 191 scale_diag=scale_diag, 192 scale_identity_multiplier=scale_identity_multiplier, 193 validate_args=False, 194 assert_positive=False) 195 batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( 196 loc, scale_linop) 197 # scale_linop.diag_part() is efficient since it is a diag type linop. 198 scale_diag_part = scale_linop.diag_part() 199 dtype = scale_diag_part.dtype 200 201 if distribution is None: 202 distribution = normal.Normal( 203 loc=array_ops.zeros([], dtype=dtype), 204 scale=array_ops.ones([], dtype=dtype), 205 allow_nan_stats=allow_nan_stats) 206 else: 207 asserts = distribution_util.maybe_check_scalar_distribution( 208 distribution, dtype, validate_args) 209 if asserts: 210 scale_diag_part = control_flow_ops.with_dependencies( 211 asserts, scale_diag_part) 212 213 # Make the SAS bijector, 'F'. 214 skewness = ops.convert_to_tensor(skewness, dtype=dtype, name="skewness") 215 tailweight = ops.convert_to_tensor( 216 tailweight, dtype=dtype, name="tailweight") 217 f = bijectors.SinhArcsinh( 218 skewness=skewness, tailweight=tailweight, event_ndims=1) 219 if has_default_skewness: 220 f_noskew = f 221 else: 222 f_noskew = bijectors.SinhArcsinh( 223 skewness=skewness.dtype.as_numpy_dtype(0.), 224 tailweight=tailweight, event_ndims=0) 225 226 # Make the Affine bijector, Z --> loc + C * Z. 227 c = 2 * scale_diag_part / f_noskew.forward( 228 ops.convert_to_tensor(2, dtype=dtype)) 229 affine = bijectors.Affine( 230 shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) 231 232 bijector = bijectors.Chain([affine, f]) 233 234 super(VectorSinhArcsinhDiag, self).__init__( 235 distribution=distribution, 236 bijector=bijector, 237 batch_shape=batch_shape, 238 event_shape=event_shape, 239 validate_args=validate_args, 240 name=name) 241 self._parameters = parameters 242 self._loc = loc 243 self._scale = scale_linop 244 self._tailweight = tailweight 245 self._skewness = skewness 246 247 @property 248 def loc(self): 249 """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" 250 return self._loc 251 252 @property 253 def scale(self): 254 """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" 255 return self._scale 256 257 @property 258 def tailweight(self): 259 """Controls the tail decay. `tailweight > 1` means faster than Normal.""" 260 return self._tailweight 261 262 @property 263 def skewness(self): 264 """Controls the skewness. `Skewness > 0` means right skew.""" 265 return self._skewness 266