1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "HashtableLookup.h" 18 19#include "CpuExecutor.h" 20#include "HalInterfaces.h" 21#include "Operations.h" 22 23namespace android { 24namespace nn { 25 26namespace { 27 28int greater(const void* a, const void* b) { 29 return *static_cast<const int*>(a) - *static_cast<const int*>(b); 30} 31 32} // anonymous namespace 33 34HashtableLookup::HashtableLookup(const Operation& operation, 35 std::vector<RunTimeOperandInfo>& operands) { 36 lookup_ = GetInput(operation, operands, kLookupTensor); 37 key_ = GetInput(operation, operands, kKeyTensor); 38 value_ = GetInput(operation, operands, kValueTensor); 39 40 output_ = GetOutput(operation, operands, kOutputTensor); 41 hits_ = GetOutput(operation, operands, kHitsTensor); 42} 43 44bool HashtableLookup::Eval() { 45 const int num_rows = value_->shape().dimensions[0]; 46 const int row_bytes = sizeOfData(value_->type, value_->dimensions) / num_rows; 47 void* pointer = nullptr; 48 49 for (int i = 0; i < static_cast<int>(lookup_->shape().dimensions[0]); i++) { 50 int idx = -1; 51 pointer = bsearch(lookup_->buffer + sizeof(int) * i, key_->buffer, 52 num_rows, sizeof(int), greater); 53 if (pointer != nullptr) { 54 idx = 55 (reinterpret_cast<uint8_t*>(pointer) - key_->buffer) / sizeof(float); 56 } 57 58 if (idx >= num_rows || idx < 0) { 59 memset(output_->buffer + i * row_bytes, 0, row_bytes); 60 hits_->buffer[i] = 0; 61 } else { 62 memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, 63 row_bytes); 64 hits_->buffer[i] = 1; 65 } 66 } 67 68 return true; 69} 70 71} // namespace nn 72} // namespace android 73