1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "LSHProjection.h"
18
19#include "NeuralNetworksWrapper.h"
20#include "gmock/gmock-generated-matchers.h"
21#include "gmock/gmock-matchers.h"
22#include "gtest/gtest.h"
23
24using ::testing::FloatNear;
25using ::testing::Matcher;
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31using ::testing::ElementsAre;
32
33#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
34  ACTION(Hash, float)                            \
35  ACTION(Input, int)                             \
36  ACTION(Weight, float)
37
38// For all output and intermediate states
39#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
40  ACTION(Output, int)
41
42class LSHProjectionOpModel {
43 public:
44  LSHProjectionOpModel(LSHProjectionType type,
45                       std::initializer_list<uint32_t> hash_shape,
46                       std::initializer_list<uint32_t> input_shape,
47                       std::initializer_list<uint32_t> weight_shape)
48      : type_(type) {
49    std::vector<uint32_t> inputs;
50
51    OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
52    inputs.push_back(model_.addOperand(&HashTy));
53    OperandType InputTy(Type::TENSOR_INT32, input_shape);
54    inputs.push_back(model_.addOperand(&InputTy));
55    OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
56    inputs.push_back(model_.addOperand(&WeightTy));
57
58    OperandType TypeParamTy(Type::INT32, {});
59    inputs.push_back(model_.addOperand(&TypeParamTy));
60
61    std::vector<uint32_t> outputs;
62
63    auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
64      uint32_t sz = 1;
65      for (uint32_t d : dims) {
66        sz *= d;
67      }
68      return sz;
69    };
70
71    uint32_t outShapeDimension = 0;
72    if (type == LSHProjectionType_SPARSE) {
73      auto it = hash_shape.begin();
74      Output_.insert(Output_.end(), *it, 0.f);
75      outShapeDimension = *it;
76    } else {
77      Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
78      outShapeDimension = multiAll(hash_shape);
79    }
80
81    OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
82    outputs.push_back(model_.addOperand(&OutputTy));
83
84    model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
85    model_.identifyInputsAndOutputs(inputs, outputs);
86
87    model_.finish();
88  }
89
90#define DefineSetter(X, T)                       \
91  void Set##X(const std::vector<T> &f) {         \
92    X##_.insert(X##_.end(), f.begin(), f.end()); \
93  }
94
95  FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
96
97#undef DefineSetter
98
99  const std::vector<int> &GetOutput() const { return Output_; }
100
101  void Invoke() {
102    ASSERT_TRUE(model_.isValid());
103
104    Compilation compilation(&model_);
105    compilation.finish();
106    Execution execution(&compilation);
107
108#define SetInputOrWeight(X, T)                                             \
109    ASSERT_EQ(execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), \
110                                 sizeof(T) * X##_.size()),                 \
111              Result::NO_ERROR);
112
113    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
114
115#undef SetInputOrWeight
116
117#define SetOutput(X, T)                                                   \
118  ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
119                                sizeof(T) * X##_.size()),                 \
120            Result::NO_ERROR);
121
122    FOR_ALL_OUTPUT_TENSORS(SetOutput);
123
124#undef SetOutput
125
126    ASSERT_EQ(
127        execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
128        Result::NO_ERROR);
129
130    ASSERT_EQ(execution.compute(), Result::NO_ERROR);
131  }
132
133 private:
134  Model model_;
135  LSHProjectionType type_;
136
137  std::vector<float> Hash_;
138  std::vector<int> Input_;
139  std::vector<float> Weight_;
140  std::vector<int> Output_;
141};  // namespace wrapper
142
143TEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
144  LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
145
146  m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
147  m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
148  m.SetWeight({0.12, 0.34, 0.56});
149
150  m.Invoke();
151
152  EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
153}
154
155TEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
156  LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {});
157
158  m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
159  m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
160
161  m.Invoke();
162
163  EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
164}
165
166}  // namespace wrapper
167}  // namespace nn
168}  // namespace android
169