independent_test.py revision 49483793695247f27332c7db0b9740e95a5de3db
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the Independent distribution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib 22import numpy as np 23 24from tensorflow.contrib.distributions.python.ops import independent as independent_lib 25from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.distributions import normal as normal_lib 28from tensorflow.python.platform import test 29from tensorflow.python.platform import tf_logging 30 31 32def try_import(name): # pylint: disable=invalid-name 33 module = None 34 try: 35 module = importlib.import_module(name) 36 except ImportError as e: 37 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 38 return module 39 40stats = try_import("scipy.stats") 41 42 43class ProductDistributionTest(test.TestCase): 44 45 def testSampleAndLogProbUnivariate(self): 46 loc = np.float32([-1., 1]) 47 scale = np.float32([0.1, 0.5]) 48 with self.test_session() as sess: 49 ind = independent_lib.Independent( 50 distribution=normal_lib.Normal(loc=loc, scale=scale), 51 reduce_batch_ndims=1) 52 53 x = ind.sample([4, 5]) 54 log_prob_x = ind.log_prob(x) 55 x_, actual_log_prob_x = sess.run([x, log_prob_x]) 56 57 self.assertEqual([], ind.batch_shape) 58 self.assertEqual([2], ind.event_shape) 59 self.assertEqual([4, 5, 2], x.shape) 60 self.assertEqual([4, 5], log_prob_x.shape) 61 62 expected_log_prob_x = stats.norm(loc, scale).logpdf(x_).sum(-1) 63 self.assertAllClose(expected_log_prob_x, actual_log_prob_x, 64 rtol=1e-5, atol=0.) 65 66 def testSampleAndLogProbMultivariate(self): 67 loc = np.float32([[-1., 1], [1, -1]]) 68 scale = np.float32([1., 0.5]) 69 with self.test_session() as sess: 70 ind = independent_lib.Independent( 71 distribution=mvn_diag_lib.MultivariateNormalDiag( 72 loc=loc, 73 scale_identity_multiplier=scale), 74 reduce_batch_ndims=1) 75 76 x = ind.sample([4, 5]) 77 log_prob_x = ind.log_prob(x) 78 x_, actual_log_prob_x = sess.run([x, log_prob_x]) 79 80 self.assertEqual([], ind.batch_shape) 81 self.assertEqual([2, 2], ind.event_shape) 82 self.assertEqual([4, 5, 2, 2], x.shape) 83 self.assertEqual([4, 5], log_prob_x.shape) 84 85 expected_log_prob_x = stats.norm(loc, scale[:, None]).logpdf( 86 x_).sum(-1).sum(-1) 87 self.assertAllClose(expected_log_prob_x, actual_log_prob_x, 88 rtol=1e-6, atol=0.) 89 90 def testSampleConsistentStats(self): 91 loc = np.float32([[-1., 1], [1, -1]]) 92 scale = np.float32([1., 0.5]) 93 n_samp = 1e4 94 with self.test_session() as sess: 95 ind = independent_lib.Independent( 96 distribution=mvn_diag_lib.MultivariateNormalDiag( 97 loc=loc, 98 scale_identity_multiplier=scale), 99 reduce_batch_ndims=1) 100 101 x = ind.sample(int(n_samp), seed=42) 102 sample_mean = math_ops.reduce_mean(x, axis=0) 103 sample_var = math_ops.reduce_mean( 104 math_ops.squared_difference(x, sample_mean), axis=0) 105 sample_std = math_ops.sqrt(sample_var) 106 sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0) 107 108 [ 109 sample_mean_, sample_var_, sample_std_, sample_entropy_, 110 actual_mean_, actual_var_, actual_std_, actual_entropy_, 111 actual_mode_, 112 ] = sess.run([ 113 sample_mean, sample_var, sample_std, sample_entropy, 114 ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(), 115 ]) 116 117 self.assertAllClose(sample_mean_, actual_mean_, rtol=0.02, atol=0.) 118 self.assertAllClose(sample_var_, actual_var_, rtol=0.04, atol=0.) 119 self.assertAllClose(sample_std_, actual_std_, rtol=0.02, atol=0.) 120 self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) 121 self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) 122 123 124if __name__ == "__main__": 125 test.main() 126