NeuralNetworksWrapper.h revision 66d5cb6e3a90aefc8d545f6369080ab88de9d667
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT32 = ANEURALNETWORKS_FLOAT32,
33    INT32 = ANEURALNETWORKS_INT32,
34    UINT32 = ANEURALNETWORKS_UINT32,
35    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
36    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
37    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
38};
39
40enum class ExecutePreference {
41    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
42    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
43    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
44};
45
46enum class Result {
47    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
48    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
49    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
50    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
51    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
52};
53
54struct OperandType {
55    ANeuralNetworksOperandType operandType;
56    // int32_t type;
57    std::vector<uint32_t> dimensions;
58
59    OperandType(Type type, const std::vector<uint32_t>& d, float scale = 0.0f,
60                int32_t zeroPoint = 0)
61        : dimensions(d) {
62        operandType.type = static_cast<int32_t>(type);
63        operandType.scale = scale;
64        operandType.zeroPoint = zeroPoint;
65
66        operandType.dimensionCount = static_cast<uint32_t>(dimensions.size());
67        operandType.dimensions = dimensions.data();
68    }
69};
70
71class Memory {
72public:
73    Memory(size_t size, int protect, int fd, size_t offset) {
74        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
75                 ANEURALNETWORKS_NO_ERROR;
76    }
77
78    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
79
80    // Disallow copy semantics to ensure the runtime object can only be freed
81    // once. Copy semantics could be enabled if some sort of reference counting
82    // or deep-copy system for runtime objects is added later.
83    Memory(const Memory&) = delete;
84    Memory& operator=(const Memory&) = delete;
85
86    // Move semantics to remove access to the runtime object from the wrapper
87    // object that is being moved. This ensures the runtime object will be
88    // freed only once.
89    Memory(Memory&& other) { *this = std::move(other); }
90    Memory& operator=(Memory&& other) {
91        if (this != &other) {
92            mMemory = other.mMemory;
93            mValid = other.mValid;
94            other.mMemory = nullptr;
95            other.mValid = false;
96        }
97        return *this;
98    }
99
100    ANeuralNetworksMemory* get() const { return mMemory; }
101    bool isValid() const { return mValid; }
102
103private:
104    ANeuralNetworksMemory* mMemory = nullptr;
105    bool mValid = true;
106};
107
108class Model {
109public:
110    Model() {
111        // TODO handle the value returned by this call
112        ANeuralNetworksModel_create(&mModel);
113    }
114    ~Model() { ANeuralNetworksModel_free(mModel); }
115
116    // Disallow copy semantics to ensure the runtime object can only be freed
117    // once. Copy semantics could be enabled if some sort of reference counting
118    // or deep-copy system for runtime objects is added later.
119    Model(const Model&) = delete;
120    Model& operator=(const Model&) = delete;
121
122    // Move semantics to remove access to the runtime object from the wrapper
123    // object that is being moved. This ensures the runtime object will be
124    // freed only once.
125    Model(Model&& other) { *this = std::move(other); }
126    Model& operator=(Model&& other) {
127        if (this != &other) {
128            mModel = other.mModel;
129            mNextOperandId = other.mNextOperandId;
130            mValid = other.mValid;
131            other.mModel = nullptr;
132            other.mNextOperandId = 0;
133            other.mValid = false;
134        }
135        return *this;
136    }
137
138    Result finish() { return static_cast<Result>(ANeuralNetworksModel_finish(mModel)); }
139
140    uint32_t addOperand(const OperandType* type) {
141        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
142            ANEURALNETWORKS_NO_ERROR) {
143            mValid = false;
144        }
145        return mNextOperandId++;
146    }
147
148    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
149        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
150            ANEURALNETWORKS_NO_ERROR) {
151            mValid = false;
152        }
153    }
154
155    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
156                                   size_t length) {
157        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
158                                                           length) != ANEURALNETWORKS_NO_ERROR) {
159            mValid = false;
160        }
161    }
162
163    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
164                      const std::vector<uint32_t>& outputs) {
165        if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
166                                              inputs.data(), static_cast<uint32_t>(outputs.size()),
167                                              outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
168            mValid = false;
169        }
170    }
171    void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
172                                  const std::vector<uint32_t>& outputs) {
173        if (ANeuralNetworksModel_identifyInputsAndOutputs(
174                        mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
175                        static_cast<uint32_t>(outputs.size()),
176                        outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
177            mValid = false;
178        }
179    }
180    ANeuralNetworksModel* getHandle() const { return mModel; }
181    bool isValid() const { return mValid; }
182
183private:
184    ANeuralNetworksModel* mModel = nullptr;
185    // We keep track of the operand ID as a convenience to the caller.
186    uint32_t mNextOperandId = 0;
187    bool mValid = true;
188};
189
190class Event {
191public:
192    Event() {}
193    ~Event() { ANeuralNetworksEvent_free(mEvent); }
194
195    // Disallow copy semantics to ensure the runtime object can only be freed
196    // once. Copy semantics could be enabled if some sort of reference counting
197    // or deep-copy system for runtime objects is added later.
198    Event(const Event&) = delete;
199    Event& operator=(const Event&) = delete;
200
201    // Move semantics to remove access to the runtime object from the wrapper
202    // object that is being moved. This ensures the runtime object will be
203    // freed only once.
204    Event(Event&& other) { *this = std::move(other); }
205    Event& operator=(Event&& other) {
206        if (this != &other) {
207            mEvent = other.mEvent;
208            other.mEvent = nullptr;
209        }
210        return *this;
211    }
212
213    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
214
215    // Only for use by Execution
216    void set(ANeuralNetworksEvent* newEvent) {
217        ANeuralNetworksEvent_free(mEvent);
218        mEvent = newEvent;
219    }
220
221private:
222    ANeuralNetworksEvent* mEvent = nullptr;
223};
224
225class Compilation {
226public:
227    Compilation(const Model* model) {
228        int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
229        if (result != 0) {
230            // TODO Handle the error
231        }
232    }
233
234    ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
235
236    Compilation(const Compilation&) = delete;
237    Compilation& operator=(const Compilation&) = delete;
238
239    Compilation(Compilation&& other) { *this = std::move(other); }
240    Compilation& operator=(Compilation&& other) {
241        if (this != &other) {
242            mCompilation = other.mCompilation;
243            other.mCompilation = nullptr;
244        }
245        return *this;
246    }
247
248    Result setPreference(ExecutePreference preference) {
249        return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
250                    mCompilation, static_cast<int32_t>(preference)));
251    }
252
253    Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
254
255    ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
256
257private:
258    ANeuralNetworksCompilation* mCompilation = nullptr;
259};
260
261class Execution {
262public:
263    Execution(const Compilation* compilation) {
264        int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
265        if (result != 0) {
266            // TODO Handle the error
267        }
268    }
269
270    ~Execution() { ANeuralNetworksExecution_free(mExecution); }
271
272    // Disallow copy semantics to ensure the runtime object can only be freed
273    // once. Copy semantics could be enabled if some sort of reference counting
274    // or deep-copy system for runtime objects is added later.
275    Execution(const Execution&) = delete;
276    Execution& operator=(const Execution&) = delete;
277
278    // Move semantics to remove access to the runtime object from the wrapper
279    // object that is being moved. This ensures the runtime object will be
280    // freed only once.
281    Execution(Execution&& other) { *this = std::move(other); }
282    Execution& operator=(Execution&& other) {
283        if (this != &other) {
284            mExecution = other.mExecution;
285            other.mExecution = nullptr;
286        }
287        return *this;
288    }
289
290    Result setInput(uint32_t index, const void* buffer, size_t length,
291                    const ANeuralNetworksOperandType* type = nullptr) {
292        return static_cast<Result>(
293                    ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
294    }
295
296    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
297                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
298        return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
299                    mExecution, index, type, memory->get(), offset, length));
300    }
301
302    Result setOutput(uint32_t index, void* buffer, size_t length,
303                     const ANeuralNetworksOperandType* type = nullptr) {
304        return static_cast<Result>(
305                    ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
306    }
307
308    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
309                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
310        return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
311                    mExecution, index, type, memory->get(), offset, length));
312    }
313
314    Result startCompute(Event* event) {
315        ANeuralNetworksEvent* ev = nullptr;
316        Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
317        event->set(ev);
318        return result;
319    }
320
321    Result compute() {
322        ANeuralNetworksEvent* event = nullptr;
323        Result result =
324                    static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
325        if (result != Result::NO_ERROR) {
326            return result;
327        }
328        // TODO how to manage the lifetime of events when multiple waiters is not
329        // clear.
330        result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
331        ANeuralNetworksEvent_free(event);
332        return result;
333    }
334
335private:
336    ANeuralNetworksExecution* mExecution = nullptr;
337};
338
339}  // namespace wrapper
340}  // namespace nn
341}  // namespace android
342
343#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
344