NeuralNetworksWrapper.h revision 96775128e3bcfdc5be51b62edc50309c83861fe8
1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Provides C++ classes to more easily use the Neural Networks API. 18 19#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 20#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 21 22#include "NeuralNetworks.h" 23 24#include <vector> 25 26namespace android { 27namespace nn { 28namespace wrapper { 29 30enum class Type { 31 FLOAT16 = ANEURALNETWORKS_FLOAT16, 32 FLOAT32 = ANEURALNETWORKS_FLOAT32, 33 INT8 = ANEURALNETWORKS_INT8, 34 UINT8 = ANEURALNETWORKS_UINT8, 35 INT16 = ANEURALNETWORKS_INT16, 36 UINT16 = ANEURALNETWORKS_UINT16, 37 INT32 = ANEURALNETWORKS_INT32, 38 UINT32 = ANEURALNETWORKS_UINT32, 39 TENSOR_FLOAT16 = ANEURALNETWORKS_TENSOR_FLOAT16, 40 TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32, 41 TENSOR_SYMMETRICAL_QUANT8 = ANEURALNETWORKS_TENSOR_SYMMETRICAL_QUANT8, 42}; 43 44enum class ExecutePreference { 45 PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER, 46 PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER, 47 PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED 48}; 49 50enum class Result { 51 NO_ERROR = ANEURALNETWORKS_NO_ERROR, 52 OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY, 53 INCOMPLETE = ANEURALNETWORKS_INCOMPLETE, 54 UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL, 55 BAD_DATA = ANEURALNETWORKS_BAD_DATA, 56}; 57 58struct OperandType { 59 ANeuralNetworksOperandType operandType; 60 // uint32_t type; 61 std::vector<uint32_t> dimensions; 62 63 OperandType(Type type, const std::vector<uint32_t>& d) : dimensions(d) { 64 operandType.type = static_cast<uint32_t>(type); 65 operandType.dimensions.count = static_cast<uint32_t>(dimensions.size()); 66 operandType.dimensions.data = dimensions.data(); 67 } 68}; 69 70inline Result Initialize() { 71 return static_cast<Result>(ANeuralNetworksInitialize()); 72} 73 74inline void Shutdown() { 75 ANeuralNetworksShutdown(); 76} 77 78class Model { 79public: 80 Model() { 81 // TODO handle the value returned by this call 82 ANeuralNetworksModel_create(&mModel); 83 } 84 ~Model() { ANeuralNetworksModel_free(mModel); } 85 86 uint32_t addOperand(const OperandType* type) { 87 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 88 ANEURALNETWORKS_NO_ERROR) { 89 mValid = false; 90 } 91 return mNextOperandId++; 92 } 93 94 void setOperandValue(uint32_t index, const void* buffer, size_t length) { 95 if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != 96 ANEURALNETWORKS_NO_ERROR) { 97 mValid = false; 98 } 99 } 100 101 void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, 102 const std::vector<uint32_t>& outputs) { 103 ANeuralNetworksIntList in, out; 104 Set(&in, inputs); 105 Set(&out, outputs); 106 if (ANeuralNetworksModel_addOperation(mModel, type, &in, &out) != 107 ANEURALNETWORKS_NO_ERROR) { 108 mValid = false; 109 } 110 } 111 void setInputsAndOutputs(const std::vector<uint32_t>& inputs, 112 const std::vector<uint32_t>& outputs) { 113 ANeuralNetworksIntList in, out; 114 Set(&in, inputs); 115 Set(&out, outputs); 116 if (ANeuralNetworksModel_setInputsAndOutputs(mModel, &in, &out) != 117 ANEURALNETWORKS_NO_ERROR) { 118 mValid = false; 119 } 120 } 121 ANeuralNetworksModel* getHandle() const { return mModel; } 122 bool isValid() const { return mValid; } 123 static Model* createBaselineModel(uint32_t modelId) { 124 Model* model = new Model(); 125 if (ANeuralNetworksModel_createBaselineModel(&model->mModel, modelId) != 126 ANEURALNETWORKS_NO_ERROR) { 127 delete model; 128 model = nullptr; 129 } 130 return model; 131 } 132 133private: 134 /** 135 * WARNING list won't be valid once vec is destroyed or modified. 136 */ 137 void Set(ANeuralNetworksIntList* list, const std::vector<uint32_t>& vec) { 138 list->count = static_cast<uint32_t>(vec.size()); 139 list->data = vec.data(); 140 } 141 142 ANeuralNetworksModel* mModel = nullptr; 143 // We keep track of the operand ID as a convenience to the caller. 144 uint32_t mNextOperandId = 0; 145 bool mValid = true; 146}; 147 148class Event { 149public: 150 ~Event() { ANeuralNetworksEvent_free(mEvent); } 151 Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); } 152 void set(ANeuralNetworksEvent* newEvent) { 153 ANeuralNetworksEvent_free(mEvent); 154 mEvent = newEvent; 155 } 156 157private: 158 ANeuralNetworksEvent* mEvent = nullptr; 159}; 160 161class Request { 162public: 163 Request(const Model* model) { 164 int result = ANeuralNetworksRequest_create(model->getHandle(), &mRequest); 165 if (result != 0) { 166 // TODO Handle the error 167 } 168 } 169 170 ~Request() { ANeuralNetworksRequest_free(mRequest); } 171 172 Result setPreference(ExecutePreference preference) { 173 return static_cast<Result>(ANeuralNetworksRequest_setPreference( 174 mRequest, static_cast<uint32_t>(preference))); 175 } 176 177 Result setInput(uint32_t index, const void* buffer, size_t length, 178 const ANeuralNetworksOperandType* type = nullptr) { 179 return static_cast<Result>( 180 ANeuralNetworksRequest_setInput(mRequest, index, type, buffer, length)); 181 } 182 183 Result setInputFromHardwareBuffer(uint32_t index, const AHardwareBuffer* buffer, 184 const ANeuralNetworksOperandType* type) { 185 return static_cast<Result>(ANeuralNetworksRequest_setInputFromHardwareBuffer( 186 mRequest, index, type, buffer)); 187 } 188 189 Result setOutput(uint32_t index, void* buffer, size_t length, 190 const ANeuralNetworksOperandType* type = nullptr) { 191 return static_cast<Result>( 192 ANeuralNetworksRequest_setOutput(mRequest, index, type, buffer, length)); 193 } 194 195 Result setOutputFromHardwareBuffer(uint32_t index, const AHardwareBuffer* buffer, 196 const ANeuralNetworksOperandType* type = nullptr) { 197 return static_cast<Result>(ANeuralNetworksRequest_setOutputFromHardwareBuffer( 198 mRequest, index, type, buffer)); 199 } 200 201 Result startCompute(Event* event) { 202 ANeuralNetworksEvent* ev = nullptr; 203 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &ev)); 204 event->set(ev); 205 return result; 206 } 207 208 Result compute() { 209 ANeuralNetworksEvent* event = nullptr; 210 Result result = static_cast<Result>(ANeuralNetworksRequest_startCompute(mRequest, &event)); 211 if (result != Result::NO_ERROR) { 212 return result; 213 } 214 // TODO how to manage the lifetime of events when multiple waiters is not 215 // clear. 216 return static_cast<Result>(ANeuralNetworksEvent_wait(event)); 217 } 218 219private: 220 ANeuralNetworksRequest* mRequest = nullptr; 221}; 222 223} // namespace wrapper 224} // namespace nn 225} // namespace android 226 227#endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 228