1/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#undef NDEBUG
18
19#include "Callbacks.h"
20#include "CompilationBuilder.h"
21#include "Manager.h"
22#include "ModelBuilder.h"
23#include "NeuralNetworks.h"
24#include "NeuralNetworksWrapper.h"
25#include "SampleDriver.h"
26#include "ValidateHal.h"
27
28#include <algorithm>
29#include <cassert>
30#include <vector>
31
32#include <gtest/gtest.h>
33
34namespace android {
35
36using CompilationBuilder = nn::CompilationBuilder;
37using Device = nn::Device;
38using DeviceManager = nn::DeviceManager;
39using HidlModel = hardware::neuralnetworks::V1_1::Model;
40using PreparedModelCallback = hardware::neuralnetworks::V1_0::implementation::PreparedModelCallback;
41using Result = nn::wrapper::Result;
42using SampleDriver = nn::sample_driver::SampleDriver;
43using WrapperCompilation = nn::wrapper::Compilation;
44using WrapperEvent = nn::wrapper::Event;
45using WrapperExecution = nn::wrapper::Execution;
46using WrapperModel = nn::wrapper::Model;
47using WrapperOperandType = nn::wrapper::OperandType;
48using WrapperType = nn::wrapper::Type;
49
50namespace {
51
52// Wraps an IPreparedModel to allow dummying up the execution status.
53class TestPreparedModel : public IPreparedModel {
54public:
55    // If errorStatus is NONE, then execute behaves normally (and sends back
56    // the actual execution status).  Otherwise, don't bother to execute, and
57    // just send back errorStatus (as the execution status, not the launch
58    // status).
59    TestPreparedModel(sp<IPreparedModel> preparedModel, ErrorStatus errorStatus) :
60            mPreparedModel(preparedModel), mErrorStatus(errorStatus) {}
61
62    Return<ErrorStatus> execute(const Request& request,
63                                const sp<IExecutionCallback>& callback) override {
64        if (mErrorStatus == ErrorStatus::NONE) {
65            return mPreparedModel->execute(request, callback);
66        } else {
67            callback->notify(mErrorStatus);
68            return ErrorStatus::NONE;
69        }
70    }
71private:
72    sp<IPreparedModel> mPreparedModel;
73    ErrorStatus mErrorStatus;
74};
75
76// Behaves like SampleDriver, except that it produces wrapped IPreparedModel.
77class TestDriver : public SampleDriver {
78public:
79    // Allow dummying up the error status for execution of all models
80    // prepared from this driver.  If errorStatus is NONE, then
81    // execute behaves normally (and sends back the actual execution
82    // status).  Otherwise, don't bother to execute, and just send
83    // back errorStatus (as the execution status, not the launch
84    // status).
85    TestDriver(const std::string& name, ErrorStatus errorStatus) :
86            SampleDriver(name.c_str()), mErrorStatus(errorStatus) { }
87
88    Return<void> getCapabilities_1_1(getCapabilities_1_1_cb _hidl_cb) override {
89        android::nn::initVLogMask();
90        Capabilities capabilities =
91                {.float32Performance = {.execTime = 0.75f, .powerUsage = 0.75f},
92                 .quantized8Performance = {.execTime = 0.75f, .powerUsage = 0.75f},
93                 .relaxedFloat32toFloat16Performance = {.execTime = 0.75f, .powerUsage = 0.75f}};
94        _hidl_cb(ErrorStatus::NONE, capabilities);
95        return Void();
96    }
97
98    Return<void> getSupportedOperations_1_1(const HidlModel& model,
99                                            getSupportedOperations_cb cb) override {
100        if (nn::validateModel(model)) {
101            std::vector<bool> supported(model.operations.size(), true);
102            cb(ErrorStatus::NONE, supported);
103        } else {
104            std::vector<bool> supported;
105            cb(ErrorStatus::INVALID_ARGUMENT, supported);
106        }
107        return Void();
108    }
109
110    Return<ErrorStatus> prepareModel_1_1(
111        const HidlModel& model,
112        ExecutionPreference preference,
113        const sp<IPreparedModelCallback>& actualCallback) override {
114
115        sp<PreparedModelCallback> localCallback = new PreparedModelCallback;
116        Return<ErrorStatus> prepareModelReturn =
117                SampleDriver::prepareModel_1_1(model, preference, localCallback);
118        if (!prepareModelReturn.isOkUnchecked()) {
119            return prepareModelReturn;
120        }
121        if (prepareModelReturn != ErrorStatus::NONE) {
122            actualCallback->notify(localCallback->getStatus(), localCallback->getPreparedModel());
123            return prepareModelReturn;
124        }
125        localCallback->wait();
126        if (localCallback->getStatus() != ErrorStatus::NONE) {
127            actualCallback->notify(localCallback->getStatus(), localCallback->getPreparedModel());
128        } else {
129            actualCallback->notify(ErrorStatus::NONE,
130                                   new TestPreparedModel(localCallback->getPreparedModel(),
131                                                         mErrorStatus));
132        }
133        return prepareModelReturn;
134    }
135
136private:
137    ErrorStatus mErrorStatus;
138};
139
140// This class adds some simple utilities on top of
141// ::android::nn::wrapper::Compilation in order to provide access to
142// certain features from CompilationBuilder that are not exposed by
143// the base class.
144class TestCompilation : public WrapperCompilation {
145public:
146    TestCompilation(const WrapperModel* model) : WrapperCompilation(model) {
147        // We need to ensure that we use our TestDriver and do not
148        // fall back to CPU.  (If we allow CPU fallback, then when our
149        // TestDriver reports an execution failure, we'll re-execute
150        // on CPU, and will not see the failure.)
151        builder()->setPartitioning(DeviceManager::kPartitioningWithoutFallback);
152    }
153
154    // Allow dummying up the error status for all executions from this
155    // compilation.  If errorStatus is NONE, then execute behaves
156    // normally (and sends back the actual execution status).
157    // Otherwise, don't bother to execute, and just send back
158    // errorStatus (as the execution status, not the launch status).
159    Result finish(const std::string& deviceName, ErrorStatus errorStatus) {
160        std::vector<std::shared_ptr<Device>> devices;
161        auto device = std::make_shared<Device>(deviceName, new TestDriver(deviceName, errorStatus));
162        assert(device->initialize());
163        devices.push_back(device);
164        return static_cast<Result>(builder()->finish(devices));
165    }
166
167private:
168    CompilationBuilder* builder() {
169        return reinterpret_cast<CompilationBuilder*>(getHandle());
170    }
171};
172
173class ExecutionTest :
174            public ::testing::TestWithParam<std::tuple<ErrorStatus, Result>> {
175public:
176    ExecutionTest() :
177            kName(toString(std::get<0>(GetParam()))),
178            kForceErrorStatus(std::get<0>(GetParam())),
179            kExpectResult(std::get<1>(GetParam())),
180            mModel(makeModel()),
181            mCompilation(&mModel) { }
182
183protected:
184    const std::string kName;
185
186    // Allow dummying up the error status for execution.  If
187    // kForceErrorStatus is NONE, then execution behaves normally (and
188    // sends back the actual execution status).  Otherwise, don't
189    // bother to execute, and just send back kForceErrorStatus (as the
190    // execution status, not the launch status).
191    const ErrorStatus kForceErrorStatus;
192
193    // What result do we expect from the execution?  (The Result
194    // equivalent of kForceErrorStatus.)
195    const Result kExpectResult;
196
197    WrapperModel mModel;
198    TestCompilation mCompilation;
199
200    void setInputOutput(WrapperExecution* execution) {
201        ASSERT_EQ(execution->setInput(0, &mInputBuffer, sizeof(mInputBuffer)), Result::NO_ERROR);
202        ASSERT_EQ(execution->setOutput(0, &mOutputBuffer, sizeof(mOutputBuffer)), Result::NO_ERROR);
203    }
204
205    float mInputBuffer  = 3.14;
206    float mOutputBuffer = 0;
207    const float kOutputBufferExpected = 3;
208
209private:
210    static WrapperModel makeModel() {
211        static const WrapperOperandType tensorType(WrapperType::TENSOR_FLOAT32, { 1 });
212
213        WrapperModel model;
214        uint32_t input = model.addOperand(&tensorType);
215        uint32_t output = model.addOperand(&tensorType);
216        model.addOperation(ANEURALNETWORKS_FLOOR, { input }, { output });
217        model.identifyInputsAndOutputs({ input }, { output } );
218        assert(model.finish() == Result::NO_ERROR);
219
220        return model;
221    }
222};
223
224TEST_P(ExecutionTest, Wait) {
225    SCOPED_TRACE(kName);
226    ASSERT_EQ(mCompilation.finish(kName, kForceErrorStatus), Result::NO_ERROR);
227    WrapperExecution execution(&mCompilation);
228    ASSERT_NO_FATAL_FAILURE(setInputOutput(&execution));
229    WrapperEvent event;
230    ASSERT_EQ(execution.startCompute(&event), Result::NO_ERROR);
231    ASSERT_EQ(event.wait(), kExpectResult);
232    if (kExpectResult == Result::NO_ERROR) {
233        ASSERT_EQ(mOutputBuffer, kOutputBufferExpected);
234    }
235}
236
237INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest,
238                        ::testing::Values(std::make_tuple(ErrorStatus::NONE,
239                                                          Result::NO_ERROR),
240                                          std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE,
241                                                          Result::OP_FAILED),
242                                          std::make_tuple(ErrorStatus::GENERAL_FAILURE,
243                                                          Result::OP_FAILED),
244                                          std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
245                                                          Result::OP_FAILED),
246                                          std::make_tuple(ErrorStatus::INVALID_ARGUMENT,
247                                                          Result::BAD_DATA)));
248
249}  // namespace
250}  // namespace android
251