1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <memory> 17 18#include "tensorflow/compiler/xla/array2d.h" 19#include "tensorflow/compiler/xla/client/computation_builder.h" 20#include "tensorflow/compiler/xla/client/local_client.h" 21#include "tensorflow/compiler/xla/reference_util.h" 22#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 23#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 24#include "tensorflow/compiler/xla/tests/literal_test_util.h" 25#include "tensorflow/compiler/xla/tests/test_macros.h" 26#include "tensorflow/compiler/xla/xla_data.pb.h" 27#include "tensorflow/core/platform/test.h" 28 29namespace xla { 30namespace { 31 32class TransposeTest : public ClientLibraryTestBase { 33 public: 34 ErrorSpec error_spec_{0.0001}; 35 36 protected: 37 void TestTransposeConstant021(size_t n1, size_t n2, size_t n3); 38}; 39 40XLA_TEST_F(TransposeTest, Transpose0x0) { 41 ComputationBuilder builder(client_, "Transpose"); 42 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0)); 43 auto result = builder.Transpose(lhs, {1, 0}); 44 45 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_); 46} 47 48XLA_TEST_F(TransposeTest, Transpose0x42) { 49 ComputationBuilder builder(client_, "Transpose"); 50 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42)); 51 auto result = builder.Transpose(lhs, {1, 0}); 52 53 ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_); 54} 55 56XLA_TEST_F(TransposeTest, Transpose7x0) { 57 ComputationBuilder builder(client_, "Transpose"); 58 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0)); 59 auto result = builder.Transpose(lhs, {1, 0}); 60 61 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_); 62} 63 64TEST_F(TransposeTest, Transpose2x2) { 65 ComputationBuilder builder(client_, "Transpose"); 66 auto lhs = builder.ConstantR2<float>({ 67 {1.0, 2.0}, {3.0, 4.0}, 68 }); 69 auto result = builder.Transpose(lhs, {1, 0}); 70 71 Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); 72 73 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 74} 75 76XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { 77 ComputationBuilder builder(client_, "Transpose"); 78 auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3)); 79 auto result = builder.Transpose(operand, {1, 2, 0}); 80 81 ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {}); 82} 83 84TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { 85 ComputationBuilder builder(client_, "Transpose"); 86 auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}}); 87 auto result = builder.Transpose(operand, {1, 2, 0}); 88 89 Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}}); 90 91 ComputeAndCompareR3<int32>(&builder, expected, {}); 92} 93 94TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { 95 ComputationBuilder builder(client_, "Transpose"); 96 auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}}); 97 auto result = builder.Transpose(operand, {2, 1, 0}); 98 99 Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}}); 100 101 ComputeAndCompareR3<int32>(&builder, expected, {}); 102} 103 104TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { 105 ComputationBuilder builder(client_, "Transpose"); 106 auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}}); 107 auto result = builder.Transpose(operand, {0, 1, 2}); 108 109 Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}}); 110 111 ComputeAndCompareR3<int32>(&builder, expected, {}); 112} 113 114TEST_F(TransposeTest, MultiTranspose3x2) { 115 Array2D<float> input({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}); 116 Array2D<float> transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}}); 117 118 for (int transposes = 0; transposes <= 10; ++transposes) { 119 ComputationBuilder builder(client_, "Transpose"); 120 auto computed = builder.ConstantR2FromArray2D<float>(input); 121 for (int i = 0; i < transposes; ++i) { 122 computed = builder.Transpose(computed, {1, 0}); 123 } 124 const Array2D<float>& expected = transposes % 2 == 0 ? input : transposed; 125 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); 126 } 127} 128 129// Test for transposing [1x1] matrix. 130TEST_F(TransposeTest, Small_1x1) { 131 auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); 132 133 ComputationBuilder builder(client_, "transpose_1x1"); 134 auto operand = builder.ConstantR2FromArray2D<float>(*aoperand); 135 builder.Transpose(operand, {1, 0}); 136 137 auto expected = ReferenceUtil::TransposeArray2D(*aoperand); 138 ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4)); 139} 140 141// Test for transposing [2x2] matrix. 142TEST_F(TransposeTest, Small_2x2) { 143 auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); 144 145 ComputationBuilder builder(client_, "transpose_2x2"); 146 auto operand = builder.ConstantR2FromArray2D<float>(*aoperand); 147 builder.Transpose(operand, {1, 0}); 148 149 auto expected = ReferenceUtil::TransposeArray2D(*aoperand); 150 ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4)); 151} 152 153void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { 154 Array3D<int32> aoperand(n1, n2, n3); 155 Array3D<int32> expected(n1, n3, n2); 156 for (size_t i = 0; i < n1; ++i) { 157 for (size_t j = 0; j < n2; ++j) { 158 for (size_t k = 0; k < n3; ++k) { 159 aoperand(i, j, k) = i * n3 * n2 + j * n3 + k; 160 expected(i, k, j) = aoperand(i, j, k); 161 } 162 } 163 } 164 165 ComputationBuilder builder(client_, TestName()); 166 auto operand = builder.ConstantR3FromArray3D(aoperand); 167 builder.Transpose(operand, {0, 2, 1}); 168 169 ComputeAndCompareR3<int32>(&builder, expected, {}); 170} 171 172TEST_F(TransposeTest, TransposeConstant021_SingleIncompleteTilePerLayer) { 173 TestTransposeConstant021(2, 2, 3); 174} 175 176TEST_F(TransposeTest, TransposeConstant021_SingleCompleteTilePerLayer) { 177 TestTransposeConstant021(2, 32, 32); 178} 179 180TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) { 181 TestTransposeConstant021(2, 70, 35); 182} 183 184} // namespace 185} // namespace xla 186