independent_test.py revision 49483793695247f27332c7db0b9740e95a5de3db
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the Independent distribution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib
22import numpy as np
23
24from tensorflow.contrib.distributions.python.ops import independent as independent_lib
25from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.distributions import normal as normal_lib
28from tensorflow.python.platform import test
29from tensorflow.python.platform import tf_logging
30
31
32def try_import(name):  # pylint: disable=invalid-name
33  module = None
34  try:
35    module = importlib.import_module(name)
36  except ImportError as e:
37    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
38  return module
39
40stats = try_import("scipy.stats")
41
42
43class ProductDistributionTest(test.TestCase):
44
45  def testSampleAndLogProbUnivariate(self):
46    loc = np.float32([-1., 1])
47    scale = np.float32([0.1, 0.5])
48    with self.test_session() as sess:
49      ind = independent_lib.Independent(
50          distribution=normal_lib.Normal(loc=loc, scale=scale),
51          reduce_batch_ndims=1)
52
53      x = ind.sample([4, 5])
54      log_prob_x = ind.log_prob(x)
55      x_, actual_log_prob_x = sess.run([x, log_prob_x])
56
57      self.assertEqual([], ind.batch_shape)
58      self.assertEqual([2], ind.event_shape)
59      self.assertEqual([4, 5, 2], x.shape)
60      self.assertEqual([4, 5], log_prob_x.shape)
61
62      expected_log_prob_x = stats.norm(loc, scale).logpdf(x_).sum(-1)
63      self.assertAllClose(expected_log_prob_x, actual_log_prob_x,
64                          rtol=1e-5, atol=0.)
65
66  def testSampleAndLogProbMultivariate(self):
67    loc = np.float32([[-1., 1], [1, -1]])
68    scale = np.float32([1., 0.5])
69    with self.test_session() as sess:
70      ind = independent_lib.Independent(
71          distribution=mvn_diag_lib.MultivariateNormalDiag(
72              loc=loc,
73              scale_identity_multiplier=scale),
74          reduce_batch_ndims=1)
75
76      x = ind.sample([4, 5])
77      log_prob_x = ind.log_prob(x)
78      x_, actual_log_prob_x = sess.run([x, log_prob_x])
79
80      self.assertEqual([], ind.batch_shape)
81      self.assertEqual([2, 2], ind.event_shape)
82      self.assertEqual([4, 5, 2, 2], x.shape)
83      self.assertEqual([4, 5], log_prob_x.shape)
84
85      expected_log_prob_x = stats.norm(loc, scale[:, None]).logpdf(
86          x_).sum(-1).sum(-1)
87      self.assertAllClose(expected_log_prob_x, actual_log_prob_x,
88                          rtol=1e-6, atol=0.)
89
90  def testSampleConsistentStats(self):
91    loc = np.float32([[-1., 1], [1, -1]])
92    scale = np.float32([1., 0.5])
93    n_samp = 1e4
94    with self.test_session() as sess:
95      ind = independent_lib.Independent(
96          distribution=mvn_diag_lib.MultivariateNormalDiag(
97              loc=loc,
98              scale_identity_multiplier=scale),
99          reduce_batch_ndims=1)
100
101      x = ind.sample(int(n_samp), seed=42)
102      sample_mean = math_ops.reduce_mean(x, axis=0)
103      sample_var = math_ops.reduce_mean(
104          math_ops.squared_difference(x, sample_mean), axis=0)
105      sample_std = math_ops.sqrt(sample_var)
106      sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0)
107
108      [
109          sample_mean_, sample_var_, sample_std_, sample_entropy_,
110          actual_mean_, actual_var_, actual_std_, actual_entropy_,
111          actual_mode_,
112      ] = sess.run([
113          sample_mean, sample_var, sample_std, sample_entropy,
114          ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(),
115      ])
116
117      self.assertAllClose(sample_mean_, actual_mean_, rtol=0.02, atol=0.)
118      self.assertAllClose(sample_var_, actual_var_, rtol=0.04, atol=0.)
119      self.assertAllClose(sample_std_, actual_std_, rtol=0.02, atol=0.)
120      self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
121      self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
122
123
124if __name__ == "__main__":
125  test.main()
126