1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Op that looks up items from hashtable.
17//
18// Input:
19//     Tensor[0]: Hash key to lookup, dim.size == 1, int32
20//     Tensor[1]: Key of hashtable, dim.size == 1, int32
21//                *MUST* be sorted in ascending order.
22//     Tensor[2]: Value of hashtable, dim.size >= 1
23//                Tensor[1].Dim[0] == Tensor[2].Dim[0]
24//
25// Output:
26//   Output[0].dim[0] == Tensor[0].dim[0], num of lookups
27//   Each item in output is a raw bytes copy of corresponding item in input.
28//   When key does not exist in hashtable, the returned bytes are all 0s.
29//
30//   Output[1].dim = { Tensor[0].dim[0] }, num of lookups
31//   Each item indicates whether the corresponding lookup has a returned value.
32//   0 for missing key, 1 for found key.
33
34#include <unistd.h>
35#include <cassert>
36#include <cmath>
37#include <cstdio>
38#include <cstdlib>
39#include <cstring>
40#include <iostream>
41#include <limits>
42
43#include "tensorflow/contrib/lite/builtin_op_data.h"
44#include "tensorflow/contrib/lite/context.h"
45#include "tensorflow/contrib/lite/kernels/kernel_util.h"
46#include "tensorflow/contrib/lite/kernels/op_macros.h"
47#include "tensorflow/contrib/lite/string_util.h"
48
49namespace tflite {
50namespace ops {
51namespace builtin {
52
53namespace {
54
55int greater(const void* a, const void* b) {
56  return *static_cast<const int*>(a) - *static_cast<const int*>(b);
57}
58
59TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
60  TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
61  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
62
63  TfLiteTensor* lookup = GetInput(context, node, 0);
64  TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
65  TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
66
67  TfLiteTensor* key = GetInput(context, node, 1);
68  TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
69  TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
70
71  TfLiteTensor* value = GetInput(context, node, 2);
72  TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
73  TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
74                    SizeOfDimension(value, 0));
75  if (value->type == kTfLiteString) {
76    TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
77  }
78
79  TfLiteTensor* hits = GetOutput(context, node, 1);
80  TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
81  TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
82  hitSize->data[0] = SizeOfDimension(lookup, 0);
83
84  TfLiteTensor* output = GetOutput(context, node, 0);
85  TF_LITE_ENSURE_EQ(context, value->type, output->type);
86
87  TfLiteStatus status = kTfLiteOk;
88  if (output->type != kTfLiteString) {
89    TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
90    outputSize->data[0] = SizeOfDimension(lookup, 0);
91    for (int i = 1; i < NumDimensions(value); i++) {
92      outputSize->data[i] = SizeOfDimension(value, i);
93    }
94    status = context->ResizeTensor(context, output, outputSize);
95  }
96  if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) {
97    status = kTfLiteError;
98  }
99  return status;
100}
101
102TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
103  TfLiteTensor* output = GetOutput(context, node, 0);
104  TfLiteTensor* hits = GetOutput(context, node, 1);
105  TfLiteTensor* lookup = GetInput(context, node, 0);
106  TfLiteTensor* key = GetInput(context, node, 1);
107  TfLiteTensor* value = GetInput(context, node, 2);
108
109  const int num_rows = SizeOfDimension(value, 0);
110  const int row_bytes = value->bytes / num_rows;
111  void* pointer = nullptr;
112  DynamicBuffer buf;
113
114  for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
115    int idx = -1;
116    pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows,
117                      sizeof(int32_t), greater);
118    if (pointer != nullptr) {
119      idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) /
120            sizeof(int32_t);
121    }
122
123    if (idx >= num_rows || idx < 0) {
124      if (output->type == kTfLiteString) {
125        buf.AddString(nullptr, 0);
126      } else {
127        memset(output->data.raw + i * row_bytes, 0, row_bytes);
128      }
129      hits->data.uint8[i] = 0;
130    } else {
131      if (output->type == kTfLiteString) {
132        buf.AddString(GetString(value, idx));
133      } else {
134        memcpy(output->data.raw + i * row_bytes,
135               value->data.raw + idx * row_bytes, row_bytes);
136      }
137      hits->data.uint8[i] = 1;
138    }
139  }
140  if (output->type == kTfLiteString) {
141    buf.WriteToTensor(output);
142  }
143
144  return kTfLiteOk;
145}
146}  // namespace
147
148TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
149  static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
150  return &r;
151}
152
153}  // namespace builtin
154}  // namespace ops
155}  // namespace tflite
156