NeuralNetworksWrapper.h revision 910c9f04913e3bee1a0b6406b6e146457d19c5e7
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Provides C++ classes to more easily use the Neural Networks API. 18 19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 21 22#include "NeuralNetworks.h" 23 24#include <math.h> 25#include <vector> 26 27namespace android { 28namespace nn { 29namespace wrapper { 30 31enum class Type { 32 FLOAT16 = ANEURALNETWORKS_FLOAT16, 33 FLOAT32 = ANEURALNETWORKS_FLOAT32, 34 INT8 = ANEURALNETWORKS_INT8, 35 UINT8 = ANEURALNETWORKS_UINT8, 36 INT16 = ANEURALNETWORKS_INT16, 37 UINT16 = ANEURALNETWORKS_UINT16, 38 INT32 = ANEURALNETWORKS_INT32, 39 UINT32 = ANEURALNETWORKS_UINT32, 40 TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16, 41 TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32, 42 TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, 43}; 44 45enum class ExecutePreference { 46 PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER, 47 PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER, 48 PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED 49}; 50 51enum class Result { 52 NO_ERROR = ANEURALNETWORKS_NO_ERROR, 53 OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY, 54 INCOMPLETE = ANEURALNETWORKS_INCOMPLETE, 55 UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL, 56 BAD_DATA = ANEURALNETWORKS_BAD_DATA, 57}; 58 59struct OperandType { 60 ANeuralNetworksOperandType operandType; 61 // uint32_t type; 62 std::vector<uint32_t> dimensions; 63 64 OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) { 65 operandType.type = static_cast<uint32_t>(type); 66 operandType.scale = 0.0f; 67 operandType.offset = 0; 68 69 operandType.dimensions.count = static_cast<uint32_t>(dimensions.size()); 70 operandType.dimensions.data = dimensions.data(); 71 } 72 73 OperandType(Type type, float scale, const std::vector<uint32_t>& d) 74 : OperandType(type, d) { 75 operandType.scale = scale; 76 } 77 78 OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d) 79 : OperandType(type, d) { 80 uint8_t q_min = std::numeric_limits<uint8_t>::min(); 81 uint8_t q_max = std::numeric_limits<uint8_t>::max(); 82 float range = q_max - q_min; 83 float scale = (f_max - f_min) / range; 84 int32_t offset = 85 fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale)))); 86 87 operandType.scale = scale; 88 operandType.offset = offset; 89 } 90}; 91 92inline Result Initialize() { 93 return static_cast<Result>(ANeuralNetworksInitialize()); 94} 95 96inline void Shutdown() { 97 ANeuralNetworksShutdown(); 98} 99 100class Memory { 101public: 102 // TODO Also have constructors for file descriptor, gralloc buffers, etc. 103 Memory(size_t size) { 104 mValid = ANeuralNetworksMemory_create(size, &mMemory) == ANEURALNETWORKS_NO_ERROR; 105 } 106 ~Memory() { ANeuralNetworksMemory_free(mMemory); } 107 uint8_t* getPointer() { 108 return ANeuralNetworksMemory_getPointer(mMemory); 109 } 110 ANeuralNetworksMemory* get() const { return mMemory; } 111 bool isValid() const { return mValid; } 112 113private: 114 ANeuralNetworksMemory* mMemory = nullptr; 115 bool mValid = true; 116}; 117 118class Model { 119public: 120 Model() { 121 // TODO handle the value returned by this call 122 ANeuralNetworksModel_create(&mModel); 123 } 124 ~Model() { ANeuralNetworksModel_free(mModel); } 125 126 uint32_t addOperand(const OperandType* type) { 127 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 128 ANEURALNETWORKS_NO_ERROR) { 129 mValid = false; 130 } 131 return mNextOperandId++; 132 } 133 134 void setOperandValue(uint32_t index, const void* buffer, size_t length) { 135 if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != 136 ANEURALNETWORKS_NO_ERROR) { 137 mValid = false; 138 } 139 } 140 141 void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 142 size_t length) { 143 if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset, 144 length) != ANEURALNETWORKS_NO_ERROR) { 145 mValid = false; 146 } 147 } 148 149 void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, 150 const std::vector<uint32_t>& outputs) { 151 ANeuralNetworksIntList in, out; 152 Set(&in, inputs); 153 Set(&out, outputs); 154 if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) != 155 ANEURALNETWORKS_NO_ERROR) { 156 mValid = false; 157 } 158 } 159 void setInputsAndOutputs(const std::vector<uint32_t>& inputs, 160 const std::vector<uint32_t>& outputs) { 161 ANeuralNetworksIntList in, out; 162 Set(&in, inputs); 163 Set(&out, outputs); 164 if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) != 165 ANEURALNETWORKS_NO_ERROR) { 166 mValid = false; 167 } 168 } 169 ANeuralNetworksModel* getHandle() const { return mModel; } 170 bool isValid() const { return mValid; } 171 172private: 173 /** 174 * WARNING list won't be valid once vec is destroyed or modified. 175 */ 176 void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) { 177 list->count = static_cast<uint32_t>(vec.size()); 178 list->data = vec.data(); 179 } 180 181 ANeuralNetworksModel* mModel = nullptr; 182 // We keep track of the operand ID as a convenience to the caller. 183 uint32_t mNextOperandId = 0; 184 bool mValid = true; 185}; 186 187class Event { 188public: 189 ~Event() { ANeuralNetworksEvent_free(mEvent); } 190 Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); } 191 void set(ANeuralNetworksEvent* newEvent) { 192 ANeuralNetworksEvent_free(mEvent); 193 mEvent = newEvent; 194 } 195 196private: 197 ANeuralNetworksEvent* mEvent = nullptr; 198}; 199 200class Request { 201public: 202 Request(const Model* model) { 203 int result = ANeuralNetworksRequest_create(model->getHandle(), &mRequest); 204 if (result != 0) { 205 // TODO Handle the error 206 } 207 } 208 209 ~Request() { ANeuralNetworksRequest_free(mRequest); } 210 211 Result setPreference(ExecutePreference preference) { 212 return static_cast<Result>( 213 ANeuralNetworksRequest_setPreference(mRequest, static_cast<uint32_t>(preference))); 214 } 215 216 Result setInput(uint32_t index, const void* buffer, size_t length, 217 const ANeuralNetworksOperandType* type = nullptr) { 218 return static_cast<Result>( 219 ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length)); 220 } 221 222 Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 223 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 224 return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory(mRequest, index, type, 225 memory->get(), offset, 226 length)); 227 } 228 229 Result setOutput(uint32_t index, void* buffer, size_t length, 230 const ANeuralNetworksOperandType* type = nullptr) { 231 return static_cast<Result>( 232 ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length)); 233 } 234 235 Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 236 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 237 return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory(mRequest, index, type, 238 memory->get(), offset, 239 length)); 240 } 241 242 Result startCompute(Event* event) { 243 ANeuralNetworksEvent* ev = nullptr; 244 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev)); 245 event->set(ev); 246 return result; 247 } 248 249 Result compute() { 250 ANeuralNetworksEvent* event = nullptr; 251 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event)); 252 if (result != Result::NO_ERROR) { 253 return result; 254 } 255 // TODO how to manage the lifetime of events when multiple waiters is not 256 // clear. 257 return static_cast<Result>(ANeuralNetworksEvent_wait(event)); 258 } 259 260private: 261 ANeuralNetworksRequest* mRequest = nullptr; 262}; 263 264} // namespace wrapper 265} // namespace nn 266} // namespace android 267 268#endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 269