1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" 17 18#include "tensorflow/compiler/xla/literal_util.h" 19#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 20#include "tensorflow/compiler/xla/service/shape_inference.h" 21#include "tensorflow/compiler/xla/util.h" 22#include "tensorflow/compiler/xla/window_util.h" 23#include "tensorflow/compiler/xla/xla_data.pb.h" 24 25namespace xla { 26namespace gpu { 27 28namespace { 29bool IsForwardConvolutionCanonical(const HloInstruction& conv) { 30 CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); 31 return window_util::HasSymmetricPadding(conv.window()) && 32 !window_util::HasNegativePadding(conv.window()) && 33 !window_util::HasDilation(conv.window()); 34} 35 36// If the (positive and negative) padding on the input operand of a convolution 37// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and 38// dilation), returns kPad and/or kSlice instructions that explicitly apply the 39// padding; otherwise returns the original input operand. When there is both 40// positive padding (including dilation) and negative padding, we insert both 41// kPad and kSlice. 42HloInstruction* MaybePaddedAndSlicedInput( 43 const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, 44 HloInstruction* input) { 45 HloComputation* computation = input->parent(); 46 if (!window_util::HasSymmetricPadding(conv_window) || 47 window_util::HasBaseDilation(conv_window)) { 48 // If padding is uneven or has dilation, we insert a kPad instruction that 49 // applies positive padding and dilation. 50 // 51 // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of 52 // moving all the padding into an explicit pad op, we should keep as much 53 // padding inside of cudnn as possible, on the assumption that padding 54 // within cudnn is basically free, whereas a kPad's cost increases as the 55 // amount of padding increases. 56 PaddingConfig padding_config = 57 MakeNoPaddingConfig(input->shape().dimensions_size()); 58 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { 59 int64 dim = conv_dnums.input_spatial_dimensions(i); 60 padding_config.mutable_dimensions(dim)->set_edge_padding_low( 61 std::max<int64>(0LL, conv_window.dimensions(i).padding_low())); 62 padding_config.mutable_dimensions(dim)->set_edge_padding_high( 63 std::max<int64>(0LL, conv_window.dimensions(i).padding_high())); 64 padding_config.mutable_dimensions(dim)->set_interior_padding( 65 conv_window.dimensions(i).base_dilation() - 1); 66 } 67 PrimitiveType element_type = input->shape().element_type(); 68 HloInstruction* padding = 69 computation->AddInstruction(HloInstruction::CreateConstant( 70 MakeUnique<Literal>(Literal::Zero(element_type)))); 71 input = computation->AddInstruction(HloInstruction::CreatePad( 72 ShapeInference::InferPadShape( 73 /*operand_shape=*/input->shape(), 74 /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), 75 padding_config) 76 .ConsumeValueOrDie(), 77 input, padding, padding_config)); 78 } 79 80 if (window_util::HasNegativePadding(conv_window)) { 81 // If the window has negative padding, insert a kSlice that explicitly 82 // applies negative padding. 83 // 84 // For each dimension, initialize the start index to 0 and the limit index 85 // to the size of that dimension. 86 std::vector<int64> start_indices(input->shape().dimensions_size(), 0); 87 std::vector<int64> limit_indices(input->shape().dimensions().begin(), 88 input->shape().dimensions().end()); 89 std::vector<int64> strides(input->shape().dimensions_size(), 1); 90 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { 91 int64 dim = conv_dnums.input_spatial_dimensions(i); 92 // If dimension "dim" has negative padding, increase the start index or 93 // decrement the limit index by the amount of negative padding. 94 start_indices[dim] += 95 std::max<int64>(0LL, -conv_window.dimensions(i).padding_low()); 96 limit_indices[dim] -= 97 std::max<int64>(0LL, -conv_window.dimensions(i).padding_high()); 98 } 99 100 input = computation->AddInstruction(HloInstruction::CreateSlice( 101 ShapeInference::InferSliceShape(input->shape(), start_indices, 102 limit_indices, strides) 103 .ConsumeValueOrDie(), 104 input, start_indices, limit_indices, strides)); 105 } 106 107 return input; 108} 109 110// If the padding on the kernel operand of a convolution can't be folded into a 111// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that 112// explicitly applies the padding; otherwise returns the original kernel 113// operand. 114HloInstruction* MaybePaddedKernel(const Window& conv_window, 115 const ConvolutionDimensionNumbers& conv_dnums, 116 HloInstruction* kernel) { 117 if (!window_util::HasWindowDilation(conv_window)) { 118 return kernel; 119 } 120 121 // Compute the shape and padding config of the pad to be inserted. 122 PaddingConfig padding_config; 123 for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { 124 padding_config.add_dimensions(); 125 } 126 for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { 127 int64 dim = conv_dnums.kernel_spatial_dimensions(i); 128 padding_config.mutable_dimensions(dim)->set_interior_padding( 129 conv_window.dimensions(i).window_dilation() - 1); 130 } 131 132 HloComputation* computation = kernel->parent(); 133 PrimitiveType element_type = kernel->shape().element_type(); 134 HloInstruction* padding = 135 computation->AddInstruction(HloInstruction::CreateConstant( 136 MakeUnique<Literal>(Literal::Zero(element_type)))); 137 return computation->AddInstruction(HloInstruction::CreatePad( 138 ShapeInference::InferPadShape( 139 /*operand_shape=*/kernel->shape(), 140 /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), 141 padding_config) 142 .ConsumeValueOrDie(), 143 kernel, padding, padding_config)); 144} 145} // namespace 146 147bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { 148 if (IsForwardConvolutionCanonical(*conv)) { 149 return false; 150 } 151 152 // Insert slices and/or pads between the convolution and its input and/or 153 // kernel operand. 154 HloInstruction* new_input = MaybePaddedAndSlicedInput( 155 conv->window(), conv->convolution_dimension_numbers(), 156 conv->mutable_operand(0)); 157 HloInstruction* new_kernel = 158 MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), 159 conv->mutable_operand(1)); 160 161 // Remove the padding from convolution's window field. These paddings are 162 // made explicit with the inserted pads. 163 Window new_conv_window = conv->window(); 164 for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { 165 WindowDimension* dim = new_conv_window.mutable_dimensions(i); 166 167 // The size of the kernel may have changed so update the Window to match. 168 dim->set_size(new_kernel->shape().dimensions( 169 conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); 170 dim->set_padding_low(0); 171 dim->set_padding_high(0); 172 dim->set_base_dilation(1); 173 dim->set_window_dilation(1); 174 } 175 176 // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract 177 // out the shape of conv_result. 178 Shape old_conv_shape = conv->shape().tuple_shapes(0); 179 180 VLOG(1) << "Canonicalizing forward conv"; 181 auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, 182 new_conv_window, 183 conv->convolution_dimension_numbers()); 184 VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " 185 << new_conv->ToString(); 186 TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); 187 return true; 188} 189 190namespace { 191void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) { 192 window_dim->set_padding_low(window_dim->padding_low() + delta); 193} 194 195void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { 196 window_dim->set_padding_high(window_dim->padding_high() + delta); 197} 198} // namespace 199 200bool PadInsertion::CanonicalizeBackwardFilterConvolution( 201 HloInstruction* backward_conv) { 202 CHECK_EQ(backward_conv->custom_call_target(), 203 kCudnnConvBackwardFilterCallTarget); 204 if (window_util::HasSymmetricPadding(backward_conv->window())) { 205 return false; 206 } 207 208 // A backward filter convolution with uneven padding can be canonicalized to 209 // one with even padding by padding the activations (input) beforehand. For 210 // example, 211 // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2) 212 // is equivalent to 213 // ABCD0 = Pad(ABCD, padding_high=1) 214 // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) 215 // We choose the lesser of padding_low and padding_high as the new padding. 216 HloInstruction* input = backward_conv->mutable_operand(0); 217 Window new_backward_conv_window = backward_conv->window(); 218 // input_padding_config is the config of the kPad to be inserted. 219 PaddingConfig input_padding_config = 220 MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); 221 ConvolutionDimensionNumbers backward_conv_dnums = 222 backward_conv->convolution_dimension_numbers(); 223 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { 224 int64 padding_low = backward_conv->window().dimensions(i).padding_low(); 225 int64 padding_high = backward_conv->window().dimensions(i).padding_high(); 226 if (padding_low < 0 || padding_high < 0) { 227 // TODO(b/32744257): The following canonicalization wouldn't remove 228 // negative padding in a backward convolution, and would therefore cause 229 // cuDNN convolution (which doesn't support negative padding) to fail. 230 return false; 231 } 232 // Compute the new, even padding for the backward conv operation. 233 int64 new_conv_padding = std::min(padding_low, padding_high); 234 int64 dim = backward_conv_dnums.input_spatial_dimensions(i); 235 input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( 236 padding_low - new_conv_padding); 237 input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( 238 padding_high - new_conv_padding); 239 240 // Since we move some padding from the backward convolution to the kPad, we 241 // need to accordingly reduce the padding amount of the backward convolution 242 // and its inner forward convolution. 243 auto* new_dim = new_backward_conv_window.mutable_dimensions(i); 244 new_dim->set_padding_low(new_conv_padding); 245 new_dim->set_padding_high(new_conv_padding); 246 } 247 248 // Create a new backward convolution replacing the old one. 249 HloComputation* computation = backward_conv->parent(); 250 HloInstruction* output = backward_conv->mutable_operand(1); 251 HloInstruction* padding = 252 computation->AddInstruction(HloInstruction::CreateConstant( 253 MakeUnique<Literal>(Literal::Zero(input->shape().element_type())))); 254 HloInstruction* padded_input = 255 computation->AddInstruction(HloInstruction::CreatePad( 256 ShapeInference::InferPadShape(input->shape(), padding->shape(), 257 input_padding_config) 258 .ConsumeValueOrDie(), 259 input, padding, input_padding_config)); 260 261 // The shape of the backward_conv CustomCall is a tuple (conv_result, 262 // scratch_buffer). Extract out the shape of conv_result. 263 Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); 264 HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( 265 backward_conv_shape, padded_input, output, new_backward_conv_window, 266 backward_conv_dnums); 267 268 VLOG(1) << "Canonicalizing backward filter conv"; 269 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " 270 << new_backward_conv->ToString(); 271 272 TF_CHECK_OK( 273 computation->ReplaceInstruction(backward_conv, new_backward_conv)); 274 return true; 275} 276 277bool PadInsertion::CanonicalizeBackwardInputConvolution( 278 HloInstruction* backward_conv) { 279 if (window_util::HasSymmetricPadding(backward_conv->window())) { 280 return false; 281 } 282 283 Window new_backward_conv_window = backward_conv->window(); 284 ConvolutionDimensionNumbers backward_conv_dnums = 285 backward_conv->convolution_dimension_numbers(); 286 287 // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). 288 // Get the shape of conv_result. 289 Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); 290 291 Shape new_backward_conv_shape = backward_conv_shape; 292 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { 293 int64 padding_low = backward_conv->window().dimensions(i).padding_low(); 294 int64 padding_high = backward_conv->window().dimensions(i).padding_high(); 295 if (padding_low < 0 || padding_high < 0) { 296 // TODO(b/32744257): The following canonicalization wouldn't remove 297 // negative padding in a backward convolution, and would therefore cause 298 // cuDNN convolution (which doesn't support negative padding) to fail. 299 return false; 300 } 301 // If the backward convolution has uneven padding on the activations, we 302 // move some padding on the larger end to "internal" padding, so that the 303 // backward convolution produces larger activations which get sliced later. 304 // 305 // For example, suppose we have a non-canonical HLO 306 // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1)) 307 // where the amount of padding low is larger, we can canonicalize it to 308 // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) 309 // [A] = Slice([B A]) 310 if (padding_low > padding_high) { 311 IncreasePaddingLowBy(padding_high - padding_low, 312 new_backward_conv_window.mutable_dimensions(i)); 313 } else if (padding_low < padding_high) { 314 IncreasePaddingHighBy(padding_low - padding_high, 315 new_backward_conv_window.mutable_dimensions(i)); 316 } 317 // Decreasing the padding by X *increases* the size of our output by X. 318 int64 dim = backward_conv_dnums.output_spatial_dimensions(i); 319 new_backward_conv_shape.set_dimensions( 320 dim, new_backward_conv_shape.dimensions(dim) + 321 std::abs(padding_low - padding_high)); 322 } 323 324 // Create a new backward convolution replacing the old one. 325 HloComputation* computation = backward_conv->parent(); 326 HloInstruction* output = backward_conv->mutable_operand(0); 327 HloInstruction* filter = backward_conv->mutable_operand(1); 328 329 HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( 330 new_backward_conv_shape, output, filter, new_backward_conv_window, 331 backward_conv_dnums); 332 333 // The CustomCall created above returns a tuple (conv_result, scratch_memory). 334 // Extract out the two elements. 335 HloInstruction* new_backward_conv = 336 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 337 new_backward_conv_shape, new_backward_conv_call, 0)); 338 HloInstruction* new_backward_conv_scratch = 339 computation->AddInstruction(HloInstruction::CreateGetTupleElement( 340 new_backward_conv_call->shape().tuple_shapes(1), 341 new_backward_conv_call, 1)); 342 343 // Slice the new backward convolution. 344 // 345 // Initialize start_indices and limit_indices as no slicing. 346 std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(), 347 0LL); 348 std::vector<int64> limit_indices( 349 new_backward_conv->shape().dimensions().begin(), 350 new_backward_conv->shape().dimensions().end()); 351 std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL); 352 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { 353 int64 padding_low = backward_conv->window().dimensions(i).padding_low(); 354 int64 padding_high = backward_conv->window().dimensions(i).padding_high(); 355 int64 dim = backward_conv_dnums.output_spatial_dimensions(i); 356 if (padding_low > padding_high) { 357 // If the amount of low padding (of the old backward convolution) is 358 // larger, we internally pad the low end of the activations and slice 359 // internal padding out here. 360 start_indices[dim] += padding_low - padding_high; 361 } else if (padding_low < padding_high) { 362 // If the amount of high padding is larger, we slice out the internal 363 // padding on the high end. 364 limit_indices[dim] -= padding_high - padding_low; 365 } 366 } 367 368 // Replace the old backward convolution with the slice. 369 Shape slice_shape = 370 ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, 371 limit_indices, strides) 372 .ConsumeValueOrDie(); 373 CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) 374 << ShapeUtil::HumanString(slice_shape) << " vs " 375 << ShapeUtil::HumanString(backward_conv_shape); 376 377 HloInstruction* slice = computation->AddInstruction( 378 HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, 379 start_indices, limit_indices, strides)); 380 HloInstruction* new_tuple = computation->AddInstruction( 381 HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); 382 383 VLOG(1) << "Canonicalizing backward input conv"; 384 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " 385 << new_tuple->ToString(); 386 387 TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); 388 return true; 389} 390 391StatusOr<bool> PadInsertion::Run(HloModule* module) { 392 bool changed = false; 393 for (HloInstruction* instruction : 394 module->entry_computation()->MakeInstructionPostOrder()) { 395 if (IsCustomCallToDnnConvolution(*instruction)) { 396 const auto& target = instruction->custom_call_target(); 397 if (target == kCudnnConvForwardCallTarget) { 398 changed |= CanonicalizeForwardConvolution(instruction); 399 } else if (target == kCudnnConvBackwardFilterCallTarget) { 400 changed |= CanonicalizeBackwardFilterConvolution(instruction); 401 } else if (target == kCudnnConvBackwardInputCallTarget) { 402 changed |= CanonicalizeBackwardInputConvolution(instruction); 403 } else { 404 LOG(FATAL) << "Unknown custom call target for cudnn conv: " 405 << instruction->ToString(); 406 } 407 } 408 } 409 return changed; 410} 411 412} // namespace gpu 413} // namespace xla 414