fused_conv2d_bias_activation_op.cc revision 5eaefbabce16bffeeb4b19cee9890b1aeccabb09
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#if GOOGLE_CUDA
17#define EIGEN_USE_GPU
18#endif  // GOOGLE_CUDA
19
20#include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h"
21
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/tensor_slice.h"
28#include "tensorflow/core/kernels/bounds_check.h"
29#include "tensorflow/core/kernels/conv_2d.h"
30#include "tensorflow/core/kernels/ops_util.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/strings/strcat.h"
33#include "tensorflow/core/util/padding.h"
34#include "tensorflow/core/util/use_cudnn.h"
35
36#if GOOGLE_CUDA
37#include "tensorflow/core/kernels/conv_ops_gpu.h"
38#include "tensorflow/core/platform/stream_executor.h"
39#include "tensorflow/core/util/activation_mode.h"
40#endif  // GOOGLE_CUDA
41
42namespace tensorflow {
43
44namespace {
45typedef Eigen::GpuDevice GPUDevice;
46
47template <typename T>
48struct RawType {
49  using type = T;
50};
51
52template <>
53struct RawType<qint8> {
54  using type = int8;
55};
56
57// Template struct to convert int8x4 to int32.
58// (for NCHW_VECT_C with element type int8, we can consider it to be
59// an NCHW layout with element type int32 for operations like padding).
60template <typename T>
61struct Int8x4ToInt32 {
62  // By default, do not change T.
63  using type = T;
64};
65
66template <>
67struct Int8x4ToInt32<int8> {
68  using type = int32;
69};
70}  // namespace
71
72// T is the element type of the conv_input, filter and side_input tensors.
73// BiasType is the element type of the bias tensor, which can be different.
74// ScaleType is the type used for conv_input_scale, side_input_scale.
75template <typename Device, typename T, typename BiasType, typename ScaleType>
76class FusedConv2DBiasActivationOp : public OpKernel {
77 public:
78  enum InputIndexes {
79    kConvInput = 0,
80    kFilter,
81    kBias,
82    kSideInput,
83    kConvInputScale,
84    kSideInputScale,
85    kNumInputs
86  };
87
88  explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
89      : OpKernel(context) {
90    string data_format_str, filter_format_str;
91    CHECK_EQ(kNumInputs, context->num_inputs());
92    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
93    OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
94                errors::InvalidArgument("Invalid data format"));
95    OP_REQUIRES_OK(context,
96                   context->GetAttr("filter_format", &filter_format_str));
97    OP_REQUIRES(context,
98                FilterFormatFromString(filter_format_str, &filter_format_),
99                errors::InvalidArgument("Invalid filter format"));
100
101    std::vector<int32> strides;
102    OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
103    OP_REQUIRES(context, strides.size() == 4,
104                errors::InvalidArgument("Sliding window strides field must "
105                                        "specify 4 dimensions"));
106
107    stride_rows_ = GetTensorDim(strides, data_format_, 'H');
108    stride_cols_ = GetTensorDim(strides, data_format_, 'W');
109    OP_REQUIRES(
110        context,
111        (GetTensorDim(strides, data_format_, 'N') == 1 &&
112         GetTensorDim(strides, data_format_, 'C') == 1),
113        errors::InvalidArgument("Convolutional strides are not supported in "
114                                "the batch or depth dimensions."));
115
116    // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
117    constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
118
119    // Note: Only NCHW_VECT_C format is supported for int8.
120    // This is because it is expected to be the fastest, and our previous tests
121    // found cudnn 6 does not fully support the other formats for int8 mode.
122    OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
123                errors::InvalidArgument(
124                    "qint8 should be used with data_format NCHW_VECT_C."));
125
126    OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)),
127                errors::InvalidArgument(
128                    "qint8 should be used with filter_format OIHW_VECT_I."));
129
130    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_));
131    eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_);
132    string activation_mode_str;
133    OP_REQUIRES_OK(context,
134                   context->GetAttr("activation_mode", &activation_mode_str));
135    OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str,
136                                                        &activation_mode_));
137    OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
138                errors::InvalidArgument("Current implementation only supports "
139                                        "RELU as the activation function."));
140    cudnn_use_autotune_ = CudnnUseAutotune();
141  }
142
143  Status CheckShape(const Tensor& tensor, const string& tensor_name) {
144    const int num_dims = tensor.dims();
145    for (int i = 0; i < num_dims; i++) {
146      if (!FastBoundsCheck(tensor.dim_size(i),
147                           std::numeric_limits<int32>::max())) {
148        return errors::InvalidArgument(tensor_name, " dimension ", i,
149                                       " too large");
150      }
151    }
152    // If there is a 5th dimension it is the VECT_C or VECT_I dimension.
153    if (num_dims == 5 && tensor.dim_size(4) != 4) {
154      return errors::InvalidArgument("The last dimension of ", tensor_name,
155                                     " must be of size 4 for qint8.");
156    }
157    return Status::OK();
158  }
159
160  void Compute(OpKernelContext* context) override {
161    // The conv_input tensor is one of the following formats:
162    // NHWC, NCHW, NCHW_VECT_C.
163    const Tensor& conv_input = context->input(kConvInput);
164    OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
165
166    // The filter tensor is one of the following formats:
167    // HWIO, OIHW, OIHW_VECT_I.
168    const Tensor& filter = context->input(kFilter);
169    OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
170
171    // Input bias is a 1-D tensor, with size matching output depth.
172    const Tensor& bias = context->input(kBias);
173    OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
174
175    const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
176    const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
177
178    auto conv_input_scale = *reinterpret_cast<const ScaleType*>(
179        conv_input_scale_tensor.tensor_data().data());
180    auto side_input_scale = *reinterpret_cast<const ScaleType*>(
181        side_input_scale_tensor.tensor_data().data());
182
183    // If side_input_scale != 0, then side_input is not ignored and
184    // has the same type and dimensions as the output.
185    const Tensor& side_input = context->input(kSideInput);
186    if (side_input_scale != 0) {
187      OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
188    }
189
190    // TODO(pauldonnelly): Switch to a more efficient mechanism to access
191    // dimension indexes and per-dimension attributes.
192    const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H');
193    const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W');
194    const int32 output_depth = GetFilterDim(filter, filter_format_, 'O');
195
196    const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N');
197    const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H');
198    const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W');
199
200    int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0;
201    OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows,
202                                                  stride_rows_, padding_type_,
203                                                  &output_rows, &pad_rows));
204    OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols,
205                                                  stride_cols_, padding_type_,
206                                                  &output_cols, &pad_cols));
207    // Initialize the output tensor shape according to data_format_
208    TensorShape output_shape = ShapeFromFormat(
209        data_format_, batch_size, output_rows, output_cols, output_depth);
210    Tensor* output = nullptr;
211    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
212
213    VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = "
214            << conv_input_cols << ", conv_input_rows = " << conv_input_rows
215            << ", filter_cols = " << filter_cols
216            << ", filter_rows = " << filter_rows
217            << ", stride_cols = " << stride_cols_
218            << ", stride_rows = " << stride_rows_
219            << ", output_depth = " << output_depth
220            << ", output_cols = " << output_cols
221            << ", output_rows = " << output_rows
222            << ", output_shape.num_elements = " << output_shape.num_elements();
223
224    // If there is nothing to compute, return.
225    if (output_shape.num_elements() == 0) {
226      return;
227    }
228
229    launcher_.launch(context, cudnn_use_autotune_, conv_input, conv_input_scale,
230                     filter, stride_rows_, stride_cols_, eigen_padding_type_,
231                     side_input, side_input_scale, bias, activation_mode_,
232                     data_format_, filter_format_, output);
233  }
234
235 private:
236  int32 stride_rows_, stride_cols_;
237  Padding padding_type_;
238  Eigen::PaddingType eigen_padding_type_;
239  ActivationMode activation_mode_;
240  TensorFormat data_format_;
241  FilterTensorFormat filter_format_;
242  LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
243  bool cudnn_use_autotune_;
244
245  TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp);
246};
247
248#if GOOGLE_CUDA
249namespace dnn = ::perftools::gputools::dnn;
250
251// A dummy type to group forward convolution autotune results together.
252struct ConvBiasActivationAutoTuneGroup {
253  static string name() { return "ConvBiasActivation"; }
254};
255typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, FusedConvParameters,
256                          dnn::AlgorithmConfig>
257    AutoTuneConvBiasActivation;
258
259// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it
260// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions.
261template <typename T, size_t NDIMS>
262Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor,
263                           int batch_size, int rows, int cols, int depth,
264                           Tensor* transformed_tensor, const Tensor** result) {
265  TensorShape nchw_shape =
266      ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth);
267  if (depth > 1) {
268    TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
269                                          transformed_tensor));
270    functor::NHWCToNCHW<GPUDevice, T, NDIMS>()(
271        ctx->eigen_device<GPUDevice>(), nhwc_tensor.tensor<T, NDIMS>(),
272        transformed_tensor->tensor<T, NDIMS>());
273  } else {
274    // If depth <= 1, then just reshape.
275    CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape));
276  }
277  *result = transformed_tensor;
278  return Status::OK();
279}
280
281template <typename T, typename BiasType, typename ScaleType>
282void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
283    launch(OpKernelContext* ctx, bool cudnn_use_autotune,
284           const Tensor& conv_input_param, ScaleType conv_input_scale,
285           const Tensor& filter_param, int32 row_stride, int32 col_stride,
286           const Eigen::PaddingType& padding, const Tensor& side_input_param,
287           ScaleType side_input_scale, const Tensor& bias,
288           ActivationMode activation_mode, TensorFormat data_format,
289           FilterTensorFormat filter_format, Tensor* output_param) {
290  auto* stream = ctx->op_device_context()->stream();
291  OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
292
293  // TODO(yangzihao): refactor all the complicated/duplicated code in regular
294  // conv ops to a shared conv utility.
295
296  // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
297  constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
298  constexpr int rank = is_int8x4 ? 5 : 4;
299  constexpr int vect = is_int8x4 ? 4 : 1;
300
301  const int batch_size = GetTensorDim(conv_input_param, data_format, 'N');
302  int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H');
303  int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W');
304
305  const int conv_input_depth =
306      GetTensorDim(conv_input_param, data_format, 'C') * vect;
307  const int output_rows = GetTensorDim(*output_param, data_format, 'H');
308  const int output_cols = GetTensorDim(*output_param, data_format, 'W');
309  const int output_depth = GetFilterDim(filter_param, filter_format, 'O');
310  const int filter_rows = GetFilterDim(filter_param, filter_format, 'H');
311  const int filter_cols = GetFilterDim(filter_param, filter_format, 'W');
312  int padding_rows = 0;
313  int padding_cols = 0;
314  const Tensor* conv_input = &conv_input_param;
315
316  Tensor maybe_padded_conv_input;
317  if (padding == Eigen::PADDING_SAME) {
318    // Total padding on rows and cols is
319    // Pr = (R' - 1) * S + Kr - R
320    // Pc = (C' - 1) * S + Kc - C
321    // where (R', C') are output dimensions, (R, C) are input dimensions, S
322    // is stride, (Kr, Kc) are filter dimensions.
323    // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
324    // and Pc - Pc/2 on the bottom.  When Pr or Pc is odd, this means
325    // we pad more on the right and bottom than on the top and left.
326    padding_rows = std::max<int>(
327        0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows);
328    padding_cols = std::max<int>(
329        0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols);
330    const int padding_rows_parity = padding_rows & 1;
331    const int padding_cols_parity = padding_cols & 1;
332    if ((padding_rows_parity | padding_cols_parity) != 0) {
333      Tensor transformed_input;
334      const int new_conv_input_rows = conv_input_rows + padding_rows_parity;
335      const int new_conv_input_cols = conv_input_cols + padding_cols_parity;
336
337      using VectT = typename Int8x4ToInt32<typename RawType<T>::type>::type;
338      auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format;
339
340      OP_REQUIRES_OK(
341          ctx, ctx->allocate_temp(
342                   DataTypeToEnum<T>::value,
343                   ShapeFromFormat(data_format, batch_size, new_conv_input_rows,
344                                   new_conv_input_cols, conv_input_depth),
345                   &maybe_padded_conv_input));
346
347      auto conv_input_eigen_tensor =
348          To32Bit(conv_input_param.reinterpret_last_dimension<VectT, 4>());
349      auto padded_conv_input_eigen_tensor = To32Bit(
350          maybe_padded_conv_input.reinterpret_last_dimension<VectT, 4>());
351
352      functor::PadInput<GPUDevice, VectT, int, 4>()(
353          ctx->eigen_device<GPUDevice>(), conv_input_eigen_tensor, {{0, 0}},
354          {{padding_rows_parity, padding_cols_parity}},
355          padded_conv_input_eigen_tensor, pad_data_format);
356
357      conv_input = &maybe_padded_conv_input;
358      conv_input_rows = new_conv_input_rows;
359      conv_input_cols = new_conv_input_cols;
360    }
361  }
362
363  Tensor maybe_transformed_conv_input, maybe_transformed_side_input;
364  Tensor maybe_transformed_output;
365  const Tensor* side_input = &side_input_param;
366  Tensor* output = output_param;
367
368  // NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary
369  // and inefficient, but it is actually both a time and code size optimization,
370  // since 'is_int8x4' is a constexpr determined by the template parameter.
371  if (!is_int8x4 && data_format == FORMAT_NHWC) {
372    OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW<T, rank>(
373                            ctx, *conv_input, batch_size, conv_input_rows,
374                            conv_input_cols, conv_input_depth,
375                            &maybe_transformed_conv_input, &conv_input)));
376    if (side_input_scale != 0) {
377      OP_REQUIRES_OK(
378          ctx, (TransformNHWCToNCHW<T, rank>(
379                   ctx, side_input_param, batch_size, output_rows, output_cols,
380                   output_depth, &maybe_transformed_side_input, &side_input)));
381    }
382    if (output_depth > 1) {
383      // Allocate a tensor for the NCHW output of the kernel and point output
384      // to it. Afterwards, we will transform it to NHWC while copying back to
385      // 'output_param'.
386      TensorShape nchw_shape = ShapeFromFormat(
387          FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth);
388      OP_REQUIRES_OK(ctx,
389                     ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
390                                        &maybe_transformed_output));
391      output = &maybe_transformed_output;
392    }
393  }
394
395  constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4
396                                         : dnn::DataLayout::kBatchDepthYX;
397  constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4
398                                           : dnn::FilterLayout::kOutputInputYX;
399
400  dnn::BatchDescriptor conv_input_desc;
401  conv_input_desc.set_count(batch_size)
402      .set_feature_map_count(conv_input_depth)
403      .set_height(conv_input_rows)
404      .set_width(conv_input_cols)
405      .set_layout(data_layout);
406  dnn::FilterDescriptor filter_desc;
407  filter_desc.set_input_filter_height(filter_rows)
408      .set_input_filter_width(filter_cols)
409      .set_input_feature_map_count(conv_input_depth)
410      .set_output_feature_map_count(output_depth)
411      .set_layout(filter_layout);
412  dnn::BatchDescriptor side_input_desc;
413  side_input_desc.set_count(batch_size)
414      .set_height(output_rows)
415      .set_width(output_cols)
416      .set_feature_map_count(output_depth)
417      .set_layout(data_layout);
418  dnn::BatchDescriptor bias_desc;
419  bias_desc.set_count(1)
420      .set_height(1)
421      .set_width(1)
422      .set_feature_map_count(output_depth)
423      .set_layout(dnn::DataLayout::kBatchDepthYX);
424  dnn::BatchDescriptor output_desc;
425  output_desc.set_count(batch_size)
426      .set_height(output_rows)
427      .set_width(output_cols)
428      .set_feature_map_count(output_depth)
429      .set_layout(data_layout);
430  dnn::ConvolutionDescriptor conv_desc;
431  conv_desc.set_vertical_filter_stride(row_stride)
432      .set_horizontal_filter_stride(col_stride)
433      .set_zero_padding_height(padding_rows / 2)
434      .set_zero_padding_width(padding_cols / 2);
435
436  Tensor maybe_transformed_filter;
437  const Tensor* filter;
438  if (is_int8x4) {
439    // We have already checked filter is OIHW_VECT_I in the constructor.
440    filter = &filter_param;
441  } else if (filter_format == FORMAT_HWIO) {
442    // Shuffle filter tensor from HWIO to OIHW:
443    OP_REQUIRES_OK(ctx, ctx->allocate_temp(
444                            DataTypeToEnum<T>::value,
445                            ShapeFromFilterFormat(
446                                FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
447                            &maybe_transformed_filter));
448    functor::TransformFilter<GPUDevice, T, int, 4>()(
449        ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()),
450        To32Bit(maybe_transformed_filter.tensor<T, 4>()));
451    filter = &maybe_transformed_filter;
452  }
453
454  auto conv_input_ptr =
455      AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
456                         conv_input->template flat<T>().data()),
457                     conv_input->template flat<T>().size());
458  auto filter_ptr =
459      AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
460                         filter->template flat<T>().data()),
461                     filter->template flat<T>().size());
462  auto side_input_ptr =
463      AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
464                         side_input->template flat<T>().data()),
465                     side_input->template flat<T>().size());
466  auto output_ptr =
467      AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
468                         output->template flat<T>().data()),
469                     output->template flat<T>().size());
470  auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(),
471                                 bias.template flat<BiasType>().size());
472
473  static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
474      // default value is in bytes despite the name of the environment variable
475      "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
476  );
477
478  int device_id = stream->parent()->device_ordinal();
479  FusedConvParameters fused_conv_parameters = {
480      batch_size,
481      conv_input_depth,
482      {{conv_input_rows, conv_input_cols}},
483      output_depth,
484      {{filter_rows, filter_cols}},
485      {{row_stride, col_stride}},
486      {{padding_rows, padding_cols}},
487      conv_input->dtype(),
488      device_id,
489      (side_input_scale != 0),
490      activation_mode,
491  };
492
493  dnn::AlgorithmConfig algorithm_config;
494  if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
495                                fused_conv_parameters, &algorithm_config)) {
496    std::vector<dnn::AlgorithmDesc> algorithms;
497    CHECK(stream->parent()->GetConvolveAlgorithms(
498        fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(),
499        &algorithms));
500    dnn::ProfileResult best_result;
501    dnn::ProfileResult best_result_no_scratch;
502    for (auto profile_algorithm : algorithms) {
503      // TODO(zhengxq): profile each algorithm multiple times to better
504      // accuracy.
505      CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
506      dnn::ProfileResult profile_result;
507      bool cudnn_launch_status =
508          stream
509              ->ThenFusedConvolveWithAlgorithm(
510                  conv_input_desc, conv_input_ptr, conv_input_scale,
511                  filter_desc, filter_ptr, conv_desc, side_input_ptr,
512                  side_input_scale, bias_desc, bias_ptr,
513                  dnn::ActivationMode::kRelu, output_desc, &output_ptr,
514                  &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
515                  &profile_result)
516              .ok();
517      if (cudnn_launch_status) {
518        if (profile_result.is_valid()) {
519          if (profile_result.elapsed_time_in_ms() <
520              best_result.elapsed_time_in_ms()) {
521            best_result = profile_result;
522          }
523          if (scratch_allocator.TotalByteSize() == 0 &&
524              profile_result.elapsed_time_in_ms() <
525                  best_result_no_scratch.elapsed_time_in_ms()) {
526            best_result_no_scratch = profile_result;
527          }
528        }
529      }
530    }
531    OP_REQUIRES(ctx,
532                best_result.is_valid() || best_result_no_scratch.is_valid(),
533                errors::NotFound("No algorithm worked!"));
534    if (best_result.is_valid()) {
535      algorithm_config.set_algorithm(best_result.algorithm());
536    }
537    if (best_result_no_scratch.is_valid()) {
538      algorithm_config.set_algorithm_no_scratch(
539          best_result_no_scratch.algorithm());
540    }
541    AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
542                                                      algorithm_config);
543  }
544
545  CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
546  bool cudnn_launch_status =
547      stream
548          ->ThenFusedConvolveWithAlgorithm(
549              conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
550              filter_ptr, conv_desc, side_input_ptr, side_input_scale,
551              bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc,
552              &output_ptr, &scratch_allocator, algorithm_config,
553              /*output_profile_result=*/nullptr)
554          .ok();
555
556  if (!cudnn_launch_status) {
557    ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(",
558                                    conv_input->shape().DebugString(),
559                                    ") filter shape(",
560                                    filter->shape().DebugString(), ")"));
561  }
562
563  // Convert the output tensor back from NCHW to NHWC if necessary.
564  if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) {
565    functor::NCHWToNHWC<GPUDevice, T, 4>()(
566        ctx->eigen_device<GPUDevice>(),
567        const_cast<const Tensor*>(output)->tensor<T, 4>(),
568        output_param->tensor<T, 4>());
569  }
570}
571
572// Forward declarations of the functor specializations for GPU used above.
573namespace functor {
574#define DECLARE_GPU_SPEC(T)                                              \
575  template <>                                                            \
576  void PadInput<GPUDevice, T, int, 4>::operator()(                       \
577      const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
578      const std::array<int, 2>& padding_left,                            \
579      const std::array<int, 2>& padding_right,                           \
580      typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
581  extern template struct PadInput<GPUDevice, T, int, 4>;
582
583DECLARE_GPU_SPEC(float);
584DECLARE_GPU_SPEC(int32);
585#undef DECLARE_GPU_SPEC
586}  // namespace functor
587
588// Registration of the GPU implementations.
589
590REGISTER_KERNEL_BUILDER(
591    Name("FusedConv2DBiasActivation")
592        .Device(DEVICE_GPU)
593        .TypeConstraint<float>("T")
594        .TypeConstraint<float>("Tbias")
595        .HostMemory("conv_input_scale")
596        .HostMemory("side_input_scale"),
597    FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
598
599REGISTER_KERNEL_BUILDER(
600    Name("FusedConv2DBiasActivation")
601        .Device(DEVICE_GPU)
602        .TypeConstraint<qint8>("T")
603        .TypeConstraint<float>("Tbias")
604        .HostMemory("conv_input_scale")
605        .HostMemory("side_input_scale"),
606    FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
607
608#endif  // GOOGLE_CUDA
609
610}  // namespace tensorflow
611