1// Copyright 2016 The Sonnet Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15
16#if GOOGLE_CUDA
17
18#define EIGEN_USE_GPU
19
20#include "tensorflow/contrib/resampler/kernels/resampler_ops.h"
21
22#include <stdio.h>
23#include <cmath>
24
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/util/cuda_kernel_helper.h"
27
28namespace tensorflow {
29
30using GPUDevice = Eigen::GpuDevice;
31
32namespace {
33
34#define GET_DATA_POINT(x, y)                                                 \
35  data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \
36       chan]
37
38template <typename T>
39__global__ void Resampler2DKernel(const T* __restrict__ data,
40                                  const T* __restrict__ warp,
41                                  T* __restrict__ output, const int batch_size,
42                                  const int data_height, const int data_width,
43                                  const int data_channels,
44                                  const int num_sampling_points) {
45  const int output_data_size = batch_size * num_sampling_points * data_channels;
46  CUDA_1D_KERNEL_LOOP(index, output_data_size) {
47    const int out_index = index;
48
49    // Get (idxSample, channel, point) from the index.
50    // Use this formula
51    //   index = batch_id * num_sampling_points * num_chans +
52    //           sample_id * num_chans + chan_id,
53    // with sample_id = [0, ... ,num_sampling_points)
54    const int data_batch_stride = data_height * data_width * data_channels;
55    const int warp_batch_stride = num_sampling_points * 2;
56    const int output_batch_stride = num_sampling_points * data_channels;
57
58    const int batch_id = index / output_batch_stride;
59    const int index_in_batch = index % output_batch_stride;
60    const int chan = index_in_batch % data_channels;
61    const int sample_id = index_in_batch / data_channels;
62
63    // Get coords of 2D point where data will be resampled
64    const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
65    const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
66    const T zero = static_cast<T>(0.0);
67    const T one = static_cast<T>(1.0);
68    // The interpolation function:
69    // a) implicitly pads the input data with 0s (hence the unusual checks
70    // with {x,y} > -1)
71    // b) returns 0 when sampling outside the (padded) image.
72    // The effect is that the sampled signal smoothly goes to 0 outside
73    // the original input domain, rather than presenting a jump
74    // discontinuity at the image boundaries.
75    if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
76        x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
77      // Precompute floor (f) and ceil (c) values for x and y.
78      const int fx = std::floor(static_cast<float>(x));
79      const int fy = std::floor(static_cast<float>(y));
80      const int cx = fx + 1;
81      const int cy = fy + 1;
82      const T dx = static_cast<T>(cx) - x;
83      const T dy = static_cast<T>(cy) - y;
84
85      const T img_fxfy =
86          (fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero;
87
88      const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
89                             ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy)
90                             : zero;
91
92      const T img_fxcy = (fx >= 0 && cy <= data_height - 1)
93                             ? dx * (one - dy) * GET_DATA_POINT(fx, cy)
94                             : zero;
95
96      const T img_cxfy = (cx <= data_width - 1 && fy >= 0)
97                             ? (one - dx) * dy * GET_DATA_POINT(cx, fy)
98                             : zero;
99
100      output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy;
101    } else {
102      output[out_index] = zero;
103    }
104  }
105}
106
107}  // namespace
108
109namespace functor {
110
111template <typename T>
112struct Resampler2DFunctor<GPUDevice, T> {
113  void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d,
114                  const T* __restrict__ data, const T* __restrict__ warp,
115                  T* __restrict__ output, const int batch_size,
116                  const int data_height, const int data_width,
117                  const int data_channels, const int num_sampling_points) {
118    const int output_data_size =
119        batch_size * num_sampling_points * data_channels;
120    ::tensorflow::CudaLaunchConfig config =
121        ::tensorflow::GetCudaLaunchConfig(output_data_size, d);
122    Resampler2DKernel<T>
123        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
124            data, warp, output, batch_size, data_height, data_width,
125            data_channels, num_sampling_points);
126  }
127};
128
129// TODO(fviola): gcudacc fails at compile time with Eigen::half.
130// template struct Resampler2DFunctor<GPUDevice, Eigen::half>;
131template struct Resampler2DFunctor<GPUDevice, float>;
132template struct Resampler2DFunctor<GPUDevice, double>;
133
134}  // namespace functor
135
136namespace {
137
138#define UPDATE_GRAD_DATA_POINT(x, y, v)                                \
139  atomicAdd(grad_data + (batch_id * data_batch_stride +                \
140                         data_channels * (y * data_width + x) + chan), \
141            v)
142
143template <typename T>
144__global__ void ResamplerGrad2DKernel(
145    const T* __restrict__ data, const T* __restrict__ warp,
146    const T* __restrict__ grad_output, T* __restrict__ grad_data,
147    T* __restrict__ grad_warp, const int batch_size, const int data_height,
148    const int data_width, const int data_channels,
149    const int num_sampling_points) {
150  const int resampler_output_size =
151      batch_size * num_sampling_points * data_channels;
152  CUDA_1D_KERNEL_LOOP(index, resampler_output_size) {
153    const int out_index = index;
154
155    // Get (idxSample, channel, point) from the index.
156    // Use this formula
157    //   index = batch_id * num_sampling_points * num_chans +
158    //           sample_id * num_chans + chan_id,
159    // with sample_id = [0, ... ,num_sampling_points)
160    const int data_batch_stride = data_height * data_width * data_channels;
161    const int warp_batch_stride = num_sampling_points * 2;
162    const int output_batch_stride = num_sampling_points * data_channels;
163
164    const int batch_id = index / output_batch_stride;
165    const int index_in_batch = index % output_batch_stride;
166    const int chan = index_in_batch % data_channels;
167    const int sample_id = index_in_batch / data_channels;
168
169    // Get coords of 2D point where data will be resampled
170    const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2;
171    const int warp_id_y = warp_id_x + 1;
172    const T x = warp[warp_id_x];
173    const T y = warp[warp_id_y];
174    const T zero = static_cast<T>(0.0);
175    const T one = static_cast<T>(1.0);
176
177    // Get grad output
178    const T grad_output_value = grad_output[out_index];
179    // The interpolation function whose gradient this kernel implements:
180    // a) implicitly pads the input data with 0s (hence the unusual checks
181    // with {x,y} > -1)
182    // b) returns 0 when sampling outside the (padded) image.
183    // The effect is that the sampled signal smoothly goes to 0 outside
184    // the original input domain, rather than presenting a jump
185    // discontinuity at the image boundaries.
186    if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
187        x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
188      // Precompute floor (f) and ceil (c) values for x and y.
189      const int fx = std::floor(static_cast<float>(x));
190      const int fy = std::floor(static_cast<float>(y));
191      const int cx = fx + 1;
192      const int cy = fy + 1;
193      const T dx = static_cast<T>(cx) - x;
194      const T dy = static_cast<T>(cy) - y;
195
196      const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero;
197
198      const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
199                             ? GET_DATA_POINT(cx, cy)
200                             : zero;
201
202      const T img_fxcy =
203          (fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero;
204
205      const T img_cxfy =
206          (cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero;
207
208      // Update partial gradients wrt relevant warp field entries
209      atomicAdd(grad_warp + warp_id_x,
210                grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
211                                     dy * (img_cxfy - img_fxfy)));
212      atomicAdd(grad_warp + warp_id_y,
213                grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
214                                     dx * (img_fxcy - img_fxfy)));
215
216      // Update partial gradients wrt sampled data
217      if (fx >= 0 && fy >= 0) {
218        UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy);
219      }
220      if (cx <= data_width - 1 && cy <= data_height - 1) {
221        UPDATE_GRAD_DATA_POINT(cx, cy,
222                               grad_output_value * (one - dx) * (one - dy));
223      }
224      if (fx >= 0 && cy <= data_height - 1) {
225        UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy));
226      }
227      if (cx <= data_width - 1 && fy >= 0) {
228        UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy);
229      }
230    }
231  }
232}
233
234#undef GET_DATA_POINT
235#undef UPDATE_GRAD_DATA_POINT
236
237}  // namespace
238
239namespace functor {
240
241template <typename T>
242struct ResamplerGrad2DFunctor<GPUDevice, T> {
243  void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d,
244                  const T* __restrict__ data, const T* __restrict__ warp,
245                  const T* __restrict__ grad_output, T* __restrict__ grad_data,
246                  T* __restrict__ grad_warp, const int batch_size,
247                  const int data_height, const int data_width,
248                  const int data_channels, const int num_sampling_points) {
249    // Set gradients to 0, because the kernel incrementally updates the
250    // tensor entries by adding partial contributions.
251    const int grad_warp_size = batch_size * num_sampling_points * 2;
252    const int grad_data_size =
253        batch_size * data_height * data_width * data_channels;
254
255    ::tensorflow::CudaLaunchConfig config =
256        ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d);
257    ::tensorflow::
258        SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
259            grad_warp_size, grad_warp);
260
261    config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d);
262    ::tensorflow::
263        SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
264            grad_data_size, grad_data);
265
266    const int resampler_output_size =
267        batch_size * num_sampling_points * data_channels;
268    config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d);
269    ResamplerGrad2DKernel<T>
270        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
271            data, warp, grad_output, grad_data, grad_warp, batch_size,
272            data_height, data_width, data_channels, num_sampling_points);
273  }
274};
275
276template struct ResamplerGrad2DFunctor<GPUDevice, float>;
277
278}  // namespace functor
279
280}  // namespace tensorflow
281
282#endif  // GOOGLE_CUDA
283