1/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19#include "VtsHalNeuralnetworks.h"
20
21#include "Callbacks.h"
22#include "TestHarness.h"
23#include "Utils.h"
24
25#include <android-base/logging.h>
26#include <android/hidl/memory/1.0/IMemory.h>
27#include <hidlmemory/mapping.h>
28
29namespace android {
30namespace hardware {
31namespace neuralnetworks {
32namespace V1_0 {
33namespace vts {
34namespace functional {
35
36using ::android::hardware::neuralnetworks::V1_0::implementation::ExecutionCallback;
37using ::android::hardware::neuralnetworks::V1_0::implementation::PreparedModelCallback;
38using ::android::hidl::memory::V1_0::IMemory;
39using test_helper::MixedTyped;
40using test_helper::MixedTypedExampleType;
41using test_helper::for_all;
42
43///////////////////////// UTILITY FUNCTIONS /////////////////////////
44
45static void createPreparedModel(const sp<IDevice>& device, const V1_0::Model& model,
46                                sp<IPreparedModel>* preparedModel) {
47    ASSERT_NE(nullptr, preparedModel);
48
49    // see if service can handle model
50    bool fullySupportsModel = false;
51    Return<void> supportedOpsLaunchStatus = device->getSupportedOperations(
52        model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
53            ASSERT_EQ(ErrorStatus::NONE, status);
54            ASSERT_NE(0ul, supported.size());
55            fullySupportsModel =
56                std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
57        });
58    ASSERT_TRUE(supportedOpsLaunchStatus.isOk());
59
60    // launch prepare model
61    sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
62    ASSERT_NE(nullptr, preparedModelCallback.get());
63    Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
64    ASSERT_TRUE(prepareLaunchStatus.isOk());
65    ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
66
67    // retrieve prepared model
68    preparedModelCallback->wait();
69    ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
70    *preparedModel = preparedModelCallback->getPreparedModel();
71
72    // The getSupportedOperations call returns a list of operations that are
73    // guaranteed not to fail if prepareModel is called, and
74    // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed.
75    // If a driver has any doubt that it can prepare an operation, it must
76    // return false. So here, if a driver isn't sure if it can support an
77    // operation, but reports that it successfully prepared the model, the test
78    // can continue.
79    if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
80        ASSERT_EQ(nullptr, preparedModel->get());
81        LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot "
82                     "prepare model that it does not support.";
83        std::cout << "[          ]   Unable to test Request validation because vendor service "
84                     "cannot prepare model that it does not support."
85                  << std::endl;
86        return;
87    }
88    ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus);
89    ASSERT_NE(nullptr, preparedModel->get());
90}
91
92// Primary validation function. This function will take a valid request, apply a
93// mutation to it to invalidate the request, then pass it to interface calls
94// that use the request. Note that the request here is passed by value, and any
95// mutation to the request does not leave this function.
96static void validate(const sp<IPreparedModel>& preparedModel, const std::string& message,
97                     Request request, const std::function<void(Request*)>& mutation) {
98    mutation(&request);
99    SCOPED_TRACE(message + " [execute]");
100
101    sp<ExecutionCallback> executionCallback = new ExecutionCallback();
102    ASSERT_NE(nullptr, executionCallback.get());
103    Return<ErrorStatus> executeLaunchStatus = preparedModel->execute(request, executionCallback);
104    ASSERT_TRUE(executeLaunchStatus.isOk());
105    ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(executeLaunchStatus));
106
107    executionCallback->wait();
108    ErrorStatus executionReturnStatus = executionCallback->getStatus();
109    ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, executionReturnStatus);
110}
111
112// Delete element from hidl_vec. hidl_vec doesn't support a "remove" operation,
113// so this is efficiently accomplished by moving the element to the end and
114// resizing the hidl_vec to one less.
115template <typename Type>
116static void hidl_vec_removeAt(hidl_vec<Type>* vec, uint32_t index) {
117    if (vec) {
118        std::rotate(vec->begin() + index, vec->begin() + index + 1, vec->end());
119        vec->resize(vec->size() - 1);
120    }
121}
122
123template <typename Type>
124static uint32_t hidl_vec_push_back(hidl_vec<Type>* vec, const Type& value) {
125    // assume vec is valid
126    const uint32_t index = vec->size();
127    vec->resize(index + 1);
128    (*vec)[index] = value;
129    return index;
130}
131
132///////////////////////// REMOVE INPUT ////////////////////////////////////
133
134static void removeInputTest(const sp<IPreparedModel>& preparedModel, const Request& request) {
135    for (size_t input = 0; input < request.inputs.size(); ++input) {
136        const std::string message = "removeInput: removed input " + std::to_string(input);
137        validate(preparedModel, message, request,
138                 [input](Request* request) { hidl_vec_removeAt(&request->inputs, input); });
139    }
140}
141
142///////////////////////// REMOVE OUTPUT ////////////////////////////////////
143
144static void removeOutputTest(const sp<IPreparedModel>& preparedModel, const Request& request) {
145    for (size_t output = 0; output < request.outputs.size(); ++output) {
146        const std::string message = "removeOutput: removed Output " + std::to_string(output);
147        validate(preparedModel, message, request,
148                 [output](Request* request) { hidl_vec_removeAt(&request->outputs, output); });
149    }
150}
151
152///////////////////////////// ENTRY POINT //////////////////////////////////
153
154std::vector<Request> createRequests(const std::vector<MixedTypedExampleType>& examples) {
155    const uint32_t INPUT = 0;
156    const uint32_t OUTPUT = 1;
157
158    std::vector<Request> requests;
159
160    for (auto& example : examples) {
161        const MixedTyped& inputs = example.first;
162        const MixedTyped& outputs = example.second;
163
164        std::vector<RequestArgument> inputs_info, outputs_info;
165        uint32_t inputSize = 0, outputSize = 0;
166
167        // This function only partially specifies the metadata (vector of RequestArguments).
168        // The contents are copied over below.
169        for_all(inputs, [&inputs_info, &inputSize](int index, auto, auto s) {
170            if (inputs_info.size() <= static_cast<size_t>(index)) inputs_info.resize(index + 1);
171            RequestArgument arg = {
172                .location = {.poolIndex = INPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
173                .dimensions = {},
174            };
175            RequestArgument arg_empty = {
176                .hasNoValue = true,
177            };
178            inputs_info[index] = s ? arg : arg_empty;
179            inputSize += s;
180        });
181        // Compute offset for inputs 1 and so on
182        {
183            size_t offset = 0;
184            for (auto& i : inputs_info) {
185                if (!i.hasNoValue) i.location.offset = offset;
186                offset += i.location.length;
187            }
188        }
189
190        // Go through all outputs, initialize RequestArgument descriptors
191        for_all(outputs, [&outputs_info, &outputSize](int index, auto, auto s) {
192            if (outputs_info.size() <= static_cast<size_t>(index)) outputs_info.resize(index + 1);
193            RequestArgument arg = {
194                .location = {.poolIndex = OUTPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
195                .dimensions = {},
196            };
197            outputs_info[index] = arg;
198            outputSize += s;
199        });
200        // Compute offset for outputs 1 and so on
201        {
202            size_t offset = 0;
203            for (auto& i : outputs_info) {
204                i.location.offset = offset;
205                offset += i.location.length;
206            }
207        }
208        std::vector<hidl_memory> pools = {nn::allocateSharedMemory(inputSize),
209                                          nn::allocateSharedMemory(outputSize)};
210        if (pools[INPUT].size() == 0 || pools[OUTPUT].size() == 0) {
211            return {};
212        }
213
214        // map pool
215        sp<IMemory> inputMemory = mapMemory(pools[INPUT]);
216        if (inputMemory == nullptr) {
217            return {};
218        }
219        char* inputPtr = reinterpret_cast<char*>(static_cast<void*>(inputMemory->getPointer()));
220        if (inputPtr == nullptr) {
221            return {};
222        }
223
224        // initialize pool
225        inputMemory->update();
226        for_all(inputs, [&inputs_info, inputPtr](int index, auto p, auto s) {
227            char* begin = (char*)p;
228            char* end = begin + s;
229            // TODO: handle more than one input
230            std::copy(begin, end, inputPtr + inputs_info[index].location.offset);
231        });
232        inputMemory->commit();
233
234        requests.push_back({.inputs = inputs_info, .outputs = outputs_info, .pools = pools});
235    }
236
237    return requests;
238}
239
240void ValidationTest::validateRequests(const V1_0::Model& model,
241                                      const std::vector<Request>& requests) {
242    // create IPreparedModel
243    sp<IPreparedModel> preparedModel;
244    ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel));
245    if (preparedModel == nullptr) {
246        return;
247    }
248
249    // validate each request
250    for (const Request& request : requests) {
251        removeInputTest(preparedModel, request);
252        removeOutputTest(preparedModel, request);
253    }
254}
255
256}  // namespace functional
257}  // namespace vts
258}  // namespace V1_0
259}  // namespace neuralnetworks
260}  // namespace hardware
261}  // namespace android
262