sinh_arcsinh.py revision e7ab55b01f25bc1c9023dcc9510667ea480c6186
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SinhArcsinh transformation of a distribution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops import bijectors 22from tensorflow.contrib.distributions.python.ops import distribution_util 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops.distributions import normal 27from tensorflow.python.ops.distributions import transformed_distribution 28 29__all__ = [ 30 "SinhArcsinh", 31] 32 33 34class SinhArcsinh(transformed_distribution.TransformedDistribution): 35 """The SinhArcsinh transformation of a distribution on `(-inf, inf)`. 36 37 This distribution models a random variable, making use of 38 a `SinhArcsinh` transformation (which has adjustable tailweight and skew), 39 a rescaling, and a shift. 40 41 The `SinhArcsinh` transformation of the Normal is described in great depth in 42 [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865). 43 Here we use a slightly different parameterization, in terms of `tailweight` 44 and `skewness`. Additionally we allow for distributions other than Normal, 45 and control over `scale` as well as a "shift" parameter `loc`. 46 47 #### Mathematical Details 48 49 Given random variable `Z`, we define the SinhArcsinh 50 transformation of `Z`, `Y`, parameterized by 51 `(loc, scale, skewness, tailweight)`, via the relation: 52 53 ``` 54 Y := loc + scale * F(Z) * (2 / F_0(2)) 55 F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) 56 F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) 57 ``` 58 59 This distribution is similar to the location-scale transformation 60 `L(Z) := loc + scale * Z` in the following ways: 61 62 * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then 63 `Y = L(Z)` exactly. 64 * `loc` is used in both to shift the result by a constant factor. 65 * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` 66 `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. 67 Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond 68 `loc + 2 * scale` are the same. 69 70 This distribution is different than `loc + scale * Z` due to the 71 reshaping done by `F`: 72 73 * Positive (negative) `skewness` leads to positive (negative) skew. 74 * positive skew means, the mode of `F(Z)` is "tilted" to the right. 75 * positive skew means positive values of `F(Z)` become more likely, and 76 negative values become less likely. 77 * Larger (smaller) `tailweight` leads to fatter (thinner) tails. 78 * Fatter tails mean larger values of `|F(Z)|` become more likely. 79 * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`, 80 and a very steep drop-off in the tails. 81 * `tailweight > 1` leads to a distribution more peaked at the mode with 82 heavier tails. 83 84 To see the argument about the tails, note that for `|Z| >> 1` and 85 `|Z| >> (|skewness| * tailweight)**tailweight`, we have 86 `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. 87 88 To see the argument regarding multiplying `scale` by `2 / F_0(2)`, 89 90 ``` 91 P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] 92 = P[F(Z) <= F_0(2)] 93 = P[Z <= 2] (if F = F_0). 94 ``` 95 """ 96 97 def __init__(self, 98 loc, 99 scale, 100 skewness=None, 101 tailweight=None, 102 distribution=None, 103 validate_args=False, 104 allow_nan_stats=True, 105 name="SinhArcsinh"): 106 """Construct SinhArcsinh distribution on `(-inf, inf)`. 107 108 Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape 109 (indexing batch dimensions). They must all have the same `dtype`. 110 111 Args: 112 loc: Floating-point `Tensor`. 113 scale: `Tensor` of same `dtype` as `loc`. 114 skewness: Skewness parameter. Default is `0.0` (no skew). 115 tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) 116 distribution: `tf.Distribution`-like instance. Distribution that is 117 transformed to produce this distribution. 118 Default is `ds.Normal(0., 1.)`. 119 Must be a scalar-batch, scalar-event distribution. Typically 120 `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is 121 a function of non-trainable parameters. WARNING: If you backprop through 122 a `SinhArcsinh` sample and `distribution` is not 123 `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then 124 the gradient will be incorrect! 125 validate_args: Python `bool`, default `False`. When `True` distribution 126 parameters are checked for validity despite possibly degrading runtime 127 performance. When `False` invalid inputs may silently render incorrect 128 outputs. 129 allow_nan_stats: Python `bool`, default `True`. When `True`, 130 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 131 indicate the result is undefined. When `False`, an exception is raised 132 if one or more of the statistic's batch members are undefined. 133 name: Python `str` name prefixed to Ops created by this class. 134 """ 135 parameters = locals() 136 137 with ops.name_scope(name, values=[loc, scale, skewness, tailweight]): 138 loc = ops.convert_to_tensor(loc, name="loc") 139 dtype = loc.dtype 140 scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype) 141 tailweight = 1. if tailweight is None else tailweight 142 has_default_skewness = skewness is None 143 skewness = 0. if skewness is None else skewness 144 tailweight = ops.convert_to_tensor( 145 tailweight, name="tailweight", dtype=dtype) 146 skewness = ops.convert_to_tensor(skewness, name="skewness", dtype=dtype) 147 148 batch_shape = distribution_util.get_broadcast_shape( 149 loc, scale, tailweight, skewness) 150 151 # Recall, with Z a random variable, 152 # Y := loc + C * F(Z), 153 # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) 154 # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) 155 # C := 2 * scale / F_0(2) 156 if distribution is None: 157 distribution = normal.Normal( 158 loc=array_ops.zeros([], dtype=dtype), 159 scale=array_ops.ones([], dtype=dtype), 160 allow_nan_stats=allow_nan_stats) 161 else: 162 asserts = distribution_util.maybe_check_scalar_distribution( 163 distribution, dtype, validate_args) 164 if asserts: 165 loc = control_flow_ops.with_dependencies(asserts, loc) 166 167 # Make the SAS bijector, 'F'. 168 f = bijectors.SinhArcsinh( 169 skewness=skewness, tailweight=tailweight, event_ndims=0) 170 if has_default_skewness: 171 f_noskew = f 172 else: 173 f_noskew = bijectors.SinhArcsinh( 174 skewness=skewness.dtype.as_numpy_dtype(0.), 175 tailweight=tailweight, event_ndims=0) 176 177 # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) 178 c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) 179 affine = bijectors.Affine( 180 shift=loc, 181 scale_identity_multiplier=c, 182 validate_args=validate_args, 183 event_ndims=0) 184 185 bijector = bijectors.Chain([affine, f]) 186 187 super(SinhArcsinh, self).__init__( 188 distribution=distribution, 189 bijector=bijector, 190 batch_shape=batch_shape, 191 validate_args=validate_args, 192 name=name) 193 self._parameters = parameters 194 self._loc = loc 195 self._scale = scale 196 self._tailweight = tailweight 197 self._skewness = skewness 198 199 @property 200 def loc(self): 201 """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" 202 return self._loc 203 204 @property 205 def scale(self): 206 """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" 207 return self._scale 208 209 @property 210 def tailweight(self): 211 """Controls the tail decay. `tailweight > 1` means faster than Normal.""" 212 return self._tailweight 213 214 @property 215 def skewness(self): 216 """Controls the skewness. `Skewness > 0` means right skew.""" 217 return self._skewness 218