1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/transpose_folding.h"
17
18#include <memory>
19#include <unordered_set>
20#include <vector>
21
22#include "tensorflow/compiler/xla/client/computation_builder.h"
23#include "tensorflow/compiler/xla/literal_util.h"
24#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25#include "tensorflow/compiler/xla/service/hlo_computation.h"
26#include "tensorflow/compiler/xla/service/hlo_instruction.h"
27#include "tensorflow/compiler/xla/service/hlo_module.h"
28#include "tensorflow/compiler/xla/service/hlo_opcode.h"
29#include "tensorflow/compiler/xla/service/shape_inference.h"
30#include "tensorflow/compiler/xla/shape_util.h"
31#include "tensorflow/compiler/xla/test.h"
32#include "tensorflow/compiler/xla/test_helpers.h"
33#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34#include "tensorflow/compiler/xla/xla_data.pb.h"
35#include "tensorflow/core/platform/logging.h"
36
37namespace xla {
38namespace {
39
40class TransposeFoldingTest : public HloTestBase {
41 protected:
42  void FoldTranspose(HloModule* module) {
43    TransposeFolding transpose_folding(
44        [](const HloInstruction& dot,
45           const TransposeFolding::OperandIndices& candidate_operands) {
46          return candidate_operands;
47        },
48        [](const HloInstruction& convolution,
49           const TransposeFolding::OperandIndices& candidate_operands) {
50          return candidate_operands;
51        });
52    EXPECT_IS_OK(transpose_folding.Run(module).status());
53  }
54};
55
56TEST_F(TransposeFoldingTest, FoldDotTranspose) {
57  auto builder = HloComputation::Builder("entry_computation");
58  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
59      /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
60      /*name=*/"x"));
61  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
62      /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
63      /*name=*/"y"));
64  HloInstruction* transpose_y =
65      builder.AddInstruction(HloInstruction::CreateTranspose(
66          ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
67  DotDimensionNumbers dot_dnums;
68  dot_dnums.add_lhs_contracting_dimensions(1);
69  dot_dnums.add_rhs_contracting_dimensions(0);
70  HloInstruction* dot = builder.AddInstruction(
71      HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
72                                /*rhs=*/transpose_y, dot_dnums));
73
74  HloModule module("test_module");
75  HloComputation* entry_computation =
76      module.AddEntryComputation(builder.Build(dot));
77  FoldTranspose(&module);
78
79  // Instructions after folding: x, y, and the fusion.
80  std::unordered_set<HloInstruction*> instruction_set(
81      entry_computation->instructions().begin(),
82      entry_computation->instructions().end());
83  CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
84  CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
85  CHECK_EQ(1, instruction_set.size())
86      << "entry_computation should contain exactly 3 instructions.";
87  HloInstruction* fusion = *instruction_set.begin();
88  EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
89
90  // The fusion instruction should contain two parameters, one transpose and
91  // one dot.
92  EXPECT_EQ(4, fusion->fused_instruction_count());
93}
94
95TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
96  auto builder = HloComputation::Builder("entry_computation");
97  // 2x1
98  HloInstruction* const0 = builder.AddInstruction(
99      HloInstruction::CreateConstant(Literal::CreateR2<float>({{1}, {2}})));
100  // 3x2
101  HloInstruction* const1 =
102      builder.AddInstruction(HloInstruction::CreateConstant(
103          Literal::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}})));
104  HloInstruction* transpose0 =
105      builder.AddInstruction(HloInstruction::CreateTranspose(
106          ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0}));
107  HloInstruction* transpose1 =
108      builder.AddInstruction(HloInstruction::CreateTranspose(
109          ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0}));
110  DotDimensionNumbers dot_dnums;
111  dot_dnums.add_lhs_contracting_dimensions(1);
112  dot_dnums.add_rhs_contracting_dimensions(0);
113  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
114      ShapeUtil::MakeShape(F32, {1, 3}),
115      /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
116
117  HloModule module("test_module");
118  HloComputation* entry_computation =
119      module.AddEntryComputation(builder.Build(dot));
120  FoldTranspose(&module);
121
122  for (auto* instruction : entry_computation->instructions()) {
123    if (instruction->opcode() == HloOpcode::kFusion) {
124      CHECK_EQ(2, instruction->operand_count());
125      EXPECT_EQ(const0, instruction->operand(0));
126      EXPECT_EQ(const1, instruction->operand(1));
127    }
128  }
129
130  // The created fusion instruction should contain two parameters, two
131  // transposes (one for each parameter) and one dot.
132  EXPECT_EQ(5,
133            entry_computation->root_instruction()->fused_instruction_count());
134}
135
136TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
137  auto builder = HloComputation::Builder("entry");
138  // (1.0 + 2.0) * (2.0 - 3.0)
139  HloInstruction* const1 = builder.AddInstruction(
140      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
141  HloInstruction* const2 = builder.AddInstruction(
142      HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
143  HloInstruction* const3 = builder.AddInstruction(
144      HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
145  HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
146      const1->shape(), HloOpcode::kAdd, const1, const2));
147  HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
148      const2->shape(), HloOpcode::kSubtract, const2, const3));
149  HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
150      add->shape(), HloOpcode::kMultiply, add, sub));
151
152  HloModule module("fuse_with_constant_operands");
153  HloComputation* entry_computation =
154      module.AddEntryComputation(builder.Build(mul));
155  HloInstruction* call = module.OutlineExpressionFromComputation(
156      {add, sub, mul}, "", entry_computation);
157  EXPECT_EQ(call, entry_computation->root_instruction());
158  HloComputation* callee_computation = call->to_apply();
159  // The arguments to the call should be const1, const2, and const3.
160  EXPECT_THAT(call->operands(),
161              ::testing::UnorderedElementsAre(const1, const2, const3));
162
163  // The callee should contain 3 parameters and 3 binary operators.
164  EXPECT_EQ(6, callee_computation->instruction_count());
165}
166
167TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
168  auto builder = HloComputation::Builder("entry_computation");
169  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
170      /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
171      /*name=*/"x"));
172  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
173      /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
174      /*name=*/"y"));
175  HloInstruction* transpose_y =
176      builder.AddInstruction(HloInstruction::CreateTranspose(
177          ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
178  DotDimensionNumbers dot_dnums;
179  dot_dnums.add_lhs_contracting_dimensions(1);
180  dot_dnums.add_rhs_contracting_dimensions(0);
181  HloInstruction* dot = builder.AddInstruction(
182      HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
183                                /*rhs=*/transpose_y, dot_dnums));
184
185  HloModule module("test_module");
186  HloComputation* entry_computation =
187      module.AddEntryComputation(builder.Build(dot));
188
189  HloInstruction* call = module.OutlineExpressionFromComputation(
190      {transpose_y, dot}, "outlined", entry_computation);
191
192  FoldTranspose(&module);
193
194  // Instructions after folding: x, y, and the fusion.
195  std::unordered_set<HloInstruction*> instruction_set(
196      entry_computation->instructions().begin(),
197      entry_computation->instructions().end());
198  CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
199  CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
200  CHECK_EQ(1, instruction_set.erase(call))
201      << "call is not in entry_computation.";
202  CHECK(instruction_set.empty())
203      << "entry_computation should contain exactly 3 instructions.";
204  HloInstruction* fusion =
205      call->called_computations().front()->root_instruction();
206  EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
207
208  // The fusion instruction should contain two parameters, one transpose and
209  // one dot.
210  EXPECT_EQ(4, fusion->fused_instruction_count());
211}
212
213// Test that a two dimension swap of the kernel gets folded into convolution.
214TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
215  auto builder = HloComputation::Builder("entry_computation");
216  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
217      /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
218      /*name=*/"x"));
219  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
220      /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
221      /*name=*/"y"));
222  HloInstruction* transpose_y =
223      builder.AddInstruction(HloInstruction::CreateTranspose(
224          ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3}));
225  auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
226  Window window;
227  for (int i = 0; i < 2; ++i) {
228    WindowDimension* dim = window.add_dimensions();
229    dim->set_padding_low(0);
230    dim->set_padding_high(0);
231    dim->set_base_dilation(1);
232    dim->set_window_dilation(1);
233    dim->set_stride(1);
234    dim->set_size(
235        transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
236  }
237  StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
238      x->shape(), transpose_y->shape(), window, dnums);
239  EXPECT_IS_OK(conv_shape);
240  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
241      conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
242
243  HloModule module("test_module");
244  HloComputation* entry_computation =
245      module.AddEntryComputation(builder.Build(conv));
246  FoldTranspose(&module);
247
248  // Instructions after folding: x, y, and the convolution.
249  std::unordered_set<HloInstruction*> instruction_set(
250      entry_computation->instructions().begin(),
251      entry_computation->instructions().end());
252  CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
253  CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
254  CHECK_EQ(1, instruction_set.size())
255      << "entry_computation should contain exactly 3 instructions.";
256  HloInstruction* new_conv = *instruction_set.begin();
257  EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
258  EXPECT_EQ(dnums.kernel_input_feature_dimension(),
259            new_conv->convolution_dimension_numbers()
260                .kernel_output_feature_dimension());
261  EXPECT_EQ(dnums.kernel_output_feature_dimension(),
262            new_conv->convolution_dimension_numbers()
263                .kernel_input_feature_dimension());
264}
265
266// Test that a complex transpose of the kernel gets folded into convolution.
267TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
268  auto builder = HloComputation::Builder("entry_computation");
269  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
270      /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
271      /*name=*/"x"));
272  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
273      /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}),
274      /*name=*/"y"));
275  HloInstruction* transpose_y =
276      builder.AddInstruction(HloInstruction::CreateTranspose(
277          ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2}));
278  auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
279  Window window;
280  for (int i = 0; i < 2; ++i) {
281    WindowDimension* dim = window.add_dimensions();
282    dim->set_padding_low(0);
283    dim->set_padding_high(0);
284    dim->set_base_dilation(1);
285    dim->set_window_dilation(1);
286    dim->set_stride(1);
287    dim->set_size(
288        transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
289  }
290  StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
291      x->shape(), transpose_y->shape(), window, dnums);
292  EXPECT_IS_OK(conv_shape);
293  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
294      conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
295
296  HloModule module("test_module");
297  HloComputation* entry_computation =
298      module.AddEntryComputation(builder.Build(conv));
299  FoldTranspose(&module);
300
301  // Instructions after folding: x, y, and the convolution.
302  std::unordered_set<HloInstruction*> instruction_set(
303      entry_computation->instructions().begin(),
304      entry_computation->instructions().end());
305  CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
306  CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
307  CHECK_EQ(1, instruction_set.size())
308      << "entry_computation should contain exactly 3 instructions.";
309  HloInstruction* new_conv = *instruction_set.begin();
310  EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
311  EXPECT_EQ(dnums.kernel_input_feature_dimension(),
312            new_conv->convolution_dimension_numbers()
313                .kernel_output_feature_dimension());
314  EXPECT_EQ(dnums.kernel_spatial_dimensions(1),
315            new_conv->convolution_dimension_numbers()
316                .kernel_input_feature_dimension());
317  EXPECT_EQ(
318      dnums.kernel_output_feature_dimension(),
319      new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0));
320  EXPECT_EQ(
321      dnums.kernel_spatial_dimensions(0),
322      new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
323}
324
325// Test that a transpose of the activations gets folded into convolution.
326TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
327  auto builder = HloComputation::Builder("entry_computation");
328  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
329      /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
330      /*name=*/"x"));
331  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
332      /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
333      /*name=*/"y"));
334  HloInstruction* transpose_x =
335      builder.AddInstruction(HloInstruction::CreateTranspose(
336          ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3}));
337  auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
338  Window window;
339  for (int i = 0; i < 2; ++i) {
340    WindowDimension* dim = window.add_dimensions();
341    dim->set_padding_low(0);
342    dim->set_padding_high(0);
343    dim->set_base_dilation(1);
344    dim->set_window_dilation(1);
345    dim->set_stride(1);
346    dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
347  }
348  StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
349      transpose_x->shape(), y->shape(), window, dnums);
350  EXPECT_IS_OK(conv_shape);
351  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
352      conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
353
354  HloModule module("test_module");
355  HloComputation* entry_computation =
356      module.AddEntryComputation(builder.Build(conv));
357  FoldTranspose(&module);
358
359  // Instructions after folding: x, y, and the convolution.
360  std::unordered_set<HloInstruction*> instruction_set(
361      entry_computation->instructions().begin(),
362      entry_computation->instructions().end());
363  EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
364  EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
365  EXPECT_EQ(1, instruction_set.size())
366      << "entry_computation should contain exactly 3 instructions.";
367  HloInstruction* new_conv = *instruction_set.begin();
368  EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
369  EXPECT_EQ(dnums.input_feature_dimension(),
370            new_conv->convolution_dimension_numbers().input_batch_dimension());
371  EXPECT_EQ(
372      dnums.input_batch_dimension(),
373      new_conv->convolution_dimension_numbers().input_feature_dimension());
374  EXPECT_EQ(
375      dnums.input_spatial_dimensions(0),
376      new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
377  EXPECT_EQ(
378      dnums.input_spatial_dimensions(1),
379      new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
380  EXPECT_EQ(
381      dnums.output_spatial_dimensions(0),
382      new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
383  EXPECT_EQ(
384      dnums.output_spatial_dimensions(1),
385      new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
386}
387
388// Test that a transpose of every dimension in the activations gets folded into
389// convolution.
390TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
391  auto builder = HloComputation::Builder("entry_computation");
392  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
393      /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
394      /*name=*/"x"));
395  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
396      /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
397      /*name=*/"y"));
398  HloInstruction* transpose_x =
399      builder.AddInstruction(HloInstruction::CreateTranspose(
400          ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
401  auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
402  Window window;
403  for (int i = 0; i < 2; ++i) {
404    WindowDimension* dim = window.add_dimensions();
405    dim->set_padding_low(0);
406    dim->set_padding_high(0);
407    dim->set_base_dilation(1);
408    dim->set_window_dilation(1);
409    dim->set_stride(1);
410    dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
411  }
412  StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
413      transpose_x->shape(), y->shape(), window, dnums);
414  EXPECT_IS_OK(conv_shape);
415  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
416      conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
417
418  HloModule module("test_module");
419  HloComputation* entry_computation =
420      module.AddEntryComputation(builder.Build(conv));
421  FoldTranspose(&module);
422
423  // Instructions after folding: x, y, and the convolution.
424  std::unordered_set<HloInstruction*> instruction_set(
425      entry_computation->instructions().begin(),
426      entry_computation->instructions().end());
427  EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
428  EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
429  EXPECT_EQ(1, instruction_set.size())
430      << "entry_computation should contain exactly 3 instructions.";
431  HloInstruction* new_conv = *instruction_set.begin();
432  EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
433  EXPECT_EQ(dnums.input_feature_dimension(),
434            new_conv->convolution_dimension_numbers().input_batch_dimension());
435  EXPECT_EQ(
436      dnums.input_batch_dimension(),
437      new_conv->convolution_dimension_numbers().input_feature_dimension());
438  EXPECT_EQ(
439      dnums.input_spatial_dimensions(0),
440      new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
441  EXPECT_EQ(
442      dnums.input_spatial_dimensions(1),
443      new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
444  EXPECT_EQ(
445      dnums.output_spatial_dimensions(0),
446      new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
447  EXPECT_EQ(
448      dnums.output_spatial_dimensions(1),
449      new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
450}
451
452}  // namespace
453}  // namespace xla
454