1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The Deterministic distribution class.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.distributions import distribution 35 36__all__ = [ 37 "Deterministic", 38 "VectorDeterministic", 39] 40 41 42@six.add_metaclass(abc.ABCMeta) 43class _BaseDeterministic(distribution.Distribution): 44 """Base class for Deterministic distributions.""" 45 46 def __init__(self, 47 loc, 48 atol=None, 49 rtol=None, 50 is_vector=False, 51 validate_args=False, 52 allow_nan_stats=True, 53 name="_BaseDeterministic"): 54 """Initialize a batch of `_BaseDeterministic` distributions. 55 56 The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` 57 computations, e.g. due to floating-point error. 58 59 ``` 60 pmf(x; loc) 61 = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), 62 = 0, otherwise. 63 ``` 64 65 Args: 66 loc: Numeric `Tensor`. The point (or batch of points) on which this 67 distribution is supported. 68 atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable 69 shape. The absolute tolerance for comparing closeness to `loc`. 70 Default is `0`. 71 rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable 72 shape. The relative tolerance for comparing closeness to `loc`. 73 Default is `0`. 74 is_vector: Python `bool`. If `True`, this is for `VectorDeterministic`, 75 else `Deterministic`. 76 validate_args: Python `bool`, default `False`. When `True` distribution 77 parameters are checked for validity despite possibly degrading runtime 78 performance. When `False` invalid inputs may silently render incorrect 79 outputs. 80 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 81 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 82 result is undefined. When `False`, an exception is raised if one or 83 more of the statistic's batch members are undefined. 84 name: Python `str` name prefixed to Ops created by this class. 85 86 Raises: 87 ValueError: If `loc` is a scalar. 88 """ 89 parameters = locals() 90 with ops.name_scope(name, values=[loc, atol, rtol]): 91 loc = ops.convert_to_tensor(loc, name="loc") 92 if is_vector and validate_args: 93 msg = "Argument loc must be at least rank 1." 94 if loc.get_shape().ndims is not None: 95 if loc.get_shape().ndims < 1: 96 raise ValueError(msg) 97 else: 98 loc = control_flow_ops.with_dependencies( 99 [check_ops.assert_rank_at_least(loc, 1, message=msg)], loc) 100 self._loc = loc 101 102 super(_BaseDeterministic, self).__init__( 103 dtype=self._loc.dtype, 104 reparameterization_type=distribution.NOT_REPARAMETERIZED, 105 validate_args=validate_args, 106 allow_nan_stats=allow_nan_stats, 107 parameters=parameters, 108 graph_parents=[self._loc], 109 name=name) 110 111 self._atol = self._get_tol(atol) 112 self._rtol = self._get_tol(rtol) 113 # Avoid using the large broadcast with self.loc if possible. 114 if rtol is None: 115 self._slack = self.atol 116 else: 117 self._slack = self.atol + self.rtol * math_ops.abs(self.loc) 118 119 def _get_tol(self, tol): 120 if tol is None: 121 return ops.convert_to_tensor(0, dtype=self.loc.dtype) 122 123 tol = ops.convert_to_tensor(tol, dtype=self.loc.dtype) 124 if self.validate_args: 125 tol = control_flow_ops.with_dependencies([ 126 check_ops.assert_non_negative( 127 tol, message="Argument 'tol' must be non-negative") 128 ], tol) 129 return tol 130 131 @property 132 def loc(self): 133 """Point (or batch of points) at which this distribution is supported.""" 134 return self._loc 135 136 @property 137 def atol(self): 138 """Absolute tolerance for comparing points to `self.loc`.""" 139 return self._atol 140 141 @property 142 def rtol(self): 143 """Relative tolerance for comparing points to `self.loc`.""" 144 return self._rtol 145 146 def _mean(self): 147 return array_ops.identity(self.loc) 148 149 def _variance(self): 150 return array_ops.zeros_like(self.loc) 151 152 def _mode(self): 153 return self.mean() 154 155 def _sample_n(self, n, seed=None): # pylint: disable=unused-arg 156 n_static = tensor_util.constant_value(ops.convert_to_tensor(n)) 157 if n_static is not None and self.loc.get_shape().ndims is not None: 158 ones = [1] * self.loc.get_shape().ndims 159 multiples = [n_static] + ones 160 else: 161 ones = array_ops.ones_like(array_ops.shape(self.loc)) 162 multiples = array_ops.concat(([n], ones), axis=0) 163 164 return array_ops.tile(self.loc[array_ops.newaxis, ...], multiples=multiples) 165 166 167class Deterministic(_BaseDeterministic): 168 """Scalar `Deterministic` distribution on the real line. 169 170 The scalar `Deterministic` distribution is parameterized by a [batch] point 171 `loc` on the real line. The distribution is supported at this point only, 172 and corresponds to a random variable that is constant, equal to `loc`. 173 174 See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution). 175 176 #### Mathematical Details 177 178 The probability mass function (pmf) and cumulative distribution function (cdf) 179 are 180 181 ```none 182 pmf(x; loc) = 1, if x == loc, else 0 183 cdf(x; loc) = 1, if x >= loc, else 0 184 ``` 185 186 #### Examples 187 188 ```python 189 # Initialize a single Deterministic supported at zero. 190 constant = tf.contrib.distributions.Deterministic(0.) 191 constant.prob(0.) 192 ==> 1. 193 constant.prob(2.) 194 ==> 0. 195 196 # Initialize a [2, 2] batch of scalar constants. 197 loc = [[0., 1.], [2., 3.]] 198 x = [[0., 1.1], [1.99, 3.]] 199 constant = tf.contrib.distributions.Deterministic(loc) 200 constant.prob(x) 201 ==> [[1., 0.], [0., 1.]] 202 ``` 203 204 """ 205 206 def __init__(self, 207 loc, 208 atol=None, 209 rtol=None, 210 validate_args=False, 211 allow_nan_stats=True, 212 name="Deterministic"): 213 """Initialize a scalar `Deterministic` distribution. 214 215 The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` 216 computations, e.g. due to floating-point error. 217 218 ``` 219 pmf(x; loc) 220 = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), 221 = 0, otherwise. 222 ``` 223 224 Args: 225 loc: Numeric `Tensor` of shape `[B1, ..., Bb]`, with `b >= 0`. 226 The point (or batch of points) on which this distribution is supported. 227 atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable 228 shape. The absolute tolerance for comparing closeness to `loc`. 229 Default is `0`. 230 rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable 231 shape. The relative tolerance for comparing closeness to `loc`. 232 Default is `0`. 233 validate_args: Python `bool`, default `False`. When `True` distribution 234 parameters are checked for validity despite possibly degrading runtime 235 performance. When `False` invalid inputs may silently render incorrect 236 outputs. 237 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 238 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 239 result is undefined. When `False`, an exception is raised if one or 240 more of the statistic's batch members are undefined. 241 name: Python `str` name prefixed to Ops created by this class. 242 """ 243 super(Deterministic, self).__init__( 244 loc, 245 atol=atol, 246 rtol=rtol, 247 validate_args=validate_args, 248 allow_nan_stats=allow_nan_stats, 249 name=name) 250 251 def _batch_shape_tensor(self): 252 return array_ops.shape(self.loc) 253 254 def _batch_shape(self): 255 return self.loc.get_shape() 256 257 def _event_shape_tensor(self): 258 return constant_op.constant([], dtype=dtypes.int32) 259 260 def _event_shape(self): 261 return tensor_shape.scalar() 262 263 def _prob(self, x): 264 return math_ops.cast( 265 math_ops.abs(x - self.loc) <= self._slack, dtype=self.dtype) 266 267 def _cdf(self, x): 268 return math_ops.cast(x >= self.loc - self._slack, dtype=self.dtype) 269 270 271class VectorDeterministic(_BaseDeterministic): 272 """Vector `Deterministic` distribution on `R^k`. 273 274 The `VectorDeterministic` distribution is parameterized by a [batch] point 275 `loc in R^k`. The distribution is supported at this point only, 276 and corresponds to a random variable that is constant, equal to `loc`. 277 278 See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution). 279 280 #### Mathematical Details 281 282 The probability mass function (pmf) is 283 284 ```none 285 pmf(x; loc) 286 = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)], 287 = 0, otherwise. 288 ``` 289 290 #### Examples 291 292 ```python 293 tfd = tf.contrib.distributions 294 295 # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. 296 constant = tfd.Deterministic([0., 2.]) 297 constant.prob([0., 2.]) 298 ==> 1. 299 constant.prob([0., 3.]) 300 ==> 0. 301 302 # Initialize a [3] batch of constants on R^2. 303 loc = [[0., 1.], [2., 3.], [4., 5.]] 304 constant = tfd.VectorDeterministic(loc) 305 constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) 306 ==> [1., 0., 0.] 307 ``` 308 309 """ 310 311 def __init__(self, 312 loc, 313 atol=None, 314 rtol=None, 315 validate_args=False, 316 allow_nan_stats=True, 317 name="VectorDeterministic"): 318 """Initialize a `VectorDeterministic` distribution on `R^k`, for `k >= 0`. 319 320 Note that there is only one point in `R^0`, the "point" `[]`. So if `k = 0` 321 then `self.prob([]) == 1`. 322 323 The `atol` and `rtol` parameters allow for some slack in `pmf` 324 computations, e.g. due to floating-point error. 325 326 ``` 327 pmf(x; loc) 328 = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)], 329 = 0, otherwise 330 ``` 331 332 Args: 333 loc: Numeric `Tensor` of shape `[B1, ..., Bb, k]`, with `b >= 0`, `k >= 0` 334 The point (or batch of points) on which this distribution is supported. 335 atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable 336 shape. The absolute tolerance for comparing closeness to `loc`. 337 Default is `0`. 338 rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable 339 shape. The relative tolerance for comparing closeness to `loc`. 340 Default is `0`. 341 validate_args: Python `bool`, default `False`. When `True` distribution 342 parameters are checked for validity despite possibly degrading runtime 343 performance. When `False` invalid inputs may silently render incorrect 344 outputs. 345 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 346 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 347 result is undefined. When `False`, an exception is raised if one or 348 more of the statistic's batch members are undefined. 349 name: Python `str` name prefixed to Ops created by this class. 350 """ 351 super(VectorDeterministic, self).__init__( 352 loc, 353 atol=atol, 354 rtol=rtol, 355 is_vector=True, 356 validate_args=validate_args, 357 allow_nan_stats=allow_nan_stats, 358 name=name) 359 360 def _batch_shape_tensor(self): 361 return array_ops.shape(self.loc)[:-1] 362 363 def _batch_shape(self): 364 return self.loc.get_shape()[:-1] 365 366 def _event_shape_tensor(self): 367 return array_ops.shape(self.loc)[-1] 368 369 def _event_shape(self): 370 return self.loc.get_shape()[-1:] 371 372 def _prob(self, x): 373 if self.validate_args: 374 is_vector_check = check_ops.assert_rank_at_least(x, 1) 375 right_vec_space_check = check_ops.assert_equal( 376 self.event_shape_tensor(), 377 array_ops.gather(array_ops.shape(x), array_ops.rank(x) - 1), 378 message= 379 "Argument 'x' not defined in the same space R^k as this distribution") 380 with ops.control_dependencies([is_vector_check]): 381 with ops.control_dependencies([right_vec_space_check]): 382 x = array_ops.identity(x) 383 return math_ops.cast( 384 math_ops.reduce_all(math_ops.abs(x - self.loc) <= self._slack, axis=-1), 385 dtype=self.dtype) 386