1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
17
18#include "tensorflow/compiler/xla/literal_util.h"
19#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
20#include "tensorflow/compiler/xla/service/shape_inference.h"
21#include "tensorflow/compiler/xla/util.h"
22#include "tensorflow/compiler/xla/window_util.h"
23#include "tensorflow/compiler/xla/xla_data.pb.h"
24
25namespace xla {
26namespace gpu {
27
28namespace {
29bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
30  CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
31  return window_util::HasSymmetricPadding(conv.window()) &&
32         !window_util::HasNegativePadding(conv.window()) &&
33         !window_util::HasDilation(conv.window());
34}
35
36// If the (positive and negative) padding on the input operand of a convolution
37// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
38// dilation), returns kPad and/or kSlice instructions that explicitly apply the
39// padding; otherwise returns the original input operand. When there is both
40// positive padding (including dilation) and negative padding, we insert both
41// kPad and kSlice.
42HloInstruction* MaybePaddedAndSlicedInput(
43    const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums,
44    HloInstruction* input) {
45  HloComputation* computation = input->parent();
46  if (!window_util::HasSymmetricPadding(conv_window) ||
47      window_util::HasBaseDilation(conv_window)) {
48    // If padding is uneven or has dilation, we insert a kPad instruction that
49    // applies positive padding and dilation.
50    //
51    // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
52    // moving all the padding into an explicit pad op, we should keep as much
53    // padding inside of cudnn as possible, on the assumption that padding
54    // within cudnn is basically free, whereas a kPad's cost increases as the
55    // amount of padding increases.
56    PaddingConfig padding_config =
57        MakeNoPaddingConfig(input->shape().dimensions_size());
58    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
59      int64 dim = conv_dnums.input_spatial_dimensions(i);
60      padding_config.mutable_dimensions(dim)->set_edge_padding_low(
61          std::max<int64>(0LL, conv_window.dimensions(i).padding_low()));
62      padding_config.mutable_dimensions(dim)->set_edge_padding_high(
63          std::max<int64>(0LL, conv_window.dimensions(i).padding_high()));
64      padding_config.mutable_dimensions(dim)->set_interior_padding(
65          conv_window.dimensions(i).base_dilation() - 1);
66    }
67    PrimitiveType element_type = input->shape().element_type();
68    HloInstruction* padding =
69        computation->AddInstruction(HloInstruction::CreateConstant(
70            MakeUnique<Literal>(Literal::Zero(element_type))));
71    input = computation->AddInstruction(HloInstruction::CreatePad(
72        ShapeInference::InferPadShape(
73            /*operand_shape=*/input->shape(),
74            /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
75            padding_config)
76            .ConsumeValueOrDie(),
77        input, padding, padding_config));
78  }
79
80  if (window_util::HasNegativePadding(conv_window)) {
81    // If the window has negative padding, insert a kSlice that explicitly
82    // applies negative padding.
83    //
84    // For each dimension, initialize the start index to 0 and the limit index
85    // to the size of that dimension.
86    std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
87    std::vector<int64> limit_indices(input->shape().dimensions().begin(),
88                                     input->shape().dimensions().end());
89    std::vector<int64> strides(input->shape().dimensions_size(), 1);
90    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
91      int64 dim = conv_dnums.input_spatial_dimensions(i);
92      // If dimension "dim" has negative padding, increase the start index or
93      // decrement the limit index by the amount of negative padding.
94      start_indices[dim] +=
95          std::max<int64>(0LL, -conv_window.dimensions(i).padding_low());
96      limit_indices[dim] -=
97          std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
98    }
99
100    input = computation->AddInstruction(HloInstruction::CreateSlice(
101        ShapeInference::InferSliceShape(input->shape(), start_indices,
102                                        limit_indices, strides)
103            .ConsumeValueOrDie(),
104        input, start_indices, limit_indices, strides));
105  }
106
107  return input;
108}
109
110// If the padding on the kernel operand of a convolution can't be folded into a
111// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
112// explicitly applies the padding; otherwise returns the original kernel
113// operand.
114HloInstruction* MaybePaddedKernel(const Window& conv_window,
115                                  const ConvolutionDimensionNumbers& conv_dnums,
116                                  HloInstruction* kernel) {
117  if (!window_util::HasWindowDilation(conv_window)) {
118    return kernel;
119  }
120
121  // Compute the shape and padding config of the pad to be inserted.
122  PaddingConfig padding_config;
123  for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
124    padding_config.add_dimensions();
125  }
126  for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
127    int64 dim = conv_dnums.kernel_spatial_dimensions(i);
128    padding_config.mutable_dimensions(dim)->set_interior_padding(
129        conv_window.dimensions(i).window_dilation() - 1);
130  }
131
132  HloComputation* computation = kernel->parent();
133  PrimitiveType element_type = kernel->shape().element_type();
134  HloInstruction* padding =
135      computation->AddInstruction(HloInstruction::CreateConstant(
136          MakeUnique<Literal>(Literal::Zero(element_type))));
137  return computation->AddInstruction(HloInstruction::CreatePad(
138      ShapeInference::InferPadShape(
139          /*operand_shape=*/kernel->shape(),
140          /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
141          padding_config)
142          .ConsumeValueOrDie(),
143      kernel, padding, padding_config));
144}
145}  // namespace
146
147bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
148  if (IsForwardConvolutionCanonical(*conv)) {
149    return false;
150  }
151
152  // Insert slices and/or pads between the convolution and its input and/or
153  // kernel operand.
154  HloInstruction* new_input = MaybePaddedAndSlicedInput(
155      conv->window(), conv->convolution_dimension_numbers(),
156      conv->mutable_operand(0));
157  HloInstruction* new_kernel =
158      MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(),
159                        conv->mutable_operand(1));
160
161  // Remove the padding from convolution's window field. These paddings are
162  // made explicit with the inserted pads.
163  Window new_conv_window = conv->window();
164  for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
165    WindowDimension* dim = new_conv_window.mutable_dimensions(i);
166
167    // The size of the kernel may have changed so update the Window to match.
168    dim->set_size(new_kernel->shape().dimensions(
169        conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
170    dim->set_padding_low(0);
171    dim->set_padding_high(0);
172    dim->set_base_dilation(1);
173    dim->set_window_dilation(1);
174  }
175
176  // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
177  // out the shape of conv_result.
178  Shape old_conv_shape = conv->shape().tuple_shapes(0);
179
180  VLOG(1) << "Canonicalizing forward conv";
181  auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
182                                         new_conv_window,
183                                         conv->convolution_dimension_numbers());
184  VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
185          << new_conv->ToString();
186  TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
187  return true;
188}
189
190namespace {
191void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
192  window_dim->set_padding_low(window_dim->padding_low() + delta);
193}
194
195void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
196  window_dim->set_padding_high(window_dim->padding_high() + delta);
197}
198}  // namespace
199
200bool PadInsertion::CanonicalizeBackwardFilterConvolution(
201    HloInstruction* backward_conv) {
202  CHECK_EQ(backward_conv->custom_call_target(),
203           kCudnnConvBackwardFilterCallTarget);
204  if (window_util::HasSymmetricPadding(backward_conv->window())) {
205    return false;
206  }
207
208  // A backward filter convolution with uneven padding can be canonicalized to
209  // one with even padding by padding the activations (input) beforehand. For
210  // example,
211  //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
212  // is equivalent to
213  //   ABCD0 = Pad(ABCD, padding_high=1)
214  //   BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
215  // We choose the lesser of padding_low and padding_high as the new padding.
216  HloInstruction* input = backward_conv->mutable_operand(0);
217  Window new_backward_conv_window = backward_conv->window();
218  // input_padding_config is the config of the kPad to be inserted.
219  PaddingConfig input_padding_config =
220      MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
221  ConvolutionDimensionNumbers backward_conv_dnums =
222      backward_conv->convolution_dimension_numbers();
223  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
224    int64 padding_low = backward_conv->window().dimensions(i).padding_low();
225    int64 padding_high = backward_conv->window().dimensions(i).padding_high();
226    if (padding_low < 0 || padding_high < 0) {
227      // TODO(b/32744257): The following canonicalization wouldn't remove
228      // negative padding in a backward convolution, and would therefore cause
229      // cuDNN convolution (which doesn't support negative padding) to fail.
230      return false;
231    }
232    // Compute the new, even padding for the backward conv operation.
233    int64 new_conv_padding = std::min(padding_low, padding_high);
234    int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
235    input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
236        padding_low - new_conv_padding);
237    input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
238        padding_high - new_conv_padding);
239
240    // Since we move some padding from the backward convolution to the kPad, we
241    // need to accordingly reduce the padding amount of the backward convolution
242    // and its inner forward convolution.
243    auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
244    new_dim->set_padding_low(new_conv_padding);
245    new_dim->set_padding_high(new_conv_padding);
246  }
247
248  // Create a new backward convolution replacing the old one.
249  HloComputation* computation = backward_conv->parent();
250  HloInstruction* output = backward_conv->mutable_operand(1);
251  HloInstruction* padding =
252      computation->AddInstruction(HloInstruction::CreateConstant(
253          MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
254  HloInstruction* padded_input =
255      computation->AddInstruction(HloInstruction::CreatePad(
256          ShapeInference::InferPadShape(input->shape(), padding->shape(),
257                                        input_padding_config)
258              .ConsumeValueOrDie(),
259          input, padding, input_padding_config));
260
261  // The shape of the backward_conv CustomCall is a tuple (conv_result,
262  // scratch_buffer).  Extract out the shape of conv_result.
263  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
264  HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
265      backward_conv_shape, padded_input, output, new_backward_conv_window,
266      backward_conv_dnums);
267
268  VLOG(1) << "Canonicalizing backward filter conv";
269  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
270          << new_backward_conv->ToString();
271
272  TF_CHECK_OK(
273      computation->ReplaceInstruction(backward_conv, new_backward_conv));
274  return true;
275}
276
277bool PadInsertion::CanonicalizeBackwardInputConvolution(
278    HloInstruction* backward_conv) {
279  if (window_util::HasSymmetricPadding(backward_conv->window())) {
280    return false;
281  }
282
283  Window new_backward_conv_window = backward_conv->window();
284  ConvolutionDimensionNumbers backward_conv_dnums =
285      backward_conv->convolution_dimension_numbers();
286
287  // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
288  // Get the shape of conv_result.
289  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
290
291  Shape new_backward_conv_shape = backward_conv_shape;
292  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
293    int64 padding_low = backward_conv->window().dimensions(i).padding_low();
294    int64 padding_high = backward_conv->window().dimensions(i).padding_high();
295    if (padding_low < 0 || padding_high < 0) {
296      // TODO(b/32744257): The following canonicalization wouldn't remove
297      // negative padding in a backward convolution, and would therefore cause
298      // cuDNN convolution (which doesn't support negative padding) to fail.
299      return false;
300    }
301    // If the backward convolution has uneven padding on the activations, we
302    // move some padding on the larger end to "internal" padding, so that the
303    // backward convolution produces larger activations which get sliced later.
304    //
305    // For example, suppose we have a non-canonical HLO
306    //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
307    // where the amount of padding low is larger, we can canonicalize it to
308    //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
309    //   [A] = Slice([B A])
310    if (padding_low > padding_high) {
311      IncreasePaddingLowBy(padding_high - padding_low,
312                           new_backward_conv_window.mutable_dimensions(i));
313    } else if (padding_low < padding_high) {
314      IncreasePaddingHighBy(padding_low - padding_high,
315                            new_backward_conv_window.mutable_dimensions(i));
316    }
317    // Decreasing the padding by X *increases* the size of our output by X.
318    int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
319    new_backward_conv_shape.set_dimensions(
320        dim, new_backward_conv_shape.dimensions(dim) +
321                 std::abs(padding_low - padding_high));
322  }
323
324  // Create a new backward convolution replacing the old one.
325  HloComputation* computation = backward_conv->parent();
326  HloInstruction* output = backward_conv->mutable_operand(0);
327  HloInstruction* filter = backward_conv->mutable_operand(1);
328
329  HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
330      new_backward_conv_shape, output, filter, new_backward_conv_window,
331      backward_conv_dnums);
332
333  // The CustomCall created above returns a tuple (conv_result, scratch_memory).
334  // Extract out the two elements.
335  HloInstruction* new_backward_conv =
336      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
337          new_backward_conv_shape, new_backward_conv_call, 0));
338  HloInstruction* new_backward_conv_scratch =
339      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
340          new_backward_conv_call->shape().tuple_shapes(1),
341          new_backward_conv_call, 1));
342
343  // Slice the new backward convolution.
344  //
345  // Initialize start_indices and limit_indices as no slicing.
346  std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
347                                   0LL);
348  std::vector<int64> limit_indices(
349      new_backward_conv->shape().dimensions().begin(),
350      new_backward_conv->shape().dimensions().end());
351  std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
352  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
353    int64 padding_low = backward_conv->window().dimensions(i).padding_low();
354    int64 padding_high = backward_conv->window().dimensions(i).padding_high();
355    int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
356    if (padding_low > padding_high) {
357      // If the amount of low padding (of the old backward convolution) is
358      // larger, we internally pad the low end of the activations and slice
359      // internal padding out here.
360      start_indices[dim] += padding_low - padding_high;
361    } else if (padding_low < padding_high) {
362      // If the amount of high padding is larger, we slice out the internal
363      // padding on the high end.
364      limit_indices[dim] -= padding_high - padding_low;
365    }
366  }
367
368  // Replace the old backward convolution with the slice.
369  Shape slice_shape =
370      ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
371                                      limit_indices, strides)
372          .ConsumeValueOrDie();
373  CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
374      << ShapeUtil::HumanString(slice_shape) << " vs "
375      << ShapeUtil::HumanString(backward_conv_shape);
376
377  HloInstruction* slice = computation->AddInstruction(
378      HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
379                                  start_indices, limit_indices, strides));
380  HloInstruction* new_tuple = computation->AddInstruction(
381      HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
382
383  VLOG(1) << "Canonicalizing backward input conv";
384  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
385          << new_tuple->ToString();
386
387  TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
388  return true;
389}
390
391StatusOr<bool> PadInsertion::Run(HloModule* module) {
392  bool changed = false;
393  for (HloInstruction* instruction :
394       module->entry_computation()->MakeInstructionPostOrder()) {
395    if (IsCustomCallToDnnConvolution(*instruction)) {
396      const auto& target = instruction->custom_call_target();
397      if (target == kCudnnConvForwardCallTarget) {
398        changed |= CanonicalizeForwardConvolution(instruction);
399      } else if (target == kCudnnConvBackwardFilterCallTarget) {
400        changed |= CanonicalizeBackwardFilterConvolution(instruction);
401      } else if (target == kCudnnConvBackwardInputCallTarget) {
402        changed |= CanonicalizeBackwardInputConvolution(instruction);
403      } else {
404        LOG(FATAL) << "Unknown custom call target for cudnn conv: "
405                   << instruction->ToString();
406      }
407    }
408  }
409  return changed;
410}
411
412}  // namespace gpu
413}  // namespace xla
414