1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define USE_EIGEN_TENSOR
17#define EIGEN_USE_THREADS
18
19#include "tensorflow/core/kernels/conv_3d.h"
20
21#include "tensorflow/core/framework/numeric_op.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/framework/tensor_slice.h"
27#include "tensorflow/core/kernels/conv_2d.h"
28#include "tensorflow/core/kernels/conv_ops_gpu.h"
29#include "tensorflow/core/kernels/ops_util.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/gtl/inlined_vector.h"
32#include "tensorflow/core/util/padding.h"
33#include "tensorflow/core/util/tensor_format.h"
34#include "tensorflow/core/util/use_cudnn.h"
35
36#if GOOGLE_CUDA
37#include "tensorflow/core/platform/stream_executor.h"
38using perftools::gputools::dnn::DimIndex;
39#endif
40
41namespace tensorflow {
42
43typedef Eigen::ThreadPoolDevice CPUDevice;
44typedef Eigen::GpuDevice GPUDevice;
45
46// TODO(mjanusz): Get rid of the macro and return shapes directly.
47#define EXTRACT_AND_VERIFY_DIMENSIONS(label)                                   \
48  const Tensor& out_backprop = context->input(2);                              \
49  OP_REQUIRES(                                                                 \
50      context, input_shape.dims() == 5,                                        \
51      errors::InvalidArgument(label, ": input must be 5-dimensional"));        \
52  OP_REQUIRES(                                                                 \
53      context, filter_shape.dims() == 5,                                       \
54      errors::InvalidArgument(label, ": filter must be 5-dimensional"));       \
55  OP_REQUIRES(                                                                 \
56      context, out_backprop.dims() == 5,                                       \
57      errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
58  const int64 batch = input_shape.dim_size(0);                                 \
59  OP_REQUIRES(                                                                 \
60      context, batch == out_backprop.dim_size(0),                              \
61      errors::InvalidArgument(                                                 \
62          label, ": input and out_backprop must have the same batch size"));   \
63  const std::array<int64, 3> input_size = {                                    \
64      {GetTensorDim(input_shape, data_format_, '0'),                           \
65       GetTensorDim(input_shape, data_format_, '1'),                           \
66       GetTensorDim(input_shape, data_format_, '2')}};                         \
67  const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C');         \
68  const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0),         \
69                                             filter_shape.dim_size(1),         \
70                                             filter_shape.dim_size(2)}};       \
71  const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2');     \
72  const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1');     \
73  const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0');   \
74  OP_REQUIRES(context, in_depth == filter_shape.dim_size(3),                   \
75              errors::InvalidArgument(                                         \
76                  label, ": input and filter must have the same depth"));      \
77  const int64 out_depth = filter_shape.dim_size(4);                            \
78  OP_REQUIRES(                                                                 \
79      context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'),     \
80      errors::InvalidArgument(                                                 \
81          label, ": filter and out_backprop must have the same out_depth"));   \
82  const std::array<int64, 3> strides = {                                       \
83      {GetTensorDim(stride_, data_format_, '0'),                               \
84       GetTensorDim(stride_, data_format_, '1'),                               \
85       GetTensorDim(stride_, data_format_, '2')}};                             \
86  std::array<int64, 3> out, padding;                                           \
87  OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides,    \
88                                          padding_, &out, &padding));          \
89  OP_REQUIRES(context, output_planes == out[0],                                \
90              errors::InvalidArgument(                                         \
91                  label,                                                       \
92                  ": Number of planes of out_backprop doesn't match "          \
93                  "computed:  actual = ",                                      \
94                  output_planes, ", computed = ", out[0]));                    \
95  OP_REQUIRES(                                                                 \
96      context, output_rows == out[1],                                          \
97      errors::InvalidArgument(                                                 \
98          label, ": Number of rows of out_backprop doesn't match computed: ",  \
99          "actual = ", output_rows, ", computed = ", out[1]));                 \
100  OP_REQUIRES(                                                                 \
101      context, output_cols == out[2],                                          \
102      errors::InvalidArgument(                                                 \
103          label, ": Number of cols of out_backprop doesn't match computed: ",  \
104          "actual = ", output_cols, ", computed = ", out[2]));                 \
105  const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1;       \
106  const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1;           \
107  const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1;           \
108  const auto padded_out_planes = input_size[0] + filter_size[0] - 1;           \
109  const auto padded_out_rows = input_size[1] + filter_size[1] - 1;             \
110  const auto padded_out_cols = input_size[2] + filter_size[2] - 1;             \
111  const auto top_pad_planes = filter_size[0] - 1 - padding[0];                 \
112  const auto top_pad_rows = filter_size[1] - 1 - padding[1];                   \
113  const auto left_pad_cols = filter_size[2] - 1 - padding[2];                  \
114  const auto bottom_pad_planes =                                               \
115      padded_out_planes - expanded_out_planes - top_pad_planes;                \
116  const auto bottom_pad_rows =                                                 \
117      padded_out_rows - expanded_out_rows - top_pad_rows;                      \
118  const auto right_pad_cols =                                                  \
119      padded_out_cols - expanded_out_cols - left_pad_cols;                     \
120  VLOG(2) << "Conv3d: " << label                                               \
121          << ": expanded_out_planes = " << expanded_out_planes                 \
122          << ": expanded_out_rows = " << expanded_out_rows                     \
123          << ", expanded_out_cols = " << expanded_out_cols                     \
124          << ", padded_out_planes = " << padded_out_planes                     \
125          << ", padded_out_rows = " << padded_out_rows                         \
126          << ", padded_out_cols = " << padded_out_cols                         \
127          << ", top_pad_planes = " << top_pad_planes                           \
128          << ", top_pad_rows = " << top_pad_rows                               \
129          << ", left_pad_cols = " << left_pad_cols                             \
130          << ", bottom_pad_planes = " << bottom_pad_planes                     \
131          << ", bottom_pad_rows = " << bottom_pad_rows                         \
132          << ", right_pad_cols = " << right_pad_cols
133
134// Backprop for input.
135template <typename Device, class T>
136class Conv3DBackpropInputOp : public OpKernel {
137 public:
138  explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
139      : OpKernel(context),
140        data_format_(FORMAT_NHWC),
141        takes_shape_(type_string().find("V2") != std::string::npos) {
142    // data_format is only available in V2.
143    if (takes_shape_) {
144      string data_format;
145      OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
146      OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
147                  errors::InvalidArgument("Invalid data format"));
148      OP_REQUIRES(
149          context, data_format_ == FORMAT_NHWC,
150          errors::InvalidArgument(
151              "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
152    }
153
154    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
155    OP_REQUIRES(context, stride_.size() == 5,
156                errors::InvalidArgument("Sliding window strides field must "
157                                        "specify 5 dimensions"));
158    OP_REQUIRES(
159        context,
160        (GetTensorDim(stride_, data_format_, 'C') == 1 &&
161         GetTensorDim(stride_, data_format_, 'N') == 1),
162        errors::InvalidArgument("Current implementation does not yet support "
163                                "strides in the batch and depth dimensions."));
164    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
165  }
166
167  void Compute(OpKernelContext* context) override {
168    const Tensor& filter = context->input(1);
169    const TensorShape& filter_shape = filter.shape();
170    TensorShape input_shape;
171    if (takes_shape_) {
172      const Tensor& input_sizes = context->input(0);
173      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
174                                  input_sizes.vec<int32>(), &input_shape));
175    } else {
176      input_shape = context->input(0).shape();
177    }
178    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
179    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
180        {0, 0},
181        {top_pad_planes, bottom_pad_planes},
182        {top_pad_rows, bottom_pad_rows},
183        {left_pad_cols, right_pad_cols},
184        {0, 0}};
185    Tensor* in_backprop;
186    OP_REQUIRES_OK(context,
187                   context->allocate_output(0, input_shape, &in_backprop));
188
189    // Fill out a padded out_backprop.
190    TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
191                                  padded_out_cols, out_depth});
192    Tensor padded_output;
193    OP_REQUIRES_OK(context,
194                   context->allocate_temp(DataTypeToEnum<T>::v(),
195                                          padded_out_shape, &padded_output));
196    Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
197    Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
198                                                      strides[2], 1};
199    functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
200        context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
201        eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
202    const Tensor& padded_output_cref = padded_output;
203
204    // Fill a new "reverted" filter. We need to transpose the in_depth and
205    // out_depth for the filter and reverse the planes, rows and cols.
206    TensorShape r_filter_shape(
207        {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
208    Tensor r_filter;
209    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
210                                                   r_filter_shape, &r_filter));
211    Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
212    Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
213    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
214        context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
215        filter_rev_dims, r_filter.tensor<T, 5>());
216    const Tensor& r_filter_cref = r_filter;
217
218    // Now we can call conv_3d directly.
219    functor::CuboidConvolution<Device, T>()(
220        context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
221        padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
222        1, BrainPadding2EigenPadding(VALID));
223  }
224
225 private:
226  std::vector<int32> stride_;
227  Padding padding_;
228  TensorFormat data_format_;
229  bool takes_shape_;
230};
231
232#define REGISTER_CPU_KERNEL(T)                                                 \
233  REGISTER_KERNEL_BUILDER(                                                     \
234      Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
235      Conv3DBackpropInputOp<CPUDevice, T>);                                    \
236  REGISTER_KERNEL_BUILDER(                                                     \
237      Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
238      Conv3DBackpropInputOp<CPUDevice, T>);
239TF_CALL_half(REGISTER_CPU_KERNEL);
240TF_CALL_float(REGISTER_CPU_KERNEL);
241TF_CALL_double(REGISTER_CPU_KERNEL);
242#undef REGISTER_CPU_KERNEL
243
244// Backprop for filter.
245template <typename Device, class T>
246class Conv3DBackpropFilterOp : public OpKernel {
247 public:
248  explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
249      : OpKernel(context),
250        data_format_(FORMAT_NHWC),
251        takes_shape_(type_string().find("V2") != std::string::npos) {
252    // data_format is only available in V2.
253    if (takes_shape_) {
254      string data_format;
255      OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
256      OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
257                  errors::InvalidArgument("Invalid data format"));
258      OP_REQUIRES(
259          context, data_format_ == FORMAT_NHWC,
260          errors::InvalidArgument(
261              "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
262    }
263
264    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
265    OP_REQUIRES(context, stride_.size() == 5,
266                errors::InvalidArgument("Sliding window strides field must "
267                                        "specify 5 dimensions"));
268    OP_REQUIRES(
269        context,
270        (GetTensorDim(stride_, data_format_, 'C') == 1 &&
271         GetTensorDim(stride_, data_format_, 'N') == 1),
272        errors::InvalidArgument("Current implementation does not yet support "
273                                "strides in the batch and depth dimensions."));
274    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
275  }
276
277  void Compute(OpKernelContext* context) override {
278    const Tensor& input = context->input(0);
279    const TensorShape& input_shape = input.shape();
280    TensorShape filter_shape;
281
282    if (takes_shape_) {
283      const Tensor& filter_sizes = context->input(1);
284      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
285                                  filter_sizes.vec<int32>(), &filter_shape));
286    } else {
287      filter_shape = context->input(1).shape();
288    }
289
290    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
291    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
292        {0, 0},
293        {top_pad_planes, bottom_pad_planes},
294        {top_pad_rows, bottom_pad_rows},
295        {left_pad_cols, right_pad_cols},
296        {0, 0}};
297    Tensor* filter_backprop;
298    OP_REQUIRES_OK(context,
299                   context->allocate_output(0, filter_shape, &filter_backprop));
300
301    if (input_shape.num_elements() == 0) {
302      filter_backprop->template flat<T>().setZero();
303      return;
304    }
305
306    // For the backprop of the filter, we need to also transpose the
307    // out_backprop.
308    // The shape of backprop is
309    //   [batch, out_z, out_y, out_x, out_depth]
310    // And we need to change it to
311    //   [out_depth, out_x, out_y, out_z, batch]
312    Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
313    TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
314                                  padded_out_cols, batch});
315    Tensor padded_output;
316    OP_REQUIRES_OK(context,
317                   context->allocate_temp(DataTypeToEnum<T>::v(),
318                                          padded_out_shape, &padded_output));
319    Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
320                                                      strides[2], 1};
321    functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
322        context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
323        eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
324    const Tensor& padded_output_cref = padded_output;
325
326    // For the backprop of the filter, we need to transpose the input.
327    // The shape of input is
328    //   [batch, in_z, in_y, in_x, in_depth]
329    // And we need to change it to
330    //   [in_z, in_y, in_x, batch, in_depth]
331    Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
332    TensorShape in_shuffle_shape(
333        {input_size[0], input_size[1], input_size[2], batch, in_depth});
334    Tensor in_shuffle;
335    OP_REQUIRES_OK(context,
336                   context->allocate_temp(DataTypeToEnum<T>::v(),
337                                          in_shuffle_shape, &in_shuffle));
338    // No need for reversing this time.
339    Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
340    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
341        context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
342        no_reverse, in_shuffle.tensor<T, 5>());
343    const Tensor& in_shuffle_cref = in_shuffle;
344
345    // The output of the conv_3d would be
346    //   [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
347    // and we need to shuffle it back to
348    //   [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
349    // And we need to reverse the filter backprops.
350    // So we need to allocate (sigh) yet another piece of memory to hold the
351    // output.
352    TensorShape filter_shuffle_shape(
353        {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
354    Tensor filter_shuffle;
355    OP_REQUIRES_OK(
356        context, context->allocate_temp(DataTypeToEnum<T>::v(),
357                                        filter_shuffle_shape, &filter_shuffle));
358    functor::CuboidConvolution<Device, T>()(
359        context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
360        padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
361        1, BrainPadding2EigenPadding(VALID));
362
363    // Now copy the filter_backprop back to the destination.
364    Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
365    Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
366    const Tensor& filter_shuffle_cref = filter_shuffle;
367    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
368        context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
369        filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
370  }
371
372 private:
373  std::vector<int32> stride_;
374  Padding padding_;
375  TensorFormat data_format_;
376  bool takes_shape_;
377};
378
379#define REGISTER_CPU_KERNEL(T)                                                \
380  REGISTER_KERNEL_BUILDER(                                                    \
381      Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
382      Conv3DBackpropFilterOp<CPUDevice, T>);                                  \
383  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
384                              .Device(DEVICE_CPU)                             \
385                              .TypeConstraint<T>("T"),                        \
386                          Conv3DBackpropFilterOp<CPUDevice, T>);
387TF_CALL_half(REGISTER_CPU_KERNEL);
388TF_CALL_float(REGISTER_CPU_KERNEL);
389TF_CALL_double(REGISTER_CPU_KERNEL);
390#undef REGISTER_CPU_KERNEL
391
392// GPU definitions of both ops.
393#if GOOGLE_CUDA
394// Forward declarations of the functor specializations for GPU.
395// This ensures that the custom implementation is used instead of the default
396// Eigen one (which is used for CPU).
397namespace functor {
398#define DECLARE_GPU_SPEC(T)                                           \
399  template <>                                                         \
400  void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
401      const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
402      typename TTypes<T, 5, int>::Tensor out);                        \
403  template <>                                                         \
404  void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
405      const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
406      typename TTypes<T, 5>::Tensor out);                             \
407  template <>                                                         \
408  void PadInput<GPUDevice, T, int, 5>::operator()(                    \
409      const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
410      const std::array<int, 3>& padding_left,                         \
411      const std::array<int, 3>& padding_right,                        \
412      typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
413
414DECLARE_GPU_SPEC(Eigen::half);
415DECLARE_GPU_SPEC(float);
416#undef DECLARE_GPU_SPEC
417}  // namespace functor
418
419// A dummy type to group backward data autotune results together.
420struct Conv3dBackwardDataAutoTuneGroup {
421  static string name() { return "Conv3dBwdData"; }
422};
423typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters,
424                          perftools::gputools::dnn::AlgorithmConfig>
425
426    AutoTuneConv3dBwdData;
427template <typename T>
428class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
429 public:
430  explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
431      : OpKernel(context),
432        data_format_(FORMAT_NHWC),
433        takes_shape_(type_string().find("V2") != std::string::npos) {
434    // data_format is only available in V2.
435    if (takes_shape_) {
436      string data_format;
437      OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
438      OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
439                  errors::InvalidArgument("Invalid data format"));
440    }
441    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
442    OP_REQUIRES(context, stride_.size() == 5,
443                errors::InvalidArgument("Sliding window strides field must "
444                                        "specify 5 dimensions"));
445    OP_REQUIRES(
446        context,
447        (GetTensorDim(stride_, data_format_, 'C') == 1 &&
448         GetTensorDim(stride_, data_format_, 'N') == 1),
449        errors::InvalidArgument("Current implementation does not yet support "
450                                "strides in the batch and depth dimensions."));
451    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
452    cudnn_use_autotune_ = CudnnUseAutotune();
453  }
454  void Compute(OpKernelContext* context) override {
455    const Tensor& filter = context->input(1);
456    const TensorShape& filter_shape = filter.shape();
457    TensorShape input_shape;
458    if (takes_shape_) {
459      const Tensor& input_sizes = context->input(0);
460      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
461                                  input_sizes.vec<int32>(), &input_shape));
462    } else {
463      input_shape = context->input(0).shape();
464    }
465    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
466    Tensor* in_backprop;
467    OP_REQUIRES_OK(context,
468                   context->allocate_output(0, input_shape, &in_backprop));
469
470    auto* stream = context->op_device_context()->stream();
471    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
472
473    if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
474        stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
475        data_format_ == FORMAT_NHWC) {
476      const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
477      const uint64 k = out_depth;
478      const uint64 n = in_depth;
479
480      auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
481                                  out_backprop.template flat<T>().size());
482      auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
483                                  filter.template flat<T>().size());
484      auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
485                                  in_backprop->template flat<T>().size());
486
487      auto transpose = perftools::gputools::blas::Transpose::kTranspose;
488      auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
489
490      bool blas_launch_status =
491          stream
492              ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
493                             a_ptr, k, 0.0f, &c_ptr, n)
494              .ok();
495      if (!blas_launch_status) {
496        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
497                                            ", n=", n, ", k=", k));
498      }
499      return;
500    } else if (filter_size[0] == input_size[0] &&
501               filter_size[1] == input_size[1] &&
502               filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
503               data_format_ == FORMAT_NHWC) {
504      const uint64 m = batch;
505      const uint64 k = out_depth;
506      const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
507
508      auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
509                                  out_backprop.template flat<T>().size());
510      auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
511                                  filter.template flat<T>().size());
512      auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
513                                  in_backprop->template flat<T>().size());
514
515      auto transpose = perftools::gputools::blas::Transpose::kTranspose;
516      auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
517
518      bool blas_launch_status =
519          stream
520              ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
521                             a_ptr, k, 0.0f, &c_ptr, n)
522              .ok();
523      if (!blas_launch_status) {
524        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
525                                            ", n=", n, ", k=", k));
526      }
527      return;
528    }
529
530    int padding_rows = 0, padding_cols = 0, padding_planes = 0;
531
532    if (padding_ == Padding::SAME) {
533      padding_planes = std::max<int>(
534          0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
535      padding_cols = std::max<int>(
536          0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
537      padding_rows = std::max<int>(
538          0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
539    }
540    const bool rows_odd = (padding_rows % 2 != 0);
541    const bool cols_odd = (padding_cols % 2 != 0);
542    const bool planes_odd = (padding_planes % 2 != 0);
543
544    TensorShape compatible_input_shape;
545    if (rows_odd || cols_odd || planes_odd) {
546      // cuDNN only supports the same amount of padding on both sides.
547      compatible_input_shape = {
548          batch,
549          in_depth,
550          input_size[0] + planes_odd,
551          input_size[1] + rows_odd,
552          input_size[2] + cols_odd,
553      };
554    } else {
555      compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
556                                input_size[2]};
557    }
558
559    CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
560        << "Negative paddings: (" << padding_rows << ", " << padding_cols
561        << ", " << padding_planes << ")";
562    perftools::gputools::dnn::BatchDescriptor input_desc(3);
563    input_desc.set_count(batch)
564        .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
565        .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
566        .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
567        .set_feature_map_count(in_depth)
568        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
569    perftools::gputools::dnn::BatchDescriptor output_desc(3);
570    output_desc.set_count(batch)
571        .set_spatial_dim(DimIndex::X, output_cols)
572        .set_spatial_dim(DimIndex::Y, output_rows)
573        .set_spatial_dim(DimIndex::Z, output_planes)
574        .set_feature_map_count(out_depth)
575        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
576    perftools::gputools::dnn::FilterDescriptor filter_desc(3);
577    filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
578        .set_spatial_dim(DimIndex::Y, filter_size[1])
579        .set_spatial_dim(DimIndex::Z, filter_size[0])
580        .set_input_feature_map_count(in_depth)
581        .set_output_feature_map_count(out_depth);
582    perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
583    conv_desc.set_filter_stride(DimIndex::X, strides[2])
584        .set_filter_stride(DimIndex::Y, strides[1])
585        .set_filter_stride(DimIndex::Z, strides[0])
586        .set_zero_padding(DimIndex::X, padding_cols / 2)
587        .set_zero_padding(DimIndex::Y, padding_rows / 2)
588        .set_zero_padding(DimIndex::Z, padding_planes / 2);
589
590    // Shape: out, in, z, y, x.
591    Tensor transformed_filter;
592    OP_REQUIRES_OK(
593        context,
594        context->allocate_temp(DataTypeToEnum<T>::value,
595                               TensorShape({out_depth, in_depth, filter_size[0],
596                                            filter_size[1], filter_size[2]}),
597                               &transformed_filter));
598    functor::TransformFilter<GPUDevice, T, int, 5>()(
599        context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
600        To32Bit(transformed_filter.tensor<T, 5>()));
601
602    // Shape: batch, filters, z, y, x.
603    Tensor transformed_out_backprop;
604    if (data_format_ == FORMAT_NHWC) {
605      TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
606                                output_cols};
607      if (out_depth > 1) {
608        OP_REQUIRES_OK(context, context->allocate_temp(
609                                    DataTypeToEnum<T>::value, nchw_shape,
610                                    &transformed_out_backprop));
611        functor::NHWCToNCHW<GPUDevice, T, 5>()(
612            context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
613            transformed_out_backprop.tensor<T, 5>());
614      } else {
615        CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
616      }
617    } else {
618      transformed_out_backprop = out_backprop;
619    }
620    // Shape: batch, filters, z, y, x.
621    Tensor pre_transformed_in_backprop;
622    OP_REQUIRES_OK(
623        context,
624        context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
625                               &pre_transformed_in_backprop));
626
627    auto out_backprop_ptr =
628        AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
629                       transformed_out_backprop.template flat<T>().size());
630    auto filter_ptr =
631        AsDeviceMemory(transformed_filter.template flat<T>().data(),
632                       transformed_filter.template flat<T>().size());
633    auto in_backprop_ptr =
634        AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
635                       pre_transformed_in_backprop.template flat<T>().size());
636
637    static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
638        "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
639
640    const int device_id = stream->parent()->device_ordinal();
641    DataType dtype = context->input(0).dtype();
642    const ConvParameters conv_parameters = {
643        batch,
644        in_depth,
645        {{input_size[0], input_size[1], input_size[2]}},
646        out_depth,
647        {{filter_size[0], filter_size[1], filter_size[2]}},
648        // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
649        // conv is supported.
650        /*dilation=*/{{1, 1, 1}},
651        {{strides[0], strides[1], strides[2]}},
652        {{padding_planes, padding_rows, padding_cols}},
653        dtype,
654        device_id,
655    };
656
657    using perftools::gputools::dnn::AlgorithmConfig;
658    using perftools::gputools::dnn::AlgorithmDesc;
659    using perftools::gputools::dnn::ProfileResult;
660    AlgorithmConfig algorithm_config;
661    if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
662                                   conv_parameters, &algorithm_config)) {
663      std::vector<AlgorithmDesc> algorithms;
664      CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
665          conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
666      ProfileResult best_result;
667      ProfileResult best_result_no_scratch;
668      for (auto profile_algorithm : algorithms) {
669        // TODO(zhengxq): profile each algorithm multiple times to better
670        // accuracy.
671        CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
672                                                context);
673        ProfileResult profile_result;
674        bool cudnn_launch_status =
675            stream
676                ->ThenConvolveBackwardDataWithAlgorithm(
677                    filter_desc, filter_ptr, output_desc, out_backprop_ptr,
678                    conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
679                    AlgorithmConfig(profile_algorithm), &profile_result)
680                .ok();
681        if (cudnn_launch_status) {
682          if (profile_result.is_valid()) {
683            if (profile_result.elapsed_time_in_ms() <
684                best_result.elapsed_time_in_ms()) {
685              best_result = profile_result;
686            }
687            if (scratch_allocator.TotalByteSize() == 0 &&
688                profile_result.elapsed_time_in_ms() <
689                    best_result_no_scratch.elapsed_time_in_ms()) {
690              best_result_no_scratch = profile_result;
691            }
692          }
693        }
694      }
695      OP_REQUIRES(context,
696                  best_result.is_valid() || best_result_no_scratch.is_valid(),
697                  errors::NotFound("No algorithm worked!"));
698      if (best_result.is_valid()) {
699        algorithm_config.set_algorithm(best_result.algorithm());
700      }
701      if (best_result_no_scratch.is_valid()) {
702        algorithm_config.set_algorithm_no_scratch(
703            best_result_no_scratch.algorithm());
704      }
705      AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
706                                                   algorithm_config);
707    }
708    CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
709                                            context);
710    bool cudnn_launch_status =
711        stream
712            ->ThenConvolveBackwardDataWithAlgorithm(
713                filter_desc, filter_ptr, output_desc, out_backprop_ptr,
714                conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
715                algorithm_config, nullptr)
716            .ok();
717
718    if (!cudnn_launch_status) {
719      context->SetStatus(errors::Internal(
720          "cuDNN Backward Data function launch failure : input shape(",
721          input_shape.DebugString(), ") filter shape(",
722          filter_shape.DebugString(), ")"));
723    }
724
725    if (rows_odd || cols_odd || planes_odd) {
726      Tensor in_backprop_remove_padding;
727      OP_REQUIRES_OK(context,
728                     context->allocate_temp(DataTypeToEnum<T>::value,
729                                            {batch, in_depth, input_size[0],
730                                             input_size[1], input_size[2]},
731                                            &in_backprop_remove_padding));
732
733      // Remove the padding for odd spatial dimensions.
734      functor::PadInput<GPUDevice, T, int, 5>()(
735          context->eigen_device<GPUDevice>(),
736          To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
737                      .tensor<T, 5>()),
738          {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
739          To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);
740
741      pre_transformed_in_backprop = in_backprop_remove_padding;
742    }
743
744    if (data_format_ == FORMAT_NHWC) {
745      auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
746      functor::NCHWToNHWC<GPUDevice, T, 5>()(
747          context->eigen_device<GPUDevice>(),
748          toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
749          in_backprop->tensor<T, 5>());
750    } else {
751      *in_backprop = pre_transformed_in_backprop;
752    }
753  }
754
755 private:
756  std::vector<int32> stride_;
757  Padding padding_;
758  TensorFormat data_format_;
759  bool takes_shape_;
760  bool cudnn_use_autotune_;
761};
762
763// A dummy type to group backward filter autotune results together.
764struct Conv3dBackwardFilterAutoTuneGroup {
765  static string name() { return "Conv3dBwdFilter"; }
766};
767typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters,
768                          perftools::gputools::dnn::AlgorithmConfig>
769    AutoTuneConv3dBwdFilter;
770
771template <typename T>
772class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
773 public:
774  explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
775      : OpKernel(context),
776        data_format_(FORMAT_NHWC),
777        takes_shape_(type_string().find("V2") != std::string::npos) {
778    // data_format is only available in V2.
779    if (takes_shape_) {
780      string data_format;
781      OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
782      OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
783                  errors::InvalidArgument("Invalid data format"));
784    }
785    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
786    OP_REQUIRES(context, stride_.size() == 5,
787                errors::InvalidArgument("Sliding window strides field must "
788                                        "specify 5 dimensions"));
789    OP_REQUIRES(
790        context,
791        (GetTensorDim(stride_, data_format_, 'C') == 1 &&
792         GetTensorDim(stride_, data_format_, 'N') == 1),
793        errors::InvalidArgument("Current implementation does not yet support "
794                                "strides in the batch and depth dimensions."));
795    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
796    cudnn_use_autotune_ = CudnnUseAutotune();
797  }
798
799  void Compute(OpKernelContext* context) override {
800    const Tensor& input = context->input(0);
801    const TensorShape& input_shape = input.shape();
802    TensorShape filter_shape;
803    if (takes_shape_) {
804      const Tensor& filter_sizes = context->input(1);
805      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
806                                  filter_sizes.vec<int32>(), &filter_shape));
807    } else {
808      filter_shape = context->input(1).shape();
809    }
810
811    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
812
813    Tensor* filter_backprop;
814    OP_REQUIRES_OK(context,
815                   context->allocate_output(0, filter_shape, &filter_backprop));
816
817    auto* stream = context->op_device_context()->stream();
818    OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
819
820    if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
821        strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
822        data_format_ == FORMAT_NHWC) {
823      const uint64 m = in_depth;
824      const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
825      const uint64 n = out_depth;
826
827      // The shape of output backprop is
828      //   [batch, out_z, out_y, out_x, out_depth]
829      // From cublas's perspective, it is: n x k
830      auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
831                                  out_backprop.template flat<T>().size());
832
833      // The shape of input is:
834      //   [batch, in_z, in_y, in_x, in_depth],
835      // From cublas's perspective, it is: m x k
836      auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
837                                  input.template flat<T>().size());
838
839      // The shape of the filter backprop is:
840      //   [1, 1, 1, in_depth, out_depth]
841      // From cublas's perspective, it is: n x m
842      auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
843                                  filter_backprop->template flat<T>().size());
844
845      bool blas_launch_status =
846          stream
847              ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
848                             perftools::gputools::blas::Transpose::kTranspose,
849                             n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
850              .ok();
851      if (!blas_launch_status) {
852        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
853                                            ", n=", n, ", k=", k));
854      }
855      return;
856    } else if (filter_size[0] == input_size[0] &&
857               filter_size[1] == input_size[1] &&
858               filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
859               data_format_ == FORMAT_NHWC) {
860      const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
861      const uint64 k = batch;
862      const uint64 n = out_depth;
863
864      auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
865                                  input.template flat<T>().size());
866      auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
867                                  out_backprop.template flat<T>().size());
868      auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
869                                  filter_backprop->template flat<T>().size());
870
871      bool blas_launch_status =
872          stream
873              ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
874                             perftools::gputools::blas::Transpose::kTranspose,
875                             n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
876              .ok();
877      if (!blas_launch_status) {
878        context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
879                                            ", n=", n, ", k=", k));
880      }
881      return;
882    }
883
884    int padding_rows = 0, padding_cols = 0, padding_planes = 0;
885
886    if (padding_ == Padding::SAME) {
887      padding_planes = std::max<int>(
888          0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
889      padding_cols = std::max<int>(
890          0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
891      padding_rows = std::max<int>(
892          0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
893    }
894    bool rows_odd = (padding_rows % 2 != 0);
895    bool cols_odd = (padding_cols % 2 != 0);
896    bool planes_odd = (padding_planes % 2 != 0);
897
898    Tensor compatible_input;
899    if (rows_odd || cols_odd || planes_odd) {
900      OP_REQUIRES_OK(context, context->allocate_temp(
901                                  DataTypeToEnum<T>::value,
902                                  ShapeFromFormat(data_format_, batch,
903                                                  {{input_size[0] + planes_odd,
904                                                    input_size[1] + rows_odd,
905                                                    input_size[2] + cols_odd}},
906                                                  in_depth),
907                                  &compatible_input));
908      functor::PadInput<GPUDevice, T, int, 5>()(
909          context->template eigen_device<GPUDevice>(),
910          To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
911          {{planes_odd, rows_odd, cols_odd}},
912          To32Bit(compatible_input.tensor<T, 5>()), data_format_);
913    } else {
914      compatible_input = input;
915    }
916
917    CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
918        << "Negative paddings: (" << padding_rows << ", " << padding_cols
919        << ", " << padding_planes << ")";
920    perftools::gputools::dnn::BatchDescriptor input_desc(3);
921    input_desc.set_count(batch)
922        .set_spatial_dim(DimIndex::X,
923                         GetTensorDim(compatible_input, data_format_, '2'))
924        .set_spatial_dim(DimIndex::Y,
925                         GetTensorDim(compatible_input, data_format_, '1'))
926        .set_spatial_dim(DimIndex::Z,
927                         GetTensorDim(compatible_input, data_format_, '0'))
928        .set_feature_map_count(in_depth)
929        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
930    perftools::gputools::dnn::BatchDescriptor output_desc(3);
931    output_desc.set_count(batch)
932        .set_spatial_dim(DimIndex::X, output_cols)
933        .set_spatial_dim(DimIndex::Y, output_rows)
934        .set_spatial_dim(DimIndex::Z, output_planes)
935        .set_feature_map_count(out_depth)
936        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
937    perftools::gputools::dnn::FilterDescriptor filter_desc(3);
938    filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
939        .set_spatial_dim(DimIndex::Y, filter_size[1])
940        .set_spatial_dim(DimIndex::Z, filter_size[0])
941        .set_input_feature_map_count(in_depth)
942        .set_output_feature_map_count(out_depth);
943    perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
944    conv_desc.set_filter_stride(DimIndex::X, strides[2])
945        .set_filter_stride(DimIndex::Y, strides[1])
946        .set_filter_stride(DimIndex::Z, strides[0])
947        .set_zero_padding(DimIndex::X, padding_cols / 2)
948        .set_zero_padding(DimIndex::Y, padding_rows / 2)
949        .set_zero_padding(DimIndex::Z, padding_planes / 2);
950
951    Tensor pre_transformed_filter_backprop;
952    OP_REQUIRES_OK(
953        context,
954        context->allocate_temp(DataTypeToEnum<T>::value,
955                               TensorShape({out_depth, in_depth, filter_size[0],
956                                            filter_size[1], filter_size[2]}),
957                               &pre_transformed_filter_backprop));
958
959    Tensor transformed_out_backprop;
960    if (data_format_ == FORMAT_NHWC) {
961      TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
962                                output_cols};
963      OP_REQUIRES_OK(
964          context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
965                                          &transformed_out_backprop));
966      if (out_depth > 1) {
967        functor::NHWCToNCHW<GPUDevice, T, 5>()(
968            context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
969            transformed_out_backprop.tensor<T, 5>());
970      } else {
971        CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
972      }
973    } else {
974      transformed_out_backprop = out_backprop;
975    }
976    Tensor transformed_input;
977    if (data_format_ == FORMAT_NHWC) {
978      TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
979                                compatible_input.dim_size(2),
980                                compatible_input.dim_size(3)};
981      if (in_depth > 1) {
982        OP_REQUIRES_OK(context,
983                       context->allocate_temp(DataTypeToEnum<T>::value,
984                                              nchw_shape, &transformed_input));
985        functor::NHWCToNCHW<GPUDevice, T, 5>()(
986            context->eigen_device<GPUDevice>(),
987            const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
988            transformed_input.tensor<T, 5>());
989      } else {
990        CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
991      }
992    } else {
993      transformed_input = compatible_input;
994    }
995
996    auto out_backprop_ptr =
997        AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
998                       transformed_out_backprop.template flat<T>().size());
999    auto filter_backprop_ptr = AsDeviceMemory(
1000        pre_transformed_filter_backprop.template flat<T>().data(),
1001        pre_transformed_filter_backprop.template flat<T>().size());
1002    auto input_ptr =
1003        AsDeviceMemory(transformed_input.template flat<T>().data(),
1004                       transformed_input.template flat<T>().size());
1005
1006    static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
1007        "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
1008
1009    const int device_id = stream->parent()->device_ordinal();
1010    DataType dtype = input.dtype();
1011    const ConvParameters conv_parameters = {
1012        batch,
1013        in_depth,
1014        {{input_size[0], input_size[1], input_size[2]}},
1015        out_depth,
1016        {{filter_size[0], filter_size[1], filter_size[2]}},
1017        {{1, 1, 1}},
1018        {{strides[0], strides[1], strides[2]}},
1019        {{padding_planes, padding_rows, padding_cols}},
1020        dtype,
1021        device_id,
1022    };
1023
1024    using perftools::gputools::dnn::AlgorithmConfig;
1025    using perftools::gputools::dnn::AlgorithmDesc;
1026    using perftools::gputools::dnn::ProfileResult;
1027    AlgorithmConfig algorithm_config;
1028    if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
1029                                   conv_parameters, &algorithm_config)) {
1030      std::vector<AlgorithmDesc> algorithms;
1031      CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
1032          conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
1033      ProfileResult best_result;
1034      ProfileResult best_result_no_scratch;
1035      for (auto profile_algorithm : algorithms) {
1036        // TODO(zhengxq): profile each algorithm multiple times to better
1037        // accuracy.
1038        CudnnScratchAllocator scratch_allocator(
1039            ConvolveBackwardFilterScratchSize, context);
1040        ProfileResult profile_result;
1041        bool cudnn_launch_status =
1042            stream
1043                ->ThenConvolveBackwardFilterWithAlgorithm(
1044                    input_desc, input_ptr, output_desc, out_backprop_ptr,
1045                    conv_desc, filter_desc, &filter_backprop_ptr,
1046                    &scratch_allocator, AlgorithmConfig(profile_algorithm),
1047                    &profile_result)
1048                .ok();
1049        if (cudnn_launch_status) {
1050          if (profile_result.is_valid()) {
1051            if (profile_result.elapsed_time_in_ms() <
1052                best_result.elapsed_time_in_ms()) {
1053              best_result = profile_result;
1054            }
1055            if (scratch_allocator.TotalByteSize() == 0 &&
1056                profile_result.elapsed_time_in_ms() <
1057                    best_result_no_scratch.elapsed_time_in_ms()) {
1058              best_result_no_scratch = profile_result;
1059            }
1060          }
1061        }
1062      }
1063      OP_REQUIRES(context,
1064                  best_result.is_valid() || best_result_no_scratch.is_valid(),
1065                  errors::NotFound("No algorithm worked!"));
1066      if (best_result.is_valid()) {
1067        algorithm_config.set_algorithm(best_result.algorithm());
1068      }
1069      if (best_result_no_scratch.is_valid()) {
1070        algorithm_config.set_algorithm_no_scratch(
1071            best_result_no_scratch.algorithm());
1072      }
1073      AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
1074                                                     algorithm_config);
1075    }
1076    CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1077                                            context);
1078    bool cudnn_launch_status =
1079        stream
1080            ->ThenConvolveBackwardFilterWithAlgorithm(
1081                input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
1082                filter_desc, &filter_backprop_ptr, &scratch_allocator,
1083                algorithm_config, nullptr)
1084            .ok();
1085
1086    if (!cudnn_launch_status) {
1087      context->SetStatus(errors::Internal(
1088          "cuDNN Backward Filter function launch failure : input shape(",
1089          input_shape.DebugString(), ") filter shape(",
1090          filter_shape.DebugString(), ")"));
1091    }
1092
1093    auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1094    functor::ReverseTransformFilter<GPUDevice, T, 5>()(
1095        context->eigen_device<GPUDevice>(),
1096        toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
1097        filter_backprop->tensor<T, 5>());
1098  }
1099
1100 private:
1101  std::vector<int32> stride_;
1102  Padding padding_;
1103  TensorFormat data_format_;
1104  bool takes_shape_;
1105  bool cudnn_use_autotune_;
1106};
1107
1108#define REGISTER_GPU_KERNEL(T)                                                \
1109  REGISTER_KERNEL_BUILDER(                                                    \
1110      Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"),  \
1111      Conv3DBackpropInputOp<GPUDevice, T>);                                   \
1112  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                       \
1113                              .Device(DEVICE_GPU)                             \
1114                              .TypeConstraint<T>("T")                         \
1115                              .HostMemory("input_sizes"),                     \
1116                          Conv3DBackpropInputOp<GPUDevice, T>);               \
1117  REGISTER_KERNEL_BUILDER(                                                    \
1118      Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1119      Conv3DBackpropFilterOp<GPUDevice, T>);                                  \
1120  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1121                              .Device(DEVICE_GPU)                             \
1122                              .TypeConstraint<T>("T")                         \
1123                              .HostMemory("filter_sizes"),                    \
1124                          Conv3DBackpropFilterOp<GPUDevice, T>);
1125TF_CALL_half(REGISTER_GPU_KERNEL);
1126TF_CALL_float(REGISTER_GPU_KERNEL);
1127#undef REGISTER_GPU_KERNEL
1128
1129#endif  // GOOGLE_CUDA
1130
1131}  // namespace tensorflow
1132