1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
17
18#include <vector>
19
20#include "tensorflow/compiler/xla/service/hlo_computation.h"
21#include "tensorflow/compiler/xla/service/hlo_instruction.h"
22#include "tensorflow/compiler/xla/service/hlo_module.h"
23#include "tensorflow/compiler/xla/test.h"
24#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25#include "tensorflow/compiler/xla/util.h"
26
27#include "tensorflow/compiler/xla/test_helpers.h"
28
29namespace xla {
30namespace cpu {
31
32using ::testing::ElementsAre;
33
34class ConvCanonicalizationTest : public HloTestBase {
35 public:
36  ConvCanonicalizationTest() {
37    for (int i = 0; i < 2; ++i) {
38      auto dim = conv_window_.add_dimensions();
39      dim->set_size(kWindowSize);
40      dim->set_stride(1);
41      dim->set_padding_low(0);
42      dim->set_padding_high(0);
43      dim->set_window_dilation(1);
44      dim->set_base_dilation(1);
45    }
46  }
47
48 protected:
49  Window conv_window_;
50
51  static constexpr int kBatchSize = 50;
52  static constexpr int kInputSize = 28;
53  static constexpr int kWindowSize = 5;
54  static constexpr int kInputFeatureCount = 32;
55  static constexpr int kOutputFeatureCount = 64;
56};
57
58TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
59  auto builder = HloComputation::Builder(TestName());
60  // The input dimensions are in CNHW order.
61  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
62      Literal::CreateR4FromArray4D(Array4D<float>(
63          kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
64  // The kernel dimensions are in OIHW order.
65  auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
66      Literal::CreateR4FromArray4D(Array4D<float>(
67          kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
68
69  ConvolutionDimensionNumbers dnums;
70  dnums.set_input_batch_dimension(1);
71  dnums.set_output_batch_dimension(1);
72  dnums.add_input_spatial_dimensions(2);
73  dnums.add_output_spatial_dimensions(2);
74  dnums.add_input_spatial_dimensions(3);
75  dnums.add_output_spatial_dimensions(3);
76  dnums.set_input_feature_dimension(0);
77  dnums.set_output_feature_dimension(0);
78  dnums.add_kernel_spatial_dimensions(2);
79  dnums.add_kernel_spatial_dimensions(3);
80  dnums.set_kernel_input_feature_dimension(1);
81  dnums.set_kernel_output_feature_dimension(0);
82  auto output_size = kInputSize - kWindowSize + 1;
83  builder.AddInstruction(HloInstruction::CreateConvolve(
84      ShapeUtil::MakeShape(
85          F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
86      input, kernel, conv_window_, dnums));
87
88  auto module = CreateNewModule();
89  HloComputation* entry_computation =
90      module->AddEntryComputation(builder.Build());
91
92  ConvCanonicalization conv_canonicalization;
93  EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
94
95  const HloInstruction* output_reshape = entry_computation->root_instruction();
96  EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
97  const HloInstruction* canonical_conv = output_reshape->operand(0);
98  EXPECT_EQ(HloOpcode::kConvolution, canonical_conv->opcode());
99  const HloInstruction* input_reshape = canonical_conv->operand(0);
100  EXPECT_EQ(HloOpcode::kTranspose, input_reshape->opcode());
101  const HloInstruction* kernel_reshape = canonical_conv->operand(1);
102  EXPECT_EQ(HloOpcode::kTranspose, kernel_reshape->opcode());
103
104  // The input is in CNHW order. input_reshape should produce
105  // NHWC for the convolution to hit the Eigen fast path.
106  EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0));
107  // The kernel is in OIHW order. kernel_reshape should produce
108  // HWIO for the convolution to hit the Eigen fast path.
109  EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0));
110  // The output of the canonical convolution is in NHWC order (the same as
111  // input_reshape's order). output_reshape should restore that order to the
112  // order of the computation root (CNHW).
113  EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2));
114}
115
116TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
117  auto builder = HloComputation::Builder(TestName());
118  // The input dimensions are in NHWC order.
119  auto input = builder.AddInstruction(HloInstruction::CreateConstant(
120      Literal::CreateR4FromArray4D(Array4D<float>(
121          kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
122  // The kernel dimensions are in HWIO order.
123  auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
124      Literal::CreateR4FromArray4D(Array4D<float>(
125          kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
126
127  ConvolutionDimensionNumbers dnums;
128  dnums.set_input_batch_dimension(0);
129  dnums.set_output_batch_dimension(0);
130  dnums.add_input_spatial_dimensions(1);
131  dnums.add_output_spatial_dimensions(1);
132  dnums.add_input_spatial_dimensions(2);
133  dnums.add_output_spatial_dimensions(2);
134  dnums.set_input_feature_dimension(3);
135  dnums.set_output_feature_dimension(3);
136  dnums.add_kernel_spatial_dimensions(0);
137  dnums.add_kernel_spatial_dimensions(1);
138  dnums.set_kernel_input_feature_dimension(2);
139  dnums.set_kernel_output_feature_dimension(3);
140  auto output_size = kInputSize - kWindowSize + 1;
141  builder.AddInstruction(HloInstruction::CreateConvolve(
142      ShapeUtil::MakeShape(
143          F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
144      input, kernel, conv_window_, dnums));
145
146  auto module = CreateNewModule();
147  module->AddEntryComputation(builder.Build());
148
149  ConvCanonicalization conv_canonicalization;
150  EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
151}
152
153}  // namespace cpu
154}  // namespace xla
155