1fe8406149feec453250905965a14285465cd2063Shanqing Cai# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2fe8406149feec453250905965a14285465cd2063Shanqing Cai#
3fe8406149feec453250905965a14285465cd2063Shanqing Cai# Licensed under the Apache License, Version 2.0 (the "License");
4fe8406149feec453250905965a14285465cd2063Shanqing Cai# you may not use this file except in compliance with the License.
5fe8406149feec453250905965a14285465cd2063Shanqing Cai# You may obtain a copy of the License at
6fe8406149feec453250905965a14285465cd2063Shanqing Cai#
7fe8406149feec453250905965a14285465cd2063Shanqing Cai#     http://www.apache.org/licenses/LICENSE-2.0
8fe8406149feec453250905965a14285465cd2063Shanqing Cai#
9fe8406149feec453250905965a14285465cd2063Shanqing Cai# Unless required by applicable law or agreed to in writing, software
10fe8406149feec453250905965a14285465cd2063Shanqing Cai# distributed under the License is distributed on an "AS IS" BASIS,
11fe8406149feec453250905965a14285465cd2063Shanqing Cai# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12fe8406149feec453250905965a14285465cd2063Shanqing Cai# See the License for the specific language governing permissions and
13fe8406149feec453250905965a14285465cd2063Shanqing Cai# limitations under the License.
14fe8406149feec453250905965a14285465cd2063Shanqing Cai# ==============================================================================
15fe8406149feec453250905965a14285465cd2063Shanqing Cai"""The Half Normal distribution class."""
16fe8406149feec453250905965a14285465cd2063Shanqing Cai
17fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import absolute_import
18fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import division
19fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import print_function
20fe8406149feec453250905965a14285465cd2063Shanqing Cai
21fe8406149feec453250905965a14285465cd2063Shanqing Caiimport numpy as np
22fe8406149feec453250905965a14285465cd2063Shanqing Cai
23fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import constant_op
24fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import dtypes
25fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import ops
26fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import tensor_shape
27fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import array_ops
28fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import check_ops
29fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import math_ops
30fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import nn
31fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import random_ops
32fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops.distributions import distribution
33fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops.distributions import special_math
34fe8406149feec453250905965a14285465cd2063Shanqing Cai
35fe8406149feec453250905965a14285465cd2063Shanqing Cai
36fe8406149feec453250905965a14285465cd2063Shanqing Cai__all__ = [
37fe8406149feec453250905965a14285465cd2063Shanqing Cai    "HalfNormal",
38fe8406149feec453250905965a14285465cd2063Shanqing Cai]
39fe8406149feec453250905965a14285465cd2063Shanqing Cai
40fe8406149feec453250905965a14285465cd2063Shanqing Cai
41fe8406149feec453250905965a14285465cd2063Shanqing Caiclass HalfNormal(distribution.Distribution):
42fe8406149feec453250905965a14285465cd2063Shanqing Cai  """The Half Normal distribution with scale `scale`.
43fe8406149feec453250905965a14285465cd2063Shanqing Cai
44fe8406149feec453250905965a14285465cd2063Shanqing Cai  #### Mathematical details
45fe8406149feec453250905965a14285465cd2063Shanqing Cai
46fe8406149feec453250905965a14285465cd2063Shanqing Cai  The half normal is a transformation of a centered normal distribution.
47fe8406149feec453250905965a14285465cd2063Shanqing Cai  If some random variable `X` has normal distribution,
48fe8406149feec453250905965a14285465cd2063Shanqing Cai  ```none
49fe8406149feec453250905965a14285465cd2063Shanqing Cai  X ~ Normal(0.0, scale)
50fe8406149feec453250905965a14285465cd2063Shanqing Cai  Y = |X|
51fe8406149feec453250905965a14285465cd2063Shanqing Cai  ```
52fe8406149feec453250905965a14285465cd2063Shanqing Cai  Then `Y` will have half normal distribution. The probability density
53fe8406149feec453250905965a14285465cd2063Shanqing Cai  function (pdf) is:
54fe8406149feec453250905965a14285465cd2063Shanqing Cai
55fe8406149feec453250905965a14285465cd2063Shanqing Cai  ```none
56fe8406149feec453250905965a14285465cd2063Shanqing Cai  pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) *
57fe8406149feec453250905965a14285465cd2063Shanqing Cai    exp(- 1/2 * (x / scale) ** 2)
58fe8406149feec453250905965a14285465cd2063Shanqing Cai  )
59fe8406149feec453250905965a14285465cd2063Shanqing Cai  ```
60fe8406149feec453250905965a14285465cd2063Shanqing Cai  Where `scale = sigma` is the standard deviation of the underlying normal
61fe8406149feec453250905965a14285465cd2063Shanqing Cai  distribution.
62fe8406149feec453250905965a14285465cd2063Shanqing Cai
63fe8406149feec453250905965a14285465cd2063Shanqing Cai  #### Examples
64fe8406149feec453250905965a14285465cd2063Shanqing Cai
65fe8406149feec453250905965a14285465cd2063Shanqing Cai  Examples of initialization of one or a batch of distributions.
66fe8406149feec453250905965a14285465cd2063Shanqing Cai
67fe8406149feec453250905965a14285465cd2063Shanqing Cai  ```python
68fe8406149feec453250905965a14285465cd2063Shanqing Cai  # Define a single scalar HalfNormal distribution.
69fe8406149feec453250905965a14285465cd2063Shanqing Cai  dist = tf.contrib.distributions.HalfNormal(scale=3.0)
70fe8406149feec453250905965a14285465cd2063Shanqing Cai
71fe8406149feec453250905965a14285465cd2063Shanqing Cai  # Evaluate the cdf at 1, returning a scalar.
72fe8406149feec453250905965a14285465cd2063Shanqing Cai  dist.cdf(1.)
73fe8406149feec453250905965a14285465cd2063Shanqing Cai
74fe8406149feec453250905965a14285465cd2063Shanqing Cai  # Define a batch of two scalar valued HalfNormals.
75fe8406149feec453250905965a14285465cd2063Shanqing Cai  # The first has scale 11.0, the second 22.0
76fe8406149feec453250905965a14285465cd2063Shanqing Cai  dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
77fe8406149feec453250905965a14285465cd2063Shanqing Cai
78fe8406149feec453250905965a14285465cd2063Shanqing Cai  # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
79fe8406149feec453250905965a14285465cd2063Shanqing Cai  # returning a length two tensor.
80fe8406149feec453250905965a14285465cd2063Shanqing Cai  dist.prob([1.0, 1.5])
81fe8406149feec453250905965a14285465cd2063Shanqing Cai
82fe8406149feec453250905965a14285465cd2063Shanqing Cai  # Get 3 samples, returning a 3 x 2 tensor.
83fe8406149feec453250905965a14285465cd2063Shanqing Cai  dist.sample([3])
84fe8406149feec453250905965a14285465cd2063Shanqing Cai  ```
85fe8406149feec453250905965a14285465cd2063Shanqing Cai
86fe8406149feec453250905965a14285465cd2063Shanqing Cai  """
87fe8406149feec453250905965a14285465cd2063Shanqing Cai
88fe8406149feec453250905965a14285465cd2063Shanqing Cai  def __init__(self,
89fe8406149feec453250905965a14285465cd2063Shanqing Cai               scale,
90fe8406149feec453250905965a14285465cd2063Shanqing Cai               validate_args=False,
91fe8406149feec453250905965a14285465cd2063Shanqing Cai               allow_nan_stats=True,
92fe8406149feec453250905965a14285465cd2063Shanqing Cai               name="HalfNormal"):
93fe8406149feec453250905965a14285465cd2063Shanqing Cai    """Construct HalfNormals with scale `scale`.
94fe8406149feec453250905965a14285465cd2063Shanqing Cai
95fe8406149feec453250905965a14285465cd2063Shanqing Cai    Args:
96fe8406149feec453250905965a14285465cd2063Shanqing Cai      scale: Floating point tensor; the scales of the distribution(s).
97fe8406149feec453250905965a14285465cd2063Shanqing Cai        Must contain only positive values.
98fe8406149feec453250905965a14285465cd2063Shanqing Cai      validate_args: Python `bool`, default `False`. When `True` distribution
99fe8406149feec453250905965a14285465cd2063Shanqing Cai        parameters are checked for validity despite possibly degrading runtime
100fe8406149feec453250905965a14285465cd2063Shanqing Cai        performance. When `False` invalid inputs may silently render incorrect
101fe8406149feec453250905965a14285465cd2063Shanqing Cai        outputs.
102fe8406149feec453250905965a14285465cd2063Shanqing Cai      allow_nan_stats: Python `bool`, default `True`. When `True`,
103fe8406149feec453250905965a14285465cd2063Shanqing Cai        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
104fe8406149feec453250905965a14285465cd2063Shanqing Cai        indicate the result is undefined. When `False`, an exception is raised
105fe8406149feec453250905965a14285465cd2063Shanqing Cai        if one or more of the statistic's batch members are undefined.
106fe8406149feec453250905965a14285465cd2063Shanqing Cai      name: Python `str` name prefixed to Ops created by this class.
107fe8406149feec453250905965a14285465cd2063Shanqing Cai    """
108fe8406149feec453250905965a14285465cd2063Shanqing Cai    parameters = locals()
109fe8406149feec453250905965a14285465cd2063Shanqing Cai    with ops.name_scope(name, values=[scale]):
110fe8406149feec453250905965a14285465cd2063Shanqing Cai      with ops.control_dependencies([check_ops.assert_positive(scale)] if
111fe8406149feec453250905965a14285465cd2063Shanqing Cai                                    validate_args else []):
112fe8406149feec453250905965a14285465cd2063Shanqing Cai        self._scale = array_ops.identity(scale, name="scale")
113fe8406149feec453250905965a14285465cd2063Shanqing Cai    super(HalfNormal, self).__init__(
114fe8406149feec453250905965a14285465cd2063Shanqing Cai        dtype=self._scale.dtype,
115fe8406149feec453250905965a14285465cd2063Shanqing Cai        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
116fe8406149feec453250905965a14285465cd2063Shanqing Cai        validate_args=validate_args,
117fe8406149feec453250905965a14285465cd2063Shanqing Cai        allow_nan_stats=allow_nan_stats,
118fe8406149feec453250905965a14285465cd2063Shanqing Cai        parameters=parameters,
119fe8406149feec453250905965a14285465cd2063Shanqing Cai        graph_parents=[self._scale],
120fe8406149feec453250905965a14285465cd2063Shanqing Cai        name=name)
121fe8406149feec453250905965a14285465cd2063Shanqing Cai
122fe8406149feec453250905965a14285465cd2063Shanqing Cai  @staticmethod
123fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _param_shapes(sample_shape):
124fe8406149feec453250905965a14285465cd2063Shanqing Cai    return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
125fe8406149feec453250905965a14285465cd2063Shanqing Cai
126fe8406149feec453250905965a14285465cd2063Shanqing Cai  @property
127fe8406149feec453250905965a14285465cd2063Shanqing Cai  def scale(self):
128fe8406149feec453250905965a14285465cd2063Shanqing Cai    """Distribution parameter for the scale."""
129fe8406149feec453250905965a14285465cd2063Shanqing Cai    return self._scale
130fe8406149feec453250905965a14285465cd2063Shanqing Cai
131fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _batch_shape_tensor(self):
132fe8406149feec453250905965a14285465cd2063Shanqing Cai    return array_ops.shape(self.scale)
133fe8406149feec453250905965a14285465cd2063Shanqing Cai
134fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _batch_shape(self):
135fe8406149feec453250905965a14285465cd2063Shanqing Cai    return self.scale.shape
136fe8406149feec453250905965a14285465cd2063Shanqing Cai
137fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _event_shape_tensor(self):
138fe8406149feec453250905965a14285465cd2063Shanqing Cai    return constant_op.constant([], dtype=dtypes.int32)
139fe8406149feec453250905965a14285465cd2063Shanqing Cai
140fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _event_shape(self):
141fe8406149feec453250905965a14285465cd2063Shanqing Cai    return tensor_shape.scalar()
142fe8406149feec453250905965a14285465cd2063Shanqing Cai
143fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _sample_n(self, n, seed=None):
144fe8406149feec453250905965a14285465cd2063Shanqing Cai    shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
145fe8406149feec453250905965a14285465cd2063Shanqing Cai    sampled = random_ops.random_normal(
146fe8406149feec453250905965a14285465cd2063Shanqing Cai        shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)
147fe8406149feec453250905965a14285465cd2063Shanqing Cai    return math_ops.abs(sampled * self.scale)
148fe8406149feec453250905965a14285465cd2063Shanqing Cai
149fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _prob(self, x):
150fe8406149feec453250905965a14285465cd2063Shanqing Cai    coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi)
151fe8406149feec453250905965a14285465cd2063Shanqing Cai    pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2)
152fe8406149feec453250905965a14285465cd2063Shanqing Cai    return pdf * math_ops.cast(x >= 0, self.dtype)
153fe8406149feec453250905965a14285465cd2063Shanqing Cai
154fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _cdf(self, x):
155fe8406149feec453250905965a14285465cd2063Shanqing Cai    truncated_x = nn.relu(x)
156fe8406149feec453250905965a14285465cd2063Shanqing Cai    return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0))
157fe8406149feec453250905965a14285465cd2063Shanqing Cai
158fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _entropy(self):
159fe8406149feec453250905965a14285465cd2063Shanqing Cai    return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5
160fe8406149feec453250905965a14285465cd2063Shanqing Cai
161fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _mean(self):
162fe8406149feec453250905965a14285465cd2063Shanqing Cai    return self.scale * np.sqrt(2.0) / np.sqrt(np.pi)
163fe8406149feec453250905965a14285465cd2063Shanqing Cai
164fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _quantile(self, p):
165fe8406149feec453250905965a14285465cd2063Shanqing Cai    return np.sqrt(2.0) * self.scale * special_math.erfinv(p)
166fe8406149feec453250905965a14285465cd2063Shanqing Cai
167fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _mode(self):
168fe8406149feec453250905965a14285465cd2063Shanqing Cai    return array_ops.zeros(self.batch_shape_tensor())
169fe8406149feec453250905965a14285465cd2063Shanqing Cai
170fe8406149feec453250905965a14285465cd2063Shanqing Cai  def _variance(self):
171fe8406149feec453250905965a14285465cd2063Shanqing Cai    return self.scale ** 2.0 * (1.0 - 2.0 / np.pi)
172