1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
17
18#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
19#include "tensorflow/compiler/xla/service/hlo_computation.h"
20#include "tensorflow/compiler/xla/service/hlo_instruction.h"
21#include "tensorflow/compiler/xla/service/hlo_matchers.h"
22#include "tensorflow/compiler/xla/service/hlo_module.h"
23#include "tensorflow/compiler/xla/service/hlo_opcode.h"
24#include "tensorflow/compiler/xla/service/shape_inference.h"
25#include "tensorflow/compiler/xla/test.h"
26#include "tensorflow/compiler/xla/test_helpers.h"
27#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28#include "tensorflow/core/platform/test.h"
29
30namespace xla {
31namespace gpu {
32namespace {
33
34namespace op = xla::testing::opcode_matchers;
35
36class CudnnConvolutionRewriterTest : public HloTestBase {
37 public:
38  CudnnConvolutionRewriterTest() {
39    for (int i = 0; i < 2; ++i) {
40      WindowDimension* window_dim = default_conv_window_.add_dimensions();
41      window_dim->set_size(1);
42      window_dim->set_stride(1);
43      window_dim->set_padding_low(0);
44      window_dim->set_padding_high(0);
45      window_dim->set_window_dilation(1);
46      window_dim->set_base_dilation(1);
47    }
48    // TF data shapes are by default in the NHWC order, and filter shape is by
49    // default in HWIO order. For backward filter convolution, we need to swap
50    // the batch and feature dimension in the activations, and treat the batch
51    // dimension in gradients as the input feature dimension in the filter.
52    //
53    // TODO(jingyue): Add more tests on NCHW input order, which TF also
54    // supports.
55    tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
56    tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
57    tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
58    tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
59    tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
60    tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
61        3);
62    tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
63    tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
64    tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
65    tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
66    tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
67    tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
68
69    tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
70    tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
71    tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
72    tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
73    tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
74    tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
75    tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
76    tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
77    tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
78    tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
79    tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
80    tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
81  }
82
83 protected:
84  bool RunPass(HloModule* module) {
85    return CudnnConvolutionRewriter().Run(module).ValueOrDie();
86  }
87
88  // A convolution window with stride 1 and zero padding. The size fields are
89  // not set.
90  Window default_conv_window_;
91  ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
92  ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
93};
94
95TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
96  HloComputation::Builder builder(TestName());
97  HloInstruction* activations =
98      builder.AddInstruction(HloInstruction::CreateParameter(
99          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
100  HloInstruction* gradients =
101      builder.AddInstruction(HloInstruction::CreateParameter(
102          1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
103  Window conv_window = default_conv_window_;
104  conv_window.mutable_dimensions(1)->set_size(2);
105  conv_window.mutable_dimensions(1)->set_window_dilation(2);
106  builder.AddInstruction(HloInstruction::CreateConvolve(
107      ShapeInference::InferConvolveShape(activations->shape(),
108                                         gradients->shape(), conv_window,
109                                         tf_default_dnums_for_backward_filter_)
110          .ConsumeValueOrDie(),
111      activations, gradients, conv_window,
112      tf_default_dnums_for_backward_filter_));
113
114  auto module = CreateNewModule();
115  HloComputation* entry_computation =
116      module->AddEntryComputation(builder.Build());
117  EXPECT_TRUE(RunPass(module.get()));
118  EXPECT_THAT(entry_computation->root_instruction(),
119              op::GetTupleElement(
120                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
121}
122
123TEST_F(CudnnConvolutionRewriterTest,
124       BackwardFilterConvolveEquivalentToForwardConvolution) {
125  HloComputation::Builder builder(TestName());
126  HloInstruction* activations =
127      builder.AddInstruction(HloInstruction::CreateParameter(
128          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
129  HloInstruction* gradients =
130      builder.AddInstruction(HloInstruction::CreateParameter(
131          1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
132  Window conv_window = default_conv_window_;
133  conv_window.mutable_dimensions(1)->set_size(3);
134  builder.AddInstruction(HloInstruction::CreateConvolve(
135      ShapeInference::InferConvolveShape(activations->shape(),
136                                         gradients->shape(), conv_window,
137                                         tf_default_dnums_for_backward_filter_)
138          .ConsumeValueOrDie(),
139      activations, gradients, conv_window,
140      tf_default_dnums_for_backward_filter_));
141
142  auto module = CreateNewModule();
143  HloComputation* entry_computation =
144      module->AddEntryComputation(builder.Build());
145  EXPECT_TRUE(RunPass(module.get()));
146  EXPECT_THAT(entry_computation->root_instruction(),
147              op::GetTupleElement(
148                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
149}
150
151// Extracted from block35 training.
152TEST_F(CudnnConvolutionRewriterTest,
153       BackwardFilterConvolveWithPaddedActivations) {
154  auto builder = HloComputation::Builder(TestName());
155  HloInstruction* activations =
156      builder.AddInstruction(HloInstruction::CreateParameter(
157          0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
158  HloInstruction* gradients =
159      builder.AddInstruction(HloInstruction::CreateParameter(
160          1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
161
162  Window conv_window = default_conv_window_;
163  for (int i = 0; i < 2; ++i) {
164    conv_window.mutable_dimensions(i)->set_size(35);
165    conv_window.mutable_dimensions(i)->set_padding_low(1);
166    conv_window.mutable_dimensions(i)->set_padding_high(1);
167  }
168  builder.AddInstruction(HloInstruction::CreateConvolve(
169      ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
170      conv_window, tf_default_dnums_for_backward_filter_));
171
172  auto module = CreateNewModule();
173  HloComputation* entry_computation =
174      module->AddEntryComputation(builder.Build());
175  EXPECT_TRUE(RunPass(module.get()));
176  EXPECT_THAT(entry_computation->root_instruction(),
177              op::GetTupleElement(
178                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
179}
180
181// Extracted from inception v3 training.
182TEST_F(CudnnConvolutionRewriterTest,
183       BackwardFilterConvolveWithPaddedGradients) {
184  auto builder = HloComputation::Builder(TestName());
185  HloInstruction* activations =
186      builder.AddInstruction(HloInstruction::CreateParameter(
187          0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
188  HloInstruction* gradients =
189      builder.AddInstruction(HloInstruction::CreateParameter(
190          1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
191
192  Window conv_window = default_conv_window_;
193  for (int i = 0; i < 2; ++i) {
194    conv_window.mutable_dimensions(i)->set_size(4);
195    conv_window.mutable_dimensions(i)->set_padding_high(-1);
196    conv_window.mutable_dimensions(i)->set_window_dilation(2);
197  }
198  builder.AddInstruction(HloInstruction::CreateConvolve(
199      ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
200      conv_window, tf_default_dnums_for_backward_filter_));
201
202  auto module = CreateNewModule();
203  HloComputation* entry_computation =
204      module->AddEntryComputation(builder.Build());
205  EXPECT_TRUE(RunPass(module.get()));
206  EXPECT_THAT(entry_computation->root_instruction(),
207              op::GetTupleElement(
208                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
209}
210
211TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
212  auto builder = HloComputation::Builder(TestName());
213  HloInstruction* activations =
214      builder.AddInstruction(HloInstruction::CreateParameter(
215          0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
216  HloInstruction* gradients =
217      builder.AddInstruction(HloInstruction::CreateParameter(
218          1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
219
220  Window conv_window = default_conv_window_;
221  for (int i = 0; i < 2; ++i) {
222    conv_window.mutable_dimensions(i)->set_size(35);
223    // Uneven padding: padding_low=0, padding_high=1
224    conv_window.mutable_dimensions(i)->set_padding_high(1);
225  }
226  builder.AddInstruction(HloInstruction::CreateConvolve(
227      ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
228      conv_window, tf_default_dnums_for_backward_filter_));
229
230  auto module = CreateNewModule();
231  HloComputation* entry_computation =
232      module->AddEntryComputation(builder.Build());
233  EXPECT_TRUE(RunPass(module.get()));
234  EXPECT_THAT(entry_computation->root_instruction(),
235              op::GetTupleElement(
236                  op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
237}
238
239TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
240  auto builder = HloComputation::Builder(TestName());
241  HloInstruction* output =
242      builder.AddInstruction(HloInstruction::CreateParameter(
243          0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
244  HloInstruction* kernel =
245      builder.AddInstruction(HloInstruction::CreateParameter(
246          1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
247  HloInstruction* reverse_kernel = builder.AddInstruction(
248      HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
249
250  Window conv_window = default_conv_window_;
251  for (int i = 0; i < 2; ++i) {
252    conv_window.mutable_dimensions(i)->set_size(7);
253    conv_window.mutable_dimensions(i)->set_padding_low(3);
254    conv_window.mutable_dimensions(i)->set_padding_high(3);
255  }
256  ConvolutionDimensionNumbers conv_dnums;
257  conv_dnums.set_input_batch_dimension(0);
258  conv_dnums.set_output_batch_dimension(0);
259  conv_dnums.set_input_feature_dimension(1);
260  conv_dnums.set_output_feature_dimension(1);
261  conv_dnums.add_input_spatial_dimensions(2);
262  conv_dnums.add_output_spatial_dimensions(2);
263  conv_dnums.add_input_spatial_dimensions(3);
264  conv_dnums.add_output_spatial_dimensions(3);
265  conv_dnums.set_kernel_input_feature_dimension(0);
266  conv_dnums.set_kernel_output_feature_dimension(1);
267  conv_dnums.add_kernel_spatial_dimensions(2);
268  conv_dnums.add_kernel_spatial_dimensions(3);
269
270  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
271      ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
272      /*rhs=*/reverse_kernel, conv_window, conv_dnums));
273  // Verify the convolution's shape is consistent with ShapeInference.
274  CHECK(ShapeUtil::Compatible(
275      conv->shape(),
276      ShapeInference::InferConvolveShape(
277          output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
278          .ValueOrDie()));
279
280  auto module = CreateNewModule();
281  HloComputation* entry_computation =
282      module->AddEntryComputation(builder.Build());
283  EXPECT_TRUE(RunPass(module.get()));
284
285  ASSERT_THAT(entry_computation->root_instruction(),
286              op::GetTupleElement(
287                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
288  const HloInstruction* custom_call =
289      entry_computation->root_instruction()->operand(0);
290  for (int i = 0; i < 2; ++i) {
291    const WindowDimension& window_dim = custom_call->window().dimensions(i);
292    // Low padding of the backward input convolution
293    //   = kernel_size - 1 - low padding on gradients.
294    EXPECT_EQ(3, window_dim.padding_low());
295    EXPECT_EQ(3, window_dim.padding_high());
296    EXPECT_EQ(1, window_dim.stride());
297  }
298}
299
300// Convolve([abc], [x], base_dilation=2)
301//   = Convolve([abc], Reverse([x]), base_dilation=2)
302//   = BackwardInputConvolve([abc], [x], stride=2)
303TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
304  auto builder = HloComputation::Builder(TestName());
305  // NHWC dimension order.
306  HloInstruction* output =
307      builder.AddInstruction(HloInstruction::CreateParameter(
308          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
309  // HWOI dimension order.
310  HloInstruction* kernel =
311      builder.AddInstruction(HloInstruction::CreateParameter(
312          1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
313
314  Window conv_window = default_conv_window_;
315  conv_window.mutable_dimensions(1)->set_base_dilation(2);
316
317  builder.AddInstruction(HloInstruction::CreateConvolve(
318      ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
319                                         conv_window,
320                                         tf_default_dnums_for_backward_input_)
321          .ConsumeValueOrDie(),
322      /*lhs=*/output, /*rhs=*/kernel, conv_window,
323      tf_default_dnums_for_backward_input_));
324
325  auto module = CreateNewModule();
326  HloComputation* entry_computation =
327      module->AddEntryComputation(builder.Build());
328  EXPECT_TRUE(RunPass(module.get()));
329  EXPECT_THAT(entry_computation->root_instruction(),
330              op::GetTupleElement(
331                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
332}
333
334// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
335// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
336// convolution.
337TEST_F(CudnnConvolutionRewriterTest,
338       BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
339  auto builder = HloComputation::Builder(TestName());
340  // NHWC dimension order.
341  HloInstruction* output =
342      builder.AddInstruction(HloInstruction::CreateParameter(
343          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
344  // HWOI dimension order.
345  HloInstruction* kernel =
346      builder.AddInstruction(HloInstruction::CreateParameter(
347          1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
348
349  builder.AddInstruction(HloInstruction::CreateConvolve(
350      ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
351                                         default_conv_window_,
352                                         tf_default_dnums_for_backward_input_)
353          .ConsumeValueOrDie(),
354      /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
355      tf_default_dnums_for_backward_input_));
356
357  auto module = CreateNewModule();
358  HloComputation* entry_computation =
359      module->AddEntryComputation(builder.Build());
360  EXPECT_TRUE(RunPass(module.get()));
361  EXPECT_THAT(
362      entry_computation->root_instruction(),
363      op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
364}
365
366// Extracted from Inception V3 training.
367//
368//                                  filter(HWIO)
369//                                  3x3x192x320
370//                                      |
371//                                      v
372//      gradients(NHWC)              reverse
373//        20x4x4x320               3x3x192x320
374//                    \            /
375//                     \          /
376//  conv (NHWC) with padding (low=2,high=3,interior=1)
377//                     20x10x10x192
378//
379// Gradients are padded unevenly.
380TEST_F(CudnnConvolutionRewriterTest,
381       BackwardInputConvolveUnevenPaddingOnGradients) {
382  auto builder = HloComputation::Builder(TestName());
383  HloInstruction* output =
384      builder.AddInstruction(HloInstruction::CreateParameter(
385          0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
386  HloInstruction* kernel =
387      builder.AddInstruction(HloInstruction::CreateParameter(
388          1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
389  HloInstruction* reverse_kernel = builder.AddInstruction(
390      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
391
392  Window conv_window = default_conv_window_;
393  for (int i = 0; i < 2; ++i) {
394    conv_window.mutable_dimensions(i)->set_size(3);
395    conv_window.mutable_dimensions(i)->set_padding_low(2);
396    conv_window.mutable_dimensions(i)->set_padding_high(3);
397    // Interior padding = 1.
398    conv_window.mutable_dimensions(i)->set_base_dilation(2);
399  }
400  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
401      ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
402      conv_window, tf_default_dnums_for_backward_input_));
403  // Verify the convolution's shape is consistent with ShapeInference.
404  CHECK(ShapeUtil::Compatible(
405      conv->shape(), ShapeInference::InferConvolveShape(
406                         output->shape(), reverse_kernel->shape(), conv_window,
407                         tf_default_dnums_for_backward_input_)
408                         .ValueOrDie()));
409
410  auto module = CreateNewModule();
411  HloComputation* entry_computation =
412      module->AddEntryComputation(builder.Build());
413  EXPECT_TRUE(RunPass(module.get()));
414  ASSERT_THAT(entry_computation->root_instruction(),
415              op::GetTupleElement(
416                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
417  const HloInstruction* custom_call =
418      entry_computation->root_instruction()->operand(0);
419  for (int i = 0; i < 2; ++i) {
420    const WindowDimension& window_dim = custom_call->window().dimensions(i);
421    EXPECT_EQ(0, window_dim.padding_low());
422    EXPECT_EQ(0, window_dim.padding_high());
423    EXPECT_EQ(2, window_dim.stride());
424  }
425}
426
427// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
428// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
429TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
430  auto builder = HloComputation::Builder(TestName());
431  HloInstruction* output =
432      builder.AddInstruction(HloInstruction::CreateParameter(
433          0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
434  HloInstruction* kernel =
435      builder.AddInstruction(HloInstruction::CreateParameter(
436          1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
437  HloInstruction* reverse_kernel = builder.AddInstruction(
438      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
439
440  Window conv_window = default_conv_window_;
441  for (int i = 0; i < 2; ++i) {
442    conv_window.mutable_dimensions(i)->set_size(3);
443    conv_window.mutable_dimensions(i)->set_padding_low(3);
444    conv_window.mutable_dimensions(i)->set_padding_high(2);
445    conv_window.mutable_dimensions(i)->set_base_dilation(2);
446  }
447  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
448      ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
449      conv_window, tf_default_dnums_for_backward_input_));
450  // Verify the convolution's shape is consistent with ShapeInference.
451  CHECK(ShapeUtil::Compatible(
452      conv->shape(), ShapeInference::InferConvolveShape(
453                         output->shape(), reverse_kernel->shape(), conv_window,
454                         tf_default_dnums_for_backward_input_)
455                         .ValueOrDie()));
456
457  auto module = CreateNewModule();
458  HloComputation* entry_computation =
459      module->AddEntryComputation(builder.Build());
460  EXPECT_TRUE(RunPass(module.get()));
461  EXPECT_THAT(
462      entry_computation->root_instruction(),
463      op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
464}
465
466// Extracted from //learning/brain/google/xla/benchmarks/resnet.py
467//
468// For simplicity, we focus on the column dimension and ignore other dimensions.
469// We use [?] to represent the shape instead of the content.
470//
471// Suppose operator FC does
472//   [4] = conv([14], [3], stride=2, padding_high=1)  // Padding::kSame
473//
474// BC = BackwardInput(FC) does:
475//   [14] = conv([7], reverse([3]),
476//               padding_low=2, padding_high=1, base_dilation=2)
477//
478// We should fuse BC even though padding on activations is uneven, because
479// PadInsertion will canonicalize the fusion HLO.
480TEST_F(CudnnConvolutionRewriterTest,
481       BackwardInputConvolveUnevenPaddingOnActivations) {
482  auto builder = HloComputation::Builder(TestName());
483  // The gradients are in NCHW layout.
484  HloInstruction* output =
485      builder.AddInstruction(HloInstruction::CreateParameter(
486          0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
487  // The kernel is in HWIO layout.
488  HloInstruction* kernel =
489      builder.AddInstruction(HloInstruction::CreateParameter(
490          1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
491  HloInstruction* reverse_kernel = builder.AddInstruction(
492      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
493
494  Window conv_window = default_conv_window_;
495  WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
496  forward_conv_col_dim->set_size(3);
497  forward_conv_col_dim->set_padding_low(2);
498  forward_conv_col_dim->set_padding_high(1);
499  forward_conv_col_dim->set_base_dilation(2);
500  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
501      ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
502      conv_window, tf_default_dnums_for_backward_input_));
503  // Verify the convolution's shape is consistent with ShapeInference.
504  CHECK(ShapeUtil::Compatible(
505      conv->shape(), ShapeInference::InferConvolveShape(
506                         output->shape(), reverse_kernel->shape(), conv_window,
507                         tf_default_dnums_for_backward_input_)
508                         .ValueOrDie()));
509
510  auto module = CreateNewModule();
511  const HloComputation* entry_computation =
512      module->AddEntryComputation(builder.Build());
513  EXPECT_TRUE(RunPass(module.get()));
514  ASSERT_THAT(entry_computation->root_instruction(),
515              op::GetTupleElement(
516                  op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
517  const WindowDimension& backward_conv_col_dim =
518      entry_computation->root_instruction()->operand(0)->window().dimensions(1);
519  EXPECT_EQ(0, backward_conv_col_dim.padding_low());
520  EXPECT_EQ(1, backward_conv_col_dim.padding_high());
521}
522
523// For simplicity, we focus on the column dimension and ignore other dimensions.
524// We use [?] to represent the shape instead of the content.
525//
526// Suppose operator FC does
527//   [3] = conv([4], [2], padding_low=1, padding_high=-1)
528//
529// BC = BackwardInput(FC) does:
530//   [4] = conv([3], reverse([2]), padding_high=2)
531//
532// We currently don't fuse BC because PadInsertion doesn't support negative
533// padding on the gradients of backward convolution (b/32744257).
534TEST_F(CudnnConvolutionRewriterTest,
535       BackwardInputConvolveNegativePaddingHighOnActivations) {
536  auto builder = HloComputation::Builder(TestName());
537  // The gradients are in NCHW layout.
538  HloInstruction* output =
539      builder.AddInstruction(HloInstruction::CreateParameter(
540          0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
541  // The kernel is in HWIO layout.
542  HloInstruction* kernel =
543      builder.AddInstruction(HloInstruction::CreateParameter(
544          1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
545  HloInstruction* reverse_kernel = builder.AddInstruction(
546      HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
547
548  Window conv_window = default_conv_window_;
549  WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
550  forward_conv_col_dim->set_size(2);
551  forward_conv_col_dim->set_padding_high(2);
552  HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
553      ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
554      conv_window, tf_default_dnums_for_backward_input_));
555  // Verify the convolution's shape is consistent with ShapeInference.
556  CHECK(ShapeUtil::Compatible(
557      conv->shape(), ShapeInference::InferConvolveShape(
558                         output->shape(), reverse_kernel->shape(), conv_window,
559                         tf_default_dnums_for_backward_input_)
560                         .ValueOrDie()));
561
562  auto module = CreateNewModule();
563  HloComputation* entry_computation =
564      module->AddEntryComputation(builder.Build());
565  EXPECT_TRUE(RunPass(module.get()));
566  EXPECT_THAT(
567      entry_computation->root_instruction(),
568      op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
569}
570
571}  // anonymous namespace
572}  // namespace gpu
573}  // namespace xla
574