NeuralNetworksWrapper.h revision 0afe5897f4034528b027294efbe45c836924643c
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT16 = ANEURALNETWORKS_FLOAT16,
33    FLOAT32 = ANEURALNETWORKS_FLOAT32,
34    INT8 = ANEURALNETWORKS_INT8,
35    UINT8 = ANEURALNETWORKS_UINT8,
36    INT16 = ANEURALNETWORKS_INT16,
37    UINT16 = ANEURALNETWORKS_UINT16,
38    INT32 = ANEURALNETWORKS_INT32,
39    UINT32 = ANEURALNETWORKS_UINT32,
40    TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16,
41    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
42    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
43    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
44};
45
46enum class ExecutePreference {
47    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
48    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
49    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
50};
51
52enum class Result {
53    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
54    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
55    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
56    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
57    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
58};
59
60struct OperandType {
61    ANeuralNetworksOperandType operandType;
62    // uint32_t type;
63    std::vector<uint32_t> dimensions;
64
65    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
66        operandType.type = static_cast<uint32_t>(type);
67        operandType.scale = 0.0f;
68        operandType.offset = 0;
69
70        operandType.dimensions.count = static_cast<uint32_t>(dimensions.size());
71        operandType.dimensions.data = dimensions.data();
72    }
73
74    OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) {
75        operandType.scale = scale;
76    }
77
78    OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d)
79        : OperandType(type, d) {
80        uint8_t q_min = std::numeric_limits<uint8_t>::min();
81        uint8_t q_max = std::numeric_limits<uint8_t>::max();
82        float range = q_max - q_min;
83        float scale = (f_max - f_min) / range;
84        int32_t offset =
85                    fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale))));
86
87        operandType.scale = scale;
88        operandType.offset = offset;
89    }
90};
91
92inline Result Initialize() {
93    return static_cast<Result>(ANeuralNetworksInitialize());
94}
95
96inline void Shutdown() {
97    ANeuralNetworksShutdown();
98}
99
100class Memory {
101public:
102    Memory(size_t size) {
103        mValid = ANeuralNetworksMemory_createShared(size, &mMemory) == ANEURALNETWORKS_NO_ERROR;
104    }
105    Memory(size_t size, int protect, int fd) {
106        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, &mMemory) ==
107                         ANEURALNETWORKS_NO_ERROR;
108    }
109
110    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
111    Result getPointer(uint8_t** buffer) {
112        return static_cast<Result>(ANeuralNetworksMemory_getPointer(mMemory, buffer));
113    }
114
115    // Disallow copy semantics to ensure the runtime object can only be freed
116    // once. Copy semantics could be enabled if some sort of reference counting
117    // or deep-copy system for runtime objects is added later.
118    Memory(const Memory&) = delete;
119    Memory& operator=(const Memory&) = delete;
120
121    // Move semantics to remove access to the runtime object from the wrapper
122    // object that is being moved. This ensures the runtime object will be
123    // freed only once.
124    Memory(Memory&& other) {
125        *this = std::move(other);
126    }
127    Memory& operator=(Memory&& other) {
128        if (this != &other) {
129            mMemory = other.mMemory;
130            mValid = other.mValid;
131            other.mMemory = nullptr;
132            other.mValid = false;
133        }
134        return *this;
135    }
136
137    ANeuralNetworksMemory* get() const { return mMemory; }
138    bool isValid() const { return mValid; }
139
140private:
141    ANeuralNetworksMemory* mMemory = nullptr;
142    bool mValid = true;
143};
144
145class Model {
146public:
147    Model() {
148        // TODO handle the value returned by this call
149        ANeuralNetworksModel_create(&mModel);
150    }
151    ~Model() { ANeuralNetworksModel_free(mModel); }
152
153    // Disallow copy semantics to ensure the runtime object can only be freed
154    // once. Copy semantics could be enabled if some sort of reference counting
155    // or deep-copy system for runtime objects is added later.
156    Model(const Model&) = delete;
157    Model& operator=(const Model&) = delete;
158
159    // Move semantics to remove access to the runtime object from the wrapper
160    // object that is being moved. This ensures the runtime object will be
161    // freed only once.
162    Model(Model&& other) {
163        *this = std::move(other);
164    }
165    Model& operator=(Model&& other) {
166        if (this != &other) {
167            mModel = other.mModel;
168            mNextOperandId = other.mNextOperandId;
169            mValid = other.mValid;
170            other.mModel = nullptr;
171            other.mNextOperandId = 0;
172            other.mValid = false;
173        }
174        return *this;
175    }
176
177    uint32_t addOperand(const OperandType* type) {
178        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
179            ANEURALNETWORKS_NO_ERROR) {
180            mValid = false;
181        }
182        return mNextOperandId++;
183    }
184
185    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
186        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
187            ANEURALNETWORKS_NO_ERROR) {
188            mValid = false;
189        }
190    }
191
192    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
193                                   size_t length) {
194        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
195                                                           length) != ANEURALNETWORKS_NO_ERROR) {
196            mValid = false;
197        }
198    }
199
200    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
201                      const std::vector<uint32_t>& outputs) {
202        ANeuralNetworksIntList in, out;
203        Set(&in, inputs);
204        Set(&out, outputs);
205        if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) !=
206            ANEURALNETWORKS_NO_ERROR) {
207            mValid = false;
208        }
209    }
210    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
211                             const std::vector<uint32_t>& outputs) {
212        ANeuralNetworksIntList in, out;
213        Set(&in, inputs);
214        Set(&out, outputs);
215        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) !=
216            ANEURALNETWORKS_NO_ERROR) {
217            mValid = false;
218        }
219    }
220    ANeuralNetworksModel* getHandle() const { return mModel; }
221    bool isValid() const { return mValid; }
222
223private:
224    /**
225     * WARNING list won't be valid once vec is destroyed or modified.
226     */
227    void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) {
228        list->count = static_cast<uint32_t>(vec.size());
229        list->data = vec.data();
230    }
231
232    ANeuralNetworksModel* mModel = nullptr;
233    // We keep track of the operand ID as a convenience to the caller.
234    uint32_t mNextOperandId = 0;
235    bool mValid = true;
236};
237
238class Event {
239public:
240    ~Event() { ANeuralNetworksEvent_free(mEvent); }
241
242    // Disallow copy semantics to ensure the runtime object can only be freed
243    // once. Copy semantics could be enabled if some sort of reference counting
244    // or deep-copy system for runtime objects is added later.
245    Event(const Event&) = delete;
246    Event& operator=(const Event&) = delete;
247
248    // Move semantics to remove access to the runtime object from the wrapper
249    // object that is being moved. This ensures the runtime object will be
250    // freed only once.
251    Event(Event&& other) {
252        *this = std::move(other);
253    }
254    Event& operator=(Event&& other) {
255        if (this != &other) {
256            mEvent = other.mEvent;
257            other.mEvent = nullptr;
258        }
259        return *this;
260    }
261
262    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
263    void set(ANeuralNetworksEvent* newEvent) {
264        ANeuralNetworksEvent_free(mEvent);
265        mEvent = newEvent;
266    }
267
268private:
269    ANeuralNetworksEvent* mEvent = nullptr;
270};
271
272class Request {
273public:
274    Request(const Model* model) {
275        int result = ANeuralNetworksRequest_create(model->getHandle(), &mRequest);
276        if (result != 0) {
277            // TODO Handle the error
278        }
279    }
280
281    ~Request() { ANeuralNetworksRequest_free(mRequest); }
282
283    // Disallow copy semantics to ensure the runtime object can only be freed
284    // once. Copy semantics could be enabled if some sort of reference counting
285    // or deep-copy system for runtime objects is added later.
286    Request(const Request&) = delete;
287    Request& operator=(const Request&) = delete;
288
289    // Move semantics to remove access to the runtime object from the wrapper
290    // object that is being moved. This ensures the runtime object will be
291    // freed only once.
292    Request(Request&& other) {
293        *this = std::move(other);
294    }
295    Request& operator=(Request&& other) {
296        if (this != &other) {
297            mRequest = other.mRequest;
298            other.mRequest = nullptr;
299        }
300        return *this;
301    }
302
303    Result setPreference(ExecutePreference preference) {
304        return static_cast<Result>(ANeuralNetworksRequest_setPreference(
305                    mRequest, static_cast<uint32_t>(preference)));
306    }
307
308    Result setInput(uint32_t index, const void* buffer, size_t length,
309                    const ANeuralNetworksOperandType* type = nullptr) {
310        return static_cast<Result>(
311                    ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length));
312    }
313
314    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
315                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
316        return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory(
317                    mRequest, index, type, memory->get(), offset, length));
318    }
319
320    Result setOutput(uint32_t index, void* buffer, size_t length,
321                     const ANeuralNetworksOperandType* type = nullptr) {
322        return static_cast<Result>(
323                    ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length));
324    }
325
326    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
327                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
328        return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory(
329                    mRequest, index, type, memory->get(), offset, length));
330    }
331
332    Result startCompute(Event* event) {
333        ANeuralNetworksEvent* ev = nullptr;
334        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev));
335        event->set(ev);
336        return result;
337    }
338
339    Result compute() {
340        ANeuralNetworksEvent* event = nullptr;
341        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event));
342        if (result != Result::NO_ERROR) {
343            return result;
344        }
345        // TODO how to manage the lifetime of events when multiple waiters is not
346        // clear.
347        return static_cast<Result>(ANeuralNetworksEvent_wait(event));
348    }
349
350private:
351    ANeuralNetworksRequest* mRequest = nullptr;
352};
353
354}  // namespace wrapper
355}  // namespace nn
356}  // namespace android
357
358#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
359