1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/contrib/lite/builtin_op_data.h" 16#include "tensorflow/contrib/lite/context.h" 17#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 18#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 19#include "tensorflow/contrib/lite/kernels/internal/tensor.h" 20#include "tensorflow/contrib/lite/kernels/kernel_util.h" 21#include "tensorflow/contrib/lite/kernels/op_macros.h" 22 23namespace tflite { 24namespace ops { 25namespace builtin { 26namespace space_to_depth { 27 28// This file has two implementation of SpaceToDepth. Note that SpaceToDepth 29// only works on 4D tensors. 30enum KernelType { 31 kReference, 32 kGenericOptimized, 33}; 34 35constexpr int kInputTensor = 0; 36constexpr int kOutputTensor = 0; 37 38TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 39 auto* params = 40 reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data); 41 42 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); 43 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 44 45 TfLiteTensor* input = GetInput(context, node, kInputTensor); 46 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 47 48 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); 49 50 auto data_type = output->type; 51 TF_LITE_ENSURE(context, 52 data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || 53 data_type == kTfLiteInt32 || data_type == kTfLiteInt64); 54 TF_LITE_ENSURE_EQ(context, input->type, output->type); 55 56 const int block_size = params->block_size; 57 const int input_height = input->dims->data[1]; 58 const int input_width = input->dims->data[2]; 59 int output_height = input_height / block_size; 60 int output_width = input_width / block_size; 61 62 TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size); 63 TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size); 64 65 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); 66 output_size->data[0] = input->dims->data[0]; 67 output_size->data[1] = output_height; 68 output_size->data[2] = output_width; 69 output_size->data[3] = input->dims->data[3] * block_size * block_size; 70 71 return context->ResizeTensor(context, output, output_size); 72} 73 74template <KernelType kernel_type> 75TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 76 auto* params = 77 reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data); 78 79 TfLiteTensor* input = GetInput(context, node, kInputTensor); 80 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 81 82#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ 83 type::SpaceToDepth<scalar>( \ 84 GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \ 85 GetTensorData<scalar>(output), GetTensorDims(output)) 86 switch (input->type) { // Already know in/out types are same. 87 case kTfLiteFloat32: 88 if (kernel_type == kReference) { 89 TF_LITE_SPACE_TO_DEPTH(reference_ops, float); 90 } else { 91 TF_LITE_SPACE_TO_DEPTH(optimized_ops, float); 92 } 93 break; 94 case kTfLiteUInt8: 95 if (kernel_type == kReference) { 96 TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t); 97 } else { 98 TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); 99 } 100 break; 101 case kTfLiteInt32: 102 if (kernel_type == kReference) { 103 TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); 104 } else { 105 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t); 106 } 107 break; 108 case kTfLiteInt64: 109 if (kernel_type == kReference) { 110 TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t); 111 } else { 112 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t); 113 } 114 break; 115 default: 116 context->ReportError(context, "Type not currently supported."); 117 return kTfLiteError; 118 } 119#undef TF_LITE_SPACE_TO_DEPTH 120 121 return kTfLiteOk; 122} 123 124} // namespace space_to_depth 125 126TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() { 127 static TfLiteRegistration r = { 128 nullptr, nullptr, space_to_depth::Prepare, 129 space_to_depth::Eval<space_to_depth::kReference>}; 130 return &r; 131} 132 133TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() { 134 static TfLiteRegistration r = { 135 nullptr, nullptr, space_to_depth::Prepare, 136 space_to_depth::Eval<space_to_depth::kGenericOptimized>}; 137 return &r; 138} 139 140TfLiteRegistration* Register_SPACE_TO_DEPTH() { 141 return Register_SPACE_TO_DEPTH_GENERIC_OPT(); 142} 143 144} // namespace builtin 145} // namespace ops 146} // namespace tflite 147