1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Wishart."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from scipy import linalg
23from tensorflow.contrib import distributions as distributions_lib
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import random_seed
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import test
30
31distributions = distributions_lib
32
33
34def make_pd(start, n):
35  """Deterministically create a positive definite matrix."""
36  x = np.tril(linalg.circulant(np.arange(start, start + n)))
37  return np.dot(x, x.T)
38
39
40def chol(x):
41  """Compute Cholesky factorization."""
42  return linalg.cholesky(x).T
43
44
45def wishart_var(df, x):
46  """Compute Wishart variance for numpy scale matrix."""
47  x = np.sqrt(df) * np.asarray(x)
48  d = np.expand_dims(np.diag(x), -1)
49  return x**2 + np.dot(d, d.T)
50
51
52class WishartCholeskyTest(test.TestCase):
53
54  def testEntropy(self):
55    with self.test_session():
56      scale = make_pd(1., 2)
57      df = 4
58      w = distributions.WishartCholesky(df, chol(scale))
59      # sp.stats.wishart(df=4, scale=make_pd(1., 2)).entropy()
60      self.assertAllClose(6.301387092430769, w.entropy().eval())
61
62      w = distributions.WishartCholesky(df=1, scale=[[1.]])
63      # sp.stats.wishart(df=1,scale=1).entropy()
64      self.assertAllClose(0.78375711047393404, w.entropy().eval())
65
66  def testMeanLogDetAndLogNormalizingConstant(self):
67    with self.test_session():
68
69      def entropy_alt(w):
70        return (
71            w.log_normalization()
72            - 0.5 * (w.df - w.dimension - 1.) * w.mean_log_det()
73            + 0.5 * w.df * w.dimension).eval()
74
75      w = distributions.WishartCholesky(df=4,
76                                        scale=chol(make_pd(1., 2)))
77      self.assertAllClose(w.entropy().eval(), entropy_alt(w))
78
79      w = distributions.WishartCholesky(df=5, scale=[[1.]])
80      self.assertAllClose(w.entropy().eval(), entropy_alt(w))
81
82  def testMean(self):
83    with self.test_session():
84      scale = make_pd(1., 2)
85      df = 4
86      w = distributions.WishartCholesky(df, chol(scale))
87      self.assertAllEqual(df * scale, w.mean().eval())
88
89  def testMode(self):
90    with self.test_session():
91      scale = make_pd(1., 2)
92      df = 4
93      w = distributions.WishartCholesky(df, chol(scale))
94      self.assertAllEqual((df - 2. - 1.) * scale, w.mode().eval())
95
96  def testStd(self):
97    with self.test_session():
98      scale = make_pd(1., 2)
99      df = 4
100      w = distributions.WishartCholesky(df, chol(scale))
101      self.assertAllEqual(chol(wishart_var(df, scale)), w.stddev().eval())
102
103  def testVariance(self):
104    with self.test_session():
105      scale = make_pd(1., 2)
106      df = 4
107      w = distributions.WishartCholesky(df, chol(scale))
108      self.assertAllEqual(wishart_var(df, scale), w.variance().eval())
109
110  def testSample(self):
111    with self.test_session():
112      scale = make_pd(1., 2)
113      df = 4
114
115      chol_w = distributions.WishartCholesky(
116          df, chol(scale), cholesky_input_output_matrices=False)
117
118      x = chol_w.sample(1, seed=42).eval()
119      chol_x = [chol(x[0])]
120
121      full_w = distributions.WishartFull(
122          df, scale, cholesky_input_output_matrices=False)
123      self.assertAllClose(x, full_w.sample(1, seed=42).eval())
124
125      chol_w_chol = distributions.WishartCholesky(
126          df, chol(scale), cholesky_input_output_matrices=True)
127      self.assertAllClose(chol_x, chol_w_chol.sample(1, seed=42).eval())
128      eigen_values = array_ops.matrix_diag_part(
129          chol_w_chol.sample(
130              1000, seed=42))
131      np.testing.assert_array_less(0., eigen_values.eval())
132
133      full_w_chol = distributions.WishartFull(
134          df, scale, cholesky_input_output_matrices=True)
135      self.assertAllClose(chol_x, full_w_chol.sample(1, seed=42).eval())
136      eigen_values = array_ops.matrix_diag_part(
137          full_w_chol.sample(
138              1000, seed=42))
139      np.testing.assert_array_less(0., eigen_values.eval())
140
141      # Check first and second moments.
142      df = 4.
143      chol_w = distributions.WishartCholesky(
144          df=df,
145          scale=chol(make_pd(1., 3)),
146          cholesky_input_output_matrices=False)
147      x = chol_w.sample(10000, seed=42)
148      self.assertAllEqual((10000, 3, 3), x.get_shape())
149
150      moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval()
151      self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05)
152
153      # The Variance estimate uses the squares rather than outer-products
154      # because Wishart.Variance is the diagonal of the Wishart covariance
155      # matrix.
156      variance_estimate = (math_ops.reduce_mean(
157          math_ops.square(x), reduction_indices=[0]) -
158                           math_ops.square(moment1_estimate)).eval()
159      self.assertAllClose(
160          chol_w.variance().eval(), variance_estimate, rtol=0.05)
161
162  # Test that sampling with the same seed twice gives the same results.
163  def testSampleMultipleTimes(self):
164    with self.test_session():
165      df = 4.
166      n_val = 100
167
168      random_seed.set_random_seed(654321)
169      chol_w1 = distributions.WishartCholesky(
170          df=df,
171          scale=chol(make_pd(1., 3)),
172          cholesky_input_output_matrices=False,
173          name="wishart1")
174      samples1 = chol_w1.sample(n_val, seed=123456).eval()
175
176      random_seed.set_random_seed(654321)
177      chol_w2 = distributions.WishartCholesky(
178          df=df,
179          scale=chol(make_pd(1., 3)),
180          cholesky_input_output_matrices=False,
181          name="wishart2")
182      samples2 = chol_w2.sample(n_val, seed=123456).eval()
183
184      self.assertAllClose(samples1, samples2)
185
186  def testProb(self):
187    with self.test_session():
188      # Generate some positive definite (pd) matrices and their Cholesky
189      # factorizations.
190      x = np.array(
191          [make_pd(1., 2), make_pd(2., 2), make_pd(3., 2), make_pd(4., 2)])
192      chol_x = np.array([chol(x[0]), chol(x[1]), chol(x[2]), chol(x[3])])
193
194      # Since Wishart wasn"t added to SciPy until 0.16, we'll spot check some
195      # pdfs with hard-coded results from upstream SciPy.
196
197      log_prob_df_seq = np.array([
198          # math.log(stats.wishart.pdf(x[0], df=2+0, scale=x[0]))
199          -3.5310242469692907,
200          # math.log(stats.wishart.pdf(x[1], df=2+1, scale=x[1]))
201          -7.689907330328961,
202          # math.log(stats.wishart.pdf(x[2], df=2+2, scale=x[2]))
203          -10.815845159537895,
204          # math.log(stats.wishart.pdf(x[3], df=2+3, scale=x[3]))
205          -13.640549882916691,
206      ])
207
208      # This test checks that batches don't interfere with correctness.
209      w = distributions.WishartCholesky(
210          df=[2, 3, 4, 5],
211          scale=chol_x,
212          cholesky_input_output_matrices=True)
213      self.assertAllClose(log_prob_df_seq, w.log_prob(chol_x).eval())
214
215      # Now we test various constructions of Wishart with different sample
216      # shape.
217
218      log_prob = np.array([
219          # math.log(stats.wishart.pdf(x[0], df=4, scale=x[0]))
220          -4.224171427529236,
221          # math.log(stats.wishart.pdf(x[1], df=4, scale=x[0]))
222          -6.3378770664093453,
223          # math.log(stats.wishart.pdf(x[2], df=4, scale=x[0]))
224          -12.026946850193017,
225          # math.log(stats.wishart.pdf(x[3], df=4, scale=x[0]))
226          -20.951582705289454,
227      ])
228
229      for w in (
230          distributions.WishartCholesky(
231              df=4,
232              scale=chol_x[0],
233              cholesky_input_output_matrices=False),
234          distributions.WishartFull(
235              df=4,
236              scale=x[0],
237              cholesky_input_output_matrices=False)):
238        self.assertAllEqual((2, 2), w.event_shape_tensor().eval())
239        self.assertEqual(2, w.dimension.eval())
240        self.assertAllClose(log_prob[0], w.log_prob(x[0]).eval())
241        self.assertAllClose(log_prob[0:2], w.log_prob(x[0:2]).eval())
242        self.assertAllClose(
243            np.reshape(log_prob, (2, 2)),
244            w.log_prob(np.reshape(x, (2, 2, 2, 2))).eval())
245        self.assertAllClose(
246            np.reshape(np.exp(log_prob), (2, 2)),
247            w.prob(np.reshape(x, (2, 2, 2, 2))).eval())
248        self.assertAllEqual((2, 2),
249                            w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape())
250
251      for w in (
252          distributions.WishartCholesky(
253              df=4,
254              scale=chol_x[0],
255              cholesky_input_output_matrices=True),
256          distributions.WishartFull(
257              df=4,
258              scale=x[0],
259              cholesky_input_output_matrices=True)):
260        self.assertAllEqual((2, 2), w.event_shape_tensor().eval())
261        self.assertEqual(2, w.dimension.eval())
262        self.assertAllClose(log_prob[0], w.log_prob(chol_x[0]).eval())
263        self.assertAllClose(log_prob[0:2], w.log_prob(chol_x[0:2]).eval())
264        self.assertAllClose(
265            np.reshape(log_prob, (2, 2)),
266            w.log_prob(np.reshape(chol_x, (2, 2, 2, 2))).eval())
267        self.assertAllClose(
268            np.reshape(np.exp(log_prob), (2, 2)),
269            w.prob(np.reshape(chol_x, (2, 2, 2, 2))).eval())
270        self.assertAllEqual((2, 2),
271                            w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape())
272
273  def testBatchShape(self):
274    with self.test_session() as sess:
275      scale = make_pd(1., 2)
276      chol_scale = chol(scale)
277
278      w = distributions.WishartCholesky(df=4, scale=chol_scale)
279      self.assertAllEqual([], w.batch_shape)
280      self.assertAllEqual([], w.batch_shape_tensor().eval())
281
282      w = distributions.WishartCholesky(
283          df=[4., 4], scale=np.array([chol_scale, chol_scale]))
284      self.assertAllEqual([2], w.batch_shape)
285      self.assertAllEqual([2], w.batch_shape_tensor().eval())
286
287      scale_deferred = array_ops.placeholder(dtypes.float32)
288      w = distributions.WishartCholesky(df=4, scale=scale_deferred)
289      self.assertAllEqual(
290          [], sess.run(w.batch_shape_tensor(),
291                       feed_dict={scale_deferred: chol_scale}))
292      self.assertAllEqual(
293          [2],
294          sess.run(w.batch_shape_tensor(),
295                   feed_dict={scale_deferred: [chol_scale, chol_scale]}))
296
297  def testEventShape(self):
298    with self.test_session() as sess:
299      scale = make_pd(1., 2)
300      chol_scale = chol(scale)
301
302      w = distributions.WishartCholesky(df=4, scale=chol_scale)
303      self.assertAllEqual([2, 2], w.event_shape)
304      self.assertAllEqual([2, 2], w.event_shape_tensor().eval())
305
306      w = distributions.WishartCholesky(
307          df=[4., 4], scale=np.array([chol_scale, chol_scale]))
308      self.assertAllEqual([2, 2], w.event_shape)
309      self.assertAllEqual([2, 2], w.event_shape_tensor().eval())
310
311      scale_deferred = array_ops.placeholder(dtypes.float32)
312      w = distributions.WishartCholesky(df=4, scale=scale_deferred)
313      self.assertAllEqual(
314          [2, 2],
315          sess.run(w.event_shape_tensor(),
316                   feed_dict={scale_deferred: chol_scale}))
317      self.assertAllEqual(
318          [2, 2],
319          sess.run(w.event_shape_tensor(),
320                   feed_dict={scale_deferred: [chol_scale, chol_scale]}))
321
322  def testValidateArgs(self):
323    with self.test_session() as sess:
324      df_deferred = array_ops.placeholder(dtypes.float32)
325      chol_scale_deferred = array_ops.placeholder(dtypes.float32)
326      x = make_pd(1., 3)
327      chol_scale = chol(x)
328
329      # Check expensive, deferred assertions.
330      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
331                                   "cannot be less than"):
332        chol_w = distributions.WishartCholesky(
333            df=df_deferred,
334            scale=chol_scale_deferred,
335            validate_args=True)
336        sess.run(chol_w.log_prob(np.asarray(
337            x, dtype=np.float32)),
338                 feed_dict={df_deferred: 2.,
339                            chol_scale_deferred: chol_scale})
340
341      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
342                                   "Cholesky decomposition was not successful"):
343        chol_w = distributions.WishartFull(
344            df=df_deferred, scale=chol_scale_deferred)
345        # np.ones((3, 3)) is not positive, definite.
346        sess.run(chol_w.log_prob(np.asarray(
347            x, dtype=np.float32)),
348                 feed_dict={
349                     df_deferred: 4.,
350                     chol_scale_deferred: np.ones(
351                         (3, 3), dtype=np.float32)
352                 })
353
354      with self.assertRaisesOpError("scale must be square"):
355        chol_w = distributions.WishartCholesky(
356            df=4.,
357            scale=np.array([[2., 3., 4.], [1., 2., 3.]], dtype=np.float32),
358            validate_args=True)
359        sess.run(chol_w.scale().eval())
360
361      # Ensure no assertions.
362      chol_w = distributions.WishartCholesky(
363          df=df_deferred,
364          scale=chol_scale_deferred,
365          validate_args=False)
366      sess.run(chol_w.log_prob(np.asarray(
367          x, dtype=np.float32)),
368               feed_dict={df_deferred: 4,
369                          chol_scale_deferred: chol_scale})
370      # Bogus log_prob, but since we have no checks running... c"est la vie.
371      sess.run(chol_w.log_prob(np.asarray(
372          x, dtype=np.float32)),
373               feed_dict={df_deferred: 4,
374                          chol_scale_deferred: np.ones((3, 3))})
375
376  def testStaticAsserts(self):
377    with self.test_session():
378      x = make_pd(1., 3)
379      chol_scale = chol(x)
380
381      # Still has these assertions because they're resolveable at graph
382      # construction
383      with self.assertRaisesRegexp(ValueError, "cannot be less than"):
384        distributions.WishartCholesky(
385            df=2, scale=chol_scale, validate_args=False)
386      with self.assertRaisesRegexp(TypeError, "Argument tril must have dtype"):
387        distributions.WishartCholesky(
388            df=4.,
389            scale=np.asarray(
390                chol_scale, dtype=np.int32),
391            validate_args=False)
392
393
394if __name__ == "__main__":
395  test.main()
396