1// Copyright 2016 The Sonnet Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15 16#if GOOGLE_CUDA 17 18#define EIGEN_USE_GPU 19 20#include "tensorflow/contrib/resampler/kernels/resampler_ops.h" 21 22#include <stdio.h> 23#include <cmath> 24 25#include "tensorflow/core/framework/register_types.h" 26#include "tensorflow/core/util/cuda_kernel_helper.h" 27 28namespace tensorflow { 29 30using GPUDevice = Eigen::GpuDevice; 31 32namespace { 33 34#define GET_DATA_POINT(x, y) \ 35 data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \ 36 chan] 37 38template <typename T> 39__global__ void Resampler2DKernel(const T* __restrict__ data, 40 const T* __restrict__ warp, 41 T* __restrict__ output, const int batch_size, 42 const int data_height, const int data_width, 43 const int data_channels, 44 const int num_sampling_points) { 45 const int output_data_size = batch_size * num_sampling_points * data_channels; 46 CUDA_1D_KERNEL_LOOP(index, output_data_size) { 47 const int out_index = index; 48 49 // Get (idxSample, channel, point) from the index. 50 // Use this formula 51 // index = batch_id * num_sampling_points * num_chans + 52 // sample_id * num_chans + chan_id, 53 // with sample_id = [0, ... ,num_sampling_points) 54 const int data_batch_stride = data_height * data_width * data_channels; 55 const int warp_batch_stride = num_sampling_points * 2; 56 const int output_batch_stride = num_sampling_points * data_channels; 57 58 const int batch_id = index / output_batch_stride; 59 const int index_in_batch = index % output_batch_stride; 60 const int chan = index_in_batch % data_channels; 61 const int sample_id = index_in_batch / data_channels; 62 63 // Get coords of 2D point where data will be resampled 64 const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; 65 const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; 66 const T zero = static_cast<T>(0.0); 67 const T one = static_cast<T>(1.0); 68 // The interpolation function: 69 // a) implicitly pads the input data with 0s (hence the unusual checks 70 // with {x,y} > -1) 71 // b) returns 0 when sampling outside the (padded) image. 72 // The effect is that the sampled signal smoothly goes to 0 outside 73 // the original input domain, rather than presenting a jump 74 // discontinuity at the image boundaries. 75 if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) && 76 x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) { 77 // Precompute floor (f) and ceil (c) values for x and y. 78 const int fx = std::floor(static_cast<float>(x)); 79 const int fy = std::floor(static_cast<float>(y)); 80 const int cx = fx + 1; 81 const int cy = fy + 1; 82 const T dx = static_cast<T>(cx) - x; 83 const T dy = static_cast<T>(cy) - y; 84 85 const T img_fxfy = 86 (fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero; 87 88 const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) 89 ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) 90 : zero; 91 92 const T img_fxcy = (fx >= 0 && cy <= data_height - 1) 93 ? dx * (one - dy) * GET_DATA_POINT(fx, cy) 94 : zero; 95 96 const T img_cxfy = (cx <= data_width - 1 && fy >= 0) 97 ? (one - dx) * dy * GET_DATA_POINT(cx, fy) 98 : zero; 99 100 output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy; 101 } else { 102 output[out_index] = zero; 103 } 104 } 105} 106 107} // namespace 108 109namespace functor { 110 111template <typename T> 112struct Resampler2DFunctor<GPUDevice, T> { 113 void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d, 114 const T* __restrict__ data, const T* __restrict__ warp, 115 T* __restrict__ output, const int batch_size, 116 const int data_height, const int data_width, 117 const int data_channels, const int num_sampling_points) { 118 const int output_data_size = 119 batch_size * num_sampling_points * data_channels; 120 ::tensorflow::CudaLaunchConfig config = 121 ::tensorflow::GetCudaLaunchConfig(output_data_size, d); 122 Resampler2DKernel<T> 123 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 124 data, warp, output, batch_size, data_height, data_width, 125 data_channels, num_sampling_points); 126 } 127}; 128 129// TODO(fviola): gcudacc fails at compile time with Eigen::half. 130// template struct Resampler2DFunctor<GPUDevice, Eigen::half>; 131template struct Resampler2DFunctor<GPUDevice, float>; 132template struct Resampler2DFunctor<GPUDevice, double>; 133 134} // namespace functor 135 136namespace { 137 138#define UPDATE_GRAD_DATA_POINT(x, y, v) \ 139 atomicAdd(grad_data + (batch_id * data_batch_stride + \ 140 data_channels * (y * data_width + x) + chan), \ 141 v) 142 143template <typename T> 144__global__ void ResamplerGrad2DKernel( 145 const T* __restrict__ data, const T* __restrict__ warp, 146 const T* __restrict__ grad_output, T* __restrict__ grad_data, 147 T* __restrict__ grad_warp, const int batch_size, const int data_height, 148 const int data_width, const int data_channels, 149 const int num_sampling_points) { 150 const int resampler_output_size = 151 batch_size * num_sampling_points * data_channels; 152 CUDA_1D_KERNEL_LOOP(index, resampler_output_size) { 153 const int out_index = index; 154 155 // Get (idxSample, channel, point) from the index. 156 // Use this formula 157 // index = batch_id * num_sampling_points * num_chans + 158 // sample_id * num_chans + chan_id, 159 // with sample_id = [0, ... ,num_sampling_points) 160 const int data_batch_stride = data_height * data_width * data_channels; 161 const int warp_batch_stride = num_sampling_points * 2; 162 const int output_batch_stride = num_sampling_points * data_channels; 163 164 const int batch_id = index / output_batch_stride; 165 const int index_in_batch = index % output_batch_stride; 166 const int chan = index_in_batch % data_channels; 167 const int sample_id = index_in_batch / data_channels; 168 169 // Get coords of 2D point where data will be resampled 170 const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2; 171 const int warp_id_y = warp_id_x + 1; 172 const T x = warp[warp_id_x]; 173 const T y = warp[warp_id_y]; 174 const T zero = static_cast<T>(0.0); 175 const T one = static_cast<T>(1.0); 176 177 // Get grad output 178 const T grad_output_value = grad_output[out_index]; 179 // The interpolation function whose gradient this kernel implements: 180 // a) implicitly pads the input data with 0s (hence the unusual checks 181 // with {x,y} > -1) 182 // b) returns 0 when sampling outside the (padded) image. 183 // The effect is that the sampled signal smoothly goes to 0 outside 184 // the original input domain, rather than presenting a jump 185 // discontinuity at the image boundaries. 186 if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) && 187 x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) { 188 // Precompute floor (f) and ceil (c) values for x and y. 189 const int fx = std::floor(static_cast<float>(x)); 190 const int fy = std::floor(static_cast<float>(y)); 191 const int cx = fx + 1; 192 const int cy = fy + 1; 193 const T dx = static_cast<T>(cx) - x; 194 const T dy = static_cast<T>(cy) - y; 195 196 const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero; 197 198 const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) 199 ? GET_DATA_POINT(cx, cy) 200 : zero; 201 202 const T img_fxcy = 203 (fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero; 204 205 const T img_cxfy = 206 (cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero; 207 208 // Update partial gradients wrt relevant warp field entries 209 atomicAdd(grad_warp + warp_id_x, 210 grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) + 211 dy * (img_cxfy - img_fxfy))); 212 atomicAdd(grad_warp + warp_id_y, 213 grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) + 214 dx * (img_fxcy - img_fxfy))); 215 216 // Update partial gradients wrt sampled data 217 if (fx >= 0 && fy >= 0) { 218 UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy); 219 } 220 if (cx <= data_width - 1 && cy <= data_height - 1) { 221 UPDATE_GRAD_DATA_POINT(cx, cy, 222 grad_output_value * (one - dx) * (one - dy)); 223 } 224 if (fx >= 0 && cy <= data_height - 1) { 225 UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy)); 226 } 227 if (cx <= data_width - 1 && fy >= 0) { 228 UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy); 229 } 230 } 231 } 232} 233 234#undef GET_DATA_POINT 235#undef UPDATE_GRAD_DATA_POINT 236 237} // namespace 238 239namespace functor { 240 241template <typename T> 242struct ResamplerGrad2DFunctor<GPUDevice, T> { 243 void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d, 244 const T* __restrict__ data, const T* __restrict__ warp, 245 const T* __restrict__ grad_output, T* __restrict__ grad_data, 246 T* __restrict__ grad_warp, const int batch_size, 247 const int data_height, const int data_width, 248 const int data_channels, const int num_sampling_points) { 249 // Set gradients to 0, because the kernel incrementally updates the 250 // tensor entries by adding partial contributions. 251 const int grad_warp_size = batch_size * num_sampling_points * 2; 252 const int grad_data_size = 253 batch_size * data_height * data_width * data_channels; 254 255 ::tensorflow::CudaLaunchConfig config = 256 ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); 257 ::tensorflow:: 258 SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 259 grad_warp_size, grad_warp); 260 261 config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); 262 ::tensorflow:: 263 SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 264 grad_data_size, grad_data); 265 266 const int resampler_output_size = 267 batch_size * num_sampling_points * data_channels; 268 config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); 269 ResamplerGrad2DKernel<T> 270 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 271 data, warp, grad_output, grad_data, grad_warp, batch_size, 272 data_height, data_width, data_channels, num_sampling_points); 273 } 274}; 275 276template struct ResamplerGrad2DFunctor<GPUDevice, float>; 277 278} // namespace functor 279 280} // namespace tensorflow 281 282#endif // GOOGLE_CUDA 283