1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <unistd.h> 16#include <cassert> 17#include <cmath> 18#include <cstdio> 19#include <cstdlib> 20#include <iostream> 21#include <limits> 22 23#include "tensorflow/contrib/lite/builtin_op_data.h" 24#include "tensorflow/contrib/lite/context.h" 25#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 26#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 27#include "tensorflow/contrib/lite/kernels/internal/tensor.h" 28#include "tensorflow/contrib/lite/kernels/kernel_util.h" 29#include "tensorflow/contrib/lite/kernels/op_macros.h" 30 31namespace tflite { 32namespace ops { 33namespace builtin { 34namespace concatenation { 35 36// This file has two implementation of Concatenation. 37enum KernelType { 38 kReference, 39 kGenericOptimized, 40}; 41 42TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 43 auto* params = 44 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); 45 int axis = params->axis; 46 int num_inputs = node->inputs->size; 47 48 // The number of dimensions of the input tensors must match, and all 49 // dimensions except 'axis' must be equal. 50 TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; 51 TfLiteType input_type = t0->type; 52 if (axis < 0) axis += t0->dims->size; 53 TF_LITE_ENSURE(context, axis >= 0); 54 TF_LITE_ENSURE(context, axis < t0->dims->size); 55 56 // TODO(ahentz): These are limitations of our implementation that could be 57 // removed with a bit of effort. 58 TF_LITE_ENSURE(context, t0->dims->size <= 4); 59 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); 60 TF_LITE_ENSURE(context, 61 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); 62 63 // Output dimensions will match input dimensions, except 'axis', which 64 // will be the sum of inputs 65 int sum_axis = t0->dims->data[axis]; 66 for (int i = 1; i < num_inputs; ++i) { 67 TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; 68 TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); 69 TF_LITE_ENSURE_EQ(context, t->type, input_type); 70 if (input_type == kTfLiteUInt8) { 71 TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point); 72 TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale); 73 } 74 for (int d = 0; d < t0->dims->size; ++d) { 75 if (d == axis) { 76 sum_axis += t->dims->data[axis]; 77 } else { 78 TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); 79 } 80 } 81 } 82 83 TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); 84 for (int d = 0; d < t0->dims->size; ++d) { 85 output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; 86 } 87 88 TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; 89 TF_LITE_ENSURE_EQ(context, output->type, input_type); 90 if (input_type == kTfLiteUInt8) { 91 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 92 t0->params.zero_point); 93 TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale); 94 } 95 96 return context->ResizeTensor(context, output, output_size); 97} 98 99template <KernelType kernel_type> 100TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 101 auto* params = 102 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); 103 int axis = params->axis; 104 TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; 105 if (axis < 0) axis += output->dims->size; 106 107// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should 108// allocate and populate these during Prepare(). 109// TODO(ycling): Activation function parameter is ignored. For now we dont have 110// a model with a Concatenation with fused activation function. 111#define TF_LITE_CONCATENATION(type, scalar) \ 112 VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \ 113 type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \ 114 RemapDim(NumDimensions(output), axis), all_inputs.data(), \ 115 all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \ 116 GetTensorDims(output)) 117 118 switch (output->type) { // Already know in/outtypes are same. 119 case kTfLiteFloat32: 120 if (kernel_type == kReference) { 121 TF_LITE_CONCATENATION(reference_ops, float); 122 } else { 123 TF_LITE_CONCATENATION(optimized_ops, float); 124 } 125 break; 126 case kTfLiteUInt8: 127 if (kernel_type == kReference) { 128 TF_LITE_CONCATENATION(reference_ops, uint8_t); 129 } else { 130 TF_LITE_CONCATENATION(optimized_ops, uint8_t); 131 } 132 break; 133 default: 134 context->ReportError(context, 135 "Only float32 and uint8 are currently supported."); 136 return kTfLiteError; 137 } 138 139#undef TF_LITE_CONCATENATION 140 141 return kTfLiteOk; 142} 143 144#undef TF_LITE_MACRO_DISPATCH 145 146} // namespace concatenation 147 148TfLiteRegistration* Register_CONCATENATION_REF() { 149 static TfLiteRegistration r = { 150 nullptr, nullptr, concatenation::Prepare, 151 concatenation::Eval<concatenation::kReference>}; 152 return &r; 153} 154 155TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { 156 static TfLiteRegistration r = { 157 nullptr, nullptr, concatenation::Prepare, 158 concatenation::Eval<concatenation::kGenericOptimized>}; 159 return &r; 160} 161 162TfLiteRegistration* Register_CONCATENATION() { 163 // TODO(ahentz): It turns out the two versions of Concatenation are almost 164 // identical, so we should consider removing one. 165 return Register_CONCATENATION_GENERIC_OPT(); 166} 167 168} // namespace builtin 169} // namespace ops 170} // namespace tflite 171