1f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni/* 2f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * Copyright (C) 2017 The Android Open Source Project 3f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * 4f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * Licensed under the Apache License, Version 2.0 (the "License"); 5f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * you may not use this file except in compliance with the License. 6f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * You may obtain a copy of the License at 7f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * 8f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * http://www.apache.org/licenses/LICENSE-2.0 9f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * 10f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * Unless required by applicable law or agreed to in writing, software 11f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * distributed under the License is distributed on an "AS IS" BASIS, 12f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * See the License for the specific language governing permissions and 14f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * limitations under the License. 15f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni */ 16f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 17f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "LSHProjection.h" 18f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 19f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "NeuralNetworksWrapper.h" 20f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "gmock/gmock-generated-matchers.h" 21f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "gmock/gmock-matchers.h" 22f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "gtest/gtest.h" 23f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 24f1817c663af4f22bc089ef82cd50df4186422c42Yang Niusing ::testing::FloatNear; 25f1817c663af4f22bc089ef82cd50df4186422c42Yang Niusing ::testing::Matcher; 26f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 27f1817c663af4f22bc089ef82cd50df4186422c42Yang Ninamespace android { 28f1817c663af4f22bc089ef82cd50df4186422c42Yang Ninamespace nn { 29f1817c663af4f22bc089ef82cd50df4186422c42Yang Ninamespace wrapper { 30f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 31f1817c663af4f22bc089ef82cd50df4186422c42Yang Niusing ::testing::ElementsAre; 32f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 33f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ 34f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni ACTION(Hash, float) \ 35f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni ACTION(Input, int) \ 36f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni ACTION(Weight, float) 37f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 38f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni// For all output and intermediate states 399e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo#define FOR_ALL_OUTPUT_TENSORS(ACTION) \ 409e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo ACTION(Output, int) 41f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 42f1817c663af4f22bc089ef82cd50df4186422c42Yang Niclass LSHProjectionOpModel { 43f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni public: 44f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni LSHProjectionOpModel(LSHProjectionType type, 45f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::initializer_list<uint32_t> hash_shape, 46f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::initializer_list<uint32_t> input_shape, 47f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::initializer_list<uint32_t> weight_shape) 48f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni : type_(type) { 49f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::vector<uint32_t> inputs; 50f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 51f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape); 52f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni inputs.push_back(model_.addOperand(&HashTy)); 539e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo OperandType InputTy(Type::TENSOR_INT32, input_shape); 54f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni inputs.push_back(model_.addOperand(&InputTy)); 55f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape); 56f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni inputs.push_back(model_.addOperand(&WeightTy)); 57f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 58f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni OperandType TypeParamTy(Type::INT32, {}); 59f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni inputs.push_back(model_.addOperand(&TypeParamTy)); 60f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 61f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::vector<uint32_t> outputs; 62f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 63f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { 64f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni uint32_t sz = 1; 65f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni for (uint32_t d : dims) { 66f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni sz *= d; 67f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni } 68f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni return sz; 69f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni }; 70f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 719e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo uint32_t outShapeDimension = 0; 72f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni if (type == LSHProjectionType_SPARSE) { 73f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni auto it = hash_shape.begin(); 74f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni Output_.insert(Output_.end(), *it, 0.f); 759e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo outShapeDimension = *it; 76f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni } else { 77f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni Output_.insert(Output_.end(), multiAll(hash_shape), 0.f); 789e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo outShapeDimension = multiAll(hash_shape); 79f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni } 80f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 819e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension}); 829e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo outputs.push_back(model_.addOperand(&OutputTy)); 839e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo 84f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs); 8566d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet model_.identifyInputsAndOutputs(inputs, outputs); 86544739620cd7f37d40524d2407c92042e485c73fDavid Gross 87544739620cd7f37d40524d2407c92042e485c73fDavid Gross model_.finish(); 88f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni } 89f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 90f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#define DefineSetter(X, T) \ 91f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni void Set##X(const std::vector<T> &f) { \ 92f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni X##_.insert(X##_.end(), f.begin(), f.end()); \ 93f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni } 94f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 95f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); 96f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 97f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#undef DefineSetter 98f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 99f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni const std::vector<int> &GetOutput() const { return Output_; } 100f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 101f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni void Invoke() { 102f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni ASSERT_TRUE(model_.isValid()); 103f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 10483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross Compilation compilation(&model_); 10565aa556323f4a054f80a75b6c4c721b2a7ed3298David Gross compilation.finish(); 1063ced3cfd5b8f22b632c35f24e585c4847383b195David Gross Execution execution(&compilation); 107f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 1083ced3cfd5b8f22b632c35f24e585c4847383b195David Gross#define SetInputOrWeight(X, T) \ 1093ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ASSERT_EQ(execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), \ 1109e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo sizeof(T) * X##_.size()), \ 111f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni Result::NO_ERROR); 112f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 113f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); 114f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 115f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#undef SetInputOrWeight 116f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 1179e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo#define SetOutput(X, T) \ 1183ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \ 1199e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo sizeof(T) * X##_.size()), \ 120f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni Result::NO_ERROR); 121f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 122f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni FOR_ALL_OUTPUT_TENSORS(SetOutput); 123f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 124f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#undef SetOutput 125f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 126f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni ASSERT_EQ( 1273ced3cfd5b8f22b632c35f24e585c4847383b195David Gross execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)), 128f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni Result::NO_ERROR); 129f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 1303ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ASSERT_EQ(execution.compute(), Result::NO_ERROR); 131f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni } 132f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 133f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni private: 134f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni Model model_; 135f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni LSHProjectionType type_; 136f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 137f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::vector<float> Hash_; 138f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::vector<int> Input_; 139f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::vector<float> Weight_; 140f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni std::vector<int> Output_; 141f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni}; // namespace wrapper 142f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 143f1817c663af4f22bc089ef82cd50df4186422c42Yang NiTEST(LSHProjectionOpTest2, DenseWithThreeInputs) { 144f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3}); 145f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 146f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 147f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 148f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.SetWeight({0.12, 0.34, 0.56}); 149f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 150f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.Invoke(); 151f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 152f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0)); 153f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni} 154f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 155f1817c663af4f22bc089ef82cd50df4186422c42Yang NiTEST(LSHProjectionOpTest2, SparseWithTwoInputs) { 156f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {}); 157f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 158f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 159f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 160f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 161f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni m.Invoke(); 162f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 163f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0)); 164f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni} 165f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni 166f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni} // namespace wrapper 167f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni} // namespace nn 168f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni} // namespace android 169