1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Lookup projected hash signatures in Predictor model, 17// output predicted labels and weights in decreasing order. 18// 19// Input: 20// Input[0]: A list of hash signatures. int32[num of input] 21// Input[1]: Hash signature keys in the model. int32[keys of model] 22// Input[2]: Labels in the model. int32[keys of model, item per entry] 23// Input[3]: Weights in the model. float[keys of model, item per entry] 24// 25// Output: 26// Output[0]: Predicted labels. int32[num of output] 27// Output[1]: Predicted weights. float[num of output] 28// 29 30#include <algorithm> 31#include <unordered_map> 32#include <vector> 33 34#include "tensorflow/contrib/lite/context.h" 35 36namespace tflite { 37namespace ops { 38namespace custom { 39 40namespace predict { 41 42struct PredictOption { 43 int32_t num_output; 44 float weight_threshold; 45 46 static PredictOption* Cast(void* ptr) { 47 return reinterpret_cast<PredictOption*>(ptr); 48 } 49}; 50 51bool WeightGreater(const std::pair<int32_t, float>& a, 52 const std::pair<int32_t, float>& b) { 53 return a.second > b.second; 54} 55 56void* Init(TfLiteContext* context, const char* custom_option, size_t length) { 57 if (custom_option == nullptr || length != sizeof(PredictOption)) { 58 fprintf(stderr, "No Custom option set\n"); 59 exit(1); 60 } 61 PredictOption* option = new PredictOption; 62 int offset = 0; 63 option->num_output = 64 *reinterpret_cast<const int32_t*>(custom_option + offset); 65 offset += sizeof(int32_t); 66 option->weight_threshold = 67 *reinterpret_cast<const float*>(custom_option + offset); 68 return reinterpret_cast<void*>(option); 69} 70 71void Free(TfLiteContext* context, void* buffer) { 72 delete PredictOption::Cast(buffer); 73} 74 75TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 76 TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); 77 TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); 78 79 TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; 80 TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; 81 TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; 82 TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; 83 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); 84 TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32); 85 TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32); 86 TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32); 87 TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1); 88 TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1); 89 TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2); 90 TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2); 91 TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], 92 model_label->dims->data[0]); 93 TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], 94 model_weight->dims->data[0]); 95 TF_LITE_ENSURE_EQ(context, model_label->dims->data[1], 96 model_weight->dims->data[1]); 97 98 PredictOption* option = PredictOption::Cast(node->user_data); 99 TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; 100 TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; 101 TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32); 102 TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32); 103 104 TfLiteIntArray* label_size = TfLiteIntArrayCreate(1); 105 label_size->data[0] = option->num_output; 106 TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1); 107 weight_size->data[0] = option->num_output; 108 TfLiteStatus status = 109 context->ResizeTensor(context, output_label, label_size); 110 if (status != kTfLiteOk) { 111 return status; 112 } 113 return context->ResizeTensor(context, output_weight, weight_size); 114} 115 116TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 117 TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; 118 TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; 119 TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; 120 TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; 121 122 // Aggregate by key 123 std::unordered_map<int32_t, float> aggregation; 124 const int num_input = lookup->dims->data[0]; 125 const int num_rows = model_key->dims->data[0]; 126 const int items = model_label->dims->data[1]; 127 int* model_key_end = model_key->data.i32 + num_rows; 128 129 for (int i = 0; i < num_input; i++) { 130 int* ptr = std::lower_bound(model_key->data.i32, model_key_end, 131 lookup->data.i32[i]); 132 if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) { 133 int idx = ptr - model_key->data.i32; 134 for (int j = 0; j < items; j++) { 135 aggregation[model_label->data.i32[idx * items + j]] += 136 model_weight->data.f[idx * items + j] / num_input; 137 } 138 } 139 } 140 141 // Sort by value 142 std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(), 143 aggregation.end()); 144 std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater); 145 146 PredictOption* option = PredictOption::Cast(node->user_data); 147 TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; 148 TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; 149 for (int i = 0; i < output_label->dims->data[0]; i++) { 150 if (i >= sorted_labels.size() || 151 sorted_labels[i].second < option->weight_threshold) { 152 // Set -1 to avoid lookup message with id 0, which is set for backoff. 153 output_label->data.i32[i] = -1; 154 output_weight->data.f[i] = 0.0f; 155 } else { 156 output_label->data.i32[i] = sorted_labels[i].first; 157 output_weight->data.f[i] = sorted_labels[i].second; 158 } 159 } 160 161 return kTfLiteOk; 162} 163 164} // namespace predict 165 166TfLiteRegistration* Register_PREDICT() { 167 static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare, 168 predict::Eval}; 169 return &r; 170} 171 172} // namespace custom 173} // namespace ops 174} // namespace tflite 175