sinh_arcsinh.py revision e7ab55b01f25bc1c9023dcc9510667ea480c6186
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SinhArcsinh transformation of a distribution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops import bijectors
22from tensorflow.contrib.distributions.python.ops import distribution_util
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops.distributions import normal
27from tensorflow.python.ops.distributions import transformed_distribution
28
29__all__ = [
30    "SinhArcsinh",
31]
32
33
34class SinhArcsinh(transformed_distribution.TransformedDistribution):
35  """The SinhArcsinh transformation of a distribution on `(-inf, inf)`.
36
37  This distribution models a random variable, making use of
38  a `SinhArcsinh` transformation (which has adjustable tailweight and skew),
39  a rescaling, and a shift.
40
41  The `SinhArcsinh` transformation of the Normal is described in great depth in
42  [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865).
43  Here we use a slightly different parameterization, in terms of `tailweight`
44  and `skewness`.  Additionally we allow for distributions other than Normal,
45  and control over `scale` as well as a "shift" parameter `loc`.
46
47  #### Mathematical Details
48
49  Given random variable `Z`, we define the SinhArcsinh
50  transformation of `Z`, `Y`, parameterized by
51  `(loc, scale, skewness, tailweight)`, via the relation:
52
53  ```
54  Y := loc + scale * F(Z) * (2 / F_0(2))
55  F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
56  F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
57  ```
58
59  This distribution is similar to the location-scale transformation
60  `L(Z) := loc + scale * Z` in the following ways:
61
62  * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then
63    `Y = L(Z)` exactly.
64  * `loc` is used in both to shift the result by a constant factor.
65  * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0`
66    `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`.
67    Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond
68    `loc + 2 * scale` are the same.
69
70  This distribution is different than `loc + scale * Z` due to the
71  reshaping done by `F`:
72
73  * Positive (negative) `skewness` leads to positive (negative) skew.
74    * positive skew means, the mode of `F(Z)` is "tilted" to the right.
75    * positive skew means positive values of `F(Z)` become more likely, and
76      negative values become less likely.
77  * Larger (smaller) `tailweight` leads to fatter (thinner) tails.
78    * Fatter tails mean larger values of `|F(Z)|` become more likely.
79    * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`,
80      and a very steep drop-off in the tails.
81    * `tailweight > 1` leads to a distribution more peaked at the mode with
82      heavier tails.
83
84  To see the argument about the tails, note that for `|Z| >> 1` and
85  `|Z| >> (|skewness| * tailweight)**tailweight`, we have
86  `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`.
87
88  To see the argument regarding multiplying `scale` by `2 / F_0(2)`,
89
90  ```
91  P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2]
92                            = P[F(Z) <= F_0(2)]
93                            = P[Z <= 2]  (if F = F_0).
94  ```
95  """
96
97  def __init__(self,
98               loc,
99               scale,
100               skewness=None,
101               tailweight=None,
102               distribution=None,
103               validate_args=False,
104               allow_nan_stats=True,
105               name="SinhArcsinh"):
106    """Construct SinhArcsinh distribution on `(-inf, inf)`.
107
108    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
109    (indexing batch dimensions).  They must all have the same `dtype`.
110
111    Args:
112      loc: Floating-point `Tensor`.
113      scale:  `Tensor` of same `dtype` as `loc`.
114      skewness:  Skewness parameter.  Default is `0.0` (no skew).
115      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
116      distribution: `tf.Distribution`-like instance. Distribution that is
117        transformed to produce this distribution.
118        Default is `ds.Normal(0., 1.)`.
119        Must be a scalar-batch, scalar-event distribution.  Typically
120        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
121        a function of non-trainable parameters. WARNING: If you backprop through
122        a `SinhArcsinh` sample and `distribution` is not
123        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
124        the gradient will be incorrect!
125      validate_args: Python `bool`, default `False`. When `True` distribution
126        parameters are checked for validity despite possibly degrading runtime
127        performance. When `False` invalid inputs may silently render incorrect
128        outputs.
129      allow_nan_stats: Python `bool`, default `True`. When `True`,
130        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
131        indicate the result is undefined. When `False`, an exception is raised
132        if one or more of the statistic's batch members are undefined.
133      name: Python `str` name prefixed to Ops created by this class.
134    """
135    parameters = locals()
136
137    with ops.name_scope(name, values=[loc, scale, skewness, tailweight]):
138      loc = ops.convert_to_tensor(loc, name="loc")
139      dtype = loc.dtype
140      scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype)
141      tailweight = 1. if tailweight is None else tailweight
142      has_default_skewness = skewness is None
143      skewness = 0. if skewness is None else skewness
144      tailweight = ops.convert_to_tensor(
145          tailweight, name="tailweight", dtype=dtype)
146      skewness = ops.convert_to_tensor(skewness, name="skewness", dtype=dtype)
147
148      batch_shape = distribution_util.get_broadcast_shape(
149          loc, scale, tailweight, skewness)
150
151      # Recall, with Z a random variable,
152      #   Y := loc + C * F(Z),
153      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
154      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
155      #   C := 2 * scale / F_0(2)
156      if distribution is None:
157        distribution = normal.Normal(
158            loc=array_ops.zeros([], dtype=dtype),
159            scale=array_ops.ones([], dtype=dtype),
160            allow_nan_stats=allow_nan_stats)
161      else:
162        asserts = distribution_util.maybe_check_scalar_distribution(
163            distribution, dtype, validate_args)
164        if asserts:
165          loc = control_flow_ops.with_dependencies(asserts, loc)
166
167      # Make the SAS bijector, 'F'.
168      f = bijectors.SinhArcsinh(
169          skewness=skewness, tailweight=tailweight, event_ndims=0)
170      if has_default_skewness:
171        f_noskew = f
172      else:
173        f_noskew = bijectors.SinhArcsinh(
174            skewness=skewness.dtype.as_numpy_dtype(0.),
175            tailweight=tailweight, event_ndims=0)
176
177      # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2))
178      c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype))
179      affine = bijectors.Affine(
180          shift=loc,
181          scale_identity_multiplier=c,
182          validate_args=validate_args,
183          event_ndims=0)
184
185      bijector = bijectors.Chain([affine, f])
186
187      super(SinhArcsinh, self).__init__(
188          distribution=distribution,
189          bijector=bijector,
190          batch_shape=batch_shape,
191          validate_args=validate_args,
192          name=name)
193    self._parameters = parameters
194    self._loc = loc
195    self._scale = scale
196    self._tailweight = tailweight
197    self._skewness = skewness
198
199  @property
200  def loc(self):
201    """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2))."""
202    return self._loc
203
204  @property
205  def scale(self):
206    """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2))."""
207    return self._scale
208
209  @property
210  def tailweight(self):
211    """Controls the tail decay.  `tailweight > 1` means faster than Normal."""
212    return self._tailweight
213
214  @property
215  def skewness(self):
216    """Controls the skewness.  `Skewness > 0` means right skew."""
217    return self._skewness
218