1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
26a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
36a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
46a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFloweryou may not use this file except in compliance with the License.
56a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerYou may obtain a copy of the License at
66a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
76a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
86a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
96a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
106a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
116a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerSee the License for the specific language governing permissions and
136a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerlimitations under the License.
146a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower==============================================================================*/
156a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
166a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#define USE_EIGEN_TENSOR
176a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#define EIGEN_USE_THREADS
186a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
196a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/conv_2d.h"
206a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/conv_3d.h"
216a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
226a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/numeric_op.h"
236a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h"
24aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower#include "tensorflow/core/framework/register_types.h"
256a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
266a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h"
276a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/tensor_slice.h"
286a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/conv_ops_gpu.h"
296a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/ops_util.h"
306a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h"
316a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/util/padding.h"
326a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/util/tensor_format.h"
33bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower#include "tensorflow/core/util/use_cudnn.h"
346a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
356a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#if GOOGLE_CUDA
366a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/platform/stream_executor.h"
376a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerusing perftools::gputools::dnn::DimIndex;
386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#endif
396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
406a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowernamespace tensorflow {
416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertypedef Eigen::ThreadPoolDevice CPUDevice;
436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertypedef Eigen::GpuDevice GPUDevice;
446a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
456a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename Device, typename T>
466a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerstruct LaunchConvOp;
476a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename T>
496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerstruct LaunchConvOp<CPUDevice, T> {
50bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower  static void launch(OpKernelContext* context, bool cudnn_use_autotune,
51bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                     const Tensor& input, const Tensor& filter,
52bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                     const std::array<int64, 3>& strides, const Padding padding,
5361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                     TensorFormat data_format, Tensor* output) {
5461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    OP_REQUIRES(context, data_format == FORMAT_NHWC,
5561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                errors::InvalidArgument("CPU implementation of Conv3D "
5661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                        "currently only supports the NHWC "
5761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                        "tensor format."));
586a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    functor::CuboidConvolution<CPUDevice, T>()(
596a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
60374f3ce9528217b4176af1a3fafc02bb3af00e96A. Unique TensorFlower        input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
61374f3ce9528217b4176af1a3fafc02bb3af00e96A. Unique TensorFlower        strides[0], BrainPadding2EigenPadding(padding));
626a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  }
636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower};
646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename Device, typename T>
666a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerclass Conv3DOp : public BinaryOp<T> {
676a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower public:
686a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
6961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    string data_format;
7061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
7161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
7261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                errors::InvalidArgument("Invalid data format"));
736a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES(context, stride_.size() == 5,
756a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                errors::InvalidArgument("Sliding window strides field must "
766a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                        "specify 5 dimensions"));
776a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES(
7861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        context,
7961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        (GetTensorDim(stride_, data_format_, 'N') == 1 &&
8061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower         GetTensorDim(stride_, data_format_, 'C') == 1),
816a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        errors::InvalidArgument("Current implementation does not yet support "
826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                "strides in the batch and depth dimensions."));
836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
84bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    cudnn_use_autotune_ = CudnnUseAutotune();
856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  }
866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
876a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  void Compute(OpKernelContext* context) override {
886a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // Input tensor is of the following dimensions:
896a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // [ batch, in_z, in_y, in_x, in_channels ]
906a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const Tensor& input = context->input(0);
916a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
926a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // Input filter is of the following dimensions:
936a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // [ filter_z, filter_y, filter_x, in_channels, out_channels]
946a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const Tensor& filter = context->input(1);
956a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
966a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
976a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // kept consistent between input/filter/output.
986a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES(context, input.dims() == 5,
996a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                errors::InvalidArgument("input must be 5-dimensional"));
1006a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES(context, filter.dims() == 5,
1016a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                errors::InvalidArgument("filter must be 5-dimensional"));
1026a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
10361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    const int64 in_depth = GetTensorDim(input, data_format_, 'C');
10461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    const int64 in_batch = GetTensorDim(input, data_format_, 'N');
1056a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1066a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const int64 out_depth = filter.dim_size(4);
1076a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES(
1086a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        context, in_depth == filter.dim_size(3),
1096a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        errors::InvalidArgument("input and filter must have the same depth"));
1106a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
111374f3ce9528217b4176af1a3fafc02bb3af00e96A. Unique TensorFlower    // Dimension order for these arrays is: z, y, x.
1126a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    std::array<int64, 3> input_size = {
11361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        {GetTensorDim(input, data_format_, '0'),
11461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower         GetTensorDim(input, data_format_, '1'),
11561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower         GetTensorDim(input, data_format_, '2')}};
1166a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    std::array<int64, 3> filter_size = {
1176a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
11861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
11961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                     GetTensorDim(stride_, data_format_, '1'),
12061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                     GetTensorDim(stride_, data_format_, '2')}};
1216a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    std::array<int64, 3> out, padding;
1226a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1236a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides,
1246a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                            padding_, &out, &padding));
12561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    TensorShape out_shape = ShapeFromFormat(
12661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
1276a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    Tensor* output;
1286a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
1296a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1306a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // Return early if nothing to do.
1316a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    if (out_shape.num_elements() == 0) return;
1326a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
133bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter,
13461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                    strides, padding_, data_format_, output);
1356a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  }
1366a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1376a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower private:
1386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  std::vector<int32> stride_;
1396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  Padding padding_;
14061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower  TensorFormat data_format_;
141bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower  bool cudnn_use_autotune_;
1426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower};
1436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
144aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower#define REGISTER_CPU_KERNEL(T)                                  \
145aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower  REGISTER_KERNEL_BUILDER(                                      \
146aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower      Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
147aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower      Conv3DOp<CPUDevice, T>);
148b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengTF_CALL_half(REGISTER_CPU_KERNEL);
149aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlowerTF_CALL_float(REGISTER_CPU_KERNEL);
150aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlowerTF_CALL_double(REGISTER_CPU_KERNEL);
151aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower#undef REGISTER_CPU_KERNEL
1526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#if GOOGLE_CUDA
1546a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
155bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower// A dummy type to group forward convolution autotune results together.
156be81281b3c9aba3749719e7b3b08cfb51ed55b42Xiaoqiang Zhengstruct Conv3dAutoTuneGroup {
157be81281b3c9aba3749719e7b3b08cfb51ed55b42Xiaoqiang Zheng  static string name() { return "Conv3d"; }
158be81281b3c9aba3749719e7b3b08cfb51ed55b42Xiaoqiang Zheng};
159bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlowertypedef AutoTuneSingleton<Conv3dAutoTuneGroup, ConvParameters,
160bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                          perftools::gputools::dnn::AlgorithmConfig>
161bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    AutoTuneConv3d;
162bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower
1636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// TODO(mjanusz): Share logic with 2d implementation as much as possible.
1646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename T>
1656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerstruct LaunchConvOp<GPUDevice, T> {
166bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower  static void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
167bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                     const Tensor& input_param, const Tensor& filter,
168bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                     const std::array<int64, 3>& strides, const Padding padding,
16961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                     TensorFormat data_format, Tensor* output) {
1706a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    auto* stream = ctx->op_device_context()->stream();
1716a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
1726a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1736a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    Tensor input = input_param;
1746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
17561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    const int64 in_batch = GetTensorDim(input, data_format, 'N');
17661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    int64 in_planes = GetTensorDim(input, data_format, '0');
17761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    int64 in_rows = GetTensorDim(input, data_format, '1');
17861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    int64 in_cols = GetTensorDim(input, data_format, '2');
17961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    const int64 in_depth = GetTensorDim(input, data_format, 'C');
1806a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1816a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const int64 filter_planes = filter.dim_size(0);
1826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const int64 filter_rows = filter.dim_size(1);
1836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const int64 filter_cols = filter.dim_size(2);
1846a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    const int64 out_depth = filter.dim_size(4);
1856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
18761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    int64 out_planes = GetTensorDim(*output, data_format, '0');
18861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    int64 out_rows = GetTensorDim(*output, data_format, '1');
18961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    int64 out_cols = GetTensorDim(*output, data_format, '2');
1906a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
1916a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    if (padding == Padding::SAME) {
1921d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng      pad_planes = std::max<int64>(
1931d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng          0, (out_planes - 1) * strides[0] + filter_planes - in_planes);
19478cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng      pad_rows = std::max<int64>(
19578cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng          0, (out_rows - 1) * strides[1] + filter_rows - in_rows);
19678cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng      pad_cols = std::max<int64>(
19778cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng          0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
1986a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    }
1996a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
2006a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // NOTE: This only works in NHWC.
2016a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 &&
20261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        strides[0] == 1 && strides[1] == 1 && strides[2] == 1 &&
20361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        data_format == FORMAT_NHWC) {
2046a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      // 1x1 filter, so call cublas directly.
2050318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      const uint64 m = in_batch * in_planes * in_rows * in_cols;
2066a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const uint64 k = in_depth;
2076a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const uint64 n = out_depth;
2086a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
2096a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
2106a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                  input.template flat<T>().size());
2110318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
2120318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower                                  filter.template flat<T>().size());
2130318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
2140318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower                                  output->template flat<T>().size());
2150318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower
2160318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
2170318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      bool blas_launch_status =
2180318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower          stream
2190318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower              ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
2200318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower                             n, a_ptr, k, 0.0f, &c_ptr, n)
2210318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower              .ok();
2220318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      if (!blas_launch_status) {
2230318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower        ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
2240318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower                                        ", n=", n, ", k=", k));
2250318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      }
2260318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      return;
2270318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower    } else if (filter_planes == in_planes && filter_rows == in_rows &&
22861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower               filter_cols == in_cols && padding == Padding::VALID &&
22961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower               data_format == FORMAT_NHWC) {
2300318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      // The input data and filter have the same planes/height/width, so call
2310318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      // cublas directly.
2320318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      const uint64 m = in_batch;
2330318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      const uint64 k = in_planes * in_rows * in_cols * in_depth;
2340318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      const uint64 n = out_depth;
2350318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower
2360318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower      auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
2370318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower                                  input.template flat<T>().size());
2386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
2396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                  filter.template flat<T>().size());
2406a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
2416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                  output->template flat<T>().size());
2426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
2436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
2446a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      bool blas_launch_status =
2456a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower          stream
2466a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower              ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
2476a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                             n, a_ptr, k, 0.0f, &c_ptr, n)
2486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower              .ok();
2496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      if (!blas_launch_status) {
2506a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
2516a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                        ", n=", n, ", k=", k));
2526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      }
2536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      return;
2546a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    }
2556a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
2566a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    if (padding == Padding::SAME) {
2576a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const bool rows_odd = (pad_rows % 2 != 0);
2586a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const bool cols_odd = (pad_cols % 2 != 0);
2596a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const bool planes_odd = (pad_planes % 2 != 0);
2606a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
2616a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      // Necessary because cuDNN only supports symmetric padding.
2626a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      // TODO(mjanusz): Consider making this optional? This would save some
2636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      // overhead and would work as long as an op trained this way is only
2646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      // used on GPU.
2656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      if (rows_odd || cols_odd || planes_odd) {
26661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        const int64 new_in_rows = in_rows + rows_odd;
26761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        const int64 new_in_cols = in_cols + cols_odd;
26861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        const int64 new_in_planes = in_planes + planes_odd;
2696a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
27061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        Tensor transformed_input;
27161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        TensorShape transformed_shape = ShapeFromFormat(
27261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower            data_format, in_batch, {{new_in_planes, new_in_rows, new_in_cols}},
27361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower            in_depth);
2746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        OP_REQUIRES_OK(
2756a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower            ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, transformed_shape,
2766a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                    &transformed_input));
2776a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
2786a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        functor::PadInput<GPUDevice, T, int, 5>()(
2796a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower            ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 5>()),
2806a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower            {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}},
28161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower            To32Bit(transformed_input.tensor<T, 5>()), data_format);
2826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        input = transformed_input;
2836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        in_rows = new_in_rows;
2846a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        in_cols = new_in_cols;
2856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        in_planes = new_in_planes;
2866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      }
2876a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    }
2886a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
28961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    if (data_format == FORMAT_NHWC) {
29061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      const TensorShape nchw_shape = ShapeFromFormat(
29161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower          FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth);
29261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      if (in_depth > 1) {
29361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        Tensor transformed_input;
29461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
29561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                               nchw_shape, &transformed_input));
29661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        // input: [b, x, y, z, d]
29761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        // t_input: [b, d, x, y, z]
29861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        // NCDHW is the only format universally supported by cuDNN.
29961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        functor::NHWCToNCHW<GPUDevice, T, 5>()(
30061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower            ctx->eigen_device<GPUDevice>(),
30161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower            const_cast<const Tensor&>(input).tensor<T, 5>(),
30261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower            transformed_input.tensor<T, 5>());
30361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        input = transformed_input;
30461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      } else {
30561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        CHECK(input.CopyFrom(input, nchw_shape));
30661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      }
30761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    }
3086a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
3091d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng    CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
3101d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng        << "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
3111d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng        << pad_planes << ")";
3126a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    perftools::gputools::dnn::BatchDescriptor input_desc(3);
3136a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    input_desc.set_count(in_batch)
3146a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_feature_map_count(in_depth)
3156a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::X, in_cols)
3166a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::Y, in_rows)
3176a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::Z, in_planes)
3186a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
3196a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    perftools::gputools::dnn::BatchDescriptor output_desc(3);
3206a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    output_desc.set_count(in_batch)
3216a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::X, out_cols)
3226a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::Y, out_rows)
3236a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::Z, out_planes)
3246a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_feature_map_count(out_depth)
3256a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
3266a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    perftools::gputools::dnn::FilterDescriptor filter_desc(3);
3276a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
3286a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::Y, filter_rows)
3296a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_spatial_dim(DimIndex::Z, filter_planes)
3306a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_input_feature_map_count(in_depth)
3316a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_output_feature_map_count(out_depth);
3326a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
3336a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    conv_desc.set_filter_stride(DimIndex::X, strides[2])
3346a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_filter_stride(DimIndex::Y, strides[1])
3356a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_filter_stride(DimIndex::Z, strides[0])
3366a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_zero_padding(DimIndex::X, pad_cols / 2)
3376a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_zero_padding(DimIndex::Y, pad_rows / 2)
3386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        .set_zero_padding(DimIndex::Z, pad_planes / 2);
3396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
3406a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    Tensor transformed_filter;
3416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES_OK(
3426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
3436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                TensorShape({out_depth, in_depth, filter_planes,
3446a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                             filter_rows, filter_cols}),
3456a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                &transformed_filter));
3466a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // filter: [x, y, z, in, out]
3476a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    // t_filter: [out, in, x, y, z]
3486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    functor::TransformFilter<GPUDevice, T, int, 5>()(
3496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
3506a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        To32Bit(transformed_filter.tensor<T, 5>()));
3516a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
3526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    Tensor transformed_output;
3536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    OP_REQUIRES_OK(
35461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower        ctx, ctx->allocate_temp(
35561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                 DataTypeToEnum<T>::value,
35661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                 ShapeFromFormat(FORMAT_NCHW, in_batch,
35761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                                 {{out_planes, out_rows, out_cols}}, out_depth),
35861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower                 &transformed_output));
3596a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
3606a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
3616a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                                    input.template flat<T>().size());
3626a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    auto filter_ptr =
3636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        AsDeviceMemory(transformed_filter.template flat<T>().data(),
3646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                       transformed_filter.template flat<T>().size());
3656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    auto output_ptr =
3666a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        AsDeviceMemory(transformed_output.template flat<T>().data(),
3676a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower                       transformed_output.template flat<T>().size());
3686a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
3696a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
3706a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
371bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower
372bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    int device_id = stream->parent()->device_ordinal();
373f4b237f8cdd25a45dc26adc61c3086c2575f5396Yangzihao Wang    DataType dtype = input.dtype();
374bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    ConvParameters conv_parameters = {
375bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        in_batch,
376bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        in_depth,
377bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        {{in_planes, in_rows, in_cols}},
378bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        out_depth,
379bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        {{filter_planes, filter_rows, filter_cols}},
380cb4ef362e4a18b3c42a2c90bdad8754d5ead4cafYangzihao Wang        // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
381cb4ef362e4a18b3c42a2c90bdad8754d5ead4cafYangzihao Wang        // conv is supported.
382309f7e29a6f19ac060e9cf5f02e7de0eeac522deA. Unique TensorFlower        /*dilation=*/{{1, 1, 1}},
383bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        {{strides[0], strides[1], strides[2]}},
384bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        {{pad_planes, pad_rows, pad_cols}},
385f4b237f8cdd25a45dc26adc61c3086c2575f5396Yangzihao Wang        dtype,
386bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        device_id,
387bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    };
388bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower
389bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    using perftools::gputools::dnn::AlgorithmConfig;
390e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    using perftools::gputools::dnn::AlgorithmDesc;
391bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    using perftools::gputools::dnn::ProfileResult;
392bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower
393bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    AlgorithmConfig algorithm_config;
394bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower
395bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
396bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                                  conv_parameters, &algorithm_config)) {
3975eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      std::vector<AlgorithmDesc> algorithms;
398e78e5ec8a8c862e65b6a194e9caea377120d7207Yangzihao Wang      CHECK(stream->parent()->GetConvolveAlgorithms(
399e78e5ec8a8c862e65b6a194e9caea377120d7207Yangzihao Wang          conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
400bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower      ProfileResult best_result;
401bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower      ProfileResult best_result_no_scratch;
4025eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      for (auto profile_algorithm : algorithms) {
4035eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        // TODO(zhengxq): profile each algorithm multiple times to better
4045eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        // accuracy.
4055eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
4065eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        ProfileResult profile_result;
4075eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        bool cudnn_launch_status =
4085eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            stream
4095eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                ->ThenConvolveWithAlgorithm(
4105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                    input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
4115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                    output_desc, &output_ptr, &scratch_allocator,
4125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                    AlgorithmConfig(profile_algorithm), &profile_result)
4135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                .ok();
4145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen        if (cudnn_launch_status) {
4155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen          if (profile_result.is_valid()) {
4165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            if (profile_result.elapsed_time_in_ms() <
4175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                best_result.elapsed_time_in_ms()) {
4185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              best_result = profile_result;
4195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            }
4205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen            if (scratch_allocator.TotalByteSize() == 0 &&
4215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                profile_result.elapsed_time_in_ms() <
4225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                    best_result_no_scratch.elapsed_time_in_ms()) {
4235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen              best_result_no_scratch = profile_result;
424bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower            }
425bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower          }
426bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower        }
427bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower      }
428bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower      OP_REQUIRES(ctx,
429db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang                  best_result.is_valid() || best_result_no_scratch.is_valid(),
430bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                  errors::NotFound("No algorithm worked!"));
431db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang      if (best_result.is_valid()) {
432db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang        algorithm_config.set_algorithm(best_result.algorithm());
433db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang      }
434db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang      if (best_result_no_scratch.is_valid()) {
435db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang        algorithm_config.set_algorithm_no_scratch(
436db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang            best_result_no_scratch.algorithm());
437db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang      }
438bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower      AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
439bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower    }
440bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower
4416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
4426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    bool cudnn_launch_status =
4436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower        stream
444bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower            ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
445bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                                        filter_ptr, conv_desc, output_desc,
446bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                                        &output_ptr, &scratch_allocator,
447bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower                                        algorithm_config, nullptr)
4486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower            .ok();
4496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
4506a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    if (!cudnn_launch_status) {
4516a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      ctx->SetStatus(errors::Internal(
4526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower          "cuDNN launch failure : input shape(", input.shape().DebugString(),
4536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower          ") filter shape(", filter.shape().DebugString(), ")"));
4546a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    }
4556a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
45661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    if (data_format == FORMAT_NHWC) {
45761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      // t_output: [b, out, x, y, z]
45861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      // output: [b, x, y, z, out]
45961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      functor::NCHWToNHWC<GPUDevice, T, 5>()(
46061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower          ctx->eigen_device<GPUDevice>(),
46161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower          const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
46261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower          output->tensor<T, 5>());
46361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    } else {
46461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower      *output = transformed_output;
46561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower    }
4666a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  }
4676a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower};
4686a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
4696a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// Forward declarations of the functor specializations for GPU.
4706a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// This ensures that the custom implementation is used instead of the default
4716a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// Eigen one (which is used for CPU).
4726a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowernamespace functor {
4736a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#define DECLARE_GPU_SPEC(T)                                           \
4746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  template <>                                                         \
4756a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
4766a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
4776a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      typename TTypes<T, 5, int>::Tensor out);                        \
4786a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  template <>                                                         \
4796a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
4806a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
4816a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      typename TTypes<T, 5>::Tensor out);                             \
4826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  template <>                                                         \
4836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower  void PadInput<GPUDevice, T, int, 5>::operator()(                    \
4846a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
4856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const std::array<int, 3>& padding_left,                         \
4866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      const std::array<int, 3>& padding_right,                        \
4876a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower      typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
4886a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
489b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengDECLARE_GPU_SPEC(Eigen::half);
4906a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerDECLARE_GPU_SPEC(float);
4916a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#undef DECLARE_GPU_SPEC
4926a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
4936a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower}  // namespace functor
4946a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
4956a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// Registration of the GPU implementations.
4966a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerREGISTER_KERNEL_BUILDER(
497b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
498b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Conv3DOp<GPUDevice, Eigen::half>);
499b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengREGISTER_KERNEL_BUILDER(
5006a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
5016a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower    Conv3DOp<GPUDevice, float>);
5026a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#endif  // GOOGLE_CUDA
5036a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower
5046a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower}  // namespace tensorflow
505