1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for reduction operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests.xla_test import XLATestCase
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import googletest
29
30
31class ReduceOpsTest(XLATestCase):
32
33  def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
34                     rtol=1e-4, atol=1e-4):
35    """Tests that the output of 'tf_reduce_fn' matches numpy's output."""
36
37    for test_input in test_inputs:
38      with self.test_session() as sess:
39        with self.test_scope():
40          a = array_ops.placeholder(dtype)
41          index = array_ops.placeholder(dtypes.int32)
42          out = tf_reduce_fn(a, index)
43        result = sess.run(out, {a: test_input, index: [0]})
44        self.assertAllClose(result, np_reduce_fn(test_input, axis=0),
45                            rtol=rtol, atol=atol)
46
47        result = sess.run(out, {a: test_input, index: [1]})
48        self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
49                            rtol=rtol, atol=atol)
50
51        result = sess.run(out, {a: test_input, index: [-1]})
52        self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
53                            rtol=rtol, atol=atol)
54
55        with self.assertRaisesWithPredicateMatch(
56            errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
57          sess.run(out, {a: test_input, index: [-33]})
58
59        with self.assertRaisesWithPredicateMatch(
60            errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
61          sess.run(out, {a: test_input, index: [2]})
62
63  FLOAT_DATA = [
64      np.zeros(shape=(2, 0)),
65      np.zeros(shape=(0, 30)),
66      np.arange(1, 7).reshape(2, 3),
67      np.arange(-10, -4).reshape(2, 3),
68      np.arange(-4, 2).reshape(2, 3),
69  ]
70  COMPLEX_DATA = [
71      np.zeros(shape=(2, 0)).astype(np.complex64),
72      np.zeros(shape=(0, 30)).astype(np.complex64),
73      np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3),
74      np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3),
75      np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3),
76  ]
77  NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0]
78  NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0]
79  BOOL_DATA = [
80      np.array([], dtype=np.bool).reshape(2, 0),
81      np.array([], dtype=np.bool).reshape(0, 3),
82      np.array([[False, True, False], [True, True, False]]),
83  ]
84
85  def testReduceSumF32(self):
86    self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
87                        self.FLOAT_DATA)
88
89  def testReduceSumC64(self):
90    self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
91                        self.COMPLEX_DATA)
92
93  def testReduceProdF32(self):
94    self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
95                        self.FLOAT_DATA)
96
97  def testReduceProdC64(self):
98    self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
99                        self.COMPLEX_DATA)
100
101  def testReduceMin(self):
102
103    def reference_min(inp, axis):
104      """Wrapper around np.amin that returns +infinity for an empty input."""
105      if inp.shape[axis] == 0:
106        return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
107      return np.amin(inp, axis)
108
109    self._testReduction(math_ops.reduce_min, reference_min, np.float32,
110                        self.FLOAT_DATA)
111
112  def testReduceMax(self):
113
114    def reference_max(inp, axis):
115      """Wrapper around np.amax that returns -infinity for an empty input."""
116      if inp.shape[axis] == 0:
117        return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf'))
118      return np.amax(inp, axis)
119
120    self._testReduction(math_ops.reduce_max, reference_max, np.float32,
121                        self.FLOAT_DATA)
122
123  def testReduceMeanF32(self):
124    # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
125    # reducing across zero inputs.
126    self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
127                        self.NONEMPTY_FLOAT_DATA)
128
129  def testReduceMeanC64(self):
130    self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
131                        self.NONEMPTY_COMPLEX_DATA)
132
133  def testReduceAll(self):
134    self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
135
136  def testReduceAny(self):
137    self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
138
139
140if __name__ == '__main__':
141  googletest.main()
142