1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 26a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 36a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 46a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFloweryou may not use this file except in compliance with the License. 56a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerYou may obtain a copy of the License at 66a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 76a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 86a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 96a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 106a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 116a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 126a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerSee the License for the specific language governing permissions and 136a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerlimitations under the License. 146a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower==============================================================================*/ 156a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 166a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#define USE_EIGEN_TENSOR 176a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#define EIGEN_USE_THREADS 186a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 196a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/conv_2d.h" 206a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/conv_3d.h" 216a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 226a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/numeric_op.h" 236a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h" 24aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower#include "tensorflow/core/framework/register_types.h" 256a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/tensor.h" 266a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h" 276a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/framework/tensor_slice.h" 286a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/conv_ops_gpu.h" 296a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/kernels/ops_util.h" 306a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h" 316a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/util/padding.h" 326a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/util/tensor_format.h" 33bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower#include "tensorflow/core/util/use_cudnn.h" 346a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 356a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#if GOOGLE_CUDA 366a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#include "tensorflow/core/platform/stream_executor.h" 376a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerusing perftools::gputools::dnn::DimIndex; 386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#endif 396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 406a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowernamespace tensorflow { 416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertypedef Eigen::ThreadPoolDevice CPUDevice; 436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertypedef Eigen::GpuDevice GPUDevice; 446a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 456a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename Device, typename T> 466a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerstruct LaunchConvOp; 476a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename T> 496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerstruct LaunchConvOp<CPUDevice, T> { 50bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower static void launch(OpKernelContext* context, bool cudnn_use_autotune, 51bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower const Tensor& input, const Tensor& filter, 52bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower const std::array<int64, 3>& strides, const Padding padding, 5361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower TensorFormat data_format, Tensor* output) { 5461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower OP_REQUIRES(context, data_format == FORMAT_NHWC, 5561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower errors::InvalidArgument("CPU implementation of Conv3D " 5661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower "currently only supports the NHWC " 5761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower "tensor format.")); 586a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower functor::CuboidConvolution<CPUDevice, T>()( 596a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower context->eigen_device<CPUDevice>(), output->tensor<T, 5>(), 60374f3ce9528217b4176af1a3fafc02bb3af00e96A. Unique TensorFlower input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1], 61374f3ce9528217b4176af1a3fafc02bb3af00e96A. Unique TensorFlower strides[0], BrainPadding2EigenPadding(padding)); 626a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower}; 646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename Device, typename T> 666a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerclass Conv3DOp : public BinaryOp<T> { 676a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower public: 686a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) { 6961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower string data_format; 7061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 7161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 7261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower errors::InvalidArgument("Invalid data format")); 736a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES(context, stride_.size() == 5, 756a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower errors::InvalidArgument("Sliding window strides field must " 766a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower "specify 5 dimensions")); 776a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES( 7861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower context, 7961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower (GetTensorDim(stride_, data_format_, 'N') == 1 && 8061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower GetTensorDim(stride_, data_format_, 'C') == 1), 816a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower errors::InvalidArgument("Current implementation does not yet support " 826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower "strides in the batch and depth dimensions.")); 836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 84bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower cudnn_use_autotune_ = CudnnUseAutotune(); 856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 876a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower void Compute(OpKernelContext* context) override { 886a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // Input tensor is of the following dimensions: 896a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // [ batch, in_z, in_y, in_x, in_channels ] 906a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const Tensor& input = context->input(0); 916a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 926a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // Input filter is of the following dimensions: 936a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // [ filter_z, filter_y, filter_x, in_channels, out_channels] 946a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const Tensor& filter = context->input(1); 956a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 966a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be 976a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // kept consistent between input/filter/output. 986a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES(context, input.dims() == 5, 996a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower errors::InvalidArgument("input must be 5-dimensional")); 1006a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES(context, filter.dims() == 5, 1016a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower errors::InvalidArgument("filter must be 5-dimensional")); 1026a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 10361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 in_depth = GetTensorDim(input, data_format_, 'C'); 10461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 in_batch = GetTensorDim(input, data_format_, 'N'); 1056a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1066a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const int64 out_depth = filter.dim_size(4); 1076a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES( 1086a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower context, in_depth == filter.dim_size(3), 1096a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower errors::InvalidArgument("input and filter must have the same depth")); 1106a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 111374f3ce9528217b4176af1a3fafc02bb3af00e96A. Unique TensorFlower // Dimension order for these arrays is: z, y, x. 1126a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower std::array<int64, 3> input_size = { 11361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower {GetTensorDim(input, data_format_, '0'), 11461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower GetTensorDim(input, data_format_, '1'), 11561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower GetTensorDim(input, data_format_, '2')}}; 1166a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower std::array<int64, 3> filter_size = { 1176a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}}; 11861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'), 11961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower GetTensorDim(stride_, data_format_, '1'), 12061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower GetTensorDim(stride_, data_format_, '2')}}; 1216a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower std::array<int64, 3> out, padding; 1226a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1236a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides, 1246a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower padding_, &out, &padding)); 12561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower TensorShape out_shape = ShapeFromFormat( 12661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth); 1276a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Tensor* output; 1286a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 1296a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1306a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // Return early if nothing to do. 1316a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (out_shape.num_elements() == 0) return; 1326a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 133bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter, 13461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower strides, padding_, data_format_, output); 1356a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 1366a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1376a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower private: 1386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower std::vector<int32> stride_; 1396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Padding padding_; 14061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower TensorFormat data_format_; 141bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower bool cudnn_use_autotune_; 1426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower}; 1436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 144aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower#define REGISTER_CPU_KERNEL(T) \ 145aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower REGISTER_KERNEL_BUILDER( \ 146aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 147aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower Conv3DOp<CPUDevice, T>); 148b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengTF_CALL_half(REGISTER_CPU_KERNEL); 149aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlowerTF_CALL_float(REGISTER_CPU_KERNEL); 150aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlowerTF_CALL_double(REGISTER_CPU_KERNEL); 151aba8beebab0b363f03492b3d5653ec14d148f3c3A. Unique TensorFlower#undef REGISTER_CPU_KERNEL 1526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#if GOOGLE_CUDA 1546a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 155bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower// A dummy type to group forward convolution autotune results together. 156be81281b3c9aba3749719e7b3b08cfb51ed55b42Xiaoqiang Zhengstruct Conv3dAutoTuneGroup { 157be81281b3c9aba3749719e7b3b08cfb51ed55b42Xiaoqiang Zheng static string name() { return "Conv3d"; } 158be81281b3c9aba3749719e7b3b08cfb51ed55b42Xiaoqiang Zheng}; 159bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlowertypedef AutoTuneSingleton<Conv3dAutoTuneGroup, ConvParameters, 160bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower perftools::gputools::dnn::AlgorithmConfig> 161bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower AutoTuneConv3d; 162bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower 1636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// TODO(mjanusz): Share logic with 2d implementation as much as possible. 1646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowertemplate <typename T> 1656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerstruct LaunchConvOp<GPUDevice, T> { 166bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower static void launch(OpKernelContext* ctx, bool cudnn_use_autotune, 167bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower const Tensor& input_param, const Tensor& filter, 168bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower const std::array<int64, 3>& strides, const Padding padding, 16961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower TensorFormat data_format, Tensor* output) { 1706a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto* stream = ctx->op_device_context()->stream(); 1716a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); 1726a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1736a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Tensor input = input_param; 1746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 17561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 in_batch = GetTensorDim(input, data_format, 'N'); 17661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower int64 in_planes = GetTensorDim(input, data_format, '0'); 17761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower int64 in_rows = GetTensorDim(input, data_format, '1'); 17861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower int64 in_cols = GetTensorDim(input, data_format, '2'); 17961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 in_depth = GetTensorDim(input, data_format, 'C'); 1806a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1816a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const int64 filter_planes = filter.dim_size(0); 1826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const int64 filter_rows = filter.dim_size(1); 1836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const int64 filter_cols = filter.dim_size(2); 1846a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const int64 out_depth = filter.dim_size(4); 1856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower int64 pad_planes = 0, pad_rows = 0, pad_cols = 0; 18761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower int64 out_planes = GetTensorDim(*output, data_format, '0'); 18861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower int64 out_rows = GetTensorDim(*output, data_format, '1'); 18961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower int64 out_cols = GetTensorDim(*output, data_format, '2'); 1906a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 1916a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (padding == Padding::SAME) { 1921d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng pad_planes = std::max<int64>( 1931d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng 0, (out_planes - 1) * strides[0] + filter_planes - in_planes); 19478cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng pad_rows = std::max<int64>( 19578cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng 0, (out_rows - 1) * strides[1] + filter_rows - in_rows); 19678cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng pad_cols = std::max<int64>( 19778cf08951e95cc3afd5d2e6677db6bc85b06d43eXiaoqiang Zheng 0, (out_cols - 1) * strides[2] + filter_cols - in_cols); 1986a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 1996a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 2006a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // NOTE: This only works in NHWC. 2016a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 && 20261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower strides[0] == 1 && strides[1] == 1 && strides[2] == 1 && 20361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower data_format == FORMAT_NHWC) { 2046a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // 1x1 filter, so call cublas directly. 2050318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower const uint64 m = in_batch * in_planes * in_rows * in_cols; 2066a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const uint64 k = in_depth; 2076a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const uint64 n = out_depth; 2086a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 2096a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), 2106a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower input.template flat<T>().size()); 2110318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 2120318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower filter.template flat<T>().size()); 2130318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower auto c_ptr = AsDeviceMemory(output->template flat<T>().data(), 2140318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower output->template flat<T>().size()); 2150318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower 2160318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; 2170318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower bool blas_launch_status = 2180318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower stream 2190318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, 2200318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower n, a_ptr, k, 0.0f, &c_ptr, n) 2210318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower .ok(); 2220318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower if (!blas_launch_status) { 2230318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 2240318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower ", n=", n, ", k=", k)); 2250318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower } 2260318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower return; 2270318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower } else if (filter_planes == in_planes && filter_rows == in_rows && 22861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower filter_cols == in_cols && padding == Padding::VALID && 22961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower data_format == FORMAT_NHWC) { 2300318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower // The input data and filter have the same planes/height/width, so call 2310318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower // cublas directly. 2320318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower const uint64 m = in_batch; 2330318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower const uint64 k = in_planes * in_rows * in_cols * in_depth; 2340318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower const uint64 n = out_depth; 2350318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower 2360318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), 2370318cf082ee88ff0e226a5bf7da0487f44d82182A. Unique TensorFlower input.template flat<T>().size()); 2386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 2396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower filter.template flat<T>().size()); 2406a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto c_ptr = AsDeviceMemory(output->template flat<T>().data(), 2416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower output->template flat<T>().size()); 2426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 2436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; 2446a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower bool blas_launch_status = 2456a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower stream 2466a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, 2476a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower n, a_ptr, k, 0.0f, &c_ptr, n) 2486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .ok(); 2496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (!blas_launch_status) { 2506a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 2516a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ", n=", n, ", k=", k)); 2526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 2536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower return; 2546a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 2556a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 2566a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (padding == Padding::SAME) { 2576a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const bool rows_odd = (pad_rows % 2 != 0); 2586a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const bool cols_odd = (pad_cols % 2 != 0); 2596a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const bool planes_odd = (pad_planes % 2 != 0); 2606a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 2616a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // Necessary because cuDNN only supports symmetric padding. 2626a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // TODO(mjanusz): Consider making this optional? This would save some 2636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // overhead and would work as long as an op trained this way is only 2646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // used on GPU. 2656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (rows_odd || cols_odd || planes_odd) { 26661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 new_in_rows = in_rows + rows_odd; 26761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 new_in_cols = in_cols + cols_odd; 26861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const int64 new_in_planes = in_planes + planes_odd; 2696a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 27061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower Tensor transformed_input; 27161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower TensorShape transformed_shape = ShapeFromFormat( 27261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower data_format, in_batch, {{new_in_planes, new_in_rows, new_in_cols}}, 27361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower in_depth); 2746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK( 2756a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, transformed_shape, 2766a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower &transformed_input)); 2776a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 2786a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower functor::PadInput<GPUDevice, T, int, 5>()( 2796a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 5>()), 2806a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}}, 28161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower To32Bit(transformed_input.tensor<T, 5>()), data_format); 2826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower input = transformed_input; 2836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower in_rows = new_in_rows; 2846a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower in_cols = new_in_cols; 2856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower in_planes = new_in_planes; 2866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 2876a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 2886a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 28961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower if (data_format == FORMAT_NHWC) { 29061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const TensorShape nchw_shape = ShapeFromFormat( 29161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth); 29261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower if (in_depth > 1) { 29361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower Tensor transformed_input; 29461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 29561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower nchw_shape, &transformed_input)); 29661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower // input: [b, x, y, z, d] 29761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower // t_input: [b, d, x, y, z] 29861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower // NCDHW is the only format universally supported by cuDNN. 29961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower functor::NHWCToNCHW<GPUDevice, T, 5>()( 30061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower ctx->eigen_device<GPUDevice>(), 30161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const_cast<const Tensor&>(input).tensor<T, 5>(), 30261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower transformed_input.tensor<T, 5>()); 30361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower input = transformed_input; 30461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower } else { 30561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower CHECK(input.CopyFrom(input, nchw_shape)); 30661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower } 30761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower } 3086a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 3091d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0) 3101d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng << "Negative paddings: (" << pad_rows << ", " << pad_cols << ", " 3111d9ada31b2d046bb9586a3a276155280ed8f5b6bXiaoqiang Zheng << pad_planes << ")"; 3126a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower perftools::gputools::dnn::BatchDescriptor input_desc(3); 3136a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower input_desc.set_count(in_batch) 3146a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_feature_map_count(in_depth) 3156a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::X, in_cols) 3166a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::Y, in_rows) 3176a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::Z, in_planes) 3186a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 3196a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower perftools::gputools::dnn::BatchDescriptor output_desc(3); 3206a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower output_desc.set_count(in_batch) 3216a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::X, out_cols) 3226a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::Y, out_rows) 3236a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::Z, out_planes) 3246a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_feature_map_count(out_depth) 3256a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 3266a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower perftools::gputools::dnn::FilterDescriptor filter_desc(3); 3276a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower filter_desc.set_spatial_dim(DimIndex::X, filter_cols) 3286a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::Y, filter_rows) 3296a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_spatial_dim(DimIndex::Z, filter_planes) 3306a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_input_feature_map_count(in_depth) 3316a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_output_feature_map_count(out_depth); 3326a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3); 3336a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower conv_desc.set_filter_stride(DimIndex::X, strides[2]) 3346a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_filter_stride(DimIndex::Y, strides[1]) 3356a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_filter_stride(DimIndex::Z, strides[0]) 3366a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_zero_padding(DimIndex::X, pad_cols / 2) 3376a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_zero_padding(DimIndex::Y, pad_rows / 2) 3386a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .set_zero_padding(DimIndex::Z, pad_planes / 2); 3396a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 3406a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Tensor transformed_filter; 3416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK( 3426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 3436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower TensorShape({out_depth, in_depth, filter_planes, 3446a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower filter_rows, filter_cols}), 3456a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower &transformed_filter)); 3466a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // filter: [x, y, z, in, out] 3476a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower // t_filter: [out, in, x, y, z] 3486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower functor::TransformFilter<GPUDevice, T, int, 5>()( 3496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), 3506a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower To32Bit(transformed_filter.tensor<T, 5>())); 3516a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 3526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Tensor transformed_output; 3536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower OP_REQUIRES_OK( 35461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower ctx, ctx->allocate_temp( 35561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower DataTypeToEnum<T>::value, 35661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower ShapeFromFormat(FORMAT_NCHW, in_batch, 35761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower {{out_planes, out_rows, out_cols}}, out_depth), 35861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower &transformed_output)); 3596a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 3606a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto input_ptr = AsDeviceMemory(input.template flat<T>().data(), 3616a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower input.template flat<T>().size()); 3626a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto filter_ptr = 3636a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower AsDeviceMemory(transformed_filter.template flat<T>().data(), 3646a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower transformed_filter.template flat<T>().size()); 3656a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower auto output_ptr = 3666a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower AsDeviceMemory(transformed_output.template flat<T>().data(), 3676a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower transformed_output.template flat<T>().size()); 3686a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 3696a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( 3706a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default 371bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower 372bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower int device_id = stream->parent()->device_ordinal(); 373f4b237f8cdd25a45dc26adc61c3086c2575f5396Yangzihao Wang DataType dtype = input.dtype(); 374bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower ConvParameters conv_parameters = { 375bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower in_batch, 376bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower in_depth, 377bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower {{in_planes, in_rows, in_cols}}, 378bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower out_depth, 379bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower {{filter_planes, filter_rows, filter_cols}}, 380cb4ef362e4a18b3c42a2c90bdad8754d5ead4cafYangzihao Wang // TODO(yangzihao): Send in arbitrary dilation rates after the dilated 381cb4ef362e4a18b3c42a2c90bdad8754d5ead4cafYangzihao Wang // conv is supported. 382309f7e29a6f19ac060e9cf5f02e7de0eeac522deA. Unique TensorFlower /*dilation=*/{{1, 1, 1}}, 383bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower {{strides[0], strides[1], strides[2]}}, 384bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower {{pad_planes, pad_rows, pad_cols}}, 385f4b237f8cdd25a45dc26adc61c3086c2575f5396Yangzihao Wang dtype, 386bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower device_id, 387bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower }; 388bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower 389bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower using perftools::gputools::dnn::AlgorithmConfig; 390e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai using perftools::gputools::dnn::AlgorithmDesc; 391bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower using perftools::gputools::dnn::ProfileResult; 392bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower 393bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower AlgorithmConfig algorithm_config; 394bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower 395bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find( 396bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower conv_parameters, &algorithm_config)) { 3975eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen std::vector<AlgorithmDesc> algorithms; 398e78e5ec8a8c862e65b6a194e9caea377120d7207Yangzihao Wang CHECK(stream->parent()->GetConvolveAlgorithms( 399e78e5ec8a8c862e65b6a194e9caea377120d7207Yangzihao Wang conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); 400bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower ProfileResult best_result; 401bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower ProfileResult best_result_no_scratch; 4025eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen for (auto profile_algorithm : algorithms) { 4035eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen // TODO(zhengxq): profile each algorithm multiple times to better 4045eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen // accuracy. 4055eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); 4065eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen ProfileResult profile_result; 4075eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen bool cudnn_launch_status = 4085eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen stream 4095eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen ->ThenConvolveWithAlgorithm( 4105eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, 4115eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen output_desc, &output_ptr, &scratch_allocator, 4125eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen AlgorithmConfig(profile_algorithm), &profile_result) 4135eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen .ok(); 4145eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen if (cudnn_launch_status) { 4155eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen if (profile_result.is_valid()) { 4165eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen if (profile_result.elapsed_time_in_ms() < 4175eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen best_result.elapsed_time_in_ms()) { 4185eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen best_result = profile_result; 4195eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen } 4205eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen if (scratch_allocator.TotalByteSize() == 0 && 4215eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen profile_result.elapsed_time_in_ms() < 4225eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen best_result_no_scratch.elapsed_time_in_ms()) { 4235eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen best_result_no_scratch = profile_result; 424bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower } 425bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower } 426bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower } 427bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower } 428bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower OP_REQUIRES(ctx, 429db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang best_result.is_valid() || best_result_no_scratch.is_valid(), 430bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower errors::NotFound("No algorithm worked!")); 431db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang if (best_result.is_valid()) { 432db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang algorithm_config.set_algorithm(best_result.algorithm()); 433db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang } 434db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang if (best_result_no_scratch.is_valid()) { 435db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang algorithm_config.set_algorithm_no_scratch( 436db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang best_result_no_scratch.algorithm()); 437db596594b5653b43fcb558a4753b39904bb62cbdYangzihao Wang } 438bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config); 439bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower } 440bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower 4416a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); 4426a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower bool cudnn_launch_status = 4436a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower stream 444bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc, 445bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower filter_ptr, conv_desc, output_desc, 446bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower &output_ptr, &scratch_allocator, 447bc6e0c471c4d7d6cd150149f2830e9d23a0040bcA. Unique TensorFlower algorithm_config, nullptr) 4486a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower .ok(); 4496a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 4506a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower if (!cudnn_launch_status) { 4516a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ctx->SetStatus(errors::Internal( 4526a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower "cuDNN launch failure : input shape(", input.shape().DebugString(), 4536a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower ") filter shape(", filter.shape().DebugString(), ")")); 4546a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 4556a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 45661f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower if (data_format == FORMAT_NHWC) { 45761f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower // t_output: [b, out, x, y, z] 45861f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower // output: [b, x, y, z, out] 45961f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower functor::NCHWToNHWC<GPUDevice, T, 5>()( 46061f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower ctx->eigen_device<GPUDevice>(), 46161f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower const_cast<const Tensor&>(transformed_output).tensor<T, 5>(), 46261f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower output->tensor<T, 5>()); 46361f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower } else { 46461f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower *output = transformed_output; 46561f30222eba5e3f1f51dedb3c5493f5f8eb331c8A. Unique TensorFlower } 4666a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower } 4676a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower}; 4686a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 4696a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// Forward declarations of the functor specializations for GPU. 4706a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// This ensures that the custom implementation is used instead of the default 4716a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// Eigen one (which is used for CPU). 4726a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowernamespace functor { 4736a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#define DECLARE_GPU_SPEC(T) \ 4746a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower template <> \ 4756a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower void TransformFilter<GPUDevice, T, int, 5>::operator()( \ 4766a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ 4776a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower typename TTypes<T, 5, int>::Tensor out); \ 4786a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower template <> \ 4796a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ 4806a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \ 4816a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower typename TTypes<T, 5>::Tensor out); \ 4826a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower template <> \ 4836a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower void PadInput<GPUDevice, T, int, 5>::operator()( \ 4846a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ 4856a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const std::array<int, 3>& padding_left, \ 4866a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower const std::array<int, 3>& padding_right, \ 4876a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower typename TTypes<T, 5, int>::Tensor out, TensorFormat format); 4886a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 489b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengDECLARE_GPU_SPEC(Eigen::half); 4906a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerDECLARE_GPU_SPEC(float); 4916a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#undef DECLARE_GPU_SPEC 4926a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 4936a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower} // namespace functor 4946a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 4956a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower// Registration of the GPU implementations. 4966a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlowerREGISTER_KERNEL_BUILDER( 497b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), 498b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Conv3DOp<GPUDevice, Eigen::half>); 499b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei FengREGISTER_KERNEL_BUILDER( 5006a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"), 5016a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower Conv3DOp<GPUDevice, float>); 5026a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower#endif // GOOGLE_CUDA 5036a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower 5046a187ccddaebb741ea77fc3201c6e36625f0aadbA. Unique TensorFlower} // namespace tensorflow 505