1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <unistd.h>
16#include <cassert>
17#include <cmath>
18#include <cstdio>
19#include <cstdlib>
20#include <iostream>
21#include <limits>
22
23#include "tensorflow/contrib/lite/builtin_op_data.h"
24#include "tensorflow/contrib/lite/context.h"
25#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
26#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
27#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
28#include "tensorflow/contrib/lite/kernels/kernel_util.h"
29#include "tensorflow/contrib/lite/kernels/op_macros.h"
30
31namespace tflite {
32namespace ops {
33namespace builtin {
34namespace concatenation {
35
36// This file has two implementation of Concatenation.
37enum KernelType {
38  kReference,
39  kGenericOptimized,
40};
41
42TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
43  auto* params =
44      reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
45  int axis = params->axis;
46  int num_inputs = node->inputs->size;
47
48  // The number of dimensions of the input tensors must match, and all
49  // dimensions except 'axis' must be equal.
50  TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]];
51  TfLiteType input_type = t0->type;
52  if (axis < 0) axis += t0->dims->size;
53  TF_LITE_ENSURE(context, axis >= 0);
54  TF_LITE_ENSURE(context, axis < t0->dims->size);
55
56  // TODO(ahentz): These are limitations of our implementation that could be
57  // removed with a bit of effort.
58  TF_LITE_ENSURE(context, t0->dims->size <= 4);
59  TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
60  TF_LITE_ENSURE(context,
61                 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
62
63  // Output dimensions will match input dimensions, except 'axis', which
64  // will be the sum of inputs
65  int sum_axis = t0->dims->data[axis];
66  for (int i = 1; i < num_inputs; ++i) {
67    TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
68    TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
69    TF_LITE_ENSURE_EQ(context, t->type, input_type);
70    if (input_type == kTfLiteUInt8) {
71      TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point);
72      TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale);
73    }
74    for (int d = 0; d < t0->dims->size; ++d) {
75      if (d == axis) {
76        sum_axis += t->dims->data[axis];
77      } else {
78        TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
79      }
80    }
81  }
82
83  TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
84  for (int d = 0; d < t0->dims->size; ++d) {
85    output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
86  }
87
88  TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
89  TF_LITE_ENSURE_EQ(context, output->type, input_type);
90  if (input_type == kTfLiteUInt8) {
91    TF_LITE_ENSURE_EQ(context, output->params.zero_point,
92                      t0->params.zero_point);
93    TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale);
94  }
95
96  return context->ResizeTensor(context, output, output_size);
97}
98
99template <KernelType kernel_type>
100TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
101  auto* params =
102      reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
103  int axis = params->axis;
104  TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
105  if (axis < 0) axis += output->dims->size;
106
107// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
108// allocate and populate these during Prepare().
109// TODO(ycling): Activation function parameter is ignored. For now we dont have
110// a model with a Concatenation with fused activation function.
111#define TF_LITE_CONCATENATION(type, scalar)                                 \
112  VectorOfTensors<scalar> all_inputs(*context, *node->inputs);              \
113  type::Concatenation<FusedActivationFunctionType::kNone, scalar>(          \
114      RemapDim(NumDimensions(output), axis), all_inputs.data(),             \
115      all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
116      GetTensorDims(output))
117
118  switch (output->type) {  // Already know in/outtypes are same.
119    case kTfLiteFloat32:
120      if (kernel_type == kReference) {
121        TF_LITE_CONCATENATION(reference_ops, float);
122      } else {
123        TF_LITE_CONCATENATION(optimized_ops, float);
124      }
125      break;
126    case kTfLiteUInt8:
127      if (kernel_type == kReference) {
128        TF_LITE_CONCATENATION(reference_ops, uint8_t);
129      } else {
130        TF_LITE_CONCATENATION(optimized_ops, uint8_t);
131      }
132      break;
133    default:
134      context->ReportError(context,
135                           "Only float32 and uint8 are currently supported.");
136      return kTfLiteError;
137  }
138
139#undef TF_LITE_CONCATENATION
140
141  return kTfLiteOk;
142}
143
144#undef TF_LITE_MACRO_DISPATCH
145
146}  // namespace concatenation
147
148TfLiteRegistration* Register_CONCATENATION_REF() {
149  static TfLiteRegistration r = {
150      nullptr, nullptr, concatenation::Prepare,
151      concatenation::Eval<concatenation::kReference>};
152  return &r;
153}
154
155TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
156  static TfLiteRegistration r = {
157      nullptr, nullptr, concatenation::Prepare,
158      concatenation::Eval<concatenation::kGenericOptimized>};
159  return &r;
160}
161
162TfLiteRegistration* Register_CONCATENATION() {
163  // TODO(ahentz): It turns out the two versions of Concatenation are almost
164  // identical, so we should consider removing one.
165  return Register_CONCATENATION_GENERIC_OPT();
166}
167
168}  // namespace builtin
169}  // namespace ops
170}  // namespace tflite
171