NeuralNetworksWrapper.h revision 42dc6a6cd68877cd85e3bc475b41bda0fd946c41
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Provides C++ classes to more easily use the Neural Networks API. 18 19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 21 22#include "NeuralNetworks.h" 23 24#include <math.h> 25#include <vector> 26 27namespace android { 28namespace nn { 29namespace wrapper { 30 31enum class Type { 32 FLOAT16 = ANEURALNETWORKS_FLOAT16, 33 FLOAT32 = ANEURALNETWORKS_FLOAT32, 34 INT8 = ANEURALNETWORKS_INT8, 35 UINT8 = ANEURALNETWORKS_UINT8, 36 INT16 = ANEURALNETWORKS_INT16, 37 UINT16 = ANEURALNETWORKS_UINT16, 38 INT32 = ANEURALNETWORKS_INT32, 39 UINT32 = ANEURALNETWORKS_UINT32, 40 TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16, 41 TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32, 42 TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32, 43 TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, 44}; 45 46enum class ExecutePreference { 47 PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER, 48 PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER, 49 PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED 50}; 51 52enum class Result { 53 NO_ERROR = ANEURALNETWORKS_NO_ERROR, 54 OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY, 55 INCOMPLETE = ANEURALNETWORKS_INCOMPLETE, 56 UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL, 57 BAD_DATA = ANEURALNETWORKS_BAD_DATA, 58}; 59 60struct OperandType { 61 ANeuralNetworksOperandType operandType; 62 // uint32_t type; 63 std::vector<uint32_t> dimensions; 64 65 OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) { 66 operandType.type = static_cast<uint32_t>(type); 67 operandType.scale = 0.0f; 68 operandType.offset = 0; 69 70 operandType.dimensions.count = static_cast<uint32_t>(dimensions.size()); 71 operandType.dimensions.data = dimensions.data(); 72 } 73 74 OperandType(Type type, float scale, const std::vector<uint32_t>& d) : OperandType(type, d) { 75 operandType.scale = scale; 76 } 77 78 OperandType(Type type, float f_min, float f_max, const std::vector<uint32_t>& d) 79 : OperandType(type, d) { 80 uint8_t q_min = std::numeric_limits<uint8_t>::min(); 81 uint8_t q_max = std::numeric_limits<uint8_t>::max(); 82 float range = q_max - q_min; 83 float scale = (f_max - f_min) / range; 84 int32_t offset = 85 fmin(q_max, fmax(q_min, static_cast<uint8_t>(round(q_min - f_min / scale)))); 86 87 operandType.scale = scale; 88 operandType.offset = offset; 89 } 90}; 91 92inline Result Initialize() { 93 return static_cast<Result>(ANeuralNetworksInitialize()); 94} 95 96inline void Shutdown() { 97 ANeuralNetworksShutdown(); 98} 99 100class Memory { 101public: 102 103 Memory(size_t size, int protect, int fd, size_t offset) { 104 mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) == 105 ANEURALNETWORKS_NO_ERROR; 106 } 107 108 ~Memory() { ANeuralNetworksMemory_free(mMemory); } 109 110 // Disallow copy semantics to ensure the runtime object can only be freed 111 // once. Copy semantics could be enabled if some sort of reference counting 112 // or deep-copy system for runtime objects is added later. 113 Memory(const Memory&) = delete; 114 Memory& operator=(const Memory&) = delete; 115 116 // Move semantics to remove access to the runtime object from the wrapper 117 // object that is being moved. This ensures the runtime object will be 118 // freed only once. 119 Memory(Memory&& other) { 120 *this = std::move(other); 121 } 122 Memory& operator=(Memory&& other) { 123 if (this != &other) { 124 mMemory = other.mMemory; 125 mValid = other.mValid; 126 other.mMemory = nullptr; 127 other.mValid = false; 128 } 129 return *this; 130 } 131 132 ANeuralNetworksMemory* get() const { return mMemory; } 133 bool isValid() const { return mValid; } 134 135private: 136 ANeuralNetworksMemory* mMemory = nullptr; 137 bool mValid = true; 138}; 139 140class Model { 141public: 142 Model() { 143 // TODO handle the value returned by this call 144 ANeuralNetworksModel_create(&mModel); 145 } 146 ~Model() { ANeuralNetworksModel_free(mModel); } 147 148 // Disallow copy semantics to ensure the runtime object can only be freed 149 // once. Copy semantics could be enabled if some sort of reference counting 150 // or deep-copy system for runtime objects is added later. 151 Model(const Model&) = delete; 152 Model& operator=(const Model&) = delete; 153 154 // Move semantics to remove access to the runtime object from the wrapper 155 // object that is being moved. This ensures the runtime object will be 156 // freed only once. 157 Model(Model&& other) { 158 *this = std::move(other); 159 } 160 Model& operator=(Model&& other) { 161 if (this != &other) { 162 mModel = other.mModel; 163 mNextOperandId = other.mNextOperandId; 164 mValid = other.mValid; 165 other.mModel = nullptr; 166 other.mNextOperandId = 0; 167 other.mValid = false; 168 } 169 return *this; 170 } 171 172 int finish() { return ANeuralNetworksModel_finish(mModel); } 173 174 uint32_t addOperand(const OperandType* type) { 175 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 176 ANEURALNETWORKS_NO_ERROR) { 177 mValid = false; 178 } 179 return mNextOperandId++; 180 } 181 182 void setOperandValue(uint32_t index, const void* buffer, size_t length) { 183 if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != 184 ANEURALNETWORKS_NO_ERROR) { 185 mValid = false; 186 } 187 } 188 189 void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 190 size_t length) { 191 if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset, 192 length) != ANEURALNETWORKS_NO_ERROR) { 193 mValid = false; 194 } 195 } 196 197 void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, 198 const std::vector<uint32_t>& outputs) { 199 ANeuralNetworksIntList in, out; 200 Set(&in, inputs); 201 Set(&out, outputs); 202 if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) != 203 ANEURALNETWORKS_NO_ERROR) { 204 mValid = false; 205 } 206 } 207 void setInputsAndOutputs(const std::vector<uint32_t>& inputs, 208 const std::vector<uint32_t>& outputs) { 209 ANeuralNetworksIntList in, out; 210 Set(&in, inputs); 211 Set(&out, outputs); 212 if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) != 213 ANEURALNETWORKS_NO_ERROR) { 214 mValid = false; 215 } 216 } 217 ANeuralNetworksModel* getHandle() const { return mModel; } 218 bool isValid() const { return mValid; } 219 220private: 221 /** 222 * WARNING list won't be valid once vec is destroyed or modified. 223 */ 224 void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) { 225 list->count = static_cast<uint32_t>(vec.size()); 226 list->data = vec.data(); 227 } 228 229 ANeuralNetworksModel* mModel = nullptr; 230 // We keep track of the operand ID as a convenience to the caller. 231 uint32_t mNextOperandId = 0; 232 bool mValid = true; 233}; 234 235class Compilation { 236public: 237 Compilation(const Model* model) { 238 int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation); 239 if (result != 0) { 240 // TODO Handle the error 241 } 242 } 243 244 ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); } 245 246 Compilation(const Compilation&) = delete; 247 Compilation& operator=(const Compilation &) = delete; 248 249 Compilation(Compilation&& other) { 250 *this = std::move(other); 251 } 252 Compilation& operator=(Compilation&& other) { 253 if (this != &other) { 254 mCompilation = other.mCompilation; 255 other.mCompilation = nullptr; 256 } 257 return *this; 258 } 259 260 Result setPreference(ExecutePreference preference) { 261 return static_cast<Result>(ANeuralNetworksCompilation_setPreference( 262 mCompilation, static_cast<uint32_t>(preference))); 263 } 264 265 // TODO startCompile 266 267 Result compile() { 268 Result result = static_cast<Result>(ANeuralNetworksCompilation_start(mCompilation)); 269 if (result != Result::NO_ERROR) { 270 return result; 271 } 272 // TODO how to manage the lifetime of compilations when multiple waiters 273 // is not clear. 274 return static_cast<Result>(ANeuralNetworksCompilation_wait(mCompilation)); 275 } 276 277 ANeuralNetworksCompilation* getHandle() const { return mCompilation; } 278 279private: 280 ANeuralNetworksCompilation* mCompilation = nullptr; 281}; 282 283class Request { 284public: 285 Request(const Compilation* compilation) { 286 int result = ANeuralNetworksRequest_create(compilation->getHandle(), &mRequest); 287 if (result != 0) { 288 // TODO Handle the error 289 } 290 } 291 292 ~Request() { ANeuralNetworksRequest_free(mRequest); } 293 294 // Disallow copy semantics to ensure the runtime object can only be freed 295 // once. Copy semantics could be enabled if some sort of reference counting 296 // or deep-copy system for runtime objects is added later. 297 Request(const Request&) = delete; 298 Request& operator=(const Request&) = delete; 299 300 // Move semantics to remove access to the runtime object from the wrapper 301 // object that is being moved. This ensures the runtime object will be 302 // freed only once. 303 Request(Request&& other) { 304 *this = std::move(other); 305 } 306 Request& operator=(Request&& other) { 307 if (this != &other) { 308 mRequest = other.mRequest; 309 other.mRequest = nullptr; 310 } 311 return *this; 312 } 313 314 Result setInput(uint32_t index, const void* buffer, size_t length, 315 const ANeuralNetworksOperandType* type = nullptr) { 316 return static_cast<Result>( 317 ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length)); 318 } 319 320 Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 321 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 322 return static_cast<Result>(ANeuralNetworksRequest_setInputFromMemory( 323 mRequest, index, type, memory->get(), offset, length)); 324 } 325 326 Result setOutput(uint32_t index, void* buffer, size_t length, 327 const ANeuralNetworksOperandType* type = nullptr) { 328 return static_cast<Result>( 329 ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length)); 330 } 331 332 Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 333 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 334 return static_cast<Result>(ANeuralNetworksRequest_setOutputFromMemory( 335 mRequest, index, type, memory->get(), offset, length)); 336 } 337 338 Result startCompute() { 339 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest)); 340 return result; 341 } 342 343 Result wait() { 344 return static_cast<Result>(ANeuralNetworksRequest_wait(mRequest)); 345 } 346 347 Result compute() { 348 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest)); 349 if (result != Result::NO_ERROR) { 350 return result; 351 } 352 // TODO how to manage the lifetime of events when multiple waiters is not 353 // clear. 354 return static_cast<Result>(ANeuralNetworksRequest_wait(mRequest)); 355 } 356 357private: 358 ANeuralNetworksRequest* mRequest = nullptr; 359}; 360 361} // namespace wrapper 362} // namespace nn 363} // namespace android 364 365#endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 366