1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for segment reduction ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import numpy as np
23
24from tensorflow.compiler.tests.xla_test import XLATestCase
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import googletest
28
29
30class SegmentReductionOpsTest(XLATestCase):
31  """Test cases for segment reduction ops."""
32
33  def UnsortedSegmentSum(self, data, indices, num_segments):
34    with self.test_session() as sess, self.test_scope():
35      d = array_ops.placeholder(data.dtype, shape=data.shape)
36      if isinstance(indices, int):
37        i = array_ops.placeholder(np.int32, shape=[])
38      else:
39        i = array_ops.placeholder(indices.dtype, shape=indices.shape)
40      return sess.run(
41          math_ops.unsorted_segment_sum(d, i, num_segments),
42          {d: data,
43           i: indices})
44
45  def testUnsortedSegmentSum0DIndices1DData(self):
46    for dtype in self.numeric_types:
47      self.assertAllClose(
48          np.array(
49              [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
50               [0, 0, 0, 0, 0, 0]],
51              dtype=dtype),
52          self.UnsortedSegmentSum(
53              np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
54
55  def testUnsortedSegmentSum1DIndices1DData(self):
56    for dtype in self.numeric_types:
57      self.assertAllClose(
58          np.array([1, 3, 2, 9], dtype=dtype),
59          self.UnsortedSegmentSum(
60              np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
61              np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
62
63  def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self):
64    for dtype in self.numeric_types:
65      self.assertAllClose(
66          np.array([6, 3, 0, 6], dtype=dtype),
67          self.UnsortedSegmentSum(
68              np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
69              np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
70
71  def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
72    for dtype in self.numeric_types:
73      data = np.array(
74          [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
75           [50, 51, 52, 53]],
76          dtype=dtype)
77      indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
78      num_segments = 10
79      y = self.UnsortedSegmentSum(data, indices, num_segments)
80      self.assertAllClose(
81          np.array(
82              [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
83               [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
84               [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]],
85              dtype=dtype), y)
86
87  def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
88    for dtype in self.numeric_types:
89      data = np.array(
90          [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
91           [50, 51, 52, 53]],
92          dtype=dtype)
93      indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
94      num_segments = 4
95      y = self.UnsortedSegmentSum(data, indices, num_segments)
96      self.assertAllClose(
97          np.array(
98              [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
99               [0, 0, 0, 0]],
100              dtype=dtype), y)
101
102  def testUnsortedSegmentSum2DIndices3DData(self):
103    for dtype in self.numeric_types:
104      data = np.array(
105          [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
106           [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
107                                                [310, 311, 312]]],
108          dtype=dtype)
109      indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
110      num_segments = 8
111      y = self.UnsortedSegmentSum(data, indices, num_segments)
112      self.assertAllClose(
113          np.array(
114              [[210, 211, 212], [110, 111, 112], [310, 311, 312],
115               [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301,
116                                                              302], [0, 0, 0]],
117              dtype=dtype), y)
118
119  def testUnsortedSegmentSum1DIndices3DData(self):
120    for dtype in self.numeric_types:
121      data = np.array(
122          [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
123           [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
124                                                [310, 311, 312]]],
125          dtype=dtype)
126      indices = np.array([3, 0, 2, 5], dtype=np.int32)
127      num_segments = 6
128      y = self.UnsortedSegmentSum(data, indices, num_segments)
129      self.assertAllClose(
130          np.array(
131              [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
132               [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]],
133               [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]],
134              dtype=dtype), y)
135
136  def testUnsortedSegmentSumShapeError(self):
137    for dtype in self.numeric_types:
138      data = np.ones((4, 8, 7), dtype=dtype)
139      indices = np.ones((3, 2), dtype=np.int32)
140      num_segments = 4
141      self.assertRaises(ValueError,
142                        functools.partial(self.UnsortedSegmentSum, data,
143                                          indices, num_segments))
144
145
146if __name__ == '__main__':
147  googletest.main()
148