NeuralNetworksWrapper.h revision 96775128e3bcfdc5be51b62edc50309c83861fe8
1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Provides C++ classes to more easily use the Neural Networks API.
18
19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21
22#include "NeuralNetworks.h"
23
24#include <vector>
25
26namespace android {
27namespace nn {
28namespace wrapper {
29
30enum class Type {
31    FLOAT16 = ANEURALNETWORKS_FLOAT16,
32    FLOAT32 = ANEURALNETWORKS_FLOAT32,
33    INT8 = ANEURALNETWORKS_INT8,
34    UINT8 = ANEURALNETWORKS_UINT8,
35    INT16 = ANEURALNETWORKS_INT16,
36    UINT16 = ANEURALNETWORKS_UINT16,
37    INT32 = ANEURALNETWORKS_INT32,
38    UINT32 = ANEURALNETWORKS_UINT32,
39    TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16,
40    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
41    TENSOR_SYMMETRICAL_QUANT8 = ANEURALNETWORKS_TENSOR_SYMMETRICAL_QUANT8,
42};
43
44enum class ExecutePreference {
45    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
46    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
47    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
48};
49
50enum class Result {
51    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
52    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
53    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
54    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
55    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
56};
57
58struct OperandType {
59    ANeuralNetworksOperandType operandType;
60    // uint32_t type;
61    std::vector<uint32_t> dimensions;
62
63    OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) {
64        operandType.type = static_cast<uint32_t>(type);
65        operandType.dimensions.count = static_cast<uint32_t>(dimensions.size());
66        operandType.dimensions.data = dimensions.data();
67    }
68};
69
70inline Result Initialize() {
71    return static_cast<Result>(ANeuralNetworksInitialize());
72}
73
74inline void Shutdown() {
75    ANeuralNetworksShutdown();
76}
77
78class Model {
79public:
80    Model() {
81        // TODO handle the value returned by this call
82        ANeuralNetworksModel_create(&mModel);
83    }
84    ~Model() { ANeuralNetworksModel_free(mModel); }
85
86    uint32_t addOperand(const OperandType* type) {
87        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
88            ANEURALNETWORKS_NO_ERROR) {
89            mValid = false;
90        }
91        return mNextOperandId++;
92    }
93
94    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
95        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
96            ANEURALNETWORKS_NO_ERROR) {
97            mValid = false;
98        }
99    }
100
101    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
102                      const std::vector<uint32_t>& outputs) {
103        ANeuralNetworksIntList in, out;
104        Set(&in, inputs);
105        Set(&out, outputs);
106        if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) !=
107            ANEURALNETWORKS_NO_ERROR) {
108            mValid = false;
109        }
110    }
111    void setInputsAndOutputs(const std::vector<uint32_t>& inputs,
112                             const std::vector<uint32_t>& outputs) {
113        ANeuralNetworksIntList in, out;
114        Set(&in, inputs);
115        Set(&out, outputs);
116        if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) !=
117            ANEURALNETWORKS_NO_ERROR) {
118            mValid = false;
119        }
120    }
121    ANeuralNetworksModel* getHandle() const { return mModel; }
122    bool isValid() const { return mValid; }
123    static Model* createBaselineModel(uint32_t modelId) {
124        Model* model = new Model();
125        if (ANeuralNetworksModel_createBaselineModel(&model->mModel, modelId) !=
126            ANEURALNETWORKS_NO_ERROR) {
127            delete model;
128            model = nullptr;
129        }
130        return model;
131    }
132
133private:
134    /**
135     * WARNING list won't be valid once vec is destroyed or modified.
136     */
137    void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) {
138        list->count = static_cast<uint32_t>(vec.size());
139        list->data = vec.data();
140    }
141
142    ANeuralNetworksModel* mModel = nullptr;
143    // We keep track of the operand ID as a convenience to the caller.
144    uint32_t mNextOperandId = 0;
145    bool mValid = true;
146};
147
148class Event {
149public:
150    ~Event() { ANeuralNetworksEvent_free(mEvent); }
151    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
152    void set(ANeuralNetworksEvent* newEvent) {
153        ANeuralNetworksEvent_free(mEvent);
154        mEvent = newEvent;
155    }
156
157private:
158    ANeuralNetworksEvent* mEvent = nullptr;
159};
160
161class Request {
162public:
163    Request(const Model* model) {
164        int result = ANeuralNetworksRequest_create(model->getHandle(), &mRequest);
165        if (result != 0) {
166            // TODO Handle the error
167        }
168    }
169
170    ~Request() { ANeuralNetworksRequest_free(mRequest); }
171
172    Result setPreference(ExecutePreference preference) {
173        return static_cast<Result>(ANeuralNetworksRequest_setPreference(
174                    mRequest, static_cast<uint32_t>(preference)));
175    }
176
177    Result setInput(uint32_t index, const void* buffer, size_t length,
178                    const ANeuralNetworksOperandType* type = nullptr) {
179        return static_cast<Result>(
180                    ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length));
181    }
182
183    Result setInputFromHardwareBuffer(uint32_t index, const AHardwareBuffer* buffer,
184                                      const ANeuralNetworksOperandType* type) {
185        return static_cast<Result>(ANeuralNetworksRequest_setInputFromHardwareBuffer(
186                    mRequest, index, type, buffer));
187    }
188
189    Result setOutput(uint32_t index, void* buffer, size_t length,
190                     const ANeuralNetworksOperandType* type = nullptr) {
191        return static_cast<Result>(
192                    ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length));
193    }
194
195    Result setOutputFromHardwareBuffer(uint32_t index, const AHardwareBuffer* buffer,
196                                       const ANeuralNetworksOperandType* type = nullptr) {
197        return static_cast<Result>(ANeuralNetworksRequest_setOutputFromHardwareBuffer(
198                    mRequest, index, type, buffer));
199    }
200
201    Result startCompute(Event* event) {
202        ANeuralNetworksEvent* ev = nullptr;
203        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev));
204        event->set(ev);
205        return result;
206    }
207
208    Result compute() {
209        ANeuralNetworksEvent* event = nullptr;
210        Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event));
211        if (result != Result::NO_ERROR) {
212            return result;
213        }
214        // TODO how to manage the lifetime of events when multiple waiters is not
215        // clear.
216        return static_cast<Result>(ANeuralNetworksEvent_wait(event));
217    }
218
219private:
220    ANeuralNetworksRequest* mRequest = nullptr;
221};
222
223}  // namespace wrapper
224}  // namespace nn
225}  // namespace android
226
227#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
228