NeuralNetworksWrapper.h revision 42dc6a6cd68877cd85e3bc475b41bda0fd946c41
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT16 = ANEURALNETWORKS_FLOAT16,
33    FLOAT32 = ANEURALNETWORKS_FLOAT32,
34    INT8 = ANEURALNETWORKS_INT8,
35    UINT8 = ANEURALNETWORKS_UINT8,
36    INT16 = ANEURALNETWORKS_INT16,
37    UINT16 = ANEURALNETWORKS_UINT16,
38    INT32 = ANEURALNETWORKS_INT32,
39    UINT32 = ANEURALNETWORKS_UINT32,
40    TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16,
41    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
42    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
43    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
44};
45
46enum class ExecutePreference {
47    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
48    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
49    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
50};
51
52enum class Result {
53    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
54    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
55    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
56    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
57    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
58};
59
60struct OperandType {
61    ANeuralNetworksOperandType operandType;
62    // uint32_t type;
63    std::vector<uint32_t> dimensions;
64
65    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
66        operandType.type = static_cast<uint32_t>(type);
67        operandType.scale = 0.0f;
68        operandType.offset = 0;
69
70        operandType.dimensions.count = static_cast<uint32_t>(dimensions.size());
71        operandType.dimensions.data = dimensions.data();
72    }
73
74    OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) {
75        operandType.scale = scale;
76    }
77
78    OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d)
79        : OperandType(type, d) {
80        uint8_t q_min = std::numeric_limits<uint8_t>::min();
81        uint8_t q_max = std::numeric_limits<uint8_t>::max();
82        float range = q_max - q_min;
83        float scale = (f_max - f_min) / range;
84        int32_t offset =
85                    fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale))));
86
87        operandType.scale = scale;
88        operandType.offset = offset;
89    }
90};
91
92inline Result Initialize() {
93    return static_cast<Result>(ANeuralNetworksInitialize());
94}
95
96inline void Shutdown() {
97    ANeuralNetworksShutdown();
98}
99
100class Memory {
101public:
102
103    Memory(size_t size, int protect, int fd, size_t offset) {
104        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
105                         ANEURALNETWORKS_NO_ERROR;
106    }
107
108    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
109
110    // Disallow copy semantics to ensure the runtime object can only be freed
111    // once. Copy semantics could be enabled if some sort of reference counting
112    // or deep-copy system for runtime objects is added later.
113    Memory(const Memory&) = delete;
114    Memory& operator=(const Memory&) = delete;
115
116    // Move semantics to remove access to the runtime object from the wrapper
117    // object that is being moved. This ensures the runtime object will be
118    // freed only once.
119    Memory(Memory&& other) {
120        *this = std::move(other);
121    }
122    Memory& operator=(Memory&& other) {
123        if (this != &other) {
124            mMemory = other.mMemory;
125            mValid = other.mValid;
126            other.mMemory = nullptr;
127            other.mValid = false;
128        }
129        return *this;
130    }
131
132    ANeuralNetworksMemory* get() const { return mMemory; }
133    bool isValid() const { return mValid; }
134
135private:
136    ANeuralNetworksMemory* mMemory = nullptr;
137    bool mValid = true;
138};
139
140class Model {
141public:
142    Model() {
143        // TODO handle the value returned by this call
144        ANeuralNetworksModel_create(&mModel);
145    }
146    ~Model() { ANeuralNetworksModel_free(mModel); }
147
148    // Disallow copy semantics to ensure the runtime object can only be freed
149    // once. Copy semantics could be enabled if some sort of reference counting
150    // or deep-copy system for runtime objects is added later.
151    Model(const Model&) = delete;
152    Model& operator=(const Model&) = delete;
153
154    // Move semantics to remove access to the runtime object from the wrapper
155    // object that is being moved. This ensures the runtime object will be
156    // freed only once.
157    Model(Model&& other) {
158        *this = std::move(other);
159    }
160    Model& operator=(Model&& other) {
161        if (this != &other) {
162            mModel = other.mModel;
163            mNextOperandId = other.mNextOperandId;
164            mValid = other.mValid;
165            other.mModel = nullptr;
166            other.mNextOperandId = 0;
167            other.mValid = false;
168        }
169        return *this;
170    }
171
172    int finish() { return ANeuralNetworksModel_finish(mModel); }
173
174    uint32_t addOperand(const OperandType* type) {
175        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
176            ANEURALNETWORKS_NO_ERROR) {
177            mValid = false;
178        }
179        return mNextOperandId++;
180    }
181
182    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
183        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
184            ANEURALNETWORKS_NO_ERROR) {
185            mValid = false;
186        }
187    }
188
189    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
190                                   size_t length) {
191        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
192                                                           length) != ANEURALNETWORKS_NO_ERROR) {
193            mValid = false;
194        }
195    }
196
197    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
198                      const std::vector<uint32_t>& outputs) {
199        ANeuralNetworksIntList in, out;
200        Set(&in, inputs);
201        Set(&out, outputs);
202        if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) !=
203            ANEURALNETWORKS_NO_ERROR) {
204            mValid = false;
205        }
206    }
207    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
208                             const std::vector<uint32_t>& outputs) {
209        ANeuralNetworksIntList in, out;
210        Set(&in, inputs);
211        Set(&out, outputs);
212        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) !=
213            ANEURALNETWORKS_NO_ERROR) {
214            mValid = false;
215        }
216    }
217    ANeuralNetworksModel* getHandle() const { return mModel; }
218    bool isValid() const { return mValid; }
219
220private:
221    /**
222     * WARNING list won't be valid once vec is destroyed or modified.
223     */
224    void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) {
225        list->count = static_cast<uint32_t>(vec.size());
226        list->data = vec.data();
227    }
228
229    ANeuralNetworksModel* mModel = nullptr;
230    // We keep track of the operand ID as a convenience to the caller.
231    uint32_t mNextOperandId = 0;
232    bool mValid = true;
233};
234
235class Compilation {
236public:
237    Compilation(const Model* model) {
238        int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
239        if (result != 0) {
240            // TODO Handle the error
241        }
242    }
243
244    ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
245
246    Compilation(const Compilation&) = delete;
247    Compilation& operator=(const Compilation &) = delete;
248
249    Compilation(Compilation&& other) {
250        *this = std::move(other);
251    }
252    Compilation& operator=(Compilation&& other) {
253        if (this != &other) {
254            mCompilation = other.mCompilation;
255            other.mCompilation = nullptr;
256        }
257        return *this;
258    }
259
260    Result setPreference(ExecutePreference preference) {
261        return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
262                    mCompilation, static_cast<uint32_t>(preference)));
263    }
264
265    // TODO startCompile
266
267    Result compile() {
268        Result result = static_cast<Result>(ANeuralNetworksCompilation_start(mCompilation));
269        if (result != Result::NO_ERROR) {
270            return result;
271        }
272        // TODO how to manage the lifetime of compilations when multiple waiters
273        // is not clear.
274        return static_cast<Result>(ANeuralNetworksCompilation_wait(mCompilation));
275    }
276
277    ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
278
279private:
280    ANeuralNetworksCompilation* mCompilation = nullptr;
281};
282
283class Request {
284public:
285    Request(const Compilation* compilation) {
286        int result = ANeuralNetworksRequest_create(compilation->getHandle(), &mRequest);
287        if (result != 0) {
288            // TODO Handle the error
289        }
290    }
291
292    ~Request() { ANeuralNetworksRequest_free(mRequest); }
293
294    // Disallow copy semantics to ensure the runtime object can only be freed
295    // once. Copy semantics could be enabled if some sort of reference counting
296    // or deep-copy system for runtime objects is added later.
297    Request(const Request&) = delete;
298    Request& operator=(const Request&) = delete;
299
300    // Move semantics to remove access to the runtime object from the wrapper
301    // object that is being moved. This ensures the runtime object will be
302    // freed only once.
303    Request(Request&& other) {
304        *this = std::move(other);
305    }
306    Request& operator=(Request&& other) {
307        if (this != &other) {
308            mRequest = other.mRequest;
309            other.mRequest = nullptr;
310        }
311        return *this;
312    }
313
314    Result setInput(uint32_t index, const void* buffer, size_t length,
315                    const ANeuralNetworksOperandType* type = nullptr) {
316        return static_cast<Result>(
317                    ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length));
318    }
319
320    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
321                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
322        return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory(
323                    mRequest, index, type, memory->get(), offset, length));
324    }
325
326    Result setOutput(uint32_t index, void* buffer, size_t length,
327                     const ANeuralNetworksOperandType* type = nullptr) {
328        return static_cast<Result>(
329                    ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length));
330    }
331
332    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
333                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
334        return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory(
335                    mRequest, index, type, memory->get(), offset, length));
336    }
337
338    Result startCompute() {
339        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest));
340        return result;
341    }
342
343    Result wait() {
344        return static_cast<Result>(ANeuralNetworksRequest_wait(mRequest));
345    }
346
347    Result compute() {
348        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest));
349        if (result != Result::NO_ERROR) {
350            return result;
351        }
352        // TODO how to manage the lifetime of events when multiple waiters is not
353        // clear.
354        return static_cast<Result>(ANeuralNetworksRequest_wait(mRequest));
355    }
356
357private:
358    ANeuralNetworksRequest* mRequest = nullptr;
359};
360
361}  // namespace wrapper
362}  // namespace nn
363}  // namespace android
364
365#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
366