1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.random_ops.random_gamma.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23import numpy as np 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import random_seed 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import random_ops 32from tensorflow.python.platform import test 33from tensorflow.python.platform import tf_logging 34 35 36class RandomGammaTest(test.TestCase): 37 """This is a medium test due to the moments computation taking some time.""" 38 39 def setUp(self): 40 np.random.seed(137) 41 random_seed.set_random_seed(137) 42 43 def _Sampler(self, num, alpha, beta, dtype, use_gpu, seed=None): 44 45 def func(): 46 with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess: 47 rng = random_ops.random_gamma( 48 [num], alpha, beta=beta, dtype=dtype, seed=seed) 49 ret = np.empty([10, num]) 50 for i in xrange(10): 51 ret[i, :] = sess.run(rng) 52 return ret 53 54 return func 55 56 def testMomentsFloat32(self): 57 self._testMoments(dtypes.float32) 58 59 def testMomentsFloat64(self): 60 self._testMoments(dtypes.float64) 61 62 def _testMoments(self, dt): 63 try: 64 from scipy import stats # pylint: disable=g-import-not-at-top 65 except ImportError as e: 66 tf_logging.warn("Cannot test moments: %s" % e) 67 return 68 69 # Check the given array of samples matches the given theoretical moment 70 # function at different orders. The test is considered passing if the 71 # z-tests of all statistical moments are all below z_limit. 72 # Parameters: 73 # max_moments: the largest moments of the distribution to be tested 74 # stride: the distance between samples to check for statistical properties 75 # 0 means the n-th moment of each sample 76 # any other strides tests for spatial correlation between samples; 77 # z_limit: the maximum z-test we would consider the test to pass; 78 79 # The moments test is a z-value test. This is the largest z-value 80 # we want to tolerate. Since the z-test approximates a unit normal 81 # distribution, it should almost definitely never exceed 6. 82 z_limit = 6.0 83 84 for stride in 0, 1, 4, 17: 85 alphas = [0.2, 1.0, 3.0] 86 if dt == dtypes.float64: 87 alphas = [0.01] + alphas 88 for alpha in alphas: 89 for scale in 9, 17: 90 # Gamma moments only defined for values less than the scale param. 91 max_moment = min(6, scale // 2) 92 sampler = self._Sampler( 93 20000, alpha, 1 / scale, dt, use_gpu=False, seed=12345) 94 moments = [0] * (max_moment + 1) 95 moments_sample_count = [0] * (max_moment + 1) 96 x = np.array(sampler().flat) # sampler does 10x samples 97 for k in range(len(x)): 98 moment = 1. 99 for i in range(max_moment + 1): 100 index = k + i * stride 101 if index >= len(x): 102 break 103 moments[i] += moment 104 moments_sample_count[i] += 1 105 moment *= x[index] 106 for i in range(max_moment + 1): 107 moments[i] /= moments_sample_count[i] 108 for i in range(1, max_moment + 1): 109 g = stats.gamma(alpha, scale=scale) 110 if stride == 0: 111 moments_i_mean = g.moment(i) 112 moments_i_squared = g.moment(2 * i) 113 else: 114 moments_i_mean = pow(g.moment(1), i) 115 moments_i_squared = pow(g.moment(2), i) 116 # Calculate moment variance safely: 117 # This is just 118 # (moments_i_squared - moments_i_mean**2) / moments_sample_count[i] 119 normalized_moments_i_var = ( 120 moments_i_mean / moments_sample_count[i] * 121 (moments_i_squared / moments_i_mean - moments_i_mean)) 122 # Assume every operation has a small numerical error. 123 # It takes i multiplications to calculate one i-th moment. 124 error_per_moment = i * np.finfo(dt.as_numpy_dtype).eps 125 total_variance = (normalized_moments_i_var + error_per_moment) 126 tiny = np.finfo(dt.as_numpy_dtype).tiny 127 self.assertGreaterEqual(total_variance, 0) 128 if total_variance < tiny: 129 total_variance = tiny 130 # z_test is approximately a unit normal distribution. 131 z_test = abs( 132 (moments[i] - moments_i_mean) / math.sqrt(total_variance)) 133 self.assertLess(z_test, z_limit) 134 135 def _testZeroDensity(self, alpha): 136 """Zero isn't in the support of the gamma distribution. 137 138 But quantized floating point math has its limits. 139 TODO(bjp): Implement log-gamma sampler for small-shape distributions. 140 141 Args: 142 alpha: float shape value to test 143 """ 144 try: 145 from scipy import stats # pylint: disable=g-import-not-at-top 146 except ImportError as e: 147 tf_logging.warn("Cannot test zero density proportions: %s" % e) 148 return 149 allowable_zeros = { 150 dtypes.float16: stats.gamma(alpha).cdf(np.finfo(np.float16).tiny), 151 dtypes.float32: stats.gamma(alpha).cdf(np.finfo(np.float32).tiny), 152 dtypes.float64: stats.gamma(alpha).cdf(np.finfo(np.float64).tiny) 153 } 154 failures = [] 155 for use_gpu in [False, True]: 156 for dt in dtypes.float16, dtypes.float32, dtypes.float64: 157 sampler = self._Sampler( 158 10000, alpha, 1.0, dt, use_gpu=use_gpu, seed=12345) 159 x = sampler() 160 allowable = allowable_zeros[dt] * x.size 161 allowable = allowable * 2 if allowable < 10 else allowable * 1.05 162 if np.sum(x <= 0) > allowable: 163 failures += [(use_gpu, dt)] 164 self.assertEqual([], failures) 165 166 def testNonZeroSmallShape(self): 167 self._testZeroDensity(0.01) 168 169 def testNonZeroSmallishShape(self): 170 self._testZeroDensity(0.35) 171 172 # Asserts that different trials (1000 samples per trial) is unlikely 173 # to see the same sequence of values. Will catch buggy 174 # implementations which uses the same random number seed. 175 def testDistinct(self): 176 for use_gpu in [False, True]: 177 for dt in dtypes.float16, dtypes.float32, dtypes.float64: 178 sampler = self._Sampler(1000, 2.0, 1.0, dt, use_gpu=use_gpu) 179 x = sampler() 180 y = sampler() 181 # Number of different samples. 182 count = (x == y).sum() 183 count_limit = 20 if dt == dtypes.float16 else 10 184 if count >= count_limit: 185 print(use_gpu, dt) 186 print("x = ", x) 187 print("y = ", y) 188 print("count = ", count) 189 self.assertLess(count, count_limit) 190 191 # Checks that the CPU and GPU implementation returns the same results, 192 # given the same random seed 193 def testCPUGPUMatch(self): 194 for dt in dtypes.float16, dtypes.float32, dtypes.float64: 195 results = {} 196 for use_gpu in [False, True]: 197 sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=12345) 198 results[use_gpu] = sampler() 199 if dt == dtypes.float16: 200 self.assertAllClose(results[False], results[True], rtol=1e-3, atol=1e-3) 201 else: 202 self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6) 203 204 def testSeed(self): 205 for use_gpu in [False, True]: 206 for dt in dtypes.float16, dtypes.float32, dtypes.float64: 207 sx = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345) 208 sy = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345) 209 self.assertAllEqual(sx(), sy()) 210 211 def testNoCSE(self): 212 """CSE = constant subexpression eliminator. 213 214 SetIsStateful() should prevent two identical random ops from getting 215 merged. 216 """ 217 for dtype in dtypes.float16, dtypes.float32, dtypes.float64: 218 for use_gpu in [False, True]: 219 with self.test_session(use_gpu=use_gpu): 220 rnd1 = random_ops.random_gamma([24], 2.0, dtype=dtype) 221 rnd2 = random_ops.random_gamma([24], 2.0, dtype=dtype) 222 diff = rnd2 - rnd1 223 self.assertGreater(np.linalg.norm(diff.eval()), 0.1) 224 225 def testShape(self): 226 # Fully known shape. 227 rnd = random_ops.random_gamma([150], 2.0) 228 self.assertEqual([150], rnd.get_shape().as_list()) 229 rnd = random_ops.random_gamma([150], 2.0, beta=[3.0, 4.0]) 230 self.assertEqual([150, 2], rnd.get_shape().as_list()) 231 rnd = random_ops.random_gamma([150], array_ops.ones([1, 2, 3])) 232 self.assertEqual([150, 1, 2, 3], rnd.get_shape().as_list()) 233 rnd = random_ops.random_gamma([20, 30], array_ops.ones([1, 2, 3])) 234 self.assertEqual([20, 30, 1, 2, 3], rnd.get_shape().as_list()) 235 rnd = random_ops.random_gamma( 236 [123], array_ops.placeholder( 237 dtypes.float32, shape=(2,))) 238 self.assertEqual([123, 2], rnd.get_shape().as_list()) 239 # Partially known shape. 240 rnd = random_ops.random_gamma( 241 array_ops.placeholder( 242 dtypes.int32, shape=(1,)), array_ops.ones([7, 3])) 243 self.assertEqual([None, 7, 3], rnd.get_shape().as_list()) 244 rnd = random_ops.random_gamma( 245 array_ops.placeholder( 246 dtypes.int32, shape=(3,)), array_ops.ones([9, 6])) 247 self.assertEqual([None, None, None, 9, 6], rnd.get_shape().as_list()) 248 # Unknown shape. 249 rnd = random_ops.random_gamma( 250 array_ops.placeholder(dtypes.int32), 251 array_ops.placeholder(dtypes.float32)) 252 self.assertIs(None, rnd.get_shape().ndims) 253 rnd = random_ops.random_gamma([50], array_ops.placeholder(dtypes.float32)) 254 self.assertIs(None, rnd.get_shape().ndims) 255 256 def testPositive(self): 257 n = int(10e3) 258 for dt in [dtypes.float16, dtypes.float32, dtypes.float64]: 259 with self.test_session(): 260 x = random_ops.random_gamma(shape=[n], alpha=0.001, dtype=dt, seed=0) 261 self.assertEqual(0, math_ops.reduce_sum(math_ops.cast( 262 math_ops.less_equal(x, 0.), dtype=dtypes.int64)).eval()) 263 264 265if __name__ == "__main__": 266 test.main() 267