1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/contrib/lite/builtin_op_data.h" 16#include "tensorflow/contrib/lite/context.h" 17#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 18#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 19#include "tensorflow/contrib/lite/kernels/internal/tensor.h" 20#include "tensorflow/contrib/lite/kernels/kernel_util.h" 21#include "tensorflow/contrib/lite/kernels/op_macros.h" 22 23namespace tflite { 24namespace ops { 25namespace builtin { 26namespace resize_bilinear { 27 28// This file has three implementation of RESIZE_BILINEAR. 29enum KernelType { 30 kReference, 31 kGenericOptimized, // Neon-free 32 kNeonOptimized, 33}; 34 35constexpr int kInputTensor = 0; 36constexpr int kSizeTensor = 1; 37constexpr int kOutputTensor = 0; 38 39TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input, 40 TfLiteTensor* size, TfLiteTensor* output) { 41 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); 42 output_size->data[0] = input->dims->data[0]; 43 const int32* size_data = GetTensorData<int32>(size); 44 output_size->data[1] = size_data[0]; 45 output_size->data[2] = size_data[1]; 46 output_size->data[3] = input->dims->data[3]; 47 return context->ResizeTensor(context, output, output_size); 48} 49 50TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 51 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 52 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 53 54 TfLiteTensor* input = GetInput(context, node, kInputTensor); 55 TfLiteTensor* size = GetInput(context, node, kSizeTensor); 56 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 57 58 // TODO(ahentz): Our current implementations rely on the inputs being 4D. 59 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); 60 TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); 61 62 // TODO(ahentz): Our current implementations only support float32. 63 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); 64 TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); 65 // ResizeBilinear creates a float tensor even when the input is made of 66 // integers. 67 output->type = kTfLiteFloat32; 68 69 if (!IsConstantTensor(size)) { 70 SetTensorToDynamic(output); 71 return kTfLiteOk; 72 } 73 return ResizeOutputTensor(context, input, size, output); 74} 75 76template <KernelType kernel_type> 77TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 78 auto* params = 79 reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data); 80 81 TfLiteTensor* input = GetInput(context, node, kInputTensor); 82 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 83 TfLiteTensor* size = GetInput(context, node, kSizeTensor); 84 85 if (IsDynamicTensor(output)) { 86 TF_LITE_ENSURE_OK(context, 87 ResizeOutputTensor(context, input, size, output)); 88 } 89 90 if (output->type == kTfLiteFloat32) { 91#define TF_LITE_RESIZE_BILINEAR(type) \ 92 type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \ 93 GetTensorData<int32>(size), GetTensorDims(size), \ 94 GetTensorData<float>(output), GetTensorDims(output), \ 95 params->align_corners) 96 97 if (kernel_type == kReference) { 98 TF_LITE_RESIZE_BILINEAR(reference_ops); 99 } 100 if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { 101 TF_LITE_RESIZE_BILINEAR(optimized_ops); 102 } 103#undef TF_LITE_RESIZE_BILINEAR 104 } else { 105 context->ReportError(context, "Inputs and outputs not all float types."); 106 return kTfLiteError; 107 } 108 109 return kTfLiteOk; 110} 111 112} // namespace resize_bilinear 113 114TfLiteRegistration* Register_RESIZE_BILINEAR_REF() { 115 static TfLiteRegistration r = { 116 nullptr, nullptr, resize_bilinear::Prepare, 117 resize_bilinear::Eval<resize_bilinear::kReference>}; 118 return &r; 119} 120 121TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() { 122 static TfLiteRegistration r = { 123 nullptr, nullptr, resize_bilinear::Prepare, 124 resize_bilinear::Eval<resize_bilinear::kGenericOptimized>}; 125 return &r; 126} 127 128TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() { 129 static TfLiteRegistration r = { 130 nullptr, nullptr, resize_bilinear::Prepare, 131 resize_bilinear::Eval<resize_bilinear::kNeonOptimized>}; 132 return &r; 133} 134 135TfLiteRegistration* Register_RESIZE_BILINEAR() { 136#ifdef USE_NEON 137 return Register_RESIZE_BILINEAR_NEON_OPT(); 138#else 139 return Register_RESIZE_BILINEAR_GENERIC_OPT(); 140#endif 141} 142 143} // namespace builtin 144} // namespace ops 145} // namespace tflite 146