1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Op that looks up items from hashtable. 17// 18// Input: 19// Tensor[0]: Hash key to lookup, dim.size == 1, int32 20// Tensor[1]: Key of hashtable, dim.size == 1, int32 21// *MUST* be sorted in ascending order. 22// Tensor[2]: Value of hashtable, dim.size >= 1 23// Tensor[1].Dim[0] == Tensor[2].Dim[0] 24// 25// Output: 26// Output[0].dim[0] == Tensor[0].dim[0], num of lookups 27// Each item in output is a raw bytes copy of corresponding item in input. 28// When key does not exist in hashtable, the returned bytes are all 0s. 29// 30// Output[1].dim = { Tensor[0].dim[0] }, num of lookups 31// Each item indicates whether the corresponding lookup has a returned value. 32// 0 for missing key, 1 for found key. 33 34#include <unistd.h> 35#include <cassert> 36#include <cmath> 37#include <cstdio> 38#include <cstdlib> 39#include <cstring> 40#include <iostream> 41#include <limits> 42 43#include "tensorflow/contrib/lite/builtin_op_data.h" 44#include "tensorflow/contrib/lite/context.h" 45#include "tensorflow/contrib/lite/kernels/kernel_util.h" 46#include "tensorflow/contrib/lite/kernels/op_macros.h" 47#include "tensorflow/contrib/lite/string_util.h" 48 49namespace tflite { 50namespace ops { 51namespace builtin { 52 53namespace { 54 55int greater(const void* a, const void* b) { 56 return *static_cast<const int*>(a) - *static_cast<const int*>(b); 57} 58 59TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 60 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); 61 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); 62 63 TfLiteTensor* lookup = GetInput(context, node, 0); 64 TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); 65 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); 66 67 TfLiteTensor* key = GetInput(context, node, 1); 68 TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); 69 TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); 70 71 TfLiteTensor* value = GetInput(context, node, 2); 72 TF_LITE_ENSURE(context, NumDimensions(value) >= 1); 73 TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), 74 SizeOfDimension(value, 0)); 75 if (value->type == kTfLiteString) { 76 TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); 77 } 78 79 TfLiteTensor* hits = GetOutput(context, node, 1); 80 TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); 81 TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); 82 hitSize->data[0] = SizeOfDimension(lookup, 0); 83 84 TfLiteTensor* output = GetOutput(context, node, 0); 85 TF_LITE_ENSURE_EQ(context, value->type, output->type); 86 87 TfLiteStatus status = kTfLiteOk; 88 if (output->type != kTfLiteString) { 89 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); 90 outputSize->data[0] = SizeOfDimension(lookup, 0); 91 for (int i = 1; i < NumDimensions(value); i++) { 92 outputSize->data[i] = SizeOfDimension(value, i); 93 } 94 status = context->ResizeTensor(context, output, outputSize); 95 } 96 if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { 97 status = kTfLiteError; 98 } 99 return status; 100} 101 102TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 103 TfLiteTensor* output = GetOutput(context, node, 0); 104 TfLiteTensor* hits = GetOutput(context, node, 1); 105 TfLiteTensor* lookup = GetInput(context, node, 0); 106 TfLiteTensor* key = GetInput(context, node, 1); 107 TfLiteTensor* value = GetInput(context, node, 2); 108 109 const int num_rows = SizeOfDimension(value, 0); 110 const int row_bytes = value->bytes / num_rows; 111 void* pointer = nullptr; 112 DynamicBuffer buf; 113 114 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { 115 int idx = -1; 116 pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, 117 sizeof(int32_t), greater); 118 if (pointer != nullptr) { 119 idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) / 120 sizeof(int32_t); 121 } 122 123 if (idx >= num_rows || idx < 0) { 124 if (output->type == kTfLiteString) { 125 buf.AddString(nullptr, 0); 126 } else { 127 memset(output->data.raw + i * row_bytes, 0, row_bytes); 128 } 129 hits->data.uint8[i] = 0; 130 } else { 131 if (output->type == kTfLiteString) { 132 buf.AddString(GetString(value, idx)); 133 } else { 134 memcpy(output->data.raw + i * row_bytes, 135 value->data.raw + idx * row_bytes, row_bytes); 136 } 137 hits->data.uint8[i] = 1; 138 } 139 } 140 if (output->type == kTfLiteString) { 141 buf.WriteToTensor(output); 142 } 143 144 return kTfLiteOk; 145} 146} // namespace 147 148TfLiteRegistration* Register_HASHTABLE_LOOKUP() { 149 static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; 150 return &r; 151} 152 153} // namespace builtin 154} // namespace ops 155} // namespace tflite 156