1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/builtin_op_data.h"
16#include "tensorflow/contrib/lite/context.h"
17#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
18#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
19#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
20#include "tensorflow/contrib/lite/kernels/kernel_util.h"
21#include "tensorflow/contrib/lite/kernels/op_macros.h"
22
23namespace tflite {
24namespace ops {
25namespace builtin {
26namespace resize_bilinear {
27
28// This file has three implementation of RESIZE_BILINEAR.
29enum KernelType {
30  kReference,
31  kGenericOptimized,  // Neon-free
32  kNeonOptimized,
33};
34
35constexpr int kInputTensor = 0;
36constexpr int kSizeTensor = 1;
37constexpr int kOutputTensor = 0;
38
39TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input,
40                                TfLiteTensor* size, TfLiteTensor* output) {
41  TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
42  output_size->data[0] = input->dims->data[0];
43  const int32* size_data = GetTensorData<int32>(size);
44  output_size->data[1] = size_data[0];
45  output_size->data[2] = size_data[1];
46  output_size->data[3] = input->dims->data[3];
47  return context->ResizeTensor(context, output, output_size);
48}
49
50TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
52  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
53
54  TfLiteTensor* input = GetInput(context, node, kInputTensor);
55  TfLiteTensor* size = GetInput(context, node, kSizeTensor);
56  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
57
58  // TODO(ahentz): Our current implementations rely on the inputs being 4D.
59  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
60  TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
61
62  // TODO(ahentz): Our current implementations only support float32.
63  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
64  TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
65  // ResizeBilinear creates a float tensor even when the input is made of
66  // integers.
67  output->type = kTfLiteFloat32;
68
69  if (!IsConstantTensor(size)) {
70    SetTensorToDynamic(output);
71    return kTfLiteOk;
72  }
73  return ResizeOutputTensor(context, input, size, output);
74}
75
76template <KernelType kernel_type>
77TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
78  auto* params =
79      reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
80
81  TfLiteTensor* input = GetInput(context, node, kInputTensor);
82  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
83  TfLiteTensor* size = GetInput(context, node, kSizeTensor);
84
85  if (IsDynamicTensor(output)) {
86    TF_LITE_ENSURE_OK(context,
87                      ResizeOutputTensor(context, input, size, output));
88  }
89
90  if (output->type == kTfLiteFloat32) {
91#define TF_LITE_RESIZE_BILINEAR(type)                                       \
92  type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input),   \
93                       GetTensorData<int32>(size), GetTensorDims(size),     \
94                       GetTensorData<float>(output), GetTensorDims(output), \
95                       params->align_corners)
96
97    if (kernel_type == kReference) {
98      TF_LITE_RESIZE_BILINEAR(reference_ops);
99    }
100    if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
101      TF_LITE_RESIZE_BILINEAR(optimized_ops);
102    }
103#undef TF_LITE_RESIZE_BILINEAR
104  } else {
105    context->ReportError(context, "Inputs and outputs not all float types.");
106    return kTfLiteError;
107  }
108
109  return kTfLiteOk;
110}
111
112}  // namespace resize_bilinear
113
114TfLiteRegistration* Register_RESIZE_BILINEAR_REF() {
115  static TfLiteRegistration r = {
116      nullptr, nullptr, resize_bilinear::Prepare,
117      resize_bilinear::Eval<resize_bilinear::kReference>};
118  return &r;
119}
120
121TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() {
122  static TfLiteRegistration r = {
123      nullptr, nullptr, resize_bilinear::Prepare,
124      resize_bilinear::Eval<resize_bilinear::kGenericOptimized>};
125  return &r;
126}
127
128TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() {
129  static TfLiteRegistration r = {
130      nullptr, nullptr, resize_bilinear::Prepare,
131      resize_bilinear::Eval<resize_bilinear::kNeonOptimized>};
132  return &r;
133}
134
135TfLiteRegistration* Register_RESIZE_BILINEAR() {
136#ifdef USE_NEON
137  return Register_RESIZE_BILINEAR_NEON_OPT();
138#else
139  return Register_RESIZE_BILINEAR_GENERIC_OPT();
140#endif
141}
142
143}  // namespace builtin
144}  // namespace ops
145}  // namespace tflite
146