1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/builtin_op_data.h"
16#include "tensorflow/contrib/lite/context.h"
17#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
18#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
19#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
20#include "tensorflow/contrib/lite/kernels/kernel_util.h"
21#include "tensorflow/contrib/lite/kernels/op_macros.h"
22
23namespace tflite {
24namespace ops {
25namespace builtin {
26namespace space_to_depth {
27
28// This file has two implementation of SpaceToDepth. Note that SpaceToDepth
29// only works on 4D tensors.
30enum KernelType {
31  kReference,
32  kGenericOptimized,
33};
34
35constexpr int kInputTensor = 0;
36constexpr int kOutputTensor = 0;
37
38TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
39  auto* params =
40      reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
41
42  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
43  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
44
45  TfLiteTensor* input = GetInput(context, node, kInputTensor);
46  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
47
48  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
49
50  auto data_type = output->type;
51  TF_LITE_ENSURE(context,
52                 data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
53                     data_type == kTfLiteInt32 || data_type == kTfLiteInt64);
54  TF_LITE_ENSURE_EQ(context, input->type, output->type);
55
56  const int block_size = params->block_size;
57  const int input_height = input->dims->data[1];
58  const int input_width = input->dims->data[2];
59  int output_height = input_height / block_size;
60  int output_width = input_width / block_size;
61
62  TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size);
63  TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size);
64
65  TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
66  output_size->data[0] = input->dims->data[0];
67  output_size->data[1] = output_height;
68  output_size->data[2] = output_width;
69  output_size->data[3] = input->dims->data[3] * block_size * block_size;
70
71  return context->ResizeTensor(context, output, output_size);
72}
73
74template <KernelType kernel_type>
75TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
76  auto* params =
77      reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
78
79  TfLiteTensor* input = GetInput(context, node, kInputTensor);
80  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
81
82#define TF_LITE_SPACE_TO_DEPTH(type, scalar)                                  \
83  type::SpaceToDepth<scalar>(                                                 \
84      GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
85      GetTensorData<scalar>(output), GetTensorDims(output))
86  switch (input->type) {  // Already know in/out types are same.
87    case kTfLiteFloat32:
88      if (kernel_type == kReference) {
89        TF_LITE_SPACE_TO_DEPTH(reference_ops, float);
90      } else {
91        TF_LITE_SPACE_TO_DEPTH(optimized_ops, float);
92      }
93      break;
94    case kTfLiteUInt8:
95      if (kernel_type == kReference) {
96        TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t);
97      } else {
98        TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
99      }
100      break;
101    case kTfLiteInt32:
102      if (kernel_type == kReference) {
103        TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
104      } else {
105        TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t);
106      }
107      break;
108    case kTfLiteInt64:
109      if (kernel_type == kReference) {
110        TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t);
111      } else {
112        TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t);
113      }
114      break;
115    default:
116      context->ReportError(context, "Type not currently supported.");
117      return kTfLiteError;
118  }
119#undef TF_LITE_SPACE_TO_DEPTH
120
121  return kTfLiteOk;
122}
123
124}  // namespace space_to_depth
125
126TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() {
127  static TfLiteRegistration r = {
128      nullptr, nullptr, space_to_depth::Prepare,
129      space_to_depth::Eval<space_to_depth::kReference>};
130  return &r;
131}
132
133TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() {
134  static TfLiteRegistration r = {
135      nullptr, nullptr, space_to_depth::Prepare,
136      space_to_depth::Eval<space_to_depth::kGenericOptimized>};
137  return &r;
138}
139
140TfLiteRegistration* Register_SPACE_TO_DEPTH() {
141  return Register_SPACE_TO_DEPTH_GENERIC_OPT();
142}
143
144}  // namespace builtin
145}  // namespace ops
146}  // namespace tflite
147