1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.random_ops.random_gamma."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import numpy as np
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import random_seed
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import random_ops
32from tensorflow.python.platform import test
33from tensorflow.python.platform import tf_logging
34
35
36class RandomGammaTest(test.TestCase):
37  """This is a medium test due to the moments computation taking some time."""
38
39  def setUp(self):
40    np.random.seed(137)
41    random_seed.set_random_seed(137)
42
43  def _Sampler(self, num, alpha, beta, dtype, use_gpu, seed=None):
44
45    def func():
46      with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess:
47        rng = random_ops.random_gamma(
48            [num], alpha, beta=beta, dtype=dtype, seed=seed)
49        ret = np.empty([10, num])
50        for i in xrange(10):
51          ret[i, :] = sess.run(rng)
52      return ret
53
54    return func
55
56  def testMomentsFloat32(self):
57    self._testMoments(dtypes.float32)
58
59  def testMomentsFloat64(self):
60    self._testMoments(dtypes.float64)
61
62  def _testMoments(self, dt):
63    try:
64      from scipy import stats  # pylint: disable=g-import-not-at-top
65    except ImportError as e:
66      tf_logging.warn("Cannot test moments: %s" % e)
67      return
68
69    # Check the given array of samples matches the given theoretical moment
70    # function at different orders. The test is considered passing if the
71    # z-tests of all statistical moments are all below z_limit.
72    # Parameters:
73    #   max_moments: the largest moments of the distribution to be tested
74    #   stride: the distance between samples to check for statistical properties
75    #       0 means the n-th moment of each sample
76    #       any other strides tests for spatial correlation between samples;
77    #   z_limit: the maximum z-test we would consider the test to pass;
78
79    # The moments test is a z-value test.  This is the largest z-value
80    # we want to tolerate. Since the z-test approximates a unit normal
81    # distribution, it should almost definitely never exceed 6.
82    z_limit = 6.0
83
84    for stride in 0, 1, 4, 17:
85      alphas = [0.2, 1.0, 3.0]
86      if dt == dtypes.float64:
87        alphas = [0.01] + alphas
88      for alpha in alphas:
89        for scale in 9, 17:
90          # Gamma moments only defined for values less than the scale param.
91          max_moment = min(6, scale // 2)
92          sampler = self._Sampler(
93              20000, alpha, 1 / scale, dt, use_gpu=False, seed=12345)
94          moments = [0] * (max_moment + 1)
95          moments_sample_count = [0] * (max_moment + 1)
96          x = np.array(sampler().flat)  # sampler does 10x samples
97          for k in range(len(x)):
98            moment = 1.
99            for i in range(max_moment + 1):
100              index = k + i * stride
101              if index >= len(x):
102                break
103              moments[i] += moment
104              moments_sample_count[i] += 1
105              moment *= x[index]
106          for i in range(max_moment + 1):
107            moments[i] /= moments_sample_count[i]
108          for i in range(1, max_moment + 1):
109            g = stats.gamma(alpha, scale=scale)
110            if stride == 0:
111              moments_i_mean = g.moment(i)
112              moments_i_squared = g.moment(2 * i)
113            else:
114              moments_i_mean = pow(g.moment(1), i)
115              moments_i_squared = pow(g.moment(2), i)
116            # Calculate moment variance safely:
117            # This is just
118            #  (moments_i_squared - moments_i_mean**2) / moments_sample_count[i]
119            normalized_moments_i_var = (
120                moments_i_mean / moments_sample_count[i] *
121                (moments_i_squared / moments_i_mean - moments_i_mean))
122            # Assume every operation has a small numerical error.
123            # It takes i multiplications to calculate one i-th moment.
124            error_per_moment = i * np.finfo(dt.as_numpy_dtype).eps
125            total_variance = (normalized_moments_i_var + error_per_moment)
126            tiny = np.finfo(dt.as_numpy_dtype).tiny
127            self.assertGreaterEqual(total_variance, 0)
128            if total_variance < tiny:
129              total_variance = tiny
130            # z_test is approximately a unit normal distribution.
131            z_test = abs(
132                (moments[i] - moments_i_mean) / math.sqrt(total_variance))
133            self.assertLess(z_test, z_limit)
134
135  def _testZeroDensity(self, alpha):
136    """Zero isn't in the support of the gamma distribution.
137
138    But quantized floating point math has its limits.
139    TODO(bjp): Implement log-gamma sampler for small-shape distributions.
140
141    Args:
142      alpha: float shape value to test
143    """
144    try:
145      from scipy import stats  # pylint: disable=g-import-not-at-top
146    except ImportError as e:
147      tf_logging.warn("Cannot test zero density proportions: %s" % e)
148      return
149    allowable_zeros = {
150        dtypes.float16: stats.gamma(alpha).cdf(np.finfo(np.float16).tiny),
151        dtypes.float32: stats.gamma(alpha).cdf(np.finfo(np.float32).tiny),
152        dtypes.float64: stats.gamma(alpha).cdf(np.finfo(np.float64).tiny)
153    }
154    failures = []
155    for use_gpu in [False, True]:
156      for dt in dtypes.float16, dtypes.float32, dtypes.float64:
157        sampler = self._Sampler(
158            10000, alpha, 1.0, dt, use_gpu=use_gpu, seed=12345)
159        x = sampler()
160        allowable = allowable_zeros[dt] * x.size
161        allowable = allowable * 2 if allowable < 10 else allowable * 1.05
162        if np.sum(x <= 0) > allowable:
163          failures += [(use_gpu, dt)]
164      self.assertEqual([], failures)
165
166  def testNonZeroSmallShape(self):
167    self._testZeroDensity(0.01)
168
169  def testNonZeroSmallishShape(self):
170    self._testZeroDensity(0.35)
171
172  # Asserts that different trials (1000 samples per trial) is unlikely
173  # to see the same sequence of values. Will catch buggy
174  # implementations which uses the same random number seed.
175  def testDistinct(self):
176    for use_gpu in [False, True]:
177      for dt in dtypes.float16, dtypes.float32, dtypes.float64:
178        sampler = self._Sampler(1000, 2.0, 1.0, dt, use_gpu=use_gpu)
179        x = sampler()
180        y = sampler()
181        # Number of different samples.
182        count = (x == y).sum()
183        count_limit = 20 if dt == dtypes.float16 else 10
184        if count >= count_limit:
185          print(use_gpu, dt)
186          print("x = ", x)
187          print("y = ", y)
188          print("count = ", count)
189        self.assertLess(count, count_limit)
190
191  # Checks that the CPU and GPU implementation returns the same results,
192  # given the same random seed
193  def testCPUGPUMatch(self):
194    for dt in dtypes.float16, dtypes.float32, dtypes.float64:
195      results = {}
196      for use_gpu in [False, True]:
197        sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=12345)
198        results[use_gpu] = sampler()
199      if dt == dtypes.float16:
200        self.assertAllClose(results[False], results[True], rtol=1e-3, atol=1e-3)
201      else:
202        self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
203
204  def testSeed(self):
205    for use_gpu in [False, True]:
206      for dt in dtypes.float16, dtypes.float32, dtypes.float64:
207        sx = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
208        sy = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
209        self.assertAllEqual(sx(), sy())
210
211  def testNoCSE(self):
212    """CSE = constant subexpression eliminator.
213
214    SetIsStateful() should prevent two identical random ops from getting
215    merged.
216    """
217    for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
218      for use_gpu in [False, True]:
219        with self.test_session(use_gpu=use_gpu):
220          rnd1 = random_ops.random_gamma([24], 2.0, dtype=dtype)
221          rnd2 = random_ops.random_gamma([24], 2.0, dtype=dtype)
222          diff = rnd2 - rnd1
223          self.assertGreater(np.linalg.norm(diff.eval()), 0.1)
224
225  def testShape(self):
226    # Fully known shape.
227    rnd = random_ops.random_gamma([150], 2.0)
228    self.assertEqual([150], rnd.get_shape().as_list())
229    rnd = random_ops.random_gamma([150], 2.0, beta=[3.0, 4.0])
230    self.assertEqual([150, 2], rnd.get_shape().as_list())
231    rnd = random_ops.random_gamma([150], array_ops.ones([1, 2, 3]))
232    self.assertEqual([150, 1, 2, 3], rnd.get_shape().as_list())
233    rnd = random_ops.random_gamma([20, 30], array_ops.ones([1, 2, 3]))
234    self.assertEqual([20, 30, 1, 2, 3], rnd.get_shape().as_list())
235    rnd = random_ops.random_gamma(
236        [123], array_ops.placeholder(
237            dtypes.float32, shape=(2,)))
238    self.assertEqual([123, 2], rnd.get_shape().as_list())
239    # Partially known shape.
240    rnd = random_ops.random_gamma(
241        array_ops.placeholder(
242            dtypes.int32, shape=(1,)), array_ops.ones([7, 3]))
243    self.assertEqual([None, 7, 3], rnd.get_shape().as_list())
244    rnd = random_ops.random_gamma(
245        array_ops.placeholder(
246            dtypes.int32, shape=(3,)), array_ops.ones([9, 6]))
247    self.assertEqual([None, None, None, 9, 6], rnd.get_shape().as_list())
248    # Unknown shape.
249    rnd = random_ops.random_gamma(
250        array_ops.placeholder(dtypes.int32),
251        array_ops.placeholder(dtypes.float32))
252    self.assertIs(None, rnd.get_shape().ndims)
253    rnd = random_ops.random_gamma([50], array_ops.placeholder(dtypes.float32))
254    self.assertIs(None, rnd.get_shape().ndims)
255
256  def testPositive(self):
257    n = int(10e3)
258    for dt in [dtypes.float16, dtypes.float32, dtypes.float64]:
259      with self.test_session():
260        x = random_ops.random_gamma(shape=[n], alpha=0.001, dtype=dt, seed=0)
261        self.assertEqual(0, math_ops.reduce_sum(math_ops.cast(
262            math_ops.less_equal(x, 0.), dtype=dtypes.int64)).eval())
263
264
265if __name__ == "__main__":
266  test.main()
267