NeuralNetworksWrapper.h revision 910c9f04913e3bee1a0b6406b6e146457d19c5e7
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT16 = ANEURALNETWORKS_FLOAT16,
33    FLOAT32 = ANEURALNETWORKS_FLOAT32,
34    INT8 = ANEURALNETWORKS_INT8,
35    UINT8 = ANEURALNETWORKS_UINT8,
36    INT16 = ANEURALNETWORKS_INT16,
37    UINT16 = ANEURALNETWORKS_UINT16,
38    INT32 = ANEURALNETWORKS_INT32,
39    UINT32 = ANEURALNETWORKS_UINT32,
40    TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16,
41    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
42    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
43};
44
45enum class ExecutePreference {
46    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
47    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
48    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
49};
50
51enum class Result {
52    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
53    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
54    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
55    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
56    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
57};
58
59struct OperandType {
60    ANeuralNetworksOperandType operandType;
61    // uint32_t type;
62    std::vector<uint32_t> dimensions;
63
64    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
65        operandType.type = static_cast<uint32_t>(type);
66        operandType.scale = 0.0f;
67        operandType.offset = 0;
68
69        operandType.dimensions.count = static_cast<uint32_t>(dimensions.size());
70        operandType.dimensions.data = dimensions.data();
71    }
72
73    OperandType(Type type, float scale, const std::vector<uint32_t>& d)
74            : OperandType(type, d) {
75        operandType.scale = scale;
76    }
77
78    OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d)
79            : OperandType(type, d) {
80        uint8_t q_min = std::numeric_limits<uint8_t>::min();
81        uint8_t q_max = std::numeric_limits<uint8_t>::max();
82        float range = q_max - q_min;
83        float scale = (f_max - f_min) / range;
84        int32_t offset =
85                fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale))));
86
87        operandType.scale = scale;
88        operandType.offset = offset;
89    }
90};
91
92inline Result Initialize() {
93    return static_cast<Result>(ANeuralNetworksInitialize());
94}
95
96inline void Shutdown() {
97    ANeuralNetworksShutdown();
98}
99
100class Memory {
101public:
102    // TODO Also have constructors for file descriptor, gralloc buffers, etc.
103    Memory(size_t size) {
104        mValid = ANeuralNetworksMemory_create(size, &mMemory) == ANEURALNETWORKS_NO_ERROR;
105    }
106    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
107    uint8_t* getPointer() {
108        return ANeuralNetworksMemory_getPointer(mMemory);
109    }
110    ANeuralNetworksMemory* get() const { return mMemory; }
111    bool isValid() const { return mValid; }
112
113private:
114    ANeuralNetworksMemory* mMemory = nullptr;
115    bool mValid = true;
116};
117
118class Model {
119public:
120    Model() {
121        // TODO handle the value returned by this call
122        ANeuralNetworksModel_create(&mModel);
123    }
124    ~Model() { ANeuralNetworksModel_free(mModel); }
125
126    uint32_t addOperand(const OperandType* type) {
127        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
128            ANEURALNETWORKS_NO_ERROR) {
129            mValid = false;
130        }
131        return mNextOperandId++;
132    }
133
134    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
135        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
136            ANEURALNETWORKS_NO_ERROR) {
137            mValid = false;
138        }
139    }
140
141    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
142                                   size_t length) {
143        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
144                                                           length) != ANEURALNETWORKS_NO_ERROR) {
145            mValid = false;
146        }
147    }
148
149    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
150                      const std::vector<uint32_t>& outputs) {
151        ANeuralNetworksIntList in, out;
152        Set(&in, inputs);
153        Set(&out, outputs);
154        if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) !=
155            ANEURALNETWORKS_NO_ERROR) {
156            mValid = false;
157        }
158    }
159    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
160                             const std::vector<uint32_t>& outputs) {
161        ANeuralNetworksIntList in, out;
162        Set(&in, inputs);
163        Set(&out, outputs);
164        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) !=
165            ANEURALNETWORKS_NO_ERROR) {
166            mValid = false;
167        }
168    }
169    ANeuralNetworksModel* getHandle() const { return mModel; }
170    bool isValid() const { return mValid; }
171
172private:
173    /**
174     * WARNING list won't be valid once vec is destroyed or modified.
175     */
176    void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) {
177        list->count = static_cast<uint32_t>(vec.size());
178        list->data = vec.data();
179    }
180
181    ANeuralNetworksModel* mModel = nullptr;
182    // We keep track of the operand ID as a convenience to the caller.
183    uint32_t mNextOperandId = 0;
184    bool mValid = true;
185};
186
187class Event {
188public:
189    ~Event() { ANeuralNetworksEvent_free(mEvent); }
190    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
191    void set(ANeuralNetworksEvent* newEvent) {
192        ANeuralNetworksEvent_free(mEvent);
193        mEvent = newEvent;
194    }
195
196private:
197    ANeuralNetworksEvent* mEvent = nullptr;
198};
199
200class Request {
201public:
202    Request(const Model* model) {
203        int result = ANeuralNetworksRequest_create(model->getHandle(), &mRequest);
204        if (result != 0) {
205            // TODO Handle the error
206        }
207    }
208
209    ~Request() { ANeuralNetworksRequest_free(mRequest); }
210
211    Result setPreference(ExecutePreference preference) {
212        return static_cast<Result>(
213                ANeuralNetworksRequest_setPreference(mRequest, static_cast<uint32_t>(preference)));
214    }
215
216    Result setInput(uint32_t index, const void* buffer, size_t length,
217                    const ANeuralNetworksOperandType* type = nullptr) {
218        return static_cast<Result>(
219                ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length));
220    }
221
222    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
223                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
224        return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory(mRequest, index, type,
225                                                                             memory->get(), offset,
226                                                                             length));
227    }
228
229    Result setOutput(uint32_t index, void* buffer, size_t length,
230                     const ANeuralNetworksOperandType* type = nullptr) {
231        return static_cast<Result>(
232                ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length));
233    }
234
235    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
236                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
237        return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory(mRequest, index, type,
238                                                                              memory->get(), offset,
239                                                                              length));
240    }
241
242    Result startCompute(Event* event) {
243        ANeuralNetworksEvent* ev = nullptr;
244        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev));
245        event->set(ev);
246        return result;
247    }
248
249    Result compute() {
250        ANeuralNetworksEvent* event = nullptr;
251        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event));
252        if (result != Result::NO_ERROR) {
253            return result;
254        }
255        // TODO how to manage the lifetime of events when multiple waiters is not
256        // clear.
257        return static_cast<Result>(ANeuralNetworksEvent_wait(event));
258    }
259
260private:
261    ANeuralNetworksRequest* mRequest = nullptr;
262};
263
264} // namespace wrapper
265} // namespace nn
266} // namespace android
267
268#endif //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
269