NeuralNetworksWrapper.h revision 425b2594c76e934dfdbc93209253e3c189571149
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT32 = ANEURALNETWORKS_FLOAT32,
33    INT32 = ANEURALNETWORKS_INT32,
34    UINT32 = ANEURALNETWORKS_UINT32,
35    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
36    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
37    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
38};
39
40enum class ExecutePreference {
41    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
42    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
43    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
44};
45
46enum class Result {
47    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
48    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
49    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
50    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
51    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
52};
53
54struct OperandType {
55    ANeuralNetworksOperandType operandType;
56    // int32_t type;
57    std::vector<uint32_t> dimensions;
58
59    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
60        operandType.type = static_cast<int32_t>(type);
61        operandType.scale = 0.0f;
62        operandType.offset = 0;
63
64        operandType.dimensionCount = static_cast<uint32_t>(dimensions.size());
65        operandType.dimensions = dimensions.data();
66    }
67
68    OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) {
69        operandType.scale = scale;
70    }
71
72    OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d)
73        : OperandType(type, d) {
74        uint8_t q_min = std::numeric_limits<uint8_t>::min();
75        uint8_t q_max = std::numeric_limits<uint8_t>::max();
76        float range = q_max - q_min;
77        float scale = (f_max - f_min) / range;
78        int32_t offset =
79                    fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale))));
80
81        operandType.scale = scale;
82        operandType.offset = offset;
83    }
84};
85
86class Memory {
87public:
88    Memory(size_t size, int protect, int fd, size_t offset) {
89        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
90                 ANEURALNETWORKS_NO_ERROR;
91    }
92
93    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
94
95    // Disallow copy semantics to ensure the runtime object can only be freed
96    // once. Copy semantics could be enabled if some sort of reference counting
97    // or deep-copy system for runtime objects is added later.
98    Memory(const Memory&) = delete;
99    Memory& operator=(const Memory&) = delete;
100
101    // Move semantics to remove access to the runtime object from the wrapper
102    // object that is being moved. This ensures the runtime object will be
103    // freed only once.
104    Memory(Memory&& other) { *this = std::move(other); }
105    Memory& operator=(Memory&& other) {
106        if (this != &other) {
107            mMemory = other.mMemory;
108            mValid = other.mValid;
109            other.mMemory = nullptr;
110            other.mValid = false;
111        }
112        return *this;
113    }
114
115    ANeuralNetworksMemory* get() const { return mMemory; }
116    bool isValid() const { return mValid; }
117
118private:
119    ANeuralNetworksMemory* mMemory = nullptr;
120    bool mValid = true;
121};
122
123class Model {
124public:
125    Model() {
126        // TODO handle the value returned by this call
127        ANeuralNetworksModel_create(&mModel);
128    }
129    ~Model() { ANeuralNetworksModel_free(mModel); }
130
131    // Disallow copy semantics to ensure the runtime object can only be freed
132    // once. Copy semantics could be enabled if some sort of reference counting
133    // or deep-copy system for runtime objects is added later.
134    Model(const Model&) = delete;
135    Model& operator=(const Model&) = delete;
136
137    // Move semantics to remove access to the runtime object from the wrapper
138    // object that is being moved. This ensures the runtime object will be
139    // freed only once.
140    Model(Model&& other) { *this = std::move(other); }
141    Model& operator=(Model&& other) {
142        if (this != &other) {
143            mModel = other.mModel;
144            mNextOperandId = other.mNextOperandId;
145            mValid = other.mValid;
146            other.mModel = nullptr;
147            other.mNextOperandId = 0;
148            other.mValid = false;
149        }
150        return *this;
151    }
152
153    Result finish() {
154        return static_cast<Result>(ANeuralNetworksModel_finish(mModel));
155    }
156
157    uint32_t addOperand(const OperandType* type) {
158        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
159            ANEURALNETWORKS_NO_ERROR) {
160            mValid = false;
161        }
162        return mNextOperandId++;
163    }
164
165    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
166        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
167            ANEURALNETWORKS_NO_ERROR) {
168            mValid = false;
169        }
170    }
171
172    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
173                                   size_t length) {
174        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
175                                                           length) != ANEURALNETWORKS_NO_ERROR) {
176            mValid = false;
177        }
178    }
179
180    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
181                      const std::vector<uint32_t>& outputs) {
182        if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
183                                              inputs.data(), static_cast<uint32_t>(outputs.size()),
184                                              outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
185            mValid = false;
186        }
187    }
188    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
189                             const std::vector<uint32_t>& outputs) {
190        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, static_cast<uint32_t>(inputs.size()),
191                                                     inputs.data(),
192                                                     static_cast<uint32_t>(outputs.size()),
193                                                     outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
194            mValid = false;
195        }
196    }
197    ANeuralNetworksModel* getHandle() const { return mModel; }
198    bool isValid() const { return mValid; }
199
200private:
201    ANeuralNetworksModel* mModel = nullptr;
202    // We keep track of the operand ID as a convenience to the caller.
203    uint32_t mNextOperandId = 0;
204    bool mValid = true;
205};
206
207class Event {
208public:
209    ~Event() { ANeuralNetworksEvent_free(mEvent); }
210
211    // Disallow copy semantics to ensure the runtime object can only be freed
212    // once. Copy semantics could be enabled if some sort of reference counting
213    // or deep-copy system for runtime objects is added later.
214    Event(const Event&) = delete;
215    Event& operator=(const Event&) = delete;
216
217    // Move semantics to remove access to the runtime object from the wrapper
218    // object that is being moved. This ensures the runtime object will be
219    // freed only once.
220    Event(Event&& other) {
221        *this = std::move(other);
222    }
223    Event& operator=(Event&& other) {
224        if (this != &other) {
225            mEvent = other.mEvent;
226            other.mEvent = nullptr;
227        }
228        return *this;
229    }
230
231    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
232
233    // Only for use by Compilation
234    void set(ANeuralNetworksEvent* newEvent) {
235        ANeuralNetworksEvent_free(mEvent);
236        mEvent = newEvent;
237    }
238
239private:
240    ANeuralNetworksEvent* mEvent = nullptr;
241};
242
243class Compilation {
244public:
245    Compilation(const Model* model) {
246        int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
247        if (result != 0) {
248            // TODO Handle the error
249        }
250    }
251
252    ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
253
254    Compilation(const Compilation&) = delete;
255    Compilation& operator=(const Compilation&) = delete;
256
257    Compilation(Compilation&& other) { *this = std::move(other); }
258    Compilation& operator=(Compilation&& other) {
259        if (this != &other) {
260            mCompilation = other.mCompilation;
261            other.mCompilation = nullptr;
262        }
263        return *this;
264    }
265
266    Result setPreference(ExecutePreference preference) {
267        return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
268                    mCompilation, static_cast<int32_t>(preference)));
269    }
270
271    Result finish() {
272        return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation));
273    }
274
275    ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
276
277private:
278    ANeuralNetworksCompilation* mCompilation = nullptr;
279};
280
281class Execution {
282public:
283    Execution(const Compilation* compilation) {
284        int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
285        if (result != 0) {
286            // TODO Handle the error
287        }
288    }
289
290    ~Execution() { ANeuralNetworksExecution_free(mExecution); }
291
292    // Disallow copy semantics to ensure the runtime object can only be freed
293    // once. Copy semantics could be enabled if some sort of reference counting
294    // or deep-copy system for runtime objects is added later.
295    Execution(const Execution&) = delete;
296    Execution& operator=(const Execution&) = delete;
297
298    // Move semantics to remove access to the runtime object from the wrapper
299    // object that is being moved. This ensures the runtime object will be
300    // freed only once.
301    Execution(Execution&& other) { *this = std::move(other); }
302    Execution& operator=(Execution&& other) {
303        if (this != &other) {
304            mExecution = other.mExecution;
305            other.mExecution = nullptr;
306        }
307        return *this;
308    }
309
310    Result setInput(uint32_t index, const void* buffer, size_t length,
311                    const ANeuralNetworksOperandType* type = nullptr) {
312        return static_cast<Result>(
313                    ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
314    }
315
316    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
317                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
318        return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
319                    mExecution, index, type, memory->get(), offset, length));
320    }
321
322    Result setOutput(uint32_t index, void* buffer, size_t length,
323                     const ANeuralNetworksOperandType* type = nullptr) {
324        return static_cast<Result>(
325                    ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
326    }
327
328    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
329                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
330        return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
331                    mExecution, index, type, memory->get(), offset, length));
332    }
333
334    Result startCompute(Event* event) {
335        ANeuralNetworksEvent* ev = nullptr;
336        Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
337        event->set(ev);
338        return result;
339    }
340
341    Result compute() {
342        ANeuralNetworksEvent* event = nullptr;
343        Result result =
344                static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
345        if (result != Result::NO_ERROR) {
346            return result;
347        }
348        // TODO how to manage the lifetime of events when multiple waiters is not
349        // clear.
350        result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
351        ANeuralNetworksEvent_free(event);
352        return result;
353    }
354
355private:
356    ANeuralNetworksExecution* mExecution = nullptr;
357};
358
359}  // namespace wrapper
360}  // namespace nn
361}  // namespace android
362
363#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
364