1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <string.h>
16#include "tensorflow/contrib/lite/builtin_op_data.h"
17#include "tensorflow/contrib/lite/context.h"
18#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
19#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
20#include "tensorflow/contrib/lite/kernels/kernel_util.h"
21#include "tensorflow/contrib/lite/kernels/op_macros.h"
22#include "tensorflow/contrib/lite/string_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace gather {
28constexpr int kInputTensor = 0;
29constexpr int kInputPositions = 1;
30constexpr int kOutputTensor = 0;
31
32TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35
36  const auto* params =
37      reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
38  TfLiteTensor* input = GetInput(context, node, kInputTensor);
39  TfLiteTensor* positions = GetInput(context, node, kInputPositions);
40  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
41  // Only INT32 positions are supported.
42  TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32);
43  // Check that input and output types match.
44  TF_LITE_ENSURE_EQ(context, input->type, output->type);
45  // TODO(mgubin): only 0D or 1D positions are currently supported.
46  TF_LITE_ENSURE(context, NumDimensions(positions) <= 1);
47  // TODO(mgubin): Only default axis == 0 is supported.
48  TF_LITE_ENSURE_EQ(context, params->axis, 0);
49  // Check conditions for different types.
50  switch (input->type) {
51    case kTfLiteFloat32:
52    case kTfLiteUInt8:
53    case kTfLiteInt32: {
54      // Fully supported by reference_ops::Gather.
55    } break;
56
57    case kTfLiteString: {
58      // Only 1D input is supported.
59      TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
60    } break;
61    default:
62      context->ReportError(context,
63                           "Only float32 and string types are supported");
64      return kTfLiteError;
65  }
66  const int num_dimensions =
67      NumDimensions(input) + NumDimensions(positions) - 1;
68  TF_LITE_ENSURE(context, params->axis <= num_dimensions);
69  TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
70  int output_index = 0;
71  for (int i = 0; i < params->axis; ++i) {
72    output_shape->data[output_index++] = input->dims->data[i];
73  }
74  for (int i = 0; i < positions->dims->size; ++i) {
75    output_shape->data[output_index++] = positions->dims->data[i];
76  }
77  for (int i = params->axis + 1; i < input->dims->size; ++i) {
78    output_shape->data[output_index++] = input->dims->data[i];
79  }
80  return context->ResizeTensor(context, output, output_shape);
81}
82
83TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
84  TfLiteTensor* input = GetInput(context, node, kInputTensor);
85  TfLiteTensor* positions = GetInput(context, node, kInputPositions);
86  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
87  const int input_rank = NumDimensions(input);
88#define TF_LITE_GATHER(data_type, index_type)                            \
89  optimized_ops::Gather(                                                 \
90      GetTensorData<data_type>(input), GetTensorDims(input), input_rank, \
91      GetTensorData<index_type>(positions), GetTensorDims(positions),    \
92      GetTensorData<data_type>(output), GetTensorDims(output));
93  switch (input->type) {
94    case kTfLiteFloat32:
95      TF_LITE_GATHER(float, int32_t);
96      break;
97    case kTfLiteUInt8:
98      TF_LITE_GATHER(uint8_t, int32_t);
99      break;
100    case kTfLiteInt32:
101      TF_LITE_GATHER(int32_t, int32_t);
102      break;
103    case kTfLiteString: {
104      DynamicBuffer buffer;
105      const int32* indexes = positions->data.i32;
106      const int num_strings = GetStringCount(input);
107      for (int i = 0; i < positions->dims->data[0]; ++i) {
108        const int pos = indexes[i];
109        TF_LITE_ENSURE(context, pos < num_strings);
110        const auto string_ref = GetString(input, pos);
111        buffer.AddString(string_ref.str, string_ref.len);
112      }
113      buffer.WriteToTensor(output);
114    } break;
115    default:
116      return kTfLiteError;
117  }
118#undef TF_LITE_GATHER
119  return kTfLiteOk;
120}
121}  // namespace gather
122
123TfLiteRegistration* Register_GATHER() {
124  static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
125                                 gather::Eval};
126  return &r;
127}
128
129}  // namespace builtin
130}  // namespace ops
131}  // namespace tflite
132