1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Multi-dimensional (Vector) SinhArcsinh transformation of a distribution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops import bijectors
22from tensorflow.contrib.distributions.python.ops import distribution_util
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops.distributions import normal
27from tensorflow.python.ops.distributions import transformed_distribution
28
29__all__ = [
30    "VectorSinhArcsinhDiag",
31]
32
33
34class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
35  """The (diagonal) SinhArcsinh transformation of a distribution on `R^k`.
36
37  This distribution models a random vector `Y = (Y1,...,Yk)`, making use of
38  a `SinhArcsinh` transformation (which has adjustable tailweight and skew),
39  a rescaling, and a shift.
40
41  The `SinhArcsinh` transformation of the Normal is described in great depth in
42  [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865).
43  Here we use a slightly different parameterization, in terms of `tailweight`
44  and `skewness`.  Additionally we allow for distributions other than Normal,
45  and control over `scale` as well as a "shift" parameter `loc`.
46
47  #### Mathematical Details
48
49  Given iid random vector `Z = (Z1,...,Zk)`, we define the VectorSinhArcsinhDiag
50  transformation of `Z`, `Y`, parameterized by
51  `(loc, scale, skewness, tailweight)`, via the relation (with `@` denoting
52  matrix multiplication):
53
54  ```
55  Y := loc + scale @ F(Z) * (2 / F_0(2))
56  F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
57  F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
58  ```
59
60  This distribution is similar to the location-scale transformation
61  `L(Z) := loc + scale @ Z` in the following ways:
62
63  * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then
64    `Y = L(Z)` exactly.
65  * `loc` is used in both to shift the result by a constant factor.
66  * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0`
67    `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`.
68    Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond
69    `loc + 2 * scale` are the same.
70
71  This distribution is different than `loc + scale @ Z` due to the
72  reshaping done by `F`:
73
74  * Positive (negative) `skewness` leads to positive (negative) skew.
75    * positive skew means, the mode of `F(Z)` is "tilted" to the right.
76    * positive skew means positive values of `F(Z)` become more likely, and
77      negative values become less likely.
78  * Larger (smaller) `tailweight` leads to fatter (thinner) tails.
79    * Fatter tails mean larger values of `|F(Z)|` become more likely.
80    * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`,
81      and a very steep drop-off in the tails.
82    * `tailweight > 1` leads to a distribution more peaked at the mode with
83      heavier tails.
84
85  To see the argument about the tails, note that for `|Z| >> 1` and
86  `|Z| >> (|skewness| * tailweight)**tailweight`, we have
87  `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`.
88
89  To see the argument regarding multiplying `scale` by `2 / F_0(2)`,
90
91  ```
92  P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2]
93                            = P[F(Z) <= F_0(2)]
94                            = P[Z <= 2]  (if F = F_0).
95  ```
96  """
97
98  def __init__(self,
99               loc=None,
100               scale_diag=None,
101               scale_identity_multiplier=None,
102               skewness=None,
103               tailweight=None,
104               distribution=None,
105               validate_args=False,
106               allow_nan_stats=True,
107               name="MultivariateNormalLinearOperator"):
108    """Construct VectorSinhArcsinhDiag distribution on `R^k`.
109
110    The arguments `scale_diag` and `scale_identity_multiplier` combine to
111    define the diagonal `scale` referred to in this class docstring:
112
113    ```none
114    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
115    ```
116
117    The `batch_shape` is the broadcast shape between `loc` and `scale`
118    arguments.
119
120    The `event_shape` is given by last dimension of the matrix implied by
121    `scale`. The last dimension of `loc` (if provided) must broadcast with this
122
123    Additional leading dimensions (if any) will index batches.
124
125    Args:
126      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
127        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
128        `b >= 0` and `k` is the event size.
129      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
130        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
131        and characterizes `b`-batches of `k x k` diagonal matrices added to
132        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
133        `None` then `scale` is the `Identity`.
134      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
135        a scale-identity-matrix added to `scale`. May have shape
136        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale
137        `k x k` identity matrices added to `scale`. When both
138        `scale_identity_multiplier` and `scale_diag` are `None` then `scale`
139        is the `Identity`.
140      skewness:  Skewness parameter.  floating-point `Tensor` with shape
141        broadcastable with `event_shape`.
142      tailweight:  Tailweight parameter.  floating-point `Tensor` with shape
143        broadcastable with `event_shape`.
144      distribution: `tf.Distribution`-like instance. Distribution from which `k`
145        iid samples are used as input to transformation `F`.  Default is
146        `tf.distributions.Normal(loc=0., scale=1.)`.
147        Must be a scalar-batch, scalar-event distribution.  Typically
148        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
149        a function of non-trainable parameters. WARNING: If you backprop through
150        a VectorSinhArcsinhDiag sample and `distribution` is not
151        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
152        the gradient will be incorrect!
153      validate_args: Python `bool`, default `False`. When `True` distribution
154        parameters are checked for validity despite possibly degrading runtime
155        performance. When `False` invalid inputs may silently render incorrect
156        outputs.
157      allow_nan_stats: Python `bool`, default `True`. When `True`,
158        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
159        indicate the result is undefined. When `False`, an exception is raised
160        if one or more of the statistic's batch members are undefined.
161      name: Python `str` name prefixed to Ops created by this class.
162
163    Raises:
164      ValueError: if at most `scale_identity_multiplier` is specified.
165    """
166    parameters = locals()
167
168    with ops.name_scope(
169        name,
170        values=[
171            loc, scale_diag, scale_identity_multiplier, skewness, tailweight
172        ]):
173      loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc
174      tailweight = 1. if tailweight is None else tailweight
175      has_default_skewness = skewness is None
176      skewness = 0. if skewness is None else skewness
177
178      # Recall, with Z a random variable,
179      #   Y := loc + C * F(Z),
180      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
181      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
182      #   C := 2 * scale / F_0(2)
183
184      # Construct shapes and 'scale' out of the scale_* and loc kwargs.
185      # scale_linop is only an intermediary to:
186      #  1. get shapes from looking at loc and the two scale args.
187      #  2. combine scale_diag with scale_identity_multiplier, which gives us
188      #     'scale', which in turn gives us 'C'.
189      scale_linop = distribution_util.make_diag_scale(
190          loc=loc,
191          scale_diag=scale_diag,
192          scale_identity_multiplier=scale_identity_multiplier,
193          validate_args=False,
194          assert_positive=False)
195      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
196          loc, scale_linop)
197      # scale_linop.diag_part() is efficient since it is a diag type linop.
198      scale_diag_part = scale_linop.diag_part()
199      dtype = scale_diag_part.dtype
200
201      if distribution is None:
202        distribution = normal.Normal(
203            loc=array_ops.zeros([], dtype=dtype),
204            scale=array_ops.ones([], dtype=dtype),
205            allow_nan_stats=allow_nan_stats)
206      else:
207        asserts = distribution_util.maybe_check_scalar_distribution(
208            distribution, dtype, validate_args)
209        if asserts:
210          scale_diag_part = control_flow_ops.with_dependencies(
211              asserts, scale_diag_part)
212
213      # Make the SAS bijector, 'F'.
214      skewness = ops.convert_to_tensor(skewness, dtype=dtype, name="skewness")
215      tailweight = ops.convert_to_tensor(
216          tailweight, dtype=dtype, name="tailweight")
217      f = bijectors.SinhArcsinh(
218          skewness=skewness, tailweight=tailweight, event_ndims=1)
219      if has_default_skewness:
220        f_noskew = f
221      else:
222        f_noskew = bijectors.SinhArcsinh(
223            skewness=skewness.dtype.as_numpy_dtype(0.),
224            tailweight=tailweight, event_ndims=0)
225
226      # Make the Affine bijector, Z --> loc + C * Z.
227      c = 2 * scale_diag_part / f_noskew.forward(
228          ops.convert_to_tensor(2, dtype=dtype))
229      affine = bijectors.Affine(
230          shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1)
231
232      bijector = bijectors.Chain([affine, f])
233
234      super(VectorSinhArcsinhDiag, self).__init__(
235          distribution=distribution,
236          bijector=bijector,
237          batch_shape=batch_shape,
238          event_shape=event_shape,
239          validate_args=validate_args,
240          name=name)
241    self._parameters = parameters
242    self._loc = loc
243    self._scale = scale_linop
244    self._tailweight = tailweight
245    self._skewness = skewness
246
247  @property
248  def loc(self):
249    """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2))."""
250    return self._loc
251
252  @property
253  def scale(self):
254    """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2))."""
255    return self._scale
256
257  @property
258  def tailweight(self):
259    """Controls the tail decay.  `tailweight > 1` means faster than Normal."""
260    return self._tailweight
261
262  @property
263  def skewness(self):
264    """Controls the skewness.  `Skewness > 0` means right skew."""
265    return self._skewness
266