1fe8406149feec453250905965a14285465cd2063Shanqing Cai# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2fe8406149feec453250905965a14285465cd2063Shanqing Cai# 3fe8406149feec453250905965a14285465cd2063Shanqing Cai# Licensed under the Apache License, Version 2.0 (the "License"); 4fe8406149feec453250905965a14285465cd2063Shanqing Cai# you may not use this file except in compliance with the License. 5fe8406149feec453250905965a14285465cd2063Shanqing Cai# You may obtain a copy of the License at 6fe8406149feec453250905965a14285465cd2063Shanqing Cai# 7fe8406149feec453250905965a14285465cd2063Shanqing Cai# http://www.apache.org/licenses/LICENSE-2.0 8fe8406149feec453250905965a14285465cd2063Shanqing Cai# 9fe8406149feec453250905965a14285465cd2063Shanqing Cai# Unless required by applicable law or agreed to in writing, software 10fe8406149feec453250905965a14285465cd2063Shanqing Cai# distributed under the License is distributed on an "AS IS" BASIS, 11fe8406149feec453250905965a14285465cd2063Shanqing Cai# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12fe8406149feec453250905965a14285465cd2063Shanqing Cai# See the License for the specific language governing permissions and 13fe8406149feec453250905965a14285465cd2063Shanqing Cai# limitations under the License. 14fe8406149feec453250905965a14285465cd2063Shanqing Cai# ============================================================================== 15fe8406149feec453250905965a14285465cd2063Shanqing Cai"""The Half Normal distribution class.""" 16fe8406149feec453250905965a14285465cd2063Shanqing Cai 17fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import absolute_import 18fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import division 19fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import print_function 20fe8406149feec453250905965a14285465cd2063Shanqing Cai 21fe8406149feec453250905965a14285465cd2063Shanqing Caiimport numpy as np 22fe8406149feec453250905965a14285465cd2063Shanqing Cai 23fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import constant_op 24fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import dtypes 25fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import ops 26fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import tensor_shape 27fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import array_ops 28fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import check_ops 29fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import math_ops 30fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import nn 31fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import random_ops 32fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops.distributions import distribution 33fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops.distributions import special_math 34fe8406149feec453250905965a14285465cd2063Shanqing Cai 35fe8406149feec453250905965a14285465cd2063Shanqing Cai 36fe8406149feec453250905965a14285465cd2063Shanqing Cai__all__ = [ 37fe8406149feec453250905965a14285465cd2063Shanqing Cai "HalfNormal", 38fe8406149feec453250905965a14285465cd2063Shanqing Cai] 39fe8406149feec453250905965a14285465cd2063Shanqing Cai 40fe8406149feec453250905965a14285465cd2063Shanqing Cai 41fe8406149feec453250905965a14285465cd2063Shanqing Caiclass HalfNormal(distribution.Distribution): 42fe8406149feec453250905965a14285465cd2063Shanqing Cai """The Half Normal distribution with scale `scale`. 43fe8406149feec453250905965a14285465cd2063Shanqing Cai 44fe8406149feec453250905965a14285465cd2063Shanqing Cai #### Mathematical details 45fe8406149feec453250905965a14285465cd2063Shanqing Cai 46fe8406149feec453250905965a14285465cd2063Shanqing Cai The half normal is a transformation of a centered normal distribution. 47fe8406149feec453250905965a14285465cd2063Shanqing Cai If some random variable `X` has normal distribution, 48fe8406149feec453250905965a14285465cd2063Shanqing Cai ```none 49fe8406149feec453250905965a14285465cd2063Shanqing Cai X ~ Normal(0.0, scale) 50fe8406149feec453250905965a14285465cd2063Shanqing Cai Y = |X| 51fe8406149feec453250905965a14285465cd2063Shanqing Cai ``` 52fe8406149feec453250905965a14285465cd2063Shanqing Cai Then `Y` will have half normal distribution. The probability density 53fe8406149feec453250905965a14285465cd2063Shanqing Cai function (pdf) is: 54fe8406149feec453250905965a14285465cd2063Shanqing Cai 55fe8406149feec453250905965a14285465cd2063Shanqing Cai ```none 56fe8406149feec453250905965a14285465cd2063Shanqing Cai pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * 57fe8406149feec453250905965a14285465cd2063Shanqing Cai exp(- 1/2 * (x / scale) ** 2) 58fe8406149feec453250905965a14285465cd2063Shanqing Cai ) 59fe8406149feec453250905965a14285465cd2063Shanqing Cai ``` 60fe8406149feec453250905965a14285465cd2063Shanqing Cai Where `scale = sigma` is the standard deviation of the underlying normal 61fe8406149feec453250905965a14285465cd2063Shanqing Cai distribution. 62fe8406149feec453250905965a14285465cd2063Shanqing Cai 63fe8406149feec453250905965a14285465cd2063Shanqing Cai #### Examples 64fe8406149feec453250905965a14285465cd2063Shanqing Cai 65fe8406149feec453250905965a14285465cd2063Shanqing Cai Examples of initialization of one or a batch of distributions. 66fe8406149feec453250905965a14285465cd2063Shanqing Cai 67fe8406149feec453250905965a14285465cd2063Shanqing Cai ```python 68fe8406149feec453250905965a14285465cd2063Shanqing Cai # Define a single scalar HalfNormal distribution. 69fe8406149feec453250905965a14285465cd2063Shanqing Cai dist = tf.contrib.distributions.HalfNormal(scale=3.0) 70fe8406149feec453250905965a14285465cd2063Shanqing Cai 71fe8406149feec453250905965a14285465cd2063Shanqing Cai # Evaluate the cdf at 1, returning a scalar. 72fe8406149feec453250905965a14285465cd2063Shanqing Cai dist.cdf(1.) 73fe8406149feec453250905965a14285465cd2063Shanqing Cai 74fe8406149feec453250905965a14285465cd2063Shanqing Cai # Define a batch of two scalar valued HalfNormals. 75fe8406149feec453250905965a14285465cd2063Shanqing Cai # The first has scale 11.0, the second 22.0 76fe8406149feec453250905965a14285465cd2063Shanqing Cai dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) 77fe8406149feec453250905965a14285465cd2063Shanqing Cai 78fe8406149feec453250905965a14285465cd2063Shanqing Cai # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, 79fe8406149feec453250905965a14285465cd2063Shanqing Cai # returning a length two tensor. 80fe8406149feec453250905965a14285465cd2063Shanqing Cai dist.prob([1.0, 1.5]) 81fe8406149feec453250905965a14285465cd2063Shanqing Cai 82fe8406149feec453250905965a14285465cd2063Shanqing Cai # Get 3 samples, returning a 3 x 2 tensor. 83fe8406149feec453250905965a14285465cd2063Shanqing Cai dist.sample([3]) 84fe8406149feec453250905965a14285465cd2063Shanqing Cai ``` 85fe8406149feec453250905965a14285465cd2063Shanqing Cai 86fe8406149feec453250905965a14285465cd2063Shanqing Cai """ 87fe8406149feec453250905965a14285465cd2063Shanqing Cai 88fe8406149feec453250905965a14285465cd2063Shanqing Cai def __init__(self, 89fe8406149feec453250905965a14285465cd2063Shanqing Cai scale, 90fe8406149feec453250905965a14285465cd2063Shanqing Cai validate_args=False, 91fe8406149feec453250905965a14285465cd2063Shanqing Cai allow_nan_stats=True, 92fe8406149feec453250905965a14285465cd2063Shanqing Cai name="HalfNormal"): 93fe8406149feec453250905965a14285465cd2063Shanqing Cai """Construct HalfNormals with scale `scale`. 94fe8406149feec453250905965a14285465cd2063Shanqing Cai 95fe8406149feec453250905965a14285465cd2063Shanqing Cai Args: 96fe8406149feec453250905965a14285465cd2063Shanqing Cai scale: Floating point tensor; the scales of the distribution(s). 97fe8406149feec453250905965a14285465cd2063Shanqing Cai Must contain only positive values. 98fe8406149feec453250905965a14285465cd2063Shanqing Cai validate_args: Python `bool`, default `False`. When `True` distribution 99fe8406149feec453250905965a14285465cd2063Shanqing Cai parameters are checked for validity despite possibly degrading runtime 100fe8406149feec453250905965a14285465cd2063Shanqing Cai performance. When `False` invalid inputs may silently render incorrect 101fe8406149feec453250905965a14285465cd2063Shanqing Cai outputs. 102fe8406149feec453250905965a14285465cd2063Shanqing Cai allow_nan_stats: Python `bool`, default `True`. When `True`, 103fe8406149feec453250905965a14285465cd2063Shanqing Cai statistics (e.g., mean, mode, variance) use the value "`NaN`" to 104fe8406149feec453250905965a14285465cd2063Shanqing Cai indicate the result is undefined. When `False`, an exception is raised 105fe8406149feec453250905965a14285465cd2063Shanqing Cai if one or more of the statistic's batch members are undefined. 106fe8406149feec453250905965a14285465cd2063Shanqing Cai name: Python `str` name prefixed to Ops created by this class. 107fe8406149feec453250905965a14285465cd2063Shanqing Cai """ 108fe8406149feec453250905965a14285465cd2063Shanqing Cai parameters = locals() 109fe8406149feec453250905965a14285465cd2063Shanqing Cai with ops.name_scope(name, values=[scale]): 110fe8406149feec453250905965a14285465cd2063Shanqing Cai with ops.control_dependencies([check_ops.assert_positive(scale)] if 111fe8406149feec453250905965a14285465cd2063Shanqing Cai validate_args else []): 112fe8406149feec453250905965a14285465cd2063Shanqing Cai self._scale = array_ops.identity(scale, name="scale") 113fe8406149feec453250905965a14285465cd2063Shanqing Cai super(HalfNormal, self).__init__( 114fe8406149feec453250905965a14285465cd2063Shanqing Cai dtype=self._scale.dtype, 115fe8406149feec453250905965a14285465cd2063Shanqing Cai reparameterization_type=distribution.FULLY_REPARAMETERIZED, 116fe8406149feec453250905965a14285465cd2063Shanqing Cai validate_args=validate_args, 117fe8406149feec453250905965a14285465cd2063Shanqing Cai allow_nan_stats=allow_nan_stats, 118fe8406149feec453250905965a14285465cd2063Shanqing Cai parameters=parameters, 119fe8406149feec453250905965a14285465cd2063Shanqing Cai graph_parents=[self._scale], 120fe8406149feec453250905965a14285465cd2063Shanqing Cai name=name) 121fe8406149feec453250905965a14285465cd2063Shanqing Cai 122fe8406149feec453250905965a14285465cd2063Shanqing Cai @staticmethod 123fe8406149feec453250905965a14285465cd2063Shanqing Cai def _param_shapes(sample_shape): 124fe8406149feec453250905965a14285465cd2063Shanqing Cai return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} 125fe8406149feec453250905965a14285465cd2063Shanqing Cai 126fe8406149feec453250905965a14285465cd2063Shanqing Cai @property 127fe8406149feec453250905965a14285465cd2063Shanqing Cai def scale(self): 128fe8406149feec453250905965a14285465cd2063Shanqing Cai """Distribution parameter for the scale.""" 129fe8406149feec453250905965a14285465cd2063Shanqing Cai return self._scale 130fe8406149feec453250905965a14285465cd2063Shanqing Cai 131fe8406149feec453250905965a14285465cd2063Shanqing Cai def _batch_shape_tensor(self): 132fe8406149feec453250905965a14285465cd2063Shanqing Cai return array_ops.shape(self.scale) 133fe8406149feec453250905965a14285465cd2063Shanqing Cai 134fe8406149feec453250905965a14285465cd2063Shanqing Cai def _batch_shape(self): 135fe8406149feec453250905965a14285465cd2063Shanqing Cai return self.scale.shape 136fe8406149feec453250905965a14285465cd2063Shanqing Cai 137fe8406149feec453250905965a14285465cd2063Shanqing Cai def _event_shape_tensor(self): 138fe8406149feec453250905965a14285465cd2063Shanqing Cai return constant_op.constant([], dtype=dtypes.int32) 139fe8406149feec453250905965a14285465cd2063Shanqing Cai 140fe8406149feec453250905965a14285465cd2063Shanqing Cai def _event_shape(self): 141fe8406149feec453250905965a14285465cd2063Shanqing Cai return tensor_shape.scalar() 142fe8406149feec453250905965a14285465cd2063Shanqing Cai 143fe8406149feec453250905965a14285465cd2063Shanqing Cai def _sample_n(self, n, seed=None): 144fe8406149feec453250905965a14285465cd2063Shanqing Cai shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) 145fe8406149feec453250905965a14285465cd2063Shanqing Cai sampled = random_ops.random_normal( 146fe8406149feec453250905965a14285465cd2063Shanqing Cai shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) 147fe8406149feec453250905965a14285465cd2063Shanqing Cai return math_ops.abs(sampled * self.scale) 148fe8406149feec453250905965a14285465cd2063Shanqing Cai 149fe8406149feec453250905965a14285465cd2063Shanqing Cai def _prob(self, x): 150fe8406149feec453250905965a14285465cd2063Shanqing Cai coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) 151fe8406149feec453250905965a14285465cd2063Shanqing Cai pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) 152fe8406149feec453250905965a14285465cd2063Shanqing Cai return pdf * math_ops.cast(x >= 0, self.dtype) 153fe8406149feec453250905965a14285465cd2063Shanqing Cai 154fe8406149feec453250905965a14285465cd2063Shanqing Cai def _cdf(self, x): 155fe8406149feec453250905965a14285465cd2063Shanqing Cai truncated_x = nn.relu(x) 156fe8406149feec453250905965a14285465cd2063Shanqing Cai return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) 157fe8406149feec453250905965a14285465cd2063Shanqing Cai 158fe8406149feec453250905965a14285465cd2063Shanqing Cai def _entropy(self): 159fe8406149feec453250905965a14285465cd2063Shanqing Cai return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 160fe8406149feec453250905965a14285465cd2063Shanqing Cai 161fe8406149feec453250905965a14285465cd2063Shanqing Cai def _mean(self): 162fe8406149feec453250905965a14285465cd2063Shanqing Cai return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) 163fe8406149feec453250905965a14285465cd2063Shanqing Cai 164fe8406149feec453250905965a14285465cd2063Shanqing Cai def _quantile(self, p): 165fe8406149feec453250905965a14285465cd2063Shanqing Cai return np.sqrt(2.0) * self.scale * special_math.erfinv(p) 166fe8406149feec453250905965a14285465cd2063Shanqing Cai 167fe8406149feec453250905965a14285465cd2063Shanqing Cai def _mode(self): 168fe8406149feec453250905965a14285465cd2063Shanqing Cai return array_ops.zeros(self.batch_shape_tensor()) 169fe8406149feec453250905965a14285465cd2063Shanqing Cai 170fe8406149feec453250905965a14285465cd2063Shanqing Cai def _variance(self): 171fe8406149feec453250905965a14285465cd2063Shanqing Cai return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) 172