1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17
18#include "tensorflow/compiler/xla/array2d.h"
19#include "tensorflow/compiler/xla/client/computation_builder.h"
20#include "tensorflow/compiler/xla/client/local_client.h"
21#include "tensorflow/compiler/xla/reference_util.h"
22#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24#include "tensorflow/compiler/xla/tests/literal_test_util.h"
25#include "tensorflow/compiler/xla/tests/test_macros.h"
26#include "tensorflow/compiler/xla/xla_data.pb.h"
27#include "tensorflow/core/platform/test.h"
28
29namespace xla {
30namespace {
31
32class TransposeTest : public ClientLibraryTestBase {
33 public:
34  ErrorSpec error_spec_{0.0001};
35
36 protected:
37  void TestTransposeConstant021(size_t n1, size_t n2, size_t n3);
38};
39
40XLA_TEST_F(TransposeTest, Transpose0x0) {
41  ComputationBuilder builder(client_, "Transpose");
42  auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
43  auto result = builder.Transpose(lhs, {1, 0});
44
45  ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
46}
47
48XLA_TEST_F(TransposeTest, Transpose0x42) {
49  ComputationBuilder builder(client_, "Transpose");
50  auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42));
51  auto result = builder.Transpose(lhs, {1, 0});
52
53  ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_);
54}
55
56XLA_TEST_F(TransposeTest, Transpose7x0) {
57  ComputationBuilder builder(client_, "Transpose");
58  auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0));
59  auto result = builder.Transpose(lhs, {1, 0});
60
61  ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_);
62}
63
64TEST_F(TransposeTest, Transpose2x2) {
65  ComputationBuilder builder(client_, "Transpose");
66  auto lhs = builder.ConstantR2<float>({
67      {1.0, 2.0}, {3.0, 4.0},
68  });
69  auto result = builder.Transpose(lhs, {1, 0});
70
71  Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}});
72
73  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
74}
75
76XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) {
77  ComputationBuilder builder(client_, "Transpose");
78  auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3));
79  auto result = builder.Transpose(operand, {1, 2, 0});
80
81  ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {});
82}
83
84TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
85  ComputationBuilder builder(client_, "Transpose");
86  auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
87  auto result = builder.Transpose(operand, {1, 2, 0});
88
89  Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}});
90
91  ComputeAndCompareR3<int32>(&builder, expected, {});
92}
93
94TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
95  ComputationBuilder builder(client_, "Transpose");
96  auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
97  auto result = builder.Transpose(operand, {2, 1, 0});
98
99  Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}});
100
101  ComputeAndCompareR3<int32>(&builder, expected, {});
102}
103
104TEST_F(TransposeTest, Transpose1x2x3_1x2x3) {
105  ComputationBuilder builder(client_, "Transpose");
106  auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
107  auto result = builder.Transpose(operand, {0, 1, 2});
108
109  Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}});
110
111  ComputeAndCompareR3<int32>(&builder, expected, {});
112}
113
114TEST_F(TransposeTest, MultiTranspose3x2) {
115  Array2D<float> input({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}});
116  Array2D<float> transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}});
117
118  for (int transposes = 0; transposes <= 10; ++transposes) {
119    ComputationBuilder builder(client_, "Transpose");
120    auto computed = builder.ConstantR2FromArray2D<float>(input);
121    for (int i = 0; i < transposes; ++i) {
122      computed = builder.Transpose(computed, {1, 0});
123    }
124    const Array2D<float>& expected = transposes % 2 == 0 ? input : transposed;
125    ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
126  }
127}
128
129// Test for transposing [1x1] matrix.
130TEST_F(TransposeTest, Small_1x1) {
131  auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1);
132
133  ComputationBuilder builder(client_, "transpose_1x1");
134  auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
135  builder.Transpose(operand, {1, 0});
136
137  auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
138  ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
139}
140
141// Test for transposing [2x2] matrix.
142TEST_F(TransposeTest, Small_2x2) {
143  auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2);
144
145  ComputationBuilder builder(client_, "transpose_2x2");
146  auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
147  builder.Transpose(operand, {1, 0});
148
149  auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
150  ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
151}
152
153void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) {
154  Array3D<int32> aoperand(n1, n2, n3);
155  Array3D<int32> expected(n1, n3, n2);
156  for (size_t i = 0; i < n1; ++i) {
157    for (size_t j = 0; j < n2; ++j) {
158      for (size_t k = 0; k < n3; ++k) {
159        aoperand(i, j, k) = i * n3 * n2 + j * n3 + k;
160        expected(i, k, j) = aoperand(i, j, k);
161      }
162    }
163  }
164
165  ComputationBuilder builder(client_, TestName());
166  auto operand = builder.ConstantR3FromArray3D(aoperand);
167  builder.Transpose(operand, {0, 2, 1});
168
169  ComputeAndCompareR3<int32>(&builder, expected, {});
170}
171
172TEST_F(TransposeTest, TransposeConstant021_SingleIncompleteTilePerLayer) {
173  TestTransposeConstant021(2, 2, 3);
174}
175
176TEST_F(TransposeTest, TransposeConstant021_SingleCompleteTilePerLayer) {
177  TestTransposeConstant021(2, 32, 32);
178}
179
180TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) {
181  TestTransposeConstant021(2, 70, 35);
182}
183
184}  // namespace
185}  // namespace xla
186