NeuralNetworksWrapper.h revision 18c58d289c2346d750301392866229630960b392
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <math.h>
25#include <vector>
26
27namespace android {
28namespace nn {
29namespace wrapper {
30
31enum class Type {
32    FLOAT16 = ANEURALNETWORKS_FLOAT16,
33    FLOAT32 = ANEURALNETWORKS_FLOAT32,
34    INT8 = ANEURALNETWORKS_INT8,
35    UINT8 = ANEURALNETWORKS_UINT8,
36    INT16 = ANEURALNETWORKS_INT16,
37    UINT16 = ANEURALNETWORKS_UINT16,
38    INT32 = ANEURALNETWORKS_INT32,
39    UINT32 = ANEURALNETWORKS_UINT32,
40    TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16,
41    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
42    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
43    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
44};
45
46enum class ExecutePreference {
47    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
48    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
49    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
50};
51
52enum class Result {
53    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
54    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
55    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
56    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
57    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
58};
59
60struct OperandType {
61    ANeuralNetworksOperandType operandType;
62    // uint32_t type;
63    std::vector<uint32_t> dimensions;
64
65    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
66        operandType.type = static_cast<uint32_t>(type);
67        operandType.scale = 0.0f;
68        operandType.offset = 0;
69
70        operandType.dimensions.count = static_cast<uint32_t>(dimensions.size());
71        operandType.dimensions.data = dimensions.data();
72    }
73
74    OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) {
75        operandType.scale = scale;
76    }
77
78    OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d)
79        : OperandType(type, d) {
80        uint8_t q_min = std::numeric_limits<uint8_t>::min();
81        uint8_t q_max = std::numeric_limits<uint8_t>::max();
82        float range = q_max - q_min;
83        float scale = (f_max - f_min) / range;
84        int32_t offset =
85                    fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale))));
86
87        operandType.scale = scale;
88        operandType.offset = offset;
89    }
90};
91
92inline Result Initialize() {
93    return static_cast<Result>(ANeuralNetworksInitialize());
94}
95
96inline void Shutdown() {
97    ANeuralNetworksShutdown();
98}
99
100class Memory {
101public:
102    Memory(size_t size) {
103        mValid = ANeuralNetworksMemory_createShared(size, &mMemory) == ANEURALNETWORKS_NO_ERROR;
104    }
105    Memory(size_t size, int protect, int fd) {
106        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, &mMemory) ==
107                         ANEURALNETWORKS_NO_ERROR;
108    }
109
110    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
111    Result getPointer(uint8_t** buffer) {
112        return static_cast<Result>(ANeuralNetworksMemory_getPointer(mMemory, buffer));
113    }
114
115    // Disallow copy semantics to ensure the runtime object can only be freed
116    // once. Copy semantics could be enabled if some sort of reference counting
117    // or deep-copy system for runtime objects is added later.
118    Memory(const Memory&) = delete;
119    Memory& operator=(const Memory&) = delete;
120
121    // Move semantics to remove access to the runtime object from the wrapper
122    // object that is being moved. This ensures the runtime object will be
123    // freed only once.
124    Memory(Memory&& other) {
125        *this = std::move(other);
126    }
127    Memory& operator=(Memory&& other) {
128        if (this != &other) {
129            mMemory = other.mMemory;
130            mValid = other.mValid;
131            other.mMemory = nullptr;
132            other.mValid = false;
133        }
134        return *this;
135    }
136
137    ANeuralNetworksMemory* get() const { return mMemory; }
138    bool isValid() const { return mValid; }
139
140private:
141    ANeuralNetworksMemory* mMemory = nullptr;
142    bool mValid = true;
143};
144
145class Model {
146public:
147    Model() {
148        // TODO handle the value returned by this call
149        ANeuralNetworksModel_create(&mModel);
150    }
151    ~Model() { ANeuralNetworksModel_free(mModel); }
152
153    // Disallow copy semantics to ensure the runtime object can only be freed
154    // once. Copy semantics could be enabled if some sort of reference counting
155    // or deep-copy system for runtime objects is added later.
156    Model(const Model&) = delete;
157    Model& operator=(const Model&) = delete;
158
159    // Move semantics to remove access to the runtime object from the wrapper
160    // object that is being moved. This ensures the runtime object will be
161    // freed only once.
162    Model(Model&& other) {
163        *this = std::move(other);
164    }
165    Model& operator=(Model&& other) {
166        if (this != &other) {
167            mModel = other.mModel;
168            mNextOperandId = other.mNextOperandId;
169            mValid = other.mValid;
170            other.mModel = nullptr;
171            other.mNextOperandId = 0;
172            other.mValid = false;
173        }
174        return *this;
175    }
176
177    int finish() { return ANeuralNetworksModel_finish(mModel); }
178
179    uint32_t addOperand(const OperandType* type) {
180        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
181            ANEURALNETWORKS_NO_ERROR) {
182            mValid = false;
183        }
184        return mNextOperandId++;
185    }
186
187    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
188        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
189            ANEURALNETWORKS_NO_ERROR) {
190            mValid = false;
191        }
192    }
193
194    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
195                                   size_t length) {
196        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
197                                                           length) != ANEURALNETWORKS_NO_ERROR) {
198            mValid = false;
199        }
200    }
201
202    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
203                      const std::vector<uint32_t>& outputs) {
204        ANeuralNetworksIntList in, out;
205        Set(&in, inputs);
206        Set(&out, outputs);
207        if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) !=
208            ANEURALNETWORKS_NO_ERROR) {
209            mValid = false;
210        }
211    }
212    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
213                             const std::vector<uint32_t>& outputs) {
214        ANeuralNetworksIntList in, out;
215        Set(&in, inputs);
216        Set(&out, outputs);
217        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) !=
218            ANEURALNETWORKS_NO_ERROR) {
219            mValid = false;
220        }
221    }
222    ANeuralNetworksModel* getHandle() const { return mModel; }
223    bool isValid() const { return mValid; }
224
225private:
226    /**
227     * WARNING list won't be valid once vec is destroyed or modified.
228     */
229    void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) {
230        list->count = static_cast<uint32_t>(vec.size());
231        list->data = vec.data();
232    }
233
234    ANeuralNetworksModel* mModel = nullptr;
235    // We keep track of the operand ID as a convenience to the caller.
236    uint32_t mNextOperandId = 0;
237    bool mValid = true;
238};
239
240class Event {
241public:
242    ~Event() { ANeuralNetworksEvent_free(mEvent); }
243
244    // Disallow copy semantics to ensure the runtime object can only be freed
245    // once. Copy semantics could be enabled if some sort of reference counting
246    // or deep-copy system for runtime objects is added later.
247    Event(const Event&) = delete;
248    Event& operator=(const Event&) = delete;
249
250    // Move semantics to remove access to the runtime object from the wrapper
251    // object that is being moved. This ensures the runtime object will be
252    // freed only once.
253    Event(Event&& other) {
254        *this = std::move(other);
255    }
256    Event& operator=(Event&& other) {
257        if (this != &other) {
258            mEvent = other.mEvent;
259            other.mEvent = nullptr;
260        }
261        return *this;
262    }
263
264    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
265    void set(ANeuralNetworksEvent* newEvent) {
266        ANeuralNetworksEvent_free(mEvent);
267        mEvent = newEvent;
268    }
269
270private:
271    ANeuralNetworksEvent* mEvent = nullptr;
272};
273
274class Compilation {
275public:
276    Compilation(const Model* model) {
277        int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
278        if (result != 0) {
279            // TODO Handle the error
280        }
281    }
282
283    ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
284
285    Compilation(const Compilation&) = delete;
286    Compilation& operator=(const Compilation &) = delete;
287
288    Compilation(Compilation&& other) {
289        *this = std::move(other);
290    }
291    Compilation& operator=(Compilation&& other) {
292        if (this != &other) {
293            mCompilation = other.mCompilation;
294            other.mCompilation = nullptr;
295        }
296        return *this;
297    }
298
299    Result setPreference(ExecutePreference preference) {
300        return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
301                    mCompilation, static_cast<uint32_t>(preference)));
302    }
303
304    // TODO startCompile
305
306    Result compile() {
307        Result result = static_cast<Result>(ANeuralNetworksCompilation_start(mCompilation));
308        if (result != Result::NO_ERROR) {
309            return result;
310        }
311        // TODO how to manage the lifetime of compilations when multiple waiters
312        // is not clear.
313        return static_cast<Result>(ANeuralNetworksCompilation_wait(mCompilation));
314    }
315
316    ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
317
318private:
319    ANeuralNetworksCompilation* mCompilation = nullptr;
320};
321
322class Request {
323public:
324    Request(const Compilation* compilation) {
325        int result = ANeuralNetworksRequest_create(compilation->getHandle(), &mRequest);
326        if (result != 0) {
327            // TODO Handle the error
328        }
329    }
330
331    ~Request() { ANeuralNetworksRequest_free(mRequest); }
332
333    // Disallow copy semantics to ensure the runtime object can only be freed
334    // once. Copy semantics could be enabled if some sort of reference counting
335    // or deep-copy system for runtime objects is added later.
336    Request(const Request&) = delete;
337    Request& operator=(const Request&) = delete;
338
339    // Move semantics to remove access to the runtime object from the wrapper
340    // object that is being moved. This ensures the runtime object will be
341    // freed only once.
342    Request(Request&& other) {
343        *this = std::move(other);
344    }
345    Request& operator=(Request&& other) {
346        if (this != &other) {
347            mRequest = other.mRequest;
348            other.mRequest = nullptr;
349        }
350        return *this;
351    }
352
353    Result setInput(uint32_t index, const void* buffer, size_t length,
354                    const ANeuralNetworksOperandType* type = nullptr) {
355        return static_cast<Result>(
356                    ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length));
357    }
358
359    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
360                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
361        return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory(
362                    mRequest, index, type, memory->get(), offset, length));
363    }
364
365    Result setOutput(uint32_t index, void* buffer, size_t length,
366                     const ANeuralNetworksOperandType* type = nullptr) {
367        return static_cast<Result>(
368                    ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length));
369    }
370
371    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
372                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
373        return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory(
374                    mRequest, index, type, memory->get(), offset, length));
375    }
376
377    Result startCompute(Event* event) {
378        ANeuralNetworksEvent* ev = nullptr;
379        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev));
380        event->set(ev);
381        return result;
382    }
383
384    Result compute() {
385        ANeuralNetworksEvent* event = nullptr;
386        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event));
387        if (result != Result::NO_ERROR) {
388            return result;
389        }
390        // TODO how to manage the lifetime of events when multiple waiters is not
391        // clear.
392        result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
393        ANeuralNetworksEvent_free(event);
394        return result;
395    }
396
397private:
398    ANeuralNetworksRequest* mRequest = nullptr;
399};
400
401}  // namespace wrapper
402}  // namespace nn
403}  // namespace android
404
405#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
406