1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/image_ops.cc
17#define EIGEN_USE_THREADS
18
19#include "tensorflow/core/kernels/resize_bilinear_op.h"
20
21#include <memory>
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/kernels/image_resizer_state.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/platform/logging.h"
31
32namespace tensorflow {
33
34typedef Eigen::ThreadPoolDevice CPUDevice;
35typedef Eigen::GpuDevice GPUDevice;
36
37template <typename Device, typename T>
38class ResizeBilinearOp : public OpKernel {
39 public:
40  explicit ResizeBilinearOp(OpKernelConstruction* context) : OpKernel(context) {
41    OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
42  }
43
44  void Compute(OpKernelContext* context) override {
45    const Tensor& input = context->input(0);
46    ImageResizerState st(align_corners_);
47    st.ValidateAndCreateOutput(context, input);
48
49    if (!context->status().ok()) return;
50
51    // Return if the output is empty.
52    if (st.output->NumElements() == 0) return;
53
54    typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
55    TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
56
57    functor::ResizeBilinear<Device, T>()(context->eigen_device<Device>(),
58                                         image_data, st.height_scale,
59                                         st.width_scale, output_data);
60  }
61
62 private:
63  bool align_corners_;
64};
65
66namespace {
67// Compute the interpolation indices only once.
68struct CachedInterpolation {
69  int64 lower;  // Lower source index used in the interpolation
70  int64 upper;  // Upper source index used in the interpolation
71  // 1-D linear iterpolation scale (see:
72  // https://en.wikipedia.org/wiki/Bilinear_interpolation)
73  float lerp;
74};
75
76inline void compute_interpolation_weights(const int64 out_size,
77                                          const int64 in_size,
78                                          const float scale,
79                                          CachedInterpolation* interpolation) {
80  interpolation[out_size].lower = 0;
81  interpolation[out_size].upper = 0;
82  for (int64 i = out_size - 1; i >= 0; --i) {
83    const float in = i * scale;
84    interpolation[i].lower = static_cast<int64>(in);
85    interpolation[i].upper = std::min(interpolation[i].lower + 1, in_size - 1);
86    interpolation[i].lerp = in - interpolation[i].lower;
87  }
88}
89
90/**
91 * Computes the bilinear interpolation from the appropriate 4 float points
92 * and the linear interpolation weights.
93 */
94inline float compute_lerp(const float top_left, const float top_right,
95                          const float bottom_left, const float bottom_right,
96                          const float x_lerp, const float y_lerp) {
97  const float top = top_left + (top_right - top_left) * x_lerp;
98  const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
99  return top + (bottom - top) * y_lerp;
100}
101
102template <typename T>
103void resize_image(
104    typename TTypes<T, 4>::ConstTensor images, const int batch_size,
105    const int64 in_height, const int64 in_width, const int64 out_height,
106    const int64 out_width, const int channels,
107    const std::vector<CachedInterpolation>& xs,
108    const std::vector<CachedInterpolation>& ys,
109    typename TTypes<float, 4>::Tensor output) TF_ATTRIBUTE_NOINLINE;
110template <typename T>
111void resize_image(typename TTypes<T, 4>::ConstTensor images,
112                  const int batch_size, const int64 in_height,
113                  const int64 in_width, const int64 out_height,
114                  const int64 out_width, const int channels,
115                  const std::vector<CachedInterpolation>& xs_vec,
116                  const std::vector<CachedInterpolation>& ys,
117                  typename TTypes<float, 4>::Tensor output) {
118  const int64 in_row_size = in_width * channels;
119  const int64 in_batch_num_values = in_height * in_row_size;
120  const int64 out_row_size = out_width * channels;
121
122  const T* input_b_ptr = images.data();
123  const CachedInterpolation* xs = xs_vec.data();
124
125  if (channels == 3) {
126    float* output_y_ptr = output.data();
127    for (int b = 0; b < batch_size; ++b) {
128      for (int64 y = 0; y < out_height; ++y) {
129        const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
130        const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
131        const float ys_lerp = ys[y].lerp;
132        for (int64 x = 0; x < out_width; ++x) {
133          const int64 xs_lower = xs[x].lower;
134          const int64 xs_upper = xs[x].upper;
135          const float xs_lerp = xs[x].lerp;
136
137          // Read channel 0.
138          const float top_left0(ys_input_lower_ptr[xs_lower + 0]);
139          const float top_right0(ys_input_lower_ptr[xs_upper + 0]);
140          const float bottom_left0(ys_input_upper_ptr[xs_lower + 0]);
141          const float bottom_right0(ys_input_upper_ptr[xs_upper + 0]);
142
143          // Read channel 1.
144          const float top_left1(ys_input_lower_ptr[xs_lower + 1]);
145          const float top_right1(ys_input_lower_ptr[xs_upper + 1]);
146          const float bottom_left1(ys_input_upper_ptr[xs_lower + 1]);
147          const float bottom_right1(ys_input_upper_ptr[xs_upper + 1]);
148
149          // Read channel 2.
150          const float top_left2(ys_input_lower_ptr[xs_lower + 2]);
151          const float top_right2(ys_input_lower_ptr[xs_upper + 2]);
152          const float bottom_left2(ys_input_upper_ptr[xs_lower + 2]);
153          const float bottom_right2(ys_input_upper_ptr[xs_upper + 2]);
154
155          // Compute output.
156          output_y_ptr[x * channels + 0] =
157              compute_lerp(top_left0, top_right0, bottom_left0, bottom_right0,
158                           xs_lerp, ys_lerp);
159          output_y_ptr[x * channels + 1] =
160              compute_lerp(top_left1, top_right1, bottom_left1, bottom_right1,
161                           xs_lerp, ys_lerp);
162          output_y_ptr[x * channels + 2] =
163              compute_lerp(top_left2, top_right2, bottom_left2, bottom_right2,
164                           xs_lerp, ys_lerp);
165        }
166        output_y_ptr += out_row_size;
167      }
168      input_b_ptr += in_batch_num_values;
169    }
170  } else {
171    float* output_y_ptr = output.data();
172    for (int b = 0; b < batch_size; ++b) {
173      for (int64 y = 0; y < out_height; ++y) {
174        const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
175        const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
176        const float ys_lerp = ys[y].lerp;
177        for (int64 x = 0; x < out_width; ++x) {
178          auto xs_lower = xs[x].lower;
179          auto xs_upper = xs[x].upper;
180          auto xs_lerp = xs[x].lerp;
181          for (int c = 0; c < channels; ++c) {
182            const float top_left(ys_input_lower_ptr[xs_lower + c]);
183            const float top_right(ys_input_lower_ptr[xs_upper + c]);
184            const float bottom_left(ys_input_upper_ptr[xs_lower + c]);
185            const float bottom_right(ys_input_upper_ptr[xs_upper + c]);
186            output_y_ptr[x * channels + c] =
187                compute_lerp(top_left, top_right, bottom_left, bottom_right,
188                             xs_lerp, ys_lerp);
189          }
190        }
191        output_y_ptr += out_row_size;
192      }
193      input_b_ptr += in_batch_num_values;
194    }
195  }
196}
197
198}  // namespace
199
200// Partial specialization of ResizeBilinear functor for a CPUDevice.
201namespace functor {
202template <typename T>
203struct ResizeBilinear<CPUDevice, T> {
204  void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor images,
205                  const float height_scale, const float width_scale,
206                  typename TTypes<float, 4>::Tensor output) {
207    const int batch_size = images.dimension(0);
208    const int64 in_height = images.dimension(1);
209    const int64 in_width = images.dimension(2);
210    const int channels = images.dimension(3);
211
212    const int64 out_height = output.dimension(1);
213    const int64 out_width = output.dimension(2);
214
215    // Handle no-op resizes efficiently.
216    if (out_height == in_height && out_width == in_width) {
217      output = images.template cast<float>();
218      return;
219    }
220
221    std::vector<CachedInterpolation> ys(out_height + 1);
222    std::vector<CachedInterpolation> xs(out_width + 1);
223
224    // Compute the cached interpolation weights on the x and y dimensions.
225    compute_interpolation_weights(out_height, in_height, height_scale,
226                                  ys.data());
227    compute_interpolation_weights(out_width, in_width, width_scale, xs.data());
228
229    // Scale x interpolation weights to avoid a multiplication during iteration.
230    for (int i = 0; i < xs.size(); ++i) {
231      xs[i].lower *= channels;
232      xs[i].upper *= channels;
233    }
234
235    resize_image<T>(images, batch_size, in_height, in_width, out_height,
236                    out_width, channels, xs, ys, output);
237  }
238};
239}  // namespace functor
240
241template <typename Device, typename T>
242class ResizeBilinearOpGrad : public OpKernel {
243 public:
244  explicit ResizeBilinearOpGrad(OpKernelConstruction* context)
245      : OpKernel(context) {
246    OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
247  }
248
249  void Compute(OpKernelContext* context) override {
250    // Validate input.
251    // First argument is gradient with respect to resized image.
252    const Tensor& input = context->input(0);
253    const Tensor& original_image = context->input(1);
254
255    ImageResizerGradientState st(align_corners_);
256    st.ValidateAndCreateOutput(context, input, original_image);
257
258    if (!context->status().ok()) return;
259
260    TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
261    typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
262
263    functor::ResizeBilinearGrad<Device, T>()(context->eigen_device<Device>(),
264                                             input_grad, st.height_scale,
265                                             st.width_scale, output_grad);
266  }
267
268 private:
269  bool align_corners_;
270};
271
272// Partial specialization of ResizeBilinearGrad functor for a CPUDevice.
273namespace functor {
274template <typename T>
275struct ResizeBilinearGrad<CPUDevice, T> {
276  void operator()(const CPUDevice& d,
277                  typename TTypes<float, 4>::ConstTensor input_grad,
278                  const float height_scale, const float width_scale,
279                  typename TTypes<T, 4>::Tensor output_grad) {
280    const int batch = output_grad.dimension(0);
281    const int64 original_height = output_grad.dimension(1);
282    const int64 original_width = output_grad.dimension(2);
283    const int channels = output_grad.dimension(3);
284
285    const int64 resized_height = input_grad.dimension(1);
286    const int64 resized_width = input_grad.dimension(2);
287
288    output_grad.setZero();
289
290    // Each resized pixel was computed as a weighted average of four input
291    // pixels. Here we find the pixels that contributed to each output pixel
292    // and add the corresponding coefficient to the gradient.
293    // resized(b, y, x, c) = top_left * (1 - y) * (1 - x)
294    //                       +  top_right * (1 - y) * x
295    //                       +  bottom_left * y * (1 - x)
296    //                       +  bottom_right * y * x
297    for (int64 b = 0; b < batch; ++b) {
298      for (int64 y = 0; y < resized_height; ++y) {
299        const float in_y = y * height_scale;
300        const int64 top_y_index = static_cast<int64>(floorf(in_y));
301        const int64 bottom_y_index =
302            std::min(static_cast<int64>(ceilf(in_y)), original_height - 1);
303        const float y_lerp = in_y - top_y_index;
304        const float inverse_y_lerp = (1.0f - y_lerp);
305        for (int64 x = 0; x < resized_width; ++x) {
306          const float in_x = x * width_scale;
307          const int64 left_x_index = static_cast<int64>(floorf(in_x));
308          const int64 right_x_index =
309              std::min(static_cast<int64>(ceilf(in_x)), original_width - 1);
310          const float x_lerp = in_x - left_x_index;
311          const float inverse_x_lerp = (1.0f - x_lerp);
312          for (int64 c = 0; c < channels; ++c) {
313            output_grad(b, top_y_index, left_x_index, c) +=
314                T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
315            output_grad(b, top_y_index, right_x_index, c) +=
316                T(input_grad(b, y, x, c) * inverse_y_lerp * x_lerp);
317            output_grad(b, bottom_y_index, left_x_index, c) +=
318                T(input_grad(b, y, x, c) * y_lerp * inverse_x_lerp);
319            output_grad(b, bottom_y_index, right_x_index, c) +=
320                T(input_grad(b, y, x, c) * y_lerp * x_lerp);
321          }
322        }
323      }
324    }
325  }
326};
327}  // namespace functor
328
329#define REGISTER_KERNEL(T)                            \
330  REGISTER_KERNEL_BUILDER(Name("ResizeBilinear")      \
331                              .Device(DEVICE_CPU)     \
332                              .TypeConstraint<T>("T") \
333                              .HostMemory("size"),    \
334                          ResizeBilinearOp<CPUDevice, T>);
335
336TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
337
338#undef REGISTER_KERNEL
339
340#define REGISTER_GRAD_KERNEL(T)                                             \
341  REGISTER_KERNEL_BUILDER(                                                  \
342      Name("ResizeBilinearGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
343      ResizeBilinearOpGrad<CPUDevice, T>);
344
345TF_CALL_half(REGISTER_GRAD_KERNEL);
346TF_CALL_float(REGISTER_GRAD_KERNEL);
347TF_CALL_double(REGISTER_GRAD_KERNEL);
348
349#undef REGISTER_GRAD_KERNEL
350
351#if GOOGLE_CUDA
352
353#define REGISTER_KERNEL(T)                            \
354  REGISTER_KERNEL_BUILDER(Name("ResizeBilinear")      \
355                              .Device(DEVICE_GPU)     \
356                              .TypeConstraint<T>("T") \
357                              .HostMemory("size"),    \
358                          ResizeBilinearOp<GPUDevice, T>);
359
360TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
361
362#undef REGISTER_KERNEL
363
364#define REGISTER_GRAD_KERNEL(T)                                             \
365  REGISTER_KERNEL_BUILDER(                                                  \
366      Name("ResizeBilinearGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
367      ResizeBilinearOpGrad<GPUDevice, T>);
368
369TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GRAD_KERNEL);
370
371#undef REGISTER_GRAD_KERNEL
372
373#endif  // GOOGLE_CUDA
374
375}  // namespace tensorflow
376