1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Deterministic distribution class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.distributions import distribution
35
36__all__ = [
37    "Deterministic",
38    "VectorDeterministic",
39]
40
41
42@six.add_metaclass(abc.ABCMeta)
43class _BaseDeterministic(distribution.Distribution):
44  """Base class for Deterministic distributions."""
45
46  def __init__(self,
47               loc,
48               atol=None,
49               rtol=None,
50               is_vector=False,
51               validate_args=False,
52               allow_nan_stats=True,
53               name="_BaseDeterministic"):
54    """Initialize a batch of `_BaseDeterministic` distributions.
55
56    The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf`
57    computations, e.g. due to floating-point error.
58
59    ```
60    pmf(x; loc)
61      = 1, if Abs(x - loc) <= atol + rtol * Abs(loc),
62      = 0, otherwise.
63    ```
64
65    Args:
66      loc: Numeric `Tensor`.  The point (or batch of points) on which this
67        distribution is supported.
68      atol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
69        shape.  The absolute tolerance for comparing closeness to `loc`.
70        Default is `0`.
71      rtol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
72        shape.  The relative tolerance for comparing closeness to `loc`.
73        Default is `0`.
74      is_vector:  Python `bool`.  If `True`, this is for `VectorDeterministic`,
75        else `Deterministic`.
76      validate_args: Python `bool`, default `False`. When `True` distribution
77        parameters are checked for validity despite possibly degrading runtime
78        performance. When `False` invalid inputs may silently render incorrect
79        outputs.
80      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
81        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
82        result is undefined. When `False`, an exception is raised if one or
83        more of the statistic's batch members are undefined.
84      name: Python `str` name prefixed to Ops created by this class.
85
86    Raises:
87      ValueError:  If `loc` is a scalar.
88    """
89    parameters = locals()
90    with ops.name_scope(name, values=[loc, atol, rtol]):
91      loc = ops.convert_to_tensor(loc, name="loc")
92      if is_vector and validate_args:
93        msg = "Argument loc must be at least rank 1."
94        if loc.get_shape().ndims is not None:
95          if loc.get_shape().ndims < 1:
96            raise ValueError(msg)
97        else:
98          loc = control_flow_ops.with_dependencies(
99              [check_ops.assert_rank_at_least(loc, 1, message=msg)], loc)
100      self._loc = loc
101
102      super(_BaseDeterministic, self).__init__(
103          dtype=self._loc.dtype,
104          reparameterization_type=distribution.NOT_REPARAMETERIZED,
105          validate_args=validate_args,
106          allow_nan_stats=allow_nan_stats,
107          parameters=parameters,
108          graph_parents=[self._loc],
109          name=name)
110
111      self._atol = self._get_tol(atol)
112      self._rtol = self._get_tol(rtol)
113      # Avoid using the large broadcast with self.loc if possible.
114      if rtol is None:
115        self._slack = self.atol
116      else:
117        self._slack = self.atol + self.rtol * math_ops.abs(self.loc)
118
119  def _get_tol(self, tol):
120    if tol is None:
121      return ops.convert_to_tensor(0, dtype=self.loc.dtype)
122
123    tol = ops.convert_to_tensor(tol, dtype=self.loc.dtype)
124    if self.validate_args:
125      tol = control_flow_ops.with_dependencies([
126          check_ops.assert_non_negative(
127              tol, message="Argument 'tol' must be non-negative")
128      ], tol)
129    return tol
130
131  @property
132  def loc(self):
133    """Point (or batch of points) at which this distribution is supported."""
134    return self._loc
135
136  @property
137  def atol(self):
138    """Absolute tolerance for comparing points to `self.loc`."""
139    return self._atol
140
141  @property
142  def rtol(self):
143    """Relative tolerance for comparing points to `self.loc`."""
144    return self._rtol
145
146  def _mean(self):
147    return array_ops.identity(self.loc)
148
149  def _variance(self):
150    return array_ops.zeros_like(self.loc)
151
152  def _mode(self):
153    return self.mean()
154
155  def _sample_n(self, n, seed=None):  # pylint: disable=unused-arg
156    n_static = tensor_util.constant_value(ops.convert_to_tensor(n))
157    if n_static is not None and self.loc.get_shape().ndims is not None:
158      ones = [1] * self.loc.get_shape().ndims
159      multiples = [n_static] + ones
160    else:
161      ones = array_ops.ones_like(array_ops.shape(self.loc))
162      multiples = array_ops.concat(([n], ones), axis=0)
163
164    return array_ops.tile(self.loc[array_ops.newaxis, ...], multiples=multiples)
165
166
167class Deterministic(_BaseDeterministic):
168  """Scalar `Deterministic` distribution on the real line.
169
170  The scalar `Deterministic` distribution is parameterized by a [batch] point
171  `loc` on the real line.  The distribution is supported at this point only,
172  and corresponds to a random variable that is constant, equal to `loc`.
173
174  See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution).
175
176  #### Mathematical Details
177
178  The probability mass function (pmf) and cumulative distribution function (cdf)
179  are
180
181  ```none
182  pmf(x; loc) = 1, if x == loc, else 0
183  cdf(x; loc) = 1, if x >= loc, else 0
184  ```
185
186  #### Examples
187
188  ```python
189  # Initialize a single Deterministic supported at zero.
190  constant = tf.contrib.distributions.Deterministic(0.)
191  constant.prob(0.)
192  ==> 1.
193  constant.prob(2.)
194  ==> 0.
195
196  # Initialize a [2, 2] batch of scalar constants.
197  loc = [[0., 1.], [2., 3.]]
198  x = [[0., 1.1], [1.99, 3.]]
199  constant = tf.contrib.distributions.Deterministic(loc)
200  constant.prob(x)
201  ==> [[1., 0.], [0., 1.]]
202  ```
203
204  """
205
206  def __init__(self,
207               loc,
208               atol=None,
209               rtol=None,
210               validate_args=False,
211               allow_nan_stats=True,
212               name="Deterministic"):
213    """Initialize a scalar `Deterministic` distribution.
214
215    The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf`
216    computations, e.g. due to floating-point error.
217
218    ```
219    pmf(x; loc)
220      = 1, if Abs(x - loc) <= atol + rtol * Abs(loc),
221      = 0, otherwise.
222    ```
223
224    Args:
225      loc: Numeric `Tensor` of shape `[B1, ..., Bb]`, with `b >= 0`.
226        The point (or batch of points) on which this distribution is supported.
227      atol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
228        shape.  The absolute tolerance for comparing closeness to `loc`.
229        Default is `0`.
230      rtol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
231        shape.  The relative tolerance for comparing closeness to `loc`.
232        Default is `0`.
233      validate_args: Python `bool`, default `False`. When `True` distribution
234        parameters are checked for validity despite possibly degrading runtime
235        performance. When `False` invalid inputs may silently render incorrect
236        outputs.
237      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
238        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
239        result is undefined. When `False`, an exception is raised if one or
240        more of the statistic's batch members are undefined.
241      name: Python `str` name prefixed to Ops created by this class.
242    """
243    super(Deterministic, self).__init__(
244        loc,
245        atol=atol,
246        rtol=rtol,
247        validate_args=validate_args,
248        allow_nan_stats=allow_nan_stats,
249        name=name)
250
251  def _batch_shape_tensor(self):
252    return array_ops.shape(self.loc)
253
254  def _batch_shape(self):
255    return self.loc.get_shape()
256
257  def _event_shape_tensor(self):
258    return constant_op.constant([], dtype=dtypes.int32)
259
260  def _event_shape(self):
261    return tensor_shape.scalar()
262
263  def _prob(self, x):
264    return math_ops.cast(
265        math_ops.abs(x - self.loc) <= self._slack, dtype=self.dtype)
266
267  def _cdf(self, x):
268    return math_ops.cast(x >= self.loc - self._slack, dtype=self.dtype)
269
270
271class VectorDeterministic(_BaseDeterministic):
272  """Vector `Deterministic` distribution on `R^k`.
273
274  The `VectorDeterministic` distribution is parameterized by a [batch] point
275  `loc in R^k`.  The distribution is supported at this point only,
276  and corresponds to a random variable that is constant, equal to `loc`.
277
278  See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution).
279
280  #### Mathematical Details
281
282  The probability mass function (pmf) is
283
284  ```none
285  pmf(x; loc)
286    = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)],
287    = 0, otherwise.
288  ```
289
290  #### Examples
291
292  ```python
293  tfd = tf.contrib.distributions
294
295  # Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
296  constant = tfd.Deterministic([0., 2.])
297  constant.prob([0., 2.])
298  ==> 1.
299  constant.prob([0., 3.])
300  ==> 0.
301
302  # Initialize a [3] batch of constants on R^2.
303  loc = [[0., 1.], [2., 3.], [4., 5.]]
304  constant = tfd.VectorDeterministic(loc)
305  constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]])
306  ==> [1., 0., 0.]
307  ```
308
309  """
310
311  def __init__(self,
312               loc,
313               atol=None,
314               rtol=None,
315               validate_args=False,
316               allow_nan_stats=True,
317               name="VectorDeterministic"):
318    """Initialize a `VectorDeterministic` distribution on `R^k`, for `k >= 0`.
319
320    Note that there is only one point in `R^0`, the "point" `[]`.  So if `k = 0`
321    then `self.prob([]) == 1`.
322
323    The `atol` and `rtol` parameters allow for some slack in `pmf`
324    computations, e.g. due to floating-point error.
325
326    ```
327    pmf(x; loc)
328      = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)],
329      = 0, otherwise
330    ```
331
332    Args:
333      loc: Numeric `Tensor` of shape `[B1, ..., Bb, k]`, with `b >= 0`, `k >= 0`
334        The point (or batch of points) on which this distribution is supported.
335      atol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
336        shape.  The absolute tolerance for comparing closeness to `loc`.
337        Default is `0`.
338      rtol:  Non-negative `Tensor` of same `dtype` as `loc` and broadcastable
339        shape.  The relative tolerance for comparing closeness to `loc`.
340        Default is `0`.
341      validate_args: Python `bool`, default `False`. When `True` distribution
342        parameters are checked for validity despite possibly degrading runtime
343        performance. When `False` invalid inputs may silently render incorrect
344        outputs.
345      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
346        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
347        result is undefined. When `False`, an exception is raised if one or
348        more of the statistic's batch members are undefined.
349      name: Python `str` name prefixed to Ops created by this class.
350    """
351    super(VectorDeterministic, self).__init__(
352        loc,
353        atol=atol,
354        rtol=rtol,
355        is_vector=True,
356        validate_args=validate_args,
357        allow_nan_stats=allow_nan_stats,
358        name=name)
359
360  def _batch_shape_tensor(self):
361    return array_ops.shape(self.loc)[:-1]
362
363  def _batch_shape(self):
364    return self.loc.get_shape()[:-1]
365
366  def _event_shape_tensor(self):
367    return array_ops.shape(self.loc)[-1]
368
369  def _event_shape(self):
370    return self.loc.get_shape()[-1:]
371
372  def _prob(self, x):
373    if self.validate_args:
374      is_vector_check = check_ops.assert_rank_at_least(x, 1)
375      right_vec_space_check = check_ops.assert_equal(
376          self.event_shape_tensor(),
377          array_ops.gather(array_ops.shape(x), array_ops.rank(x) - 1),
378          message=
379          "Argument 'x' not defined in the same space R^k as this distribution")
380      with ops.control_dependencies([is_vector_check]):
381        with ops.control_dependencies([right_vec_space_check]):
382          x = array_ops.identity(x)
383    return math_ops.cast(
384        math_ops.reduce_all(math_ops.abs(x - self.loc) <= self._slack, axis=-1),
385        dtype=self.dtype)
386