1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" 17 18#include <vector> 19 20#include "tensorflow/compiler/xla/service/hlo_computation.h" 21#include "tensorflow/compiler/xla/service/hlo_instruction.h" 22#include "tensorflow/compiler/xla/service/hlo_module.h" 23#include "tensorflow/compiler/xla/test.h" 24#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 25#include "tensorflow/compiler/xla/util.h" 26 27#include "tensorflow/compiler/xla/test_helpers.h" 28 29namespace xla { 30namespace cpu { 31 32using ::testing::ElementsAre; 33 34class ConvCanonicalizationTest : public HloTestBase { 35 public: 36 ConvCanonicalizationTest() { 37 for (int i = 0; i < 2; ++i) { 38 auto dim = conv_window_.add_dimensions(); 39 dim->set_size(kWindowSize); 40 dim->set_stride(1); 41 dim->set_padding_low(0); 42 dim->set_padding_high(0); 43 dim->set_window_dilation(1); 44 dim->set_base_dilation(1); 45 } 46 } 47 48 protected: 49 Window conv_window_; 50 51 static constexpr int kBatchSize = 50; 52 static constexpr int kInputSize = 28; 53 static constexpr int kWindowSize = 5; 54 static constexpr int kInputFeatureCount = 32; 55 static constexpr int kOutputFeatureCount = 64; 56}; 57 58TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { 59 auto builder = HloComputation::Builder(TestName()); 60 // The input dimensions are in CNHW order. 61 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 62 Literal::CreateR4FromArray4D(Array4D<float>( 63 kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); 64 // The kernel dimensions are in OIHW order. 65 auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( 66 Literal::CreateR4FromArray4D(Array4D<float>( 67 kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); 68 69 ConvolutionDimensionNumbers dnums; 70 dnums.set_input_batch_dimension(1); 71 dnums.set_output_batch_dimension(1); 72 dnums.add_input_spatial_dimensions(2); 73 dnums.add_output_spatial_dimensions(2); 74 dnums.add_input_spatial_dimensions(3); 75 dnums.add_output_spatial_dimensions(3); 76 dnums.set_input_feature_dimension(0); 77 dnums.set_output_feature_dimension(0); 78 dnums.add_kernel_spatial_dimensions(2); 79 dnums.add_kernel_spatial_dimensions(3); 80 dnums.set_kernel_input_feature_dimension(1); 81 dnums.set_kernel_output_feature_dimension(0); 82 auto output_size = kInputSize - kWindowSize + 1; 83 builder.AddInstruction(HloInstruction::CreateConvolve( 84 ShapeUtil::MakeShape( 85 F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), 86 input, kernel, conv_window_, dnums)); 87 88 auto module = CreateNewModule(); 89 HloComputation* entry_computation = 90 module->AddEntryComputation(builder.Build()); 91 92 ConvCanonicalization conv_canonicalization; 93 EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); 94 95 const HloInstruction* output_reshape = entry_computation->root_instruction(); 96 EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); 97 const HloInstruction* canonical_conv = output_reshape->operand(0); 98 EXPECT_EQ(HloOpcode::kConvolution, canonical_conv->opcode()); 99 const HloInstruction* input_reshape = canonical_conv->operand(0); 100 EXPECT_EQ(HloOpcode::kTranspose, input_reshape->opcode()); 101 const HloInstruction* kernel_reshape = canonical_conv->operand(1); 102 EXPECT_EQ(HloOpcode::kTranspose, kernel_reshape->opcode()); 103 104 // The input is in CNHW order. input_reshape should produce 105 // NHWC for the convolution to hit the Eigen fast path. 106 EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0)); 107 // The kernel is in OIHW order. kernel_reshape should produce 108 // HWIO for the convolution to hit the Eigen fast path. 109 EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0)); 110 // The output of the canonical convolution is in NHWC order (the same as 111 // input_reshape's order). output_reshape should restore that order to the 112 // order of the computation root (CNHW). 113 EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2)); 114} 115 116TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { 117 auto builder = HloComputation::Builder(TestName()); 118 // The input dimensions are in NHWC order. 119 auto input = builder.AddInstruction(HloInstruction::CreateConstant( 120 Literal::CreateR4FromArray4D(Array4D<float>( 121 kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); 122 // The kernel dimensions are in HWIO order. 123 auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( 124 Literal::CreateR4FromArray4D(Array4D<float>( 125 kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); 126 127 ConvolutionDimensionNumbers dnums; 128 dnums.set_input_batch_dimension(0); 129 dnums.set_output_batch_dimension(0); 130 dnums.add_input_spatial_dimensions(1); 131 dnums.add_output_spatial_dimensions(1); 132 dnums.add_input_spatial_dimensions(2); 133 dnums.add_output_spatial_dimensions(2); 134 dnums.set_input_feature_dimension(3); 135 dnums.set_output_feature_dimension(3); 136 dnums.add_kernel_spatial_dimensions(0); 137 dnums.add_kernel_spatial_dimensions(1); 138 dnums.set_kernel_input_feature_dimension(2); 139 dnums.set_kernel_output_feature_dimension(3); 140 auto output_size = kInputSize - kWindowSize + 1; 141 builder.AddInstruction(HloInstruction::CreateConvolve( 142 ShapeUtil::MakeShape( 143 F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), 144 input, kernel, conv_window_, dnums)); 145 146 auto module = CreateNewModule(); 147 module->AddEntryComputation(builder.Build()); 148 149 ConvCanonicalization conv_canonicalization; 150 EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); 151} 152 153} // namespace cpu 154} // namespace xla 155