14b935e00c584579dad18c41167f89a0e2e9e5f6fA. Unique TensorFlower# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo#
3cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# Licensed under the Apache License, Version 2.0 (the "License");
4cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# you may not use this file except in compliance with the License.
5cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# You may obtain a copy of the License at
6cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo#
7cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo#     http://www.apache.org/licenses/LICENSE-2.0
8cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo#
9cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# Unless required by applicable law or agreed to in writing, software
10cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# distributed under the License is distributed on an "AS IS" BASIS,
11cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# See the License for the specific language governing permissions and
13cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# limitations under the License.
14cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# ==============================================================================
15cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo"""The Bernoulli distribution class."""
16cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
17cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom __future__ import absolute_import
18cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom __future__ import division
19cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom __future__ import print_function
20cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
21cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.framework import dtypes
22cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.framework import ops
23cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.framework import tensor_shape
24cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.ops import array_ops
25cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.ops import math_ops
26a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlowerfrom tensorflow.python.ops import nn
27cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.ops import random_ops
280db6371fd87bd9545a81a039258382d7151b309aEugene Brevdofrom tensorflow.python.ops.distributions import distribution
2968c514faa9470d2bd8aed797339f048c50ed6317Eugene Brevdofrom tensorflow.python.ops.distributions import kullback_leibler
300db6371fd87bd9545a81a039258382d7151b309aEugene Brevdofrom tensorflow.python.ops.distributions import util as distribution_util
317c4a128ac2d569d0b614cca8d8626ca382b90d40Anna Rfrom tensorflow.python.util.tf_export import tf_export
32cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
33cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
347c4a128ac2d569d0b614cca8d8626ca382b90d40Anna R@tf_export("distributions.Bernoulli")
354b935e00c584579dad18c41167f89a0e2e9e5f6fA. Unique TensorFlowerclass Bernoulli(distribution.Distribution):
36cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo  """Bernoulli distribution.
37cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
38b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon  The Bernoulli distribution with `probs` parameter, i.e., the probability of a
39b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon  `1` outcome (vs a `0` outcome).
40cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo  """
41cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
428c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower  def __init__(self,
438c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower               logits=None,
44b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon               probs=None,
458c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower               dtype=dtypes.int32,
46369b651ffcdb271d654963de4d95bfe0483efc33Eugene Brevdo               validate_args=False,
47369b651ffcdb271d654963de4d95bfe0483efc33Eugene Brevdo               allow_nan_stats=True,
488c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower               name="Bernoulli"):
49cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo    """Construct Bernoulli distributions.
50cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
51cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo    Args:
52b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon      logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
53b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        entry in the `Tensor` parametrizes an independent Bernoulli distribution
54b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        where the probability of an event is sigmoid(logits). Only one of
55b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        `logits` or `probs` should be passed in.
56b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon      probs: An N-D `Tensor` representing the probability of a `1`
57b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        event. Each entry in the `Tensor` parameterizes an independent
58b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        Bernoulli distribution. Only one of `logits` or `probs` should be passed
59b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        in.
60b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon      dtype: The type of the event samples. Default: `int32`.
61b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon      validate_args: Python `bool`, default `False`. When `True` distribution
62b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        parameters are checked for validity despite possibly degrading runtime
63b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        performance. When `False` invalid inputs may silently render incorrect
64b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        outputs.
65b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon      allow_nan_stats: Python `bool`, default `True`. When `True`,
66b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
67b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon        indicate the result is undefined. When `False`, an exception is raised
68b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        if one or more of the statistic's batch members are undefined.
69b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon      name: Python `str` name prefixed to Ops created by this class.
70a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower
71a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower    Raises:
72a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower      ValueError: If p and logits are passed, or if neither are passed.
73cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo    """
74876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower    parameters = locals()
75f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore    with ops.name_scope(name):
76b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon      self._logits, self._probs = distribution_util.get_logits_and_probs(
77b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon          logits=logits,
78b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon          probs=probs,
79b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon          validate_args=validate_args,
80b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon          name=name)
81876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower    super(Bernoulli, self).__init__(
82876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower        dtype=dtype,
83a494308877dba4e888f70460e49910ae653630b6Eugene Brevdo        reparameterization_type=distribution.NOT_REPARAMETERIZED,
84876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower        validate_args=validate_args,
85876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower        allow_nan_stats=allow_nan_stats,
86876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower        parameters=parameters,
87b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        graph_parents=[self._logits, self._probs],
88f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore        name=name)
89cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
903c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower  @staticmethod
913c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower  def _param_shapes(sample_shape):
923c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower    return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
933c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower
94cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo  @property
95a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower  def logits(self):
96b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    """Log-odds of a `1` outcome (vs `0`)."""
97a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower    return self._logits
98a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower
99a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower  @property
100b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon  def probs(self):
101b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    """Probability of a `1` outcome (vs `0`)."""
102b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    return self._probs
103cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
1042ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon  def _batch_shape_tensor(self):
105426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    return array_ops.shape(self._logits)
106cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
1072ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon  def _batch_shape(self):
108426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    return self._logits.get_shape()
109cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
1102ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon  def _event_shape_tensor(self):
111426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    return array_ops.constant([], dtype=dtypes.int32)
112cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
1132ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon  def _event_shape(self):
114426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    return tensor_shape.scalar()
115cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
116426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower  def _sample_n(self, n, seed=None):
117b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon    new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
118426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    uniform = random_ops.random_uniform(
119b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon        new_shape, seed=seed, dtype=self.probs.dtype)
120b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    sample = math_ops.less(uniform, self.probs)
121426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    return math_ops.cast(sample, self.dtype)
122cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
123426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower  def _log_prob(self, event):
1247a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon    if self.validate_args:
1257a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon      event = distribution_util.embed_check_integer_casting_closed(
1267a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon          event, target_dtype=dtypes.bool)
1277a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon
128a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
1290815de21239955e346b562e899640649c8d2b9cbBenoit Steiner    # inconsistent behavior for logits = inf/-inf.
130426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    event = math_ops.cast(event, self.logits.dtype)
131426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    logits = self.logits
132426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
133426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    # so we do this here.
1348f332c0aae50563726d36cc809548dabd02bab31A. Unique TensorFlower
1352ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon    def _broadcast(logits, event):
1362ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon      return (array_ops.ones_like(event) * logits,
1372ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon              array_ops.ones_like(logits) * event)
1388f332c0aae50563726d36cc809548dabd02bab31A. Unique TensorFlower
13959a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower    if not (event.get_shape().is_fully_defined() and
14059a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower            logits.get_shape().is_fully_defined() and
14159a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower            event.get_shape() == logits.get_shape()):
14259a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower      logits, event = _broadcast(logits, event)
143333dc32ff79af21484695157f3d141dc776f7c02Martin Wicke    return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
144cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
145426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower  def _entropy(self):
146426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower    return (-self.logits * (math_ops.sigmoid(self.logits) - 1) +
147e57b0a038d032a9aa2e099ebe7646f3df89f83c9A. Unique TensorFlower            nn.softplus(-self.logits))
148cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
149426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower  def _mean(self):
150b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    return array_ops.identity(self.probs)
151cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
152426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower  def _variance(self):
153b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    return self._mean() * (1. - self.probs)
154cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
155426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower  def _mode(self):
156b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    """Returns `1` if `prob > 0.5` and `0` otherwise."""
157b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    return math_ops.cast(self.probs > 0.5, self.dtype)
158cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
159cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo
1603e205871f14932f60cbb995789c796843df7b5fdEugene Brevdo@kullback_leibler.RegisterKL(Bernoulli, Bernoulli)
1616900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlowerdef _kl_bernoulli_bernoulli(a, b, name=None):
1626900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower  """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.
1636900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower
1646900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower  Args:
1656900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower    a: instance of a Bernoulli distribution object.
1666900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower    b: instance of a Bernoulli distribution object.
1676900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower    name: (optional) Name to use for created operations.
1686900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower      default is "kl_bernoulli_bernoulli".
1696900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower
1706900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower  Returns:
1716900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower    Batchwise KL(a || b)
1726900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower  """
173b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon  with ops.name_scope(name, "kl_bernoulli_bernoulli",
174b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon                      values=[a.logits, b.logits]):
175b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits)
176b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits)
177b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon    return (math_ops.sigmoid(a.logits) * delta_probs0
178b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon            + math_ops.sigmoid(-a.logits) * delta_probs1)
179