1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" 17 18#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 19#include "tensorflow/compiler/xla/service/hlo_computation.h" 20#include "tensorflow/compiler/xla/service/hlo_instruction.h" 21#include "tensorflow/compiler/xla/service/hlo_matchers.h" 22#include "tensorflow/compiler/xla/service/hlo_module.h" 23#include "tensorflow/compiler/xla/service/hlo_opcode.h" 24#include "tensorflow/compiler/xla/service/shape_inference.h" 25#include "tensorflow/compiler/xla/test.h" 26#include "tensorflow/compiler/xla/test_helpers.h" 27#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 28#include "tensorflow/core/platform/test.h" 29 30namespace xla { 31namespace gpu { 32namespace { 33 34namespace op = xla::testing::opcode_matchers; 35 36class CudnnConvolutionRewriterTest : public HloTestBase { 37 public: 38 CudnnConvolutionRewriterTest() { 39 for (int i = 0; i < 2; ++i) { 40 WindowDimension* window_dim = default_conv_window_.add_dimensions(); 41 window_dim->set_size(1); 42 window_dim->set_stride(1); 43 window_dim->set_padding_low(0); 44 window_dim->set_padding_high(0); 45 window_dim->set_window_dilation(1); 46 window_dim->set_base_dilation(1); 47 } 48 // TF data shapes are by default in the NHWC order, and filter shape is by 49 // default in HWIO order. For backward filter convolution, we need to swap 50 // the batch and feature dimension in the activations, and treat the batch 51 // dimension in gradients as the input feature dimension in the filter. 52 // 53 // TODO(jingyue): Add more tests on NCHW input order, which TF also 54 // supports. 55 tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); 56 tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); 57 tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); 58 tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); 59 tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); 60 tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 61 3); 62 tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); 63 tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); 64 tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); 65 tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); 66 tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); 67 tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); 68 69 tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); 70 tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); 71 tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); 72 tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); 73 tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1); 74 tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1); 75 tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2); 76 tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2); 77 tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); 78 tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); 79 tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); 80 tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1); 81 } 82 83 protected: 84 bool RunPass(HloModule* module) { 85 return CudnnConvolutionRewriter().Run(module).ValueOrDie(); 86 } 87 88 // A convolution window with stride 1 and zero padding. The size fields are 89 // not set. 90 Window default_conv_window_; 91 ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_; 92 ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; 93}; 94 95TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { 96 HloComputation::Builder builder(TestName()); 97 HloInstruction* activations = 98 builder.AddInstruction(HloInstruction::CreateParameter( 99 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); 100 HloInstruction* gradients = 101 builder.AddInstruction(HloInstruction::CreateParameter( 102 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients")); 103 Window conv_window = default_conv_window_; 104 conv_window.mutable_dimensions(1)->set_size(2); 105 conv_window.mutable_dimensions(1)->set_window_dilation(2); 106 builder.AddInstruction(HloInstruction::CreateConvolve( 107 ShapeInference::InferConvolveShape(activations->shape(), 108 gradients->shape(), conv_window, 109 tf_default_dnums_for_backward_filter_) 110 .ConsumeValueOrDie(), 111 activations, gradients, conv_window, 112 tf_default_dnums_for_backward_filter_)); 113 114 auto module = CreateNewModule(); 115 HloComputation* entry_computation = 116 module->AddEntryComputation(builder.Build()); 117 EXPECT_TRUE(RunPass(module.get())); 118 EXPECT_THAT(entry_computation->root_instruction(), 119 op::GetTupleElement( 120 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); 121} 122 123TEST_F(CudnnConvolutionRewriterTest, 124 BackwardFilterConvolveEquivalentToForwardConvolution) { 125 HloComputation::Builder builder(TestName()); 126 HloInstruction* activations = 127 builder.AddInstruction(HloInstruction::CreateParameter( 128 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); 129 HloInstruction* gradients = 130 builder.AddInstruction(HloInstruction::CreateParameter( 131 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients")); 132 Window conv_window = default_conv_window_; 133 conv_window.mutable_dimensions(1)->set_size(3); 134 builder.AddInstruction(HloInstruction::CreateConvolve( 135 ShapeInference::InferConvolveShape(activations->shape(), 136 gradients->shape(), conv_window, 137 tf_default_dnums_for_backward_filter_) 138 .ConsumeValueOrDie(), 139 activations, gradients, conv_window, 140 tf_default_dnums_for_backward_filter_)); 141 142 auto module = CreateNewModule(); 143 HloComputation* entry_computation = 144 module->AddEntryComputation(builder.Build()); 145 EXPECT_TRUE(RunPass(module.get())); 146 EXPECT_THAT(entry_computation->root_instruction(), 147 op::GetTupleElement( 148 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); 149} 150 151// Extracted from block35 training. 152TEST_F(CudnnConvolutionRewriterTest, 153 BackwardFilterConvolveWithPaddedActivations) { 154 auto builder = HloComputation::Builder(TestName()); 155 HloInstruction* activations = 156 builder.AddInstruction(HloInstruction::CreateParameter( 157 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); 158 HloInstruction* gradients = 159 builder.AddInstruction(HloInstruction::CreateParameter( 160 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); 161 162 Window conv_window = default_conv_window_; 163 for (int i = 0; i < 2; ++i) { 164 conv_window.mutable_dimensions(i)->set_size(35); 165 conv_window.mutable_dimensions(i)->set_padding_low(1); 166 conv_window.mutable_dimensions(i)->set_padding_high(1); 167 } 168 builder.AddInstruction(HloInstruction::CreateConvolve( 169 ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, 170 conv_window, tf_default_dnums_for_backward_filter_)); 171 172 auto module = CreateNewModule(); 173 HloComputation* entry_computation = 174 module->AddEntryComputation(builder.Build()); 175 EXPECT_TRUE(RunPass(module.get())); 176 EXPECT_THAT(entry_computation->root_instruction(), 177 op::GetTupleElement( 178 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); 179} 180 181// Extracted from inception v3 training. 182TEST_F(CudnnConvolutionRewriterTest, 183 BackwardFilterConvolveWithPaddedGradients) { 184 auto builder = HloComputation::Builder(TestName()); 185 HloInstruction* activations = 186 builder.AddInstruction(HloInstruction::CreateParameter( 187 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations")); 188 HloInstruction* gradients = 189 builder.AddInstruction(HloInstruction::CreateParameter( 190 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients")); 191 192 Window conv_window = default_conv_window_; 193 for (int i = 0; i < 2; ++i) { 194 conv_window.mutable_dimensions(i)->set_size(4); 195 conv_window.mutable_dimensions(i)->set_padding_high(-1); 196 conv_window.mutable_dimensions(i)->set_window_dilation(2); 197 } 198 builder.AddInstruction(HloInstruction::CreateConvolve( 199 ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, 200 conv_window, tf_default_dnums_for_backward_filter_)); 201 202 auto module = CreateNewModule(); 203 HloComputation* entry_computation = 204 module->AddEntryComputation(builder.Build()); 205 EXPECT_TRUE(RunPass(module.get())); 206 EXPECT_THAT(entry_computation->root_instruction(), 207 op::GetTupleElement( 208 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); 209} 210 211TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { 212 auto builder = HloComputation::Builder(TestName()); 213 HloInstruction* activations = 214 builder.AddInstruction(HloInstruction::CreateParameter( 215 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); 216 HloInstruction* gradients = 217 builder.AddInstruction(HloInstruction::CreateParameter( 218 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); 219 220 Window conv_window = default_conv_window_; 221 for (int i = 0; i < 2; ++i) { 222 conv_window.mutable_dimensions(i)->set_size(35); 223 // Uneven padding: padding_low=0, padding_high=1 224 conv_window.mutable_dimensions(i)->set_padding_high(1); 225 } 226 builder.AddInstruction(HloInstruction::CreateConvolve( 227 ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, 228 conv_window, tf_default_dnums_for_backward_filter_)); 229 230 auto module = CreateNewModule(); 231 HloComputation* entry_computation = 232 module->AddEntryComputation(builder.Build()); 233 EXPECT_TRUE(RunPass(module.get())); 234 EXPECT_THAT(entry_computation->root_instruction(), 235 op::GetTupleElement( 236 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); 237} 238 239TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { 240 auto builder = HloComputation::Builder(TestName()); 241 HloInstruction* output = 242 builder.AddInstruction(HloInstruction::CreateParameter( 243 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output")); 244 HloInstruction* kernel = 245 builder.AddInstruction(HloInstruction::CreateParameter( 246 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel")); 247 HloInstruction* reverse_kernel = builder.AddInstruction( 248 HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3})); 249 250 Window conv_window = default_conv_window_; 251 for (int i = 0; i < 2; ++i) { 252 conv_window.mutable_dimensions(i)->set_size(7); 253 conv_window.mutable_dimensions(i)->set_padding_low(3); 254 conv_window.mutable_dimensions(i)->set_padding_high(3); 255 } 256 ConvolutionDimensionNumbers conv_dnums; 257 conv_dnums.set_input_batch_dimension(0); 258 conv_dnums.set_output_batch_dimension(0); 259 conv_dnums.set_input_feature_dimension(1); 260 conv_dnums.set_output_feature_dimension(1); 261 conv_dnums.add_input_spatial_dimensions(2); 262 conv_dnums.add_output_spatial_dimensions(2); 263 conv_dnums.add_input_spatial_dimensions(3); 264 conv_dnums.add_output_spatial_dimensions(3); 265 conv_dnums.set_kernel_input_feature_dimension(0); 266 conv_dnums.set_kernel_output_feature_dimension(1); 267 conv_dnums.add_kernel_spatial_dimensions(2); 268 conv_dnums.add_kernel_spatial_dimensions(3); 269 270 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 271 ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, 272 /*rhs=*/reverse_kernel, conv_window, conv_dnums)); 273 // Verify the convolution's shape is consistent with ShapeInference. 274 CHECK(ShapeUtil::Compatible( 275 conv->shape(), 276 ShapeInference::InferConvolveShape( 277 output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) 278 .ValueOrDie())); 279 280 auto module = CreateNewModule(); 281 HloComputation* entry_computation = 282 module->AddEntryComputation(builder.Build()); 283 EXPECT_TRUE(RunPass(module.get())); 284 285 ASSERT_THAT(entry_computation->root_instruction(), 286 op::GetTupleElement( 287 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); 288 const HloInstruction* custom_call = 289 entry_computation->root_instruction()->operand(0); 290 for (int i = 0; i < 2; ++i) { 291 const WindowDimension& window_dim = custom_call->window().dimensions(i); 292 // Low padding of the backward input convolution 293 // = kernel_size - 1 - low padding on gradients. 294 EXPECT_EQ(3, window_dim.padding_low()); 295 EXPECT_EQ(3, window_dim.padding_high()); 296 EXPECT_EQ(1, window_dim.stride()); 297 } 298} 299 300// Convolve([abc], [x], base_dilation=2) 301// = Convolve([abc], Reverse([x]), base_dilation=2) 302// = BackwardInputConvolve([abc], [x], stride=2) 303TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { 304 auto builder = HloComputation::Builder(TestName()); 305 // NHWC dimension order. 306 HloInstruction* output = 307 builder.AddInstruction(HloInstruction::CreateParameter( 308 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); 309 // HWOI dimension order. 310 HloInstruction* kernel = 311 builder.AddInstruction(HloInstruction::CreateParameter( 312 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); 313 314 Window conv_window = default_conv_window_; 315 conv_window.mutable_dimensions(1)->set_base_dilation(2); 316 317 builder.AddInstruction(HloInstruction::CreateConvolve( 318 ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), 319 conv_window, 320 tf_default_dnums_for_backward_input_) 321 .ConsumeValueOrDie(), 322 /*lhs=*/output, /*rhs=*/kernel, conv_window, 323 tf_default_dnums_for_backward_input_)); 324 325 auto module = CreateNewModule(); 326 HloComputation* entry_computation = 327 module->AddEntryComputation(builder.Build()); 328 EXPECT_TRUE(RunPass(module.get())); 329 EXPECT_THAT(entry_computation->root_instruction(), 330 op::GetTupleElement( 331 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); 332} 333 334// BackwardInputConvolve([abc], [x], stride=1) is equivalent to 335// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input 336// convolution. 337TEST_F(CudnnConvolutionRewriterTest, 338 BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { 339 auto builder = HloComputation::Builder(TestName()); 340 // NHWC dimension order. 341 HloInstruction* output = 342 builder.AddInstruction(HloInstruction::CreateParameter( 343 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); 344 // HWOI dimension order. 345 HloInstruction* kernel = 346 builder.AddInstruction(HloInstruction::CreateParameter( 347 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); 348 349 builder.AddInstruction(HloInstruction::CreateConvolve( 350 ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), 351 default_conv_window_, 352 tf_default_dnums_for_backward_input_) 353 .ConsumeValueOrDie(), 354 /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, 355 tf_default_dnums_for_backward_input_)); 356 357 auto module = CreateNewModule(); 358 HloComputation* entry_computation = 359 module->AddEntryComputation(builder.Build()); 360 EXPECT_TRUE(RunPass(module.get())); 361 EXPECT_THAT( 362 entry_computation->root_instruction(), 363 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); 364} 365 366// Extracted from Inception V3 training. 367// 368// filter(HWIO) 369// 3x3x192x320 370// | 371// v 372// gradients(NHWC) reverse 373// 20x4x4x320 3x3x192x320 374// \ / 375// \ / 376// conv (NHWC) with padding (low=2,high=3,interior=1) 377// 20x10x10x192 378// 379// Gradients are padded unevenly. 380TEST_F(CudnnConvolutionRewriterTest, 381 BackwardInputConvolveUnevenPaddingOnGradients) { 382 auto builder = HloComputation::Builder(TestName()); 383 HloInstruction* output = 384 builder.AddInstruction(HloInstruction::CreateParameter( 385 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); 386 HloInstruction* kernel = 387 builder.AddInstruction(HloInstruction::CreateParameter( 388 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); 389 HloInstruction* reverse_kernel = builder.AddInstruction( 390 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); 391 392 Window conv_window = default_conv_window_; 393 for (int i = 0; i < 2; ++i) { 394 conv_window.mutable_dimensions(i)->set_size(3); 395 conv_window.mutable_dimensions(i)->set_padding_low(2); 396 conv_window.mutable_dimensions(i)->set_padding_high(3); 397 // Interior padding = 1. 398 conv_window.mutable_dimensions(i)->set_base_dilation(2); 399 } 400 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 401 ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, 402 conv_window, tf_default_dnums_for_backward_input_)); 403 // Verify the convolution's shape is consistent with ShapeInference. 404 CHECK(ShapeUtil::Compatible( 405 conv->shape(), ShapeInference::InferConvolveShape( 406 output->shape(), reverse_kernel->shape(), conv_window, 407 tf_default_dnums_for_backward_input_) 408 .ValueOrDie())); 409 410 auto module = CreateNewModule(); 411 HloComputation* entry_computation = 412 module->AddEntryComputation(builder.Build()); 413 EXPECT_TRUE(RunPass(module.get())); 414 ASSERT_THAT(entry_computation->root_instruction(), 415 op::GetTupleElement( 416 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); 417 const HloInstruction* custom_call = 418 entry_computation->root_instruction()->operand(0); 419 for (int i = 0; i < 2; ++i) { 420 const WindowDimension& window_dim = custom_call->window().dimensions(i); 421 EXPECT_EQ(0, window_dim.padding_low()); 422 EXPECT_EQ(0, window_dim.padding_high()); 423 EXPECT_EQ(2, window_dim.stride()); 424 } 425} 426 427// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the 428// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. 429TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { 430 auto builder = HloComputation::Builder(TestName()); 431 HloInstruction* output = 432 builder.AddInstruction(HloInstruction::CreateParameter( 433 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); 434 HloInstruction* kernel = 435 builder.AddInstruction(HloInstruction::CreateParameter( 436 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); 437 HloInstruction* reverse_kernel = builder.AddInstruction( 438 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); 439 440 Window conv_window = default_conv_window_; 441 for (int i = 0; i < 2; ++i) { 442 conv_window.mutable_dimensions(i)->set_size(3); 443 conv_window.mutable_dimensions(i)->set_padding_low(3); 444 conv_window.mutable_dimensions(i)->set_padding_high(2); 445 conv_window.mutable_dimensions(i)->set_base_dilation(2); 446 } 447 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 448 ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, 449 conv_window, tf_default_dnums_for_backward_input_)); 450 // Verify the convolution's shape is consistent with ShapeInference. 451 CHECK(ShapeUtil::Compatible( 452 conv->shape(), ShapeInference::InferConvolveShape( 453 output->shape(), reverse_kernel->shape(), conv_window, 454 tf_default_dnums_for_backward_input_) 455 .ValueOrDie())); 456 457 auto module = CreateNewModule(); 458 HloComputation* entry_computation = 459 module->AddEntryComputation(builder.Build()); 460 EXPECT_TRUE(RunPass(module.get())); 461 EXPECT_THAT( 462 entry_computation->root_instruction(), 463 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); 464} 465 466// Extracted from //learning/brain/google/xla/benchmarks/resnet.py 467// 468// For simplicity, we focus on the column dimension and ignore other dimensions. 469// We use [?] to represent the shape instead of the content. 470// 471// Suppose operator FC does 472// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame 473// 474// BC = BackwardInput(FC) does: 475// [14] = conv([7], reverse([3]), 476// padding_low=2, padding_high=1, base_dilation=2) 477// 478// We should fuse BC even though padding on activations is uneven, because 479// PadInsertion will canonicalize the fusion HLO. 480TEST_F(CudnnConvolutionRewriterTest, 481 BackwardInputConvolveUnevenPaddingOnActivations) { 482 auto builder = HloComputation::Builder(TestName()); 483 // The gradients are in NCHW layout. 484 HloInstruction* output = 485 builder.AddInstruction(HloInstruction::CreateParameter( 486 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output")); 487 // The kernel is in HWIO layout. 488 HloInstruction* kernel = 489 builder.AddInstruction(HloInstruction::CreateParameter( 490 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel")); 491 HloInstruction* reverse_kernel = builder.AddInstruction( 492 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); 493 494 Window conv_window = default_conv_window_; 495 WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); 496 forward_conv_col_dim->set_size(3); 497 forward_conv_col_dim->set_padding_low(2); 498 forward_conv_col_dim->set_padding_high(1); 499 forward_conv_col_dim->set_base_dilation(2); 500 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 501 ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, 502 conv_window, tf_default_dnums_for_backward_input_)); 503 // Verify the convolution's shape is consistent with ShapeInference. 504 CHECK(ShapeUtil::Compatible( 505 conv->shape(), ShapeInference::InferConvolveShape( 506 output->shape(), reverse_kernel->shape(), conv_window, 507 tf_default_dnums_for_backward_input_) 508 .ValueOrDie())); 509 510 auto module = CreateNewModule(); 511 const HloComputation* entry_computation = 512 module->AddEntryComputation(builder.Build()); 513 EXPECT_TRUE(RunPass(module.get())); 514 ASSERT_THAT(entry_computation->root_instruction(), 515 op::GetTupleElement( 516 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); 517 const WindowDimension& backward_conv_col_dim = 518 entry_computation->root_instruction()->operand(0)->window().dimensions(1); 519 EXPECT_EQ(0, backward_conv_col_dim.padding_low()); 520 EXPECT_EQ(1, backward_conv_col_dim.padding_high()); 521} 522 523// For simplicity, we focus on the column dimension and ignore other dimensions. 524// We use [?] to represent the shape instead of the content. 525// 526// Suppose operator FC does 527// [3] = conv([4], [2], padding_low=1, padding_high=-1) 528// 529// BC = BackwardInput(FC) does: 530// [4] = conv([3], reverse([2]), padding_high=2) 531// 532// We currently don't fuse BC because PadInsertion doesn't support negative 533// padding on the gradients of backward convolution (b/32744257). 534TEST_F(CudnnConvolutionRewriterTest, 535 BackwardInputConvolveNegativePaddingHighOnActivations) { 536 auto builder = HloComputation::Builder(TestName()); 537 // The gradients are in NCHW layout. 538 HloInstruction* output = 539 builder.AddInstruction(HloInstruction::CreateParameter( 540 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); 541 // The kernel is in HWIO layout. 542 HloInstruction* kernel = 543 builder.AddInstruction(HloInstruction::CreateParameter( 544 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel")); 545 HloInstruction* reverse_kernel = builder.AddInstruction( 546 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); 547 548 Window conv_window = default_conv_window_; 549 WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); 550 forward_conv_col_dim->set_size(2); 551 forward_conv_col_dim->set_padding_high(2); 552 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( 553 ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, 554 conv_window, tf_default_dnums_for_backward_input_)); 555 // Verify the convolution's shape is consistent with ShapeInference. 556 CHECK(ShapeUtil::Compatible( 557 conv->shape(), ShapeInference::InferConvolveShape( 558 output->shape(), reverse_kernel->shape(), conv_window, 559 tf_default_dnums_for_backward_input_) 560 .ValueOrDie())); 561 562 auto module = CreateNewModule(); 563 HloComputation* entry_computation = 564 module->AddEntryComputation(builder.Build()); 565 EXPECT_TRUE(RunPass(module.get())); 566 EXPECT_THAT( 567 entry_computation->root_instruction(), 568 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); 569} 570 571} // anonymous namespace 572} // namespace gpu 573} // namespace xla 574