1f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni/*
2f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * Copyright (C) 2017 The Android Open Source Project
3f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni *
4f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * Licensed under the Apache License, Version 2.0 (the "License");
5f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * you may not use this file except in compliance with the License.
6f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * You may obtain a copy of the License at
7f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni *
8f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni *      http://www.apache.org/licenses/LICENSE-2.0
9f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni *
10f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * Unless required by applicable law or agreed to in writing, software
11f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * distributed under the License is distributed on an "AS IS" BASIS,
12f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * See the License for the specific language governing permissions and
14f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni * limitations under the License.
15f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni */
16f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
17f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "LSHProjection.h"
18f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
19f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "NeuralNetworksWrapper.h"
20f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "gmock/gmock-generated-matchers.h"
21f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "gmock/gmock-matchers.h"
22f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#include "gtest/gtest.h"
23f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
24f1817c663af4f22bc089ef82cd50df4186422c42Yang Niusing ::testing::FloatNear;
25f1817c663af4f22bc089ef82cd50df4186422c42Yang Niusing ::testing::Matcher;
26f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
27f1817c663af4f22bc089ef82cd50df4186422c42Yang Ninamespace android {
28f1817c663af4f22bc089ef82cd50df4186422c42Yang Ninamespace nn {
29f1817c663af4f22bc089ef82cd50df4186422c42Yang Ninamespace wrapper {
30f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
31f1817c663af4f22bc089ef82cd50df4186422c42Yang Niusing ::testing::ElementsAre;
32f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
33f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
34f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  ACTION(Hash, float)                            \
35f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  ACTION(Input, int)                             \
36f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  ACTION(Weight, float)
37f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
38f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni// For all output and intermediate states
399e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
409e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo  ACTION(Output, int)
41f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
42f1817c663af4f22bc089ef82cd50df4186422c42Yang Niclass LSHProjectionOpModel {
43f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni public:
44f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  LSHProjectionOpModel(LSHProjectionType type,
45f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni                       std::initializer_list<uint32_t> hash_shape,
46f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni                       std::initializer_list<uint32_t> input_shape,
47f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni                       std::initializer_list<uint32_t> weight_shape)
48f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      : type_(type) {
49f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    std::vector<uint32_t> inputs;
50f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
51f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
52f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    inputs.push_back(model_.addOperand(&HashTy));
539e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo    OperandType InputTy(Type::TENSOR_INT32, input_shape);
54f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    inputs.push_back(model_.addOperand(&InputTy));
55f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
56f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    inputs.push_back(model_.addOperand(&WeightTy));
57f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
58f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    OperandType TypeParamTy(Type::INT32, {});
59f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    inputs.push_back(model_.addOperand(&TypeParamTy));
60f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
61f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    std::vector<uint32_t> outputs;
62f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
63f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
64f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      uint32_t sz = 1;
65f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      for (uint32_t d : dims) {
66f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni        sz *= d;
67f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      }
68f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      return sz;
69f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    };
70f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
719e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo    uint32_t outShapeDimension = 0;
72f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    if (type == LSHProjectionType_SPARSE) {
73f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      auto it = hash_shape.begin();
74f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      Output_.insert(Output_.end(), *it, 0.f);
759e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo      outShapeDimension = *it;
76f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    } else {
77f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni      Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
789e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo      outShapeDimension = multiAll(hash_shape);
79f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    }
80f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
819e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo    OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
829e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo    outputs.push_back(model_.addOperand(&OutputTy));
839e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo
84f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
8566d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet    model_.identifyInputsAndOutputs(inputs, outputs);
86544739620cd7f37d40524d2407c92042e485c73fDavid Gross
87544739620cd7f37d40524d2407c92042e485c73fDavid Gross    model_.finish();
88f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  }
89f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
90f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#define DefineSetter(X, T)                       \
91f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  void Set##X(const std::vector<T> &f) {         \
92f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    X##_.insert(X##_.end(), f.begin(), f.end()); \
93f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  }
94f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
95f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
96f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
97f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#undef DefineSetter
98f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
99f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  const std::vector<int> &GetOutput() const { return Output_; }
100f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
101f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  void Invoke() {
102f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    ASSERT_TRUE(model_.isValid());
103f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
10483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    Compilation compilation(&model_);
10565aa556323f4a054f80a75b6c4c721b2a7ed3298David Gross    compilation.finish();
1063ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    Execution execution(&compilation);
107f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
1083ced3cfd5b8f22b632c35f24e585c4847383b195David Gross#define SetInputOrWeight(X, T)                                             \
1093ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    ASSERT_EQ(execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), \
1109e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo                                 sizeof(T) * X##_.size()),                 \
111f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni              Result::NO_ERROR);
112f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
113f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
114f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
115f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#undef SetInputOrWeight
116f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
1179e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo#define SetOutput(X, T)                                                   \
1183ced3cfd5b8f22b632c35f24e585c4847383b195David Gross  ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
1199e2ba5472927bbd3ab3cfb74eb3fc3477eac95e2Tix Lo                                sizeof(T) * X##_.size()),                 \
120f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni            Result::NO_ERROR);
121f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
122f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    FOR_ALL_OUTPUT_TENSORS(SetOutput);
123f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
124f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni#undef SetOutput
125f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
126f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni    ASSERT_EQ(
1273ced3cfd5b8f22b632c35f24e585c4847383b195David Gross        execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
128f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni        Result::NO_ERROR);
129f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
1303ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    ASSERT_EQ(execution.compute(), Result::NO_ERROR);
131f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  }
132f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
133f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni private:
134f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  Model model_;
135f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  LSHProjectionType type_;
136f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
137f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  std::vector<float> Hash_;
138f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  std::vector<int> Input_;
139f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  std::vector<float> Weight_;
140f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  std::vector<int> Output_;
141f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni};  // namespace wrapper
142f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
143f1817c663af4f22bc089ef82cd50df4186422c42Yang NiTEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
144f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
145f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
146f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
147f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
148f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.SetWeight({0.12, 0.34, 0.56});
149f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
150f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.Invoke();
151f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
152f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
153f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni}
154f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
155f1817c663af4f22bc089ef82cd50df4186422c42Yang NiTEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
156f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {});
157f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
158f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
159f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
160f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
161f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  m.Invoke();
162f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
163f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni  EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
164f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni}
165f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni
166f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni}  // namespace wrapper
167f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni}  // namespace nn
168f1817c663af4f22bc089ef82cd50df4186422c42Yang Ni}  // namespace android
169