14b935e00c584579dad18c41167f89a0e2e9e5f6fA. Unique TensorFlower# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# 3cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# Licensed under the Apache License, Version 2.0 (the "License"); 4cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# you may not use this file except in compliance with the License. 5cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# You may obtain a copy of the License at 6cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# 7cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# http://www.apache.org/licenses/LICENSE-2.0 8cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# 9cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# Unless required by applicable law or agreed to in writing, software 10cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# distributed under the License is distributed on an "AS IS" BASIS, 11cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# See the License for the specific language governing permissions and 13cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# limitations under the License. 14cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo# ============================================================================== 15cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo"""The Bernoulli distribution class.""" 16cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 17cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom __future__ import absolute_import 18cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom __future__ import division 19cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom __future__ import print_function 20cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 21cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.framework import dtypes 22cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.framework import ops 23cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.framework import tensor_shape 24cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.ops import array_ops 25cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.ops import math_ops 26a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlowerfrom tensorflow.python.ops import nn 27cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdofrom tensorflow.python.ops import random_ops 280db6371fd87bd9545a81a039258382d7151b309aEugene Brevdofrom tensorflow.python.ops.distributions import distribution 2968c514faa9470d2bd8aed797339f048c50ed6317Eugene Brevdofrom tensorflow.python.ops.distributions import kullback_leibler 300db6371fd87bd9545a81a039258382d7151b309aEugene Brevdofrom tensorflow.python.ops.distributions import util as distribution_util 317c4a128ac2d569d0b614cca8d8626ca382b90d40Anna Rfrom tensorflow.python.util.tf_export import tf_export 32cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 33cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 347c4a128ac2d569d0b614cca8d8626ca382b90d40Anna R@tf_export("distributions.Bernoulli") 354b935e00c584579dad18c41167f89a0e2e9e5f6fA. Unique TensorFlowerclass Bernoulli(distribution.Distribution): 36cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo """Bernoulli distribution. 37cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 38b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon The Bernoulli distribution with `probs` parameter, i.e., the probability of a 39b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon `1` outcome (vs a `0` outcome). 40cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo """ 41cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 428c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower def __init__(self, 438c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower logits=None, 44b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon probs=None, 458c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower dtype=dtypes.int32, 46369b651ffcdb271d654963de4d95bfe0483efc33Eugene Brevdo validate_args=False, 47369b651ffcdb271d654963de4d95bfe0483efc33Eugene Brevdo allow_nan_stats=True, 488c413daa09318c6ad021eb830650e3c66ee90891A. Unique TensorFlower name="Bernoulli"): 49cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo """Construct Bernoulli distributions. 50cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 51cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo Args: 52b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon logits: An N-D `Tensor` representing the log-odds of a `1` event. Each 53b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon entry in the `Tensor` parametrizes an independent Bernoulli distribution 54b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon where the probability of an event is sigmoid(logits). Only one of 55b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon `logits` or `probs` should be passed in. 56b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon probs: An N-D `Tensor` representing the probability of a `1` 57b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon event. Each entry in the `Tensor` parameterizes an independent 58b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon Bernoulli distribution. Only one of `logits` or `probs` should be passed 59b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon in. 60b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon dtype: The type of the event samples. Default: `int32`. 61b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon validate_args: Python `bool`, default `False`. When `True` distribution 62b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon parameters are checked for validity despite possibly degrading runtime 63b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon performance. When `False` invalid inputs may silently render incorrect 64b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon outputs. 65b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon allow_nan_stats: Python `bool`, default `True`. When `True`, 66b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon statistics (e.g., mean, mode, variance) use the value "`NaN`" to 67b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon indicate the result is undefined. When `False`, an exception is raised 68b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon if one or more of the statistic's batch members are undefined. 69b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon name: Python `str` name prefixed to Ops created by this class. 70a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower 71a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower Raises: 72a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower ValueError: If p and logits are passed, or if neither are passed. 73cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo """ 74876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower parameters = locals() 75f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore with ops.name_scope(name): 76b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon self._logits, self._probs = distribution_util.get_logits_and_probs( 77b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon logits=logits, 78b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon probs=probs, 79b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon validate_args=validate_args, 80b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon name=name) 81876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower super(Bernoulli, self).__init__( 82876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower dtype=dtype, 83a494308877dba4e888f70460e49910ae653630b6Eugene Brevdo reparameterization_type=distribution.NOT_REPARAMETERIZED, 84876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower validate_args=validate_args, 85876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower allow_nan_stats=allow_nan_stats, 86876c4aa5633ccfd80be87c72235ec1e4252d531bA. Unique TensorFlower parameters=parameters, 87b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon graph_parents=[self._logits, self._probs], 88f2e3d41ba640931b60e2cba87ea4e823d4745905Ian Langmore name=name) 89cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 903c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower @staticmethod 913c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower def _param_shapes(sample_shape): 923c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} 933c4d021895fa748abd4f3609b6749888a34976b9A. Unique TensorFlower 94cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo @property 95a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower def logits(self): 96b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon """Log-odds of a `1` outcome (vs `0`).""" 97a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower return self._logits 98a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower 99a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower @property 100b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon def probs(self): 101b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon """Probability of a `1` outcome (vs `0`).""" 102b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon return self._probs 103cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 1042ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon def _batch_shape_tensor(self): 105426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower return array_ops.shape(self._logits) 106cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 1072ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon def _batch_shape(self): 108426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower return self._logits.get_shape() 109cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 1102ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon def _event_shape_tensor(self): 111426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower return array_ops.constant([], dtype=dtypes.int32) 112cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 1132ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon def _event_shape(self): 114426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower return tensor_shape.scalar() 115cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 116426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower def _sample_n(self, n, seed=None): 117b4d475bf0966ac148f61d7f42bd9b46155bb04f6Joshua V. Dillon new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) 118426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower uniform = random_ops.random_uniform( 119b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon new_shape, seed=seed, dtype=self.probs.dtype) 120b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon sample = math_ops.less(uniform, self.probs) 121426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower return math_ops.cast(sample, self.dtype) 122cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 123426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower def _log_prob(self, event): 1247a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon if self.validate_args: 1257a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon event = distribution_util.embed_check_integer_casting_closed( 1267a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon event, target_dtype=dtypes.bool) 1277a80e267e380dae8cfdafc4dd32ab782608be1f5Joshua V. Dillon 128a89c54d57209f91161fa450605f645c9124d89acA. Unique TensorFlower # TODO(jaana): The current sigmoid_cross_entropy_with_logits has 1290815de21239955e346b562e899640649c8d2b9cbBenoit Steiner # inconsistent behavior for logits = inf/-inf. 130426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower event = math_ops.cast(event, self.logits.dtype) 131426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower logits = self.logits 132426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower # sigmoid_cross_entropy_with_logits doesn't broadcast shape, 133426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower # so we do this here. 1348f332c0aae50563726d36cc809548dabd02bab31A. Unique TensorFlower 1352ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon def _broadcast(logits, event): 1362ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon return (array_ops.ones_like(event) * logits, 1372ddb9a0960af173cbc596be8436ec79bc7efb000Joshua V. Dillon array_ops.ones_like(logits) * event) 1388f332c0aae50563726d36cc809548dabd02bab31A. Unique TensorFlower 13959a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower if not (event.get_shape().is_fully_defined() and 14059a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower logits.get_shape().is_fully_defined() and 14159a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower event.get_shape() == logits.get_shape()): 14259a7de4ce8696adcd360f0c8a9fe4d5efa90e99dA. Unique TensorFlower logits, event = _broadcast(logits, event) 143333dc32ff79af21484695157f3d141dc776f7c02Martin Wicke return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) 144cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 145426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower def _entropy(self): 146426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + 147e57b0a038d032a9aa2e099ebe7646f3df89f83c9A. Unique TensorFlower nn.softplus(-self.logits)) 148cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 149426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower def _mean(self): 150b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon return array_ops.identity(self.probs) 151cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 152426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower def _variance(self): 153b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon return self._mean() * (1. - self.probs) 154cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 155426e36d6f351dd556ccb1c8defa1ddd88015942dA. Unique TensorFlower def _mode(self): 156b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon """Returns `1` if `prob > 0.5` and `0` otherwise.""" 157b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon return math_ops.cast(self.probs > 0.5, self.dtype) 158cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 159cb44b3a001766f7ab633108014c38dadc4adab25Eugene Brevdo 1603e205871f14932f60cbb995789c796843df7b5fdEugene Brevdo@kullback_leibler.RegisterKL(Bernoulli, Bernoulli) 1616900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlowerdef _kl_bernoulli_bernoulli(a, b, name=None): 1626900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli. 1636900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower 1646900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower Args: 1656900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower a: instance of a Bernoulli distribution object. 1666900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower b: instance of a Bernoulli distribution object. 1676900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower name: (optional) Name to use for created operations. 1686900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower default is "kl_bernoulli_bernoulli". 1696900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower 1706900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower Returns: 1716900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower Batchwise KL(a || b) 1726900ed2c278bdc2b68094aeba984b7690ac9142eA. Unique TensorFlower """ 173b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon with ops.name_scope(name, "kl_bernoulli_bernoulli", 174b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon values=[a.logits, b.logits]): 175b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits) 176b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits) 177b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon return (math_ops.sigmoid(a.logits) * delta_probs0 178b945821c1f1e2b4dc4e4fa2489a4ce10a3454d66Joshua V. Dillon + math_ops.sigmoid(-a.logits) * delta_probs1) 179