fused_conv2d_bias_activation_op.cc revision 5eaefbabce16bffeeb4b19cee9890b1aeccabb09
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#if GOOGLE_CUDA 17#define EIGEN_USE_GPU 18#endif // GOOGLE_CUDA 19 20#include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h" 21 22#include "tensorflow/core/framework/numeric_op.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/register_types.h" 25#include "tensorflow/core/framework/tensor.h" 26#include "tensorflow/core/framework/tensor_shape.h" 27#include "tensorflow/core/framework/tensor_slice.h" 28#include "tensorflow/core/kernels/bounds_check.h" 29#include "tensorflow/core/kernels/conv_2d.h" 30#include "tensorflow/core/kernels/ops_util.h" 31#include "tensorflow/core/lib/core/errors.h" 32#include "tensorflow/core/lib/strings/strcat.h" 33#include "tensorflow/core/util/padding.h" 34#include "tensorflow/core/util/use_cudnn.h" 35 36#if GOOGLE_CUDA 37#include "tensorflow/core/kernels/conv_ops_gpu.h" 38#include "tensorflow/core/platform/stream_executor.h" 39#include "tensorflow/core/util/activation_mode.h" 40#endif // GOOGLE_CUDA 41 42namespace tensorflow { 43 44namespace { 45typedef Eigen::GpuDevice GPUDevice; 46 47template <typename T> 48struct RawType { 49 using type = T; 50}; 51 52template <> 53struct RawType<qint8> { 54 using type = int8; 55}; 56 57// Template struct to convert int8x4 to int32. 58// (for NCHW_VECT_C with element type int8, we can consider it to be 59// an NCHW layout with element type int32 for operations like padding). 60template <typename T> 61struct Int8x4ToInt32 { 62 // By default, do not change T. 63 using type = T; 64}; 65 66template <> 67struct Int8x4ToInt32<int8> { 68 using type = int32; 69}; 70} // namespace 71 72// T is the element type of the conv_input, filter and side_input tensors. 73// BiasType is the element type of the bias tensor, which can be different. 74// ScaleType is the type used for conv_input_scale, side_input_scale. 75template <typename Device, typename T, typename BiasType, typename ScaleType> 76class FusedConv2DBiasActivationOp : public OpKernel { 77 public: 78 enum InputIndexes { 79 kConvInput = 0, 80 kFilter, 81 kBias, 82 kSideInput, 83 kConvInputScale, 84 kSideInputScale, 85 kNumInputs 86 }; 87 88 explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context) 89 : OpKernel(context) { 90 string data_format_str, filter_format_str; 91 CHECK_EQ(kNumInputs, context->num_inputs()); 92 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 93 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), 94 errors::InvalidArgument("Invalid data format")); 95 OP_REQUIRES_OK(context, 96 context->GetAttr("filter_format", &filter_format_str)); 97 OP_REQUIRES(context, 98 FilterFormatFromString(filter_format_str, &filter_format_), 99 errors::InvalidArgument("Invalid filter format")); 100 101 std::vector<int32> strides; 102 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides)); 103 OP_REQUIRES(context, strides.size() == 4, 104 errors::InvalidArgument("Sliding window strides field must " 105 "specify 4 dimensions")); 106 107 stride_rows_ = GetTensorDim(strides, data_format_, 'H'); 108 stride_cols_ = GetTensorDim(strides, data_format_, 'W'); 109 OP_REQUIRES( 110 context, 111 (GetTensorDim(strides, data_format_, 'N') == 1 && 112 GetTensorDim(strides, data_format_, 'C') == 1), 113 errors::InvalidArgument("Convolutional strides are not supported in " 114 "the batch or depth dimensions.")); 115 116 // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. 117 constexpr bool is_int8x4 = std::is_same<T, qint8>::value; 118 119 // Note: Only NCHW_VECT_C format is supported for int8. 120 // This is because it is expected to be the fastest, and our previous tests 121 // found cudnn 6 does not fully support the other formats for int8 mode. 122 OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)), 123 errors::InvalidArgument( 124 "qint8 should be used with data_format NCHW_VECT_C.")); 125 126 OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)), 127 errors::InvalidArgument( 128 "qint8 should be used with filter_format OIHW_VECT_I.")); 129 130 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_)); 131 eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_); 132 string activation_mode_str; 133 OP_REQUIRES_OK(context, 134 context->GetAttr("activation_mode", &activation_mode_str)); 135 OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str, 136 &activation_mode_)); 137 OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU, 138 errors::InvalidArgument("Current implementation only supports " 139 "RELU as the activation function.")); 140 cudnn_use_autotune_ = CudnnUseAutotune(); 141 } 142 143 Status CheckShape(const Tensor& tensor, const string& tensor_name) { 144 const int num_dims = tensor.dims(); 145 for (int i = 0; i < num_dims; i++) { 146 if (!FastBoundsCheck(tensor.dim_size(i), 147 std::numeric_limits<int32>::max())) { 148 return errors::InvalidArgument(tensor_name, " dimension ", i, 149 " too large"); 150 } 151 } 152 // If there is a 5th dimension it is the VECT_C or VECT_I dimension. 153 if (num_dims == 5 && tensor.dim_size(4) != 4) { 154 return errors::InvalidArgument("The last dimension of ", tensor_name, 155 " must be of size 4 for qint8."); 156 } 157 return Status::OK(); 158 } 159 160 void Compute(OpKernelContext* context) override { 161 // The conv_input tensor is one of the following formats: 162 // NHWC, NCHW, NCHW_VECT_C. 163 const Tensor& conv_input = context->input(kConvInput); 164 OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input")); 165 166 // The filter tensor is one of the following formats: 167 // HWIO, OIHW, OIHW_VECT_I. 168 const Tensor& filter = context->input(kFilter); 169 OP_REQUIRES_OK(context, CheckShape(filter, "filter")); 170 171 // Input bias is a 1-D tensor, with size matching output depth. 172 const Tensor& bias = context->input(kBias); 173 OP_REQUIRES_OK(context, CheckShape(bias, "conv_input")); 174 175 const Tensor& conv_input_scale_tensor = context->input(kConvInputScale); 176 const Tensor& side_input_scale_tensor = context->input(kSideInputScale); 177 178 auto conv_input_scale = *reinterpret_cast<const ScaleType*>( 179 conv_input_scale_tensor.tensor_data().data()); 180 auto side_input_scale = *reinterpret_cast<const ScaleType*>( 181 side_input_scale_tensor.tensor_data().data()); 182 183 // If side_input_scale != 0, then side_input is not ignored and 184 // has the same type and dimensions as the output. 185 const Tensor& side_input = context->input(kSideInput); 186 if (side_input_scale != 0) { 187 OP_REQUIRES_OK(context, CheckShape(side_input, "side_input")); 188 } 189 190 // TODO(pauldonnelly): Switch to a more efficient mechanism to access 191 // dimension indexes and per-dimension attributes. 192 const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H'); 193 const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W'); 194 const int32 output_depth = GetFilterDim(filter, filter_format_, 'O'); 195 196 const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N'); 197 const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H'); 198 const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W'); 199 200 int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0; 201 OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows, 202 stride_rows_, padding_type_, 203 &output_rows, &pad_rows)); 204 OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols, 205 stride_cols_, padding_type_, 206 &output_cols, &pad_cols)); 207 // Initialize the output tensor shape according to data_format_ 208 TensorShape output_shape = ShapeFromFormat( 209 data_format_, batch_size, output_rows, output_cols, output_depth); 210 Tensor* output = nullptr; 211 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 212 213 VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = " 214 << conv_input_cols << ", conv_input_rows = " << conv_input_rows 215 << ", filter_cols = " << filter_cols 216 << ", filter_rows = " << filter_rows 217 << ", stride_cols = " << stride_cols_ 218 << ", stride_rows = " << stride_rows_ 219 << ", output_depth = " << output_depth 220 << ", output_cols = " << output_cols 221 << ", output_rows = " << output_rows 222 << ", output_shape.num_elements = " << output_shape.num_elements(); 223 224 // If there is nothing to compute, return. 225 if (output_shape.num_elements() == 0) { 226 return; 227 } 228 229 launcher_.launch(context, cudnn_use_autotune_, conv_input, conv_input_scale, 230 filter, stride_rows_, stride_cols_, eigen_padding_type_, 231 side_input, side_input_scale, bias, activation_mode_, 232 data_format_, filter_format_, output); 233 } 234 235 private: 236 int32 stride_rows_, stride_cols_; 237 Padding padding_type_; 238 Eigen::PaddingType eigen_padding_type_; 239 ActivationMode activation_mode_; 240 TensorFormat data_format_; 241 FilterTensorFormat filter_format_; 242 LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_; 243 bool cudnn_use_autotune_; 244 245 TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp); 246}; 247 248#if GOOGLE_CUDA 249namespace dnn = ::perftools::gputools::dnn; 250 251// A dummy type to group forward convolution autotune results together. 252struct ConvBiasActivationAutoTuneGroup { 253 static string name() { return "ConvBiasActivation"; } 254}; 255typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, FusedConvParameters, 256 dnn::AlgorithmConfig> 257 AutoTuneConvBiasActivation; 258 259// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it 260// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions. 261template <typename T, size_t NDIMS> 262Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor, 263 int batch_size, int rows, int cols, int depth, 264 Tensor* transformed_tensor, const Tensor** result) { 265 TensorShape nchw_shape = 266 ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth); 267 if (depth > 1) { 268 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, 269 transformed_tensor)); 270 functor::NHWCToNCHW<GPUDevice, T, NDIMS>()( 271 ctx->eigen_device<GPUDevice>(), nhwc_tensor.tensor<T, NDIMS>(), 272 transformed_tensor->tensor<T, NDIMS>()); 273 } else { 274 // If depth <= 1, then just reshape. 275 CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape)); 276 } 277 *result = transformed_tensor; 278 return Status::OK(); 279} 280 281template <typename T, typename BiasType, typename ScaleType> 282void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>:: 283 launch(OpKernelContext* ctx, bool cudnn_use_autotune, 284 const Tensor& conv_input_param, ScaleType conv_input_scale, 285 const Tensor& filter_param, int32 row_stride, int32 col_stride, 286 const Eigen::PaddingType& padding, const Tensor& side_input_param, 287 ScaleType side_input_scale, const Tensor& bias, 288 ActivationMode activation_mode, TensorFormat data_format, 289 FilterTensorFormat filter_format, Tensor* output_param) { 290 auto* stream = ctx->op_device_context()->stream(); 291 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); 292 293 // TODO(yangzihao): refactor all the complicated/duplicated code in regular 294 // conv ops to a shared conv utility. 295 296 // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. 297 constexpr bool is_int8x4 = std::is_same<T, qint8>::value; 298 constexpr int rank = is_int8x4 ? 5 : 4; 299 constexpr int vect = is_int8x4 ? 4 : 1; 300 301 const int batch_size = GetTensorDim(conv_input_param, data_format, 'N'); 302 int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H'); 303 int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W'); 304 305 const int conv_input_depth = 306 GetTensorDim(conv_input_param, data_format, 'C') * vect; 307 const int output_rows = GetTensorDim(*output_param, data_format, 'H'); 308 const int output_cols = GetTensorDim(*output_param, data_format, 'W'); 309 const int output_depth = GetFilterDim(filter_param, filter_format, 'O'); 310 const int filter_rows = GetFilterDim(filter_param, filter_format, 'H'); 311 const int filter_cols = GetFilterDim(filter_param, filter_format, 'W'); 312 int padding_rows = 0; 313 int padding_cols = 0; 314 const Tensor* conv_input = &conv_input_param; 315 316 Tensor maybe_padded_conv_input; 317 if (padding == Eigen::PADDING_SAME) { 318 // Total padding on rows and cols is 319 // Pr = (R' - 1) * S + Kr - R 320 // Pc = (C' - 1) * S + Kc - C 321 // where (R', C') are output dimensions, (R, C) are input dimensions, S 322 // is stride, (Kr, Kc) are filter dimensions. 323 // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top 324 // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means 325 // we pad more on the right and bottom than on the top and left. 326 padding_rows = std::max<int>( 327 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows); 328 padding_cols = std::max<int>( 329 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols); 330 const int padding_rows_parity = padding_rows & 1; 331 const int padding_cols_parity = padding_cols & 1; 332 if ((padding_rows_parity | padding_cols_parity) != 0) { 333 Tensor transformed_input; 334 const int new_conv_input_rows = conv_input_rows + padding_rows_parity; 335 const int new_conv_input_cols = conv_input_cols + padding_cols_parity; 336 337 using VectT = typename Int8x4ToInt32<typename RawType<T>::type>::type; 338 auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format; 339 340 OP_REQUIRES_OK( 341 ctx, ctx->allocate_temp( 342 DataTypeToEnum<T>::value, 343 ShapeFromFormat(data_format, batch_size, new_conv_input_rows, 344 new_conv_input_cols, conv_input_depth), 345 &maybe_padded_conv_input)); 346 347 auto conv_input_eigen_tensor = 348 To32Bit(conv_input_param.reinterpret_last_dimension<VectT, 4>()); 349 auto padded_conv_input_eigen_tensor = To32Bit( 350 maybe_padded_conv_input.reinterpret_last_dimension<VectT, 4>()); 351 352 functor::PadInput<GPUDevice, VectT, int, 4>()( 353 ctx->eigen_device<GPUDevice>(), conv_input_eigen_tensor, {{0, 0}}, 354 {{padding_rows_parity, padding_cols_parity}}, 355 padded_conv_input_eigen_tensor, pad_data_format); 356 357 conv_input = &maybe_padded_conv_input; 358 conv_input_rows = new_conv_input_rows; 359 conv_input_cols = new_conv_input_cols; 360 } 361 } 362 363 Tensor maybe_transformed_conv_input, maybe_transformed_side_input; 364 Tensor maybe_transformed_output; 365 const Tensor* side_input = &side_input_param; 366 Tensor* output = output_param; 367 368 // NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary 369 // and inefficient, but it is actually both a time and code size optimization, 370 // since 'is_int8x4' is a constexpr determined by the template parameter. 371 if (!is_int8x4 && data_format == FORMAT_NHWC) { 372 OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW<T, rank>( 373 ctx, *conv_input, batch_size, conv_input_rows, 374 conv_input_cols, conv_input_depth, 375 &maybe_transformed_conv_input, &conv_input))); 376 if (side_input_scale != 0) { 377 OP_REQUIRES_OK( 378 ctx, (TransformNHWCToNCHW<T, rank>( 379 ctx, side_input_param, batch_size, output_rows, output_cols, 380 output_depth, &maybe_transformed_side_input, &side_input))); 381 } 382 if (output_depth > 1) { 383 // Allocate a tensor for the NCHW output of the kernel and point output 384 // to it. Afterwards, we will transform it to NHWC while copying back to 385 // 'output_param'. 386 TensorShape nchw_shape = ShapeFromFormat( 387 FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth); 388 OP_REQUIRES_OK(ctx, 389 ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, 390 &maybe_transformed_output)); 391 output = &maybe_transformed_output; 392 } 393 } 394 395 constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4 396 : dnn::DataLayout::kBatchDepthYX; 397 constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4 398 : dnn::FilterLayout::kOutputInputYX; 399 400 dnn::BatchDescriptor conv_input_desc; 401 conv_input_desc.set_count(batch_size) 402 .set_feature_map_count(conv_input_depth) 403 .set_height(conv_input_rows) 404 .set_width(conv_input_cols) 405 .set_layout(data_layout); 406 dnn::FilterDescriptor filter_desc; 407 filter_desc.set_input_filter_height(filter_rows) 408 .set_input_filter_width(filter_cols) 409 .set_input_feature_map_count(conv_input_depth) 410 .set_output_feature_map_count(output_depth) 411 .set_layout(filter_layout); 412 dnn::BatchDescriptor side_input_desc; 413 side_input_desc.set_count(batch_size) 414 .set_height(output_rows) 415 .set_width(output_cols) 416 .set_feature_map_count(output_depth) 417 .set_layout(data_layout); 418 dnn::BatchDescriptor bias_desc; 419 bias_desc.set_count(1) 420 .set_height(1) 421 .set_width(1) 422 .set_feature_map_count(output_depth) 423 .set_layout(dnn::DataLayout::kBatchDepthYX); 424 dnn::BatchDescriptor output_desc; 425 output_desc.set_count(batch_size) 426 .set_height(output_rows) 427 .set_width(output_cols) 428 .set_feature_map_count(output_depth) 429 .set_layout(data_layout); 430 dnn::ConvolutionDescriptor conv_desc; 431 conv_desc.set_vertical_filter_stride(row_stride) 432 .set_horizontal_filter_stride(col_stride) 433 .set_zero_padding_height(padding_rows / 2) 434 .set_zero_padding_width(padding_cols / 2); 435 436 Tensor maybe_transformed_filter; 437 const Tensor* filter; 438 if (is_int8x4) { 439 // We have already checked filter is OIHW_VECT_I in the constructor. 440 filter = &filter_param; 441 } else if (filter_format == FORMAT_HWIO) { 442 // Shuffle filter tensor from HWIO to OIHW: 443 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 444 DataTypeToEnum<T>::value, 445 ShapeFromFilterFormat( 446 FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO), 447 &maybe_transformed_filter)); 448 functor::TransformFilter<GPUDevice, T, int, 4>()( 449 ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()), 450 To32Bit(maybe_transformed_filter.tensor<T, 4>())); 451 filter = &maybe_transformed_filter; 452 } 453 454 auto conv_input_ptr = 455 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>( 456 conv_input->template flat<T>().data()), 457 conv_input->template flat<T>().size()); 458 auto filter_ptr = 459 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>( 460 filter->template flat<T>().data()), 461 filter->template flat<T>().size()); 462 auto side_input_ptr = 463 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>( 464 side_input->template flat<T>().data()), 465 side_input->template flat<T>().size()); 466 auto output_ptr = 467 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>( 468 output->template flat<T>().data()), 469 output->template flat<T>().size()); 470 auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(), 471 bias.template flat<BiasType>().size()); 472 473 static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( 474 // default value is in bytes despite the name of the environment variable 475 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB 476 ); 477 478 int device_id = stream->parent()->device_ordinal(); 479 FusedConvParameters fused_conv_parameters = { 480 batch_size, 481 conv_input_depth, 482 {{conv_input_rows, conv_input_cols}}, 483 output_depth, 484 {{filter_rows, filter_cols}}, 485 {{row_stride, col_stride}}, 486 {{padding_rows, padding_cols}}, 487 conv_input->dtype(), 488 device_id, 489 (side_input_scale != 0), 490 activation_mode, 491 }; 492 493 dnn::AlgorithmConfig algorithm_config; 494 if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( 495 fused_conv_parameters, &algorithm_config)) { 496 std::vector<dnn::AlgorithmDesc> algorithms; 497 CHECK(stream->parent()->GetConvolveAlgorithms( 498 fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), 499 &algorithms)); 500 dnn::ProfileResult best_result; 501 dnn::ProfileResult best_result_no_scratch; 502 for (auto profile_algorithm : algorithms) { 503 // TODO(zhengxq): profile each algorithm multiple times to better 504 // accuracy. 505 CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); 506 dnn::ProfileResult profile_result; 507 bool cudnn_launch_status = 508 stream 509 ->ThenFusedConvolveWithAlgorithm( 510 conv_input_desc, conv_input_ptr, conv_input_scale, 511 filter_desc, filter_ptr, conv_desc, side_input_ptr, 512 side_input_scale, bias_desc, bias_ptr, 513 dnn::ActivationMode::kRelu, output_desc, &output_ptr, 514 &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), 515 &profile_result) 516 .ok(); 517 if (cudnn_launch_status) { 518 if (profile_result.is_valid()) { 519 if (profile_result.elapsed_time_in_ms() < 520 best_result.elapsed_time_in_ms()) { 521 best_result = profile_result; 522 } 523 if (scratch_allocator.TotalByteSize() == 0 && 524 profile_result.elapsed_time_in_ms() < 525 best_result_no_scratch.elapsed_time_in_ms()) { 526 best_result_no_scratch = profile_result; 527 } 528 } 529 } 530 } 531 OP_REQUIRES(ctx, 532 best_result.is_valid() || best_result_no_scratch.is_valid(), 533 errors::NotFound("No algorithm worked!")); 534 if (best_result.is_valid()) { 535 algorithm_config.set_algorithm(best_result.algorithm()); 536 } 537 if (best_result_no_scratch.is_valid()) { 538 algorithm_config.set_algorithm_no_scratch( 539 best_result_no_scratch.algorithm()); 540 } 541 AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters, 542 algorithm_config); 543 } 544 545 CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); 546 bool cudnn_launch_status = 547 stream 548 ->ThenFusedConvolveWithAlgorithm( 549 conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, 550 filter_ptr, conv_desc, side_input_ptr, side_input_scale, 551 bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc, 552 &output_ptr, &scratch_allocator, algorithm_config, 553 /*output_profile_result=*/nullptr) 554 .ok(); 555 556 if (!cudnn_launch_status) { 557 ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(", 558 conv_input->shape().DebugString(), 559 ") filter shape(", 560 filter->shape().DebugString(), ")")); 561 } 562 563 // Convert the output tensor back from NCHW to NHWC if necessary. 564 if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) { 565 functor::NCHWToNHWC<GPUDevice, T, 4>()( 566 ctx->eigen_device<GPUDevice>(), 567 const_cast<const Tensor*>(output)->tensor<T, 4>(), 568 output_param->tensor<T, 4>()); 569 } 570} 571 572// Forward declarations of the functor specializations for GPU used above. 573namespace functor { 574#define DECLARE_GPU_SPEC(T) \ 575 template <> \ 576 void PadInput<GPUDevice, T, int, 4>::operator()( \ 577 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 578 const std::array<int, 2>& padding_left, \ 579 const std::array<int, 2>& padding_right, \ 580 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \ 581 extern template struct PadInput<GPUDevice, T, int, 4>; 582 583DECLARE_GPU_SPEC(float); 584DECLARE_GPU_SPEC(int32); 585#undef DECLARE_GPU_SPEC 586} // namespace functor 587 588// Registration of the GPU implementations. 589 590REGISTER_KERNEL_BUILDER( 591 Name("FusedConv2DBiasActivation") 592 .Device(DEVICE_GPU) 593 .TypeConstraint<float>("T") 594 .TypeConstraint<float>("Tbias") 595 .HostMemory("conv_input_scale") 596 .HostMemory("side_input_scale"), 597 FusedConv2DBiasActivationOp<GPUDevice, float, float, float>); 598 599REGISTER_KERNEL_BUILDER( 600 Name("FusedConv2DBiasActivation") 601 .Device(DEVICE_GPU) 602 .TypeConstraint<qint8>("T") 603 .TypeConstraint<float>("Tbias") 604 .HostMemory("conv_input_scale") 605 .HostMemory("side_input_scale"), 606 FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>); 607 608#endif // GOOGLE_CUDA 609 610} // namespace tensorflow 611