1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Ops that looks up items from matrix. 17// 18// Input: 19// Tensor[0]: Row number to lookup, dim.size == 1, int32 20// Tensor[1]: 2-dimensional matrix of multi-dimensional items 21// dim.size >= 2, any data type. 22// first dimension is row, second dimension is column. 23// 24// Output: 25// Output.dim[0] == Tensor[0].dim[0], num of lookups 26// Output.dim[1] == Tensor[1].dim[1], num of items per row 27// Each item in output is a raw bytes copy of corresponding item in input. 28// When indices are out of bound, the ops will not succeed. 29// 30 31#include <unistd.h> 32#include <cassert> 33#include <cmath> 34#include <cstdio> 35#include <cstdlib> 36#include <cstring> 37#include <iostream> 38#include <limits> 39 40#include "tensorflow/contrib/lite/builtin_op_data.h" 41#include "tensorflow/contrib/lite/context.h" 42#include "tensorflow/contrib/lite/kernels/kernel_util.h" 43#include "tensorflow/contrib/lite/kernels/op_macros.h" 44 45namespace tflite { 46namespace ops { 47namespace builtin { 48namespace embedding_lookup { 49 50TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 51 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 52 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 53 54 TfLiteTensor* lookup = GetInput(context, node, 0); 55 TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); 56 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); 57 58 TfLiteTensor* value = GetInput(context, node, 1); 59 TF_LITE_ENSURE(context, NumDimensions(value) >= 2); 60 61 TfLiteTensor* output = GetOutput(context, node, 0); 62 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); 63 64 outputSize->data[0] = SizeOfDimension(lookup, 0); 65 outputSize->data[1] = SizeOfDimension(value, 1); 66 for (int i = 2; i < NumDimensions(value); i++) { 67 outputSize->data[i] = SizeOfDimension(value, i); 68 } 69 return context->ResizeTensor(context, output, outputSize); 70} 71 72TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 73 TfLiteTensor* output = GetOutput(context, node, 0); 74 TfLiteTensor* lookup = GetInput(context, node, 0); 75 TfLiteTensor* value = GetInput(context, node, 1); 76 77 const int row_size = SizeOfDimension(value, 0); 78 const int row_bytes = value->bytes / row_size; 79 80 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { 81 int idx = lookup->data.i32[i]; 82 if (idx >= row_size || idx < 0) { 83 context->ReportError(context, "Embedding Lookup: index out of bounds."); 84 return kTfLiteError; 85 } else { 86 memcpy(output->data.raw + i * row_bytes, 87 value->data.raw + idx * row_bytes, row_bytes); 88 } 89 } 90 91 return kTfLiteOk; 92} 93 94} // namespace embedding_lookup 95 96TfLiteRegistration* Register_EMBEDDING_LOOKUP() { 97 static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, 98 embedding_lookup::Eval}; 99 return &r; 100} 101 102} // namespace builtin 103} // namespace ops 104} // namespace tflite 105