1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Ops that looks up items from matrix.
17//
18// Input:
19//     Tensor[0]: Row number to lookup, dim.size == 1, int32
20//     Tensor[1]: 2-dimensional matrix of multi-dimensional items
21//                dim.size >= 2, any data type.
22//                first dimension is row, second dimension is column.
23//
24// Output:
25//   Output.dim[0] == Tensor[0].dim[0], num of lookups
26//   Output.dim[1] == Tensor[1].dim[1],  num of items per row
27//   Each item in output is a raw bytes copy of corresponding item in input.
28//   When indices are out of bound, the ops will not succeed.
29//
30
31#include <unistd.h>
32#include <cassert>
33#include <cmath>
34#include <cstdio>
35#include <cstdlib>
36#include <cstring>
37#include <iostream>
38#include <limits>
39
40#include "tensorflow/contrib/lite/builtin_op_data.h"
41#include "tensorflow/contrib/lite/context.h"
42#include "tensorflow/contrib/lite/kernels/kernel_util.h"
43#include "tensorflow/contrib/lite/kernels/op_macros.h"
44
45namespace tflite {
46namespace ops {
47namespace builtin {
48namespace embedding_lookup {
49
50TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
52  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
53
54  TfLiteTensor* lookup = GetInput(context, node, 0);
55  TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
56  TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
57
58  TfLiteTensor* value = GetInput(context, node, 1);
59  TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
60
61  TfLiteTensor* output = GetOutput(context, node, 0);
62  TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
63
64  outputSize->data[0] = SizeOfDimension(lookup, 0);
65  outputSize->data[1] = SizeOfDimension(value, 1);
66  for (int i = 2; i < NumDimensions(value); i++) {
67    outputSize->data[i] = SizeOfDimension(value, i);
68  }
69  return context->ResizeTensor(context, output, outputSize);
70}
71
72TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
73  TfLiteTensor* output = GetOutput(context, node, 0);
74  TfLiteTensor* lookup = GetInput(context, node, 0);
75  TfLiteTensor* value = GetInput(context, node, 1);
76
77  const int row_size = SizeOfDimension(value, 0);
78  const int row_bytes = value->bytes / row_size;
79
80  for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
81    int idx = lookup->data.i32[i];
82    if (idx >= row_size || idx < 0) {
83      context->ReportError(context, "Embedding Lookup: index out of bounds.");
84      return kTfLiteError;
85    } else {
86      memcpy(output->data.raw + i * row_bytes,
87             value->data.raw + idx * row_bytes, row_bytes);
88    }
89  }
90
91  return kTfLiteOk;
92}
93
94}  // namespace embedding_lookup
95
96TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
97  static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
98                                 embedding_lookup::Eval};
99  return &r;
100}
101
102}  // namespace builtin
103}  // namespace ops
104}  // namespace tflite
105