NeuralNetworksWrapper.h revision 7b87fec6bb919d16a8cc2820d470733a2776e8fa
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Provides C++ classes to more easily use the Neural Networks API. 18 19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 21 22#include "NeuralNetworks.h" 23 24#include <math.h> 25#include <vector> 26 27namespace android { 28namespace nn { 29namespace wrapper { 30 31enum class Type { 32 FLOAT16 = ANEURALNETWORKS_FLOAT16, 33 FLOAT32 = ANEURALNETWORKS_FLOAT32, 34 INT8 = ANEURALNETWORKS_INT8, 35 UINT8 = ANEURALNETWORKS_UINT8, 36 INT16 = ANEURALNETWORKS_INT16, 37 UINT16 = ANEURALNETWORKS_UINT16, 38 INT32 = ANEURALNETWORKS_INT32, 39 UINT32 = ANEURALNETWORKS_UINT32, 40 TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16, 41 TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32, 42 TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32, 43 TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, 44}; 45 46enum class ExecutePreference { 47 PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER, 48 PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER, 49 PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED 50}; 51 52enum class Result { 53 NO_ERROR = ANEURALNETWORKS_NO_ERROR, 54 OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY, 55 INCOMPLETE = ANEURALNETWORKS_INCOMPLETE, 56 UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL, 57 BAD_DATA = ANEURALNETWORKS_BAD_DATA, 58}; 59 60struct OperandType { 61 ANeuralNetworksOperandType operandType; 62 // uint32_t type; 63 std::vector<uint32_t> dimensions; 64 65 OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) { 66 operandType.type = static_cast<uint32_t>(type); 67 operandType.scale = 0.0f; 68 operandType.offset = 0; 69 70 operandType.dimensions.count = static_cast<uint32_t>(dimensions.size()); 71 operandType.dimensions.data = dimensions.data(); 72 } 73 74 OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) { 75 operandType.scale = scale; 76 } 77 78 OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d) 79 : OperandType(type, d) { 80 uint8_t q_min = std::numeric_limits<uint8_t>::min(); 81 uint8_t q_max = std::numeric_limits<uint8_t>::max(); 82 float range = q_max - q_min; 83 float scale = (f_max - f_min) / range; 84 int32_t offset = 85 fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale)))); 86 87 operandType.scale = scale; 88 operandType.offset = offset; 89 } 90}; 91 92inline Result Initialize() { 93 return static_cast<Result>(ANeuralNetworksInitialize()); 94} 95 96inline void Shutdown() { 97 ANeuralNetworksShutdown(); 98} 99 100class Memory { 101public: 102 // TODO Also have constructors for file descriptor, gralloc buffers, etc. 103 Memory(size_t size) { 104 mValid = ANeuralNetworksMemory_createShared(size, &mMemory) == ANEURALNETWORKS_NO_ERROR; 105 } 106 ~Memory() { ANeuralNetworksMemory_free(mMemory); } 107 Result getPointer(uint8_t** buffer) { 108 return static_cast<Result>(ANeuralNetworksMemory_getPointer(mMemory, buffer)); 109 } 110 111 // Disallow copy semantics to ensure the runtime object can only be freed 112 // once. Copy semantics could be enabled if some sort of reference counting 113 // or deep-copy system for runtime objects is added later. 114 Memory(const Memory&) = delete; 115 Memory& operator=(const Memory&) = delete; 116 117 // Move semantics to remove access to the runtime object from the wrapper 118 // object that is being moved. This ensures the runtime object will be 119 // freed only once. 120 Memory(Memory&& other) { 121 *this = std::move(other); 122 } 123 Memory& operator=(Memory&& other) { 124 if (this != &other) { 125 mMemory = other.mMemory; 126 mValid = other.mValid; 127 other.mMemory = nullptr; 128 other.mValid = false; 129 } 130 return *this; 131 } 132 133 ANeuralNetworksMemory* get() const { return mMemory; } 134 bool isValid() const { return mValid; } 135 136private: 137 ANeuralNetworksMemory* mMemory = nullptr; 138 bool mValid = true; 139}; 140 141class Model { 142public: 143 Model() { 144 // TODO handle the value returned by this call 145 ANeuralNetworksModel_create(&mModel); 146 } 147 ~Model() { ANeuralNetworksModel_free(mModel); } 148 149 // Disallow copy semantics to ensure the runtime object can only be freed 150 // once. Copy semantics could be enabled if some sort of reference counting 151 // or deep-copy system for runtime objects is added later. 152 Model(const Model&) = delete; 153 Model& operator=(const Model&) = delete; 154 155 // Move semantics to remove access to the runtime object from the wrapper 156 // object that is being moved. This ensures the runtime object will be 157 // freed only once. 158 Model(Model&& other) { 159 *this = std::move(other); 160 } 161 Model& operator=(Model&& other) { 162 if (this != &other) { 163 mModel = other.mModel; 164 mNextOperandId = other.mNextOperandId; 165 mValid = other.mValid; 166 other.mModel = nullptr; 167 other.mNextOperandId = 0; 168 other.mValid = false; 169 } 170 return *this; 171 } 172 173 uint32_t addOperand(const OperandType* type) { 174 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 175 ANEURALNETWORKS_NO_ERROR) { 176 mValid = false; 177 } 178 return mNextOperandId++; 179 } 180 181 void setOperandValue(uint32_t index, const void* buffer, size_t length) { 182 if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != 183 ANEURALNETWORKS_NO_ERROR) { 184 mValid = false; 185 } 186 } 187 188 void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 189 size_t length) { 190 if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset, 191 length) != ANEURALNETWORKS_NO_ERROR) { 192 mValid = false; 193 } 194 } 195 196 void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, 197 const std::vector<uint32_t>& outputs) { 198 ANeuralNetworksIntList in, out; 199 Set(&in, inputs); 200 Set(&out, outputs); 201 if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) != 202 ANEURALNETWORKS_NO_ERROR) { 203 mValid = false; 204 } 205 } 206 void setInputsAndOutputs(const std::vector<uint32_t>& inputs, 207 const std::vector<uint32_t>& outputs) { 208 ANeuralNetworksIntList in, out; 209 Set(&in, inputs); 210 Set(&out, outputs); 211 if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) != 212 ANEURALNETWORKS_NO_ERROR) { 213 mValid = false; 214 } 215 } 216 ANeuralNetworksModel* getHandle() const { return mModel; } 217 bool isValid() const { return mValid; } 218 219private: 220 /** 221 * WARNING list won't be valid once vec is destroyed or modified. 222 */ 223 void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) { 224 list->count = static_cast<uint32_t>(vec.size()); 225 list->data = vec.data(); 226 } 227 228 ANeuralNetworksModel* mModel = nullptr; 229 // We keep track of the operand ID as a convenience to the caller. 230 uint32_t mNextOperandId = 0; 231 bool mValid = true; 232}; 233 234class Event { 235public: 236 ~Event() { ANeuralNetworksEvent_free(mEvent); } 237 238 // Disallow copy semantics to ensure the runtime object can only be freed 239 // once. Copy semantics could be enabled if some sort of reference counting 240 // or deep-copy system for runtime objects is added later. 241 Event(const Event&) = delete; 242 Event& operator=(const Event&) = delete; 243 244 // Move semantics to remove access to the runtime object from the wrapper 245 // object that is being moved. This ensures the runtime object will be 246 // freed only once. 247 Event(Event&& other) { 248 *this = std::move(other); 249 } 250 Event& operator=(Event&& other) { 251 if (this != &other) { 252 mEvent = other.mEvent; 253 other.mEvent = nullptr; 254 } 255 return *this; 256 } 257 258 Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); } 259 void set(ANeuralNetworksEvent* newEvent) { 260 ANeuralNetworksEvent_free(mEvent); 261 mEvent = newEvent; 262 } 263 264private: 265 ANeuralNetworksEvent* mEvent = nullptr; 266}; 267 268class Request { 269public: 270 Request(const Model* model) { 271 int result = ANeuralNetworksRequest_create(model->getHandle(), &mRequest); 272 if (result != 0) { 273 // TODO Handle the error 274 } 275 } 276 277 ~Request() { ANeuralNetworksRequest_free(mRequest); } 278 279 // Disallow copy semantics to ensure the runtime object can only be freed 280 // once. Copy semantics could be enabled if some sort of reference counting 281 // or deep-copy system for runtime objects is added later. 282 Request(const Request&) = delete; 283 Request& operator=(const Request&) = delete; 284 285 // Move semantics to remove access to the runtime object from the wrapper 286 // object that is being moved. This ensures the runtime object will be 287 // freed only once. 288 Request(Request&& other) { 289 *this = std::move(other); 290 } 291 Request& operator=(Request&& other) { 292 if (this != &other) { 293 mRequest = other.mRequest; 294 other.mRequest = nullptr; 295 } 296 return *this; 297 } 298 299 Result setPreference(ExecutePreference preference) { 300 return static_cast<Result>(ANeuralNetworksRequest_setPreference( 301 mRequest, static_cast<uint32_t>(preference))); 302 } 303 304 Result setInput(uint32_t index, const void* buffer, size_t length, 305 const ANeuralNetworksOperandType* type = nullptr) { 306 return static_cast<Result>( 307 ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length)); 308 } 309 310 Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 311 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 312 return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory( 313 mRequest, index, type, memory->get(), offset, length)); 314 } 315 316 Result setOutput(uint32_t index, void* buffer, size_t length, 317 const ANeuralNetworksOperandType* type = nullptr) { 318 return static_cast<Result>( 319 ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length)); 320 } 321 322 Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 323 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 324 return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory( 325 mRequest, index, type, memory->get(), offset, length)); 326 } 327 328 Result startCompute(Event* event) { 329 ANeuralNetworksEvent* ev = nullptr; 330 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev)); 331 event->set(ev); 332 return result; 333 } 334 335 Result compute() { 336 ANeuralNetworksEvent* event = nullptr; 337 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event)); 338 if (result != Result::NO_ERROR) { 339 return result; 340 } 341 // TODO how to manage the lifetime of events when multiple waiters is not 342 // clear. 343 return static_cast<Result>(ANeuralNetworksEvent_wait(event)); 344 } 345 346private: 347 ANeuralNetworksRequest* mRequest = nullptr; 348}; 349 350} // namespace wrapper 351} // namespace nn 352} // namespace android 353 354#endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 355