1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Lookup projected hash signatures in Predictor model,
17// output predicted labels and weights in decreasing order.
18//
19// Input:
20//     Input[0]: A list of hash signatures. int32[num of input]
21//     Input[1]: Hash signature keys in the model. int32[keys of model]
22//     Input[2]: Labels in the model. int32[keys of model, item per entry]
23//     Input[3]: Weights in the model. float[keys of model, item per entry]
24//
25// Output:
26//     Output[0]: Predicted labels. int32[num of output]
27//     Output[1]: Predicted weights. float[num of output]
28//
29
30#include <algorithm>
31#include <unordered_map>
32#include <vector>
33
34#include "tensorflow/contrib/lite/context.h"
35
36namespace tflite {
37namespace ops {
38namespace custom {
39
40namespace predict {
41
42struct PredictOption {
43  int32_t num_output;
44  float weight_threshold;
45
46  static PredictOption* Cast(void* ptr) {
47    return reinterpret_cast<PredictOption*>(ptr);
48  }
49};
50
51bool WeightGreater(const std::pair<int32_t, float>& a,
52                   const std::pair<int32_t, float>& b) {
53  return a.second > b.second;
54}
55
56void* Init(TfLiteContext* context, const char* custom_option, size_t length) {
57  if (custom_option == nullptr || length != sizeof(PredictOption)) {
58    fprintf(stderr, "No Custom option set\n");
59    exit(1);
60  }
61  PredictOption* option = new PredictOption;
62  int offset = 0;
63  option->num_output =
64      *reinterpret_cast<const int32_t*>(custom_option + offset);
65  offset += sizeof(int32_t);
66  option->weight_threshold =
67      *reinterpret_cast<const float*>(custom_option + offset);
68  return reinterpret_cast<void*>(option);
69}
70
71void Free(TfLiteContext* context, void* buffer) {
72  delete PredictOption::Cast(buffer);
73}
74
75TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
76  TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
77  TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
78
79  TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
80  TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
81  TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
82  TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
83  TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
84  TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32);
85  TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32);
86  TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32);
87  TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1);
88  TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1);
89  TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2);
90  TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2);
91  TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
92                    model_label->dims->data[0]);
93  TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
94                    model_weight->dims->data[0]);
95  TF_LITE_ENSURE_EQ(context, model_label->dims->data[1],
96                    model_weight->dims->data[1]);
97
98  PredictOption* option = PredictOption::Cast(node->user_data);
99  TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
100  TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
101  TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32);
102  TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32);
103
104  TfLiteIntArray* label_size = TfLiteIntArrayCreate(1);
105  label_size->data[0] = option->num_output;
106  TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1);
107  weight_size->data[0] = option->num_output;
108  TfLiteStatus status =
109      context->ResizeTensor(context, output_label, label_size);
110  if (status != kTfLiteOk) {
111    return status;
112  }
113  return context->ResizeTensor(context, output_weight, weight_size);
114}
115
116TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
117  TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
118  TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
119  TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
120  TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
121
122  // Aggregate by key
123  std::unordered_map<int32_t, float> aggregation;
124  const int num_input = lookup->dims->data[0];
125  const int num_rows = model_key->dims->data[0];
126  const int items = model_label->dims->data[1];
127  int* model_key_end = model_key->data.i32 + num_rows;
128
129  for (int i = 0; i < num_input; i++) {
130    int* ptr = std::lower_bound(model_key->data.i32, model_key_end,
131                                lookup->data.i32[i]);
132    if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) {
133      int idx = ptr - model_key->data.i32;
134      for (int j = 0; j < items; j++) {
135        aggregation[model_label->data.i32[idx * items + j]] +=
136            model_weight->data.f[idx * items + j] / num_input;
137      }
138    }
139  }
140
141  // Sort by value
142  std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(),
143                                                       aggregation.end());
144  std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater);
145
146  PredictOption* option = PredictOption::Cast(node->user_data);
147  TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
148  TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
149  for (int i = 0; i < output_label->dims->data[0]; i++) {
150    if (i >= sorted_labels.size() ||
151        sorted_labels[i].second < option->weight_threshold) {
152      // Set -1 to avoid lookup message with id 0, which is set for backoff.
153      output_label->data.i32[i] = -1;
154      output_weight->data.f[i] = 0.0f;
155    } else {
156      output_label->data.i32[i] = sorted_labels[i].first;
157      output_weight->data.f[i] = sorted_labels[i].second;
158    }
159  }
160
161  return kTfLiteOk;
162}
163
164}  // namespace predict
165
166TfLiteRegistration* Register_PREDICT() {
167  static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare,
168                                 predict::Eval};
169  return &r;
170}
171
172}  // namespace custom
173}  // namespace ops
174}  // namespace tflite
175