1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Wishart.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from scipy import linalg 23from tensorflow.contrib import distributions as distributions_lib 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import random_seed 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import test 30 31distributions = distributions_lib 32 33 34def make_pd(start, n): 35 """Deterministically create a positive definite matrix.""" 36 x = np.tril(linalg.circulant(np.arange(start, start + n))) 37 return np.dot(x, x.T) 38 39 40def chol(x): 41 """Compute Cholesky factorization.""" 42 return linalg.cholesky(x).T 43 44 45def wishart_var(df, x): 46 """Compute Wishart variance for numpy scale matrix.""" 47 x = np.sqrt(df) * np.asarray(x) 48 d = np.expand_dims(np.diag(x), -1) 49 return x**2 + np.dot(d, d.T) 50 51 52class WishartCholeskyTest(test.TestCase): 53 54 def testEntropy(self): 55 with self.test_session(): 56 scale = make_pd(1., 2) 57 df = 4 58 w = distributions.WishartCholesky(df, chol(scale)) 59 # sp.stats.wishart(df=4, scale=make_pd(1., 2)).entropy() 60 self.assertAllClose(6.301387092430769, w.entropy().eval()) 61 62 w = distributions.WishartCholesky(df=1, scale=[[1.]]) 63 # sp.stats.wishart(df=1,scale=1).entropy() 64 self.assertAllClose(0.78375711047393404, w.entropy().eval()) 65 66 def testMeanLogDetAndLogNormalizingConstant(self): 67 with self.test_session(): 68 69 def entropy_alt(w): 70 return ( 71 w.log_normalization() 72 - 0.5 * (w.df - w.dimension - 1.) * w.mean_log_det() 73 + 0.5 * w.df * w.dimension).eval() 74 75 w = distributions.WishartCholesky(df=4, 76 scale=chol(make_pd(1., 2))) 77 self.assertAllClose(w.entropy().eval(), entropy_alt(w)) 78 79 w = distributions.WishartCholesky(df=5, scale=[[1.]]) 80 self.assertAllClose(w.entropy().eval(), entropy_alt(w)) 81 82 def testMean(self): 83 with self.test_session(): 84 scale = make_pd(1., 2) 85 df = 4 86 w = distributions.WishartCholesky(df, chol(scale)) 87 self.assertAllEqual(df * scale, w.mean().eval()) 88 89 def testMode(self): 90 with self.test_session(): 91 scale = make_pd(1., 2) 92 df = 4 93 w = distributions.WishartCholesky(df, chol(scale)) 94 self.assertAllEqual((df - 2. - 1.) * scale, w.mode().eval()) 95 96 def testStd(self): 97 with self.test_session(): 98 scale = make_pd(1., 2) 99 df = 4 100 w = distributions.WishartCholesky(df, chol(scale)) 101 self.assertAllEqual(chol(wishart_var(df, scale)), w.stddev().eval()) 102 103 def testVariance(self): 104 with self.test_session(): 105 scale = make_pd(1., 2) 106 df = 4 107 w = distributions.WishartCholesky(df, chol(scale)) 108 self.assertAllEqual(wishart_var(df, scale), w.variance().eval()) 109 110 def testSample(self): 111 with self.test_session(): 112 scale = make_pd(1., 2) 113 df = 4 114 115 chol_w = distributions.WishartCholesky( 116 df, chol(scale), cholesky_input_output_matrices=False) 117 118 x = chol_w.sample(1, seed=42).eval() 119 chol_x = [chol(x[0])] 120 121 full_w = distributions.WishartFull( 122 df, scale, cholesky_input_output_matrices=False) 123 self.assertAllClose(x, full_w.sample(1, seed=42).eval()) 124 125 chol_w_chol = distributions.WishartCholesky( 126 df, chol(scale), cholesky_input_output_matrices=True) 127 self.assertAllClose(chol_x, chol_w_chol.sample(1, seed=42).eval()) 128 eigen_values = array_ops.matrix_diag_part( 129 chol_w_chol.sample( 130 1000, seed=42)) 131 np.testing.assert_array_less(0., eigen_values.eval()) 132 133 full_w_chol = distributions.WishartFull( 134 df, scale, cholesky_input_output_matrices=True) 135 self.assertAllClose(chol_x, full_w_chol.sample(1, seed=42).eval()) 136 eigen_values = array_ops.matrix_diag_part( 137 full_w_chol.sample( 138 1000, seed=42)) 139 np.testing.assert_array_less(0., eigen_values.eval()) 140 141 # Check first and second moments. 142 df = 4. 143 chol_w = distributions.WishartCholesky( 144 df=df, 145 scale=chol(make_pd(1., 3)), 146 cholesky_input_output_matrices=False) 147 x = chol_w.sample(10000, seed=42) 148 self.assertAllEqual((10000, 3, 3), x.get_shape()) 149 150 moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval() 151 self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05) 152 153 # The Variance estimate uses the squares rather than outer-products 154 # because Wishart.Variance is the diagonal of the Wishart covariance 155 # matrix. 156 variance_estimate = (math_ops.reduce_mean( 157 math_ops.square(x), reduction_indices=[0]) - 158 math_ops.square(moment1_estimate)).eval() 159 self.assertAllClose( 160 chol_w.variance().eval(), variance_estimate, rtol=0.05) 161 162 # Test that sampling with the same seed twice gives the same results. 163 def testSampleMultipleTimes(self): 164 with self.test_session(): 165 df = 4. 166 n_val = 100 167 168 random_seed.set_random_seed(654321) 169 chol_w1 = distributions.WishartCholesky( 170 df=df, 171 scale=chol(make_pd(1., 3)), 172 cholesky_input_output_matrices=False, 173 name="wishart1") 174 samples1 = chol_w1.sample(n_val, seed=123456).eval() 175 176 random_seed.set_random_seed(654321) 177 chol_w2 = distributions.WishartCholesky( 178 df=df, 179 scale=chol(make_pd(1., 3)), 180 cholesky_input_output_matrices=False, 181 name="wishart2") 182 samples2 = chol_w2.sample(n_val, seed=123456).eval() 183 184 self.assertAllClose(samples1, samples2) 185 186 def testProb(self): 187 with self.test_session(): 188 # Generate some positive definite (pd) matrices and their Cholesky 189 # factorizations. 190 x = np.array( 191 [make_pd(1., 2), make_pd(2., 2), make_pd(3., 2), make_pd(4., 2)]) 192 chol_x = np.array([chol(x[0]), chol(x[1]), chol(x[2]), chol(x[3])]) 193 194 # Since Wishart wasn"t added to SciPy until 0.16, we'll spot check some 195 # pdfs with hard-coded results from upstream SciPy. 196 197 log_prob_df_seq = np.array([ 198 # math.log(stats.wishart.pdf(x[0], df=2+0, scale=x[0])) 199 -3.5310242469692907, 200 # math.log(stats.wishart.pdf(x[1], df=2+1, scale=x[1])) 201 -7.689907330328961, 202 # math.log(stats.wishart.pdf(x[2], df=2+2, scale=x[2])) 203 -10.815845159537895, 204 # math.log(stats.wishart.pdf(x[3], df=2+3, scale=x[3])) 205 -13.640549882916691, 206 ]) 207 208 # This test checks that batches don't interfere with correctness. 209 w = distributions.WishartCholesky( 210 df=[2, 3, 4, 5], 211 scale=chol_x, 212 cholesky_input_output_matrices=True) 213 self.assertAllClose(log_prob_df_seq, w.log_prob(chol_x).eval()) 214 215 # Now we test various constructions of Wishart with different sample 216 # shape. 217 218 log_prob = np.array([ 219 # math.log(stats.wishart.pdf(x[0], df=4, scale=x[0])) 220 -4.224171427529236, 221 # math.log(stats.wishart.pdf(x[1], df=4, scale=x[0])) 222 -6.3378770664093453, 223 # math.log(stats.wishart.pdf(x[2], df=4, scale=x[0])) 224 -12.026946850193017, 225 # math.log(stats.wishart.pdf(x[3], df=4, scale=x[0])) 226 -20.951582705289454, 227 ]) 228 229 for w in ( 230 distributions.WishartCholesky( 231 df=4, 232 scale=chol_x[0], 233 cholesky_input_output_matrices=False), 234 distributions.WishartFull( 235 df=4, 236 scale=x[0], 237 cholesky_input_output_matrices=False)): 238 self.assertAllEqual((2, 2), w.event_shape_tensor().eval()) 239 self.assertEqual(2, w.dimension.eval()) 240 self.assertAllClose(log_prob[0], w.log_prob(x[0]).eval()) 241 self.assertAllClose(log_prob[0:2], w.log_prob(x[0:2]).eval()) 242 self.assertAllClose( 243 np.reshape(log_prob, (2, 2)), 244 w.log_prob(np.reshape(x, (2, 2, 2, 2))).eval()) 245 self.assertAllClose( 246 np.reshape(np.exp(log_prob), (2, 2)), 247 w.prob(np.reshape(x, (2, 2, 2, 2))).eval()) 248 self.assertAllEqual((2, 2), 249 w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape()) 250 251 for w in ( 252 distributions.WishartCholesky( 253 df=4, 254 scale=chol_x[0], 255 cholesky_input_output_matrices=True), 256 distributions.WishartFull( 257 df=4, 258 scale=x[0], 259 cholesky_input_output_matrices=True)): 260 self.assertAllEqual((2, 2), w.event_shape_tensor().eval()) 261 self.assertEqual(2, w.dimension.eval()) 262 self.assertAllClose(log_prob[0], w.log_prob(chol_x[0]).eval()) 263 self.assertAllClose(log_prob[0:2], w.log_prob(chol_x[0:2]).eval()) 264 self.assertAllClose( 265 np.reshape(log_prob, (2, 2)), 266 w.log_prob(np.reshape(chol_x, (2, 2, 2, 2))).eval()) 267 self.assertAllClose( 268 np.reshape(np.exp(log_prob), (2, 2)), 269 w.prob(np.reshape(chol_x, (2, 2, 2, 2))).eval()) 270 self.assertAllEqual((2, 2), 271 w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape()) 272 273 def testBatchShape(self): 274 with self.test_session() as sess: 275 scale = make_pd(1., 2) 276 chol_scale = chol(scale) 277 278 w = distributions.WishartCholesky(df=4, scale=chol_scale) 279 self.assertAllEqual([], w.batch_shape) 280 self.assertAllEqual([], w.batch_shape_tensor().eval()) 281 282 w = distributions.WishartCholesky( 283 df=[4., 4], scale=np.array([chol_scale, chol_scale])) 284 self.assertAllEqual([2], w.batch_shape) 285 self.assertAllEqual([2], w.batch_shape_tensor().eval()) 286 287 scale_deferred = array_ops.placeholder(dtypes.float32) 288 w = distributions.WishartCholesky(df=4, scale=scale_deferred) 289 self.assertAllEqual( 290 [], sess.run(w.batch_shape_tensor(), 291 feed_dict={scale_deferred: chol_scale})) 292 self.assertAllEqual( 293 [2], 294 sess.run(w.batch_shape_tensor(), 295 feed_dict={scale_deferred: [chol_scale, chol_scale]})) 296 297 def testEventShape(self): 298 with self.test_session() as sess: 299 scale = make_pd(1., 2) 300 chol_scale = chol(scale) 301 302 w = distributions.WishartCholesky(df=4, scale=chol_scale) 303 self.assertAllEqual([2, 2], w.event_shape) 304 self.assertAllEqual([2, 2], w.event_shape_tensor().eval()) 305 306 w = distributions.WishartCholesky( 307 df=[4., 4], scale=np.array([chol_scale, chol_scale])) 308 self.assertAllEqual([2, 2], w.event_shape) 309 self.assertAllEqual([2, 2], w.event_shape_tensor().eval()) 310 311 scale_deferred = array_ops.placeholder(dtypes.float32) 312 w = distributions.WishartCholesky(df=4, scale=scale_deferred) 313 self.assertAllEqual( 314 [2, 2], 315 sess.run(w.event_shape_tensor(), 316 feed_dict={scale_deferred: chol_scale})) 317 self.assertAllEqual( 318 [2, 2], 319 sess.run(w.event_shape_tensor(), 320 feed_dict={scale_deferred: [chol_scale, chol_scale]})) 321 322 def testValidateArgs(self): 323 with self.test_session() as sess: 324 df_deferred = array_ops.placeholder(dtypes.float32) 325 chol_scale_deferred = array_ops.placeholder(dtypes.float32) 326 x = make_pd(1., 3) 327 chol_scale = chol(x) 328 329 # Check expensive, deferred assertions. 330 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 331 "cannot be less than"): 332 chol_w = distributions.WishartCholesky( 333 df=df_deferred, 334 scale=chol_scale_deferred, 335 validate_args=True) 336 sess.run(chol_w.log_prob(np.asarray( 337 x, dtype=np.float32)), 338 feed_dict={df_deferred: 2., 339 chol_scale_deferred: chol_scale}) 340 341 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 342 "Cholesky decomposition was not successful"): 343 chol_w = distributions.WishartFull( 344 df=df_deferred, scale=chol_scale_deferred) 345 # np.ones((3, 3)) is not positive, definite. 346 sess.run(chol_w.log_prob(np.asarray( 347 x, dtype=np.float32)), 348 feed_dict={ 349 df_deferred: 4., 350 chol_scale_deferred: np.ones( 351 (3, 3), dtype=np.float32) 352 }) 353 354 with self.assertRaisesOpError("scale must be square"): 355 chol_w = distributions.WishartCholesky( 356 df=4., 357 scale=np.array([[2., 3., 4.], [1., 2., 3.]], dtype=np.float32), 358 validate_args=True) 359 sess.run(chol_w.scale().eval()) 360 361 # Ensure no assertions. 362 chol_w = distributions.WishartCholesky( 363 df=df_deferred, 364 scale=chol_scale_deferred, 365 validate_args=False) 366 sess.run(chol_w.log_prob(np.asarray( 367 x, dtype=np.float32)), 368 feed_dict={df_deferred: 4, 369 chol_scale_deferred: chol_scale}) 370 # Bogus log_prob, but since we have no checks running... c"est la vie. 371 sess.run(chol_w.log_prob(np.asarray( 372 x, dtype=np.float32)), 373 feed_dict={df_deferred: 4, 374 chol_scale_deferred: np.ones((3, 3))}) 375 376 def testStaticAsserts(self): 377 with self.test_session(): 378 x = make_pd(1., 3) 379 chol_scale = chol(x) 380 381 # Still has these assertions because they're resolveable at graph 382 # construction 383 with self.assertRaisesRegexp(ValueError, "cannot be less than"): 384 distributions.WishartCholesky( 385 df=2, scale=chol_scale, validate_args=False) 386 with self.assertRaisesRegexp(TypeError, "Argument tril must have dtype"): 387 distributions.WishartCholesky( 388 df=4., 389 scale=np.asarray( 390 chol_scale, dtype=np.int32), 391 validate_args=False) 392 393 394if __name__ == "__main__": 395 test.main() 396