NeuralNetworksWrapper.h revision 7612f29b31f97f3b15769264131566b36dea9a25
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT32 = ANEURALNETWORKS_FLOAT32,
33    INT32 = ANEURALNETWORKS_INT32,
34    UINT32 = ANEURALNETWORKS_UINT32,
35    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
36    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
37    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
38};
39
40enum class ExecutePreference {
41    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
42    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
43    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
44};
45
46enum class Result {
47    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
48    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
49    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
50    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
51    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
52};
53
54struct OperandType {
55    ANeuralNetworksOperandType operandType;
56    // int32_t type;
57    std::vector<uint32_t> dimensions;
58
59    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
60        operandType.type = static_cast<int32_t>(type);
61        operandType.scale = 0.0f;
62        operandType.offset = 0;
63
64        operandType.dimensionCount = static_cast<uint32_t>(dimensions.size());
65        operandType.dimensions = dimensions.data();
66    }
67
68    OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) {
69        operandType.scale = scale;
70    }
71
72    OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d)
73        : OperandType(type, d) {
74        uint8_t q_min = std::numeric_limits<uint8_t>::min();
75        uint8_t q_max = std::numeric_limits<uint8_t>::max();
76        float range = q_max - q_min;
77        float scale = (f_max - f_min) / range;
78        int32_t offset =
79                    fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale))));
80
81        operandType.scale = scale;
82        operandType.offset = offset;
83    }
84};
85
86class Memory {
87public:
88    Memory(size_t size, int protect, int fd, size_t offset) {
89        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
90                 ANEURALNETWORKS_NO_ERROR;
91    }
92
93    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
94
95    // Disallow copy semantics to ensure the runtime object can only be freed
96    // once. Copy semantics could be enabled if some sort of reference counting
97    // or deep-copy system for runtime objects is added later.
98    Memory(const Memory&) = delete;
99    Memory& operator=(const Memory&) = delete;
100
101    // Move semantics to remove access to the runtime object from the wrapper
102    // object that is being moved. This ensures the runtime object will be
103    // freed only once.
104    Memory(Memory&& other) { *this = std::move(other); }
105    Memory& operator=(Memory&& other) {
106        if (this != &other) {
107            mMemory = other.mMemory;
108            mValid = other.mValid;
109            other.mMemory = nullptr;
110            other.mValid = false;
111        }
112        return *this;
113    }
114
115    ANeuralNetworksMemory* get() const { return mMemory; }
116    bool isValid() const { return mValid; }
117
118private:
119    ANeuralNetworksMemory* mMemory = nullptr;
120    bool mValid = true;
121};
122
123class Model {
124public:
125    Model() {
126        // TODO handle the value returned by this call
127        ANeuralNetworksModel_create(&mModel);
128    }
129    ~Model() { ANeuralNetworksModel_free(mModel); }
130
131    // Disallow copy semantics to ensure the runtime object can only be freed
132    // once. Copy semantics could be enabled if some sort of reference counting
133    // or deep-copy system for runtime objects is added later.
134    Model(const Model&) = delete;
135    Model& operator=(const Model&) = delete;
136
137    // Move semantics to remove access to the runtime object from the wrapper
138    // object that is being moved. This ensures the runtime object will be
139    // freed only once.
140    Model(Model&& other) { *this = std::move(other); }
141    Model& operator=(Model&& other) {
142        if (this != &other) {
143            mModel = other.mModel;
144            mNextOperandId = other.mNextOperandId;
145            mValid = other.mValid;
146            other.mModel = nullptr;
147            other.mNextOperandId = 0;
148            other.mValid = false;
149        }
150        return *this;
151    }
152
153    int finish() { return ANeuralNetworksModel_finish(mModel); }
154
155    uint32_t addOperand(const OperandType* type) {
156        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
157            ANEURALNETWORKS_NO_ERROR) {
158            mValid = false;
159        }
160        return mNextOperandId++;
161    }
162
163    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
164        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
165            ANEURALNETWORKS_NO_ERROR) {
166            mValid = false;
167        }
168    }
169
170    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
171                                   size_t length) {
172        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
173                                                           length) != ANEURALNETWORKS_NO_ERROR) {
174            mValid = false;
175        }
176    }
177
178    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
179                      const std::vector<uint32_t>& outputs) {
180        if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
181                                              inputs.data(), static_cast<uint32_t>(outputs.size()),
182                                              outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
183            mValid = false;
184        }
185    }
186    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
187                             const std::vector<uint32_t>& outputs) {
188        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, static_cast<uint32_t>(inputs.size()),
189                                                     inputs.data(),
190                                                     static_cast<uint32_t>(outputs.size()),
191                                                     outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
192            mValid = false;
193        }
194    }
195    ANeuralNetworksModel* getHandle() const { return mModel; }
196    bool isValid() const { return mValid; }
197
198private:
199    ANeuralNetworksModel* mModel = nullptr;
200    // We keep track of the operand ID as a convenience to the caller.
201    uint32_t mNextOperandId = 0;
202    bool mValid = true;
203};
204
205class Compilation {
206public:
207    Compilation(const Model* model) {
208        int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
209        if (result != 0) {
210            // TODO Handle the error
211        }
212    }
213
214    ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
215
216    Compilation(const Compilation&) = delete;
217    Compilation& operator=(const Compilation&) = delete;
218
219    Compilation(Compilation&& other) { *this = std::move(other); }
220    Compilation& operator=(Compilation&& other) {
221        if (this != &other) {
222            mCompilation = other.mCompilation;
223            other.mCompilation = nullptr;
224        }
225        return *this;
226    }
227
228    Result setPreference(ExecutePreference preference) {
229        return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
230                    mCompilation, static_cast<int32_t>(preference)));
231    }
232
233    // TODO startCompile
234
235    Result compile() {
236        Result result = static_cast<Result>(ANeuralNetworksCompilation_start(mCompilation));
237        if (result != Result::NO_ERROR) {
238            return result;
239        }
240        // TODO how to manage the lifetime of compilations when multiple waiters
241        // is not clear.
242        return static_cast<Result>(ANeuralNetworksCompilation_wait(mCompilation));
243    }
244
245    ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
246
247private:
248    ANeuralNetworksCompilation* mCompilation = nullptr;
249};
250
251class Execution {
252public:
253    Execution(const Compilation* compilation) {
254        int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
255        if (result != 0) {
256            // TODO Handle the error
257        }
258    }
259
260    ~Execution() { ANeuralNetworksExecution_free(mExecution); }
261
262    // Disallow copy semantics to ensure the runtime object can only be freed
263    // once. Copy semantics could be enabled if some sort of reference counting
264    // or deep-copy system for runtime objects is added later.
265    Execution(const Execution&) = delete;
266    Execution& operator=(const Execution&) = delete;
267
268    // Move semantics to remove access to the runtime object from the wrapper
269    // object that is being moved. This ensures the runtime object will be
270    // freed only once.
271    Execution(Execution&& other) { *this = std::move(other); }
272    Execution& operator=(Execution&& other) {
273        if (this != &other) {
274            mExecution = other.mExecution;
275            other.mExecution = nullptr;
276        }
277        return *this;
278    }
279
280    Result setInput(uint32_t index, const void* buffer, size_t length,
281                    const ANeuralNetworksOperandType* type = nullptr) {
282        return static_cast<Result>(
283                    ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
284    }
285
286    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
287                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
288        return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
289                    mExecution, index, type, memory->get(), offset, length));
290    }
291
292    Result setOutput(uint32_t index, void* buffer, size_t length,
293                     const ANeuralNetworksOperandType* type = nullptr) {
294        return static_cast<Result>(
295                    ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
296    }
297
298    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
299                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
300        return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
301                    mExecution, index, type, memory->get(), offset, length));
302    }
303
304    Result startCompute() {
305        Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution));
306        return result;
307    }
308
309    Result wait() { return static_cast<Result>(ANeuralNetworksExecution_wait(mExecution)); }
310
311    Result compute() {
312        Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution));
313        if (result != Result::NO_ERROR) {
314            return result;
315        }
316        // TODO how to manage the lifetime of events when multiple waiters is not
317        // clear.
318        return static_cast<Result>(ANeuralNetworksExecution_wait(mExecution));
319    }
320
321private:
322    ANeuralNetworksExecution* mExecution = nullptr;
323};
324
325}  // namespace wrapper
326}  // namespace nn
327}  // namespace android
328
329#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
330