1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for reduction operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests.xla_test import XLATestCase 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import googletest 29 30 31class ReduceOpsTest(XLATestCase): 32 33 def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, 34 rtol=1e-4, atol=1e-4): 35 """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" 36 37 for test_input in test_inputs: 38 with self.test_session() as sess: 39 with self.test_scope(): 40 a = array_ops.placeholder(dtype) 41 index = array_ops.placeholder(dtypes.int32) 42 out = tf_reduce_fn(a, index) 43 result = sess.run(out, {a: test_input, index: [0]}) 44 self.assertAllClose(result, np_reduce_fn(test_input, axis=0), 45 rtol=rtol, atol=atol) 46 47 result = sess.run(out, {a: test_input, index: [1]}) 48 self.assertAllClose(result, np_reduce_fn(test_input, axis=1), 49 rtol=rtol, atol=atol) 50 51 result = sess.run(out, {a: test_input, index: [-1]}) 52 self.assertAllClose(result, np_reduce_fn(test_input, axis=1), 53 rtol=rtol, atol=atol) 54 55 with self.assertRaisesWithPredicateMatch( 56 errors_impl.InvalidArgumentError, 'Invalid reduction dim'): 57 sess.run(out, {a: test_input, index: [-33]}) 58 59 with self.assertRaisesWithPredicateMatch( 60 errors_impl.InvalidArgumentError, 'Invalid reduction dim'): 61 sess.run(out, {a: test_input, index: [2]}) 62 63 FLOAT_DATA = [ 64 np.zeros(shape=(2, 0)), 65 np.zeros(shape=(0, 30)), 66 np.arange(1, 7).reshape(2, 3), 67 np.arange(-10, -4).reshape(2, 3), 68 np.arange(-4, 2).reshape(2, 3), 69 ] 70 COMPLEX_DATA = [ 71 np.zeros(shape=(2, 0)).astype(np.complex64), 72 np.zeros(shape=(0, 30)).astype(np.complex64), 73 np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3), 74 np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), 75 np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), 76 ] 77 NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] 78 NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] 79 BOOL_DATA = [ 80 np.array([], dtype=np.bool).reshape(2, 0), 81 np.array([], dtype=np.bool).reshape(0, 3), 82 np.array([[False, True, False], [True, True, False]]), 83 ] 84 85 def testReduceSumF32(self): 86 self._testReduction(math_ops.reduce_sum, np.sum, np.float32, 87 self.FLOAT_DATA) 88 89 def testReduceSumC64(self): 90 self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, 91 self.COMPLEX_DATA) 92 93 def testReduceProdF32(self): 94 self._testReduction(math_ops.reduce_prod, np.prod, np.float32, 95 self.FLOAT_DATA) 96 97 def testReduceProdC64(self): 98 self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, 99 self.COMPLEX_DATA) 100 101 def testReduceMin(self): 102 103 def reference_min(inp, axis): 104 """Wrapper around np.amin that returns +infinity for an empty input.""" 105 if inp.shape[axis] == 0: 106 return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) 107 return np.amin(inp, axis) 108 109 self._testReduction(math_ops.reduce_min, reference_min, np.float32, 110 self.FLOAT_DATA) 111 112 def testReduceMax(self): 113 114 def reference_max(inp, axis): 115 """Wrapper around np.amax that returns -infinity for an empty input.""" 116 if inp.shape[axis] == 0: 117 return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf')) 118 return np.amax(inp, axis) 119 120 self._testReduction(math_ops.reduce_max, reference_max, np.float32, 121 self.FLOAT_DATA) 122 123 def testReduceMeanF32(self): 124 # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when 125 # reducing across zero inputs. 126 self._testReduction(math_ops.reduce_mean, np.mean, np.float32, 127 self.NONEMPTY_FLOAT_DATA) 128 129 def testReduceMeanC64(self): 130 self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, 131 self.NONEMPTY_COMPLEX_DATA) 132 133 def testReduceAll(self): 134 self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) 135 136 def testReduceAny(self): 137 self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) 138 139 140if __name__ == '__main__': 141 googletest.main() 142