1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Contains all the entry points to the C Neural Networks API. 18// We do basic validation of the operands and then call the class 19// that implements the functionality. 20 21#define LOG_TAG "NeuralNetworks" 22 23#include "NeuralNetworks.h" 24 25#include "Callbacks.h" 26#include "CompilationBuilder.h" 27#include "ExecutionBuilder.h" 28#include "Manager.h" 29#include "Memory.h" 30#include "NeuralNetworksOEM.h" 31#include "ModelBuilder.h" 32 33#include <memory> 34#include <vector> 35 36// Make sure the constants defined in the header files have not changed values. 37// IMPORTANT: When adding new values, update kNumberOfDataTypes or kNumberOfDataTypesOEM 38// in Utils.h. 39static_assert(ANEURALNETWORKS_FLOAT32 == 0, "ANEURALNETWORKS_FLOAT32 has changed"); 40static_assert(ANEURALNETWORKS_INT32 == 1, "ANEURALNETWORKS_INT32 has changed"); 41static_assert(ANEURALNETWORKS_UINT32 == 2, "ANEURALNETWORKS_UINT32 has changed"); 42static_assert(ANEURALNETWORKS_TENSOR_FLOAT32 == 3, 43 "ANEURALNETWORKS_TENSOR_FLOAT32 has changed"); 44static_assert(ANEURALNETWORKS_TENSOR_INT32 == 4, "ANEURALNETWORKS_TENSOR_INT32 has changed"); 45static_assert(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM == 5, 46 "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM has changed"); 47static_assert(ANEURALNETWORKS_OEM_SCALAR == 10000, "ANEURALNETWORKS_OEM_SCALAR has changed"); 48static_assert(ANEURALNETWORKS_TENSOR_OEM_BYTE == 10001, 49 "ANEURALNETWORKS_TENSOR_OEM_BYTE has changed"); 50 51// IMPORTANT: When adding new values, update kNumberOfOperationTypes or 52// kNumberOfOperationTypesOEMin Utils.h. 53static_assert(ANEURALNETWORKS_ADD == 0, "ANEURALNETWORKS_ADD has changed"); 54static_assert(ANEURALNETWORKS_AVERAGE_POOL_2D == 1, 55 "ANEURALNETWORKS_AVERAGE_POOL_2D has changed"); 56static_assert(ANEURALNETWORKS_CONCATENATION == 2, "ANEURALNETWORKS_CONCATENATION has changed"); 57static_assert(ANEURALNETWORKS_CONV_2D == 3, "ANEURALNETWORKS_CONV_2D has changed"); 58static_assert(ANEURALNETWORKS_DEPTHWISE_CONV_2D == 4, 59 "ANEURALNETWORKS_DEPTHWISE_CONV_2D has changed"); 60static_assert(ANEURALNETWORKS_DEPTH_TO_SPACE == 5, 61 "ANEURALNETWORKS_DEPTH_TO_SPACE has changed"); 62static_assert(ANEURALNETWORKS_DEQUANTIZE == 6, "ANEURALNETWORKS_DEQUANTIZE has changed"); 63static_assert(ANEURALNETWORKS_EMBEDDING_LOOKUP == 7, 64 "ANEURALNETWORKS_EMBEDDING_LOOKUP has changed"); 65static_assert(ANEURALNETWORKS_FLOOR == 8, "ANEURALNETWORKS_FLOOR has changed"); 66static_assert(ANEURALNETWORKS_FULLY_CONNECTED == 9, 67 "ANEURALNETWORKS_FULLY_CONNECTED has changed"); 68static_assert(ANEURALNETWORKS_HASHTABLE_LOOKUP == 10, 69 "ANEURALNETWORKS_HASHTABLE_LOOKUP has changed"); 70static_assert(ANEURALNETWORKS_L2_NORMALIZATION == 11, 71 "ANEURALNETWORKS_L2_NORMALIZATION has changed"); 72static_assert(ANEURALNETWORKS_L2_POOL_2D == 12, "ANEURALNETWORKS_L2_POOL has changed"); 73static_assert(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION == 13, 74 "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION has changed"); 75static_assert(ANEURALNETWORKS_LOGISTIC == 14, "ANEURALNETWORKS_LOGISTIC has changed"); 76static_assert(ANEURALNETWORKS_LSH_PROJECTION == 15, 77 "ANEURALNETWORKS_LSH_PROJECTION has changed"); 78static_assert(ANEURALNETWORKS_LSTM == 16, "ANEURALNETWORKS_LSTM has changed"); 79static_assert(ANEURALNETWORKS_MAX_POOL_2D == 17, "ANEURALNETWORKS_MAX_POOL has changed"); 80static_assert(ANEURALNETWORKS_MUL == 18, "ANEURALNETWORKS_MUL has changed"); 81static_assert(ANEURALNETWORKS_RELU == 19, "ANEURALNETWORKS_RELU has changed"); 82static_assert(ANEURALNETWORKS_RELU1 == 20, "ANEURALNETWORKS_RELU1 has changed"); 83static_assert(ANEURALNETWORKS_RELU6 == 21, "ANEURALNETWORKS_RELU6 has changed"); 84static_assert(ANEURALNETWORKS_RESHAPE == 22, "ANEURALNETWORKS_RESHAPE has changed"); 85static_assert(ANEURALNETWORKS_RESIZE_BILINEAR == 23, 86 "ANEURALNETWORKS_RESIZE_BILINEAR has changed"); 87static_assert(ANEURALNETWORKS_RNN == 24, "ANEURALNETWORKS_RNN has changed"); 88static_assert(ANEURALNETWORKS_SOFTMAX == 25, "ANEURALNETWORKS_SOFTMAX has changed"); 89static_assert(ANEURALNETWORKS_SPACE_TO_DEPTH == 26, 90 "ANEURALNETWORKS_SPACE_TO_DEPTH has changed"); 91static_assert(ANEURALNETWORKS_SVDF == 27, "ANEURALNETWORKS_SVDF has changed"); 92static_assert(ANEURALNETWORKS_TANH == 28, "ANEURALNETWORKS_TANH has changed"); 93static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000, 94 "ANEURALNETWORKS_OEM_OPERATION has changed"); 95 96static_assert(ANEURALNETWORKS_FUSED_NONE == 0, "ANEURALNETWORKS_FUSED_NONE has changed"); 97static_assert(ANEURALNETWORKS_FUSED_RELU == 1, "ANEURALNETWORKS_FUSED_RELU has changed"); 98static_assert(ANEURALNETWORKS_FUSED_RELU1 == 2, "ANEURALNETWORKS_FUSED_RELU1 has changed"); 99static_assert(ANEURALNETWORKS_FUSED_RELU6 == 3, "ANEURALNETWORKS_FUSED_RELU6 has changed"); 100 101static_assert(ANEURALNETWORKS_PREFER_LOW_POWER == 0, 102 "ANEURALNETWORKS_PREFER_LOW_POWER has changed"); 103static_assert(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER == 1, 104 "ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER has changed"); 105static_assert(ANEURALNETWORKS_PREFER_SUSTAINED_SPEED == 2, 106 "ANEURALNETWORKS_PREFER_SUSTAINED_SPEED has changed"); 107 108static_assert(ANEURALNETWORKS_NO_ERROR == 0, "ANEURALNETWORKS_NO_ERROR has changed"); 109static_assert(ANEURALNETWORKS_OUT_OF_MEMORY == 1, "ANEURALNETWORKS_OUT_OF_MEMORY has changed"); 110static_assert(ANEURALNETWORKS_INCOMPLETE == 2, "ANEURALNETWORKS_INCOMPLETE has changed"); 111static_assert(ANEURALNETWORKS_UNEXPECTED_NULL == 3, 112 "ANEURALNETWORKS_UNEXPECTED_NULL has changed"); 113static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA has changed"); 114static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED has changed"); 115static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE has changed"); 116 117static_assert(ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES == 128, 118 "ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES has changed"); 119 120// Make sure that the constants are compatible with the values defined in 121// hardware/interfaces/neuralnetworks/1.0/types.hal. 122static_assert(static_cast<int32_t>(OperandType::OEM) == ANEURALNETWORKS_OEM_SCALAR, 123 "OEM != ANEURALNETWORKS_OEM"); 124static_assert(static_cast<int32_t>(OperandType::FLOAT32) == ANEURALNETWORKS_FLOAT32, 125 "FLOAT32 != ANEURALNETWORKS_FLOAT32"); 126static_assert(static_cast<int32_t>(OperandType::INT32) == ANEURALNETWORKS_INT32, 127 "INT32 != ANEURALNETWORKS_INT32"); 128static_assert(static_cast<int32_t>(OperandType::UINT32) == ANEURALNETWORKS_UINT32, 129 "UINT32 != ANEURALNETWORKS_UINT32"); 130static_assert(static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) == ANEURALNETWORKS_TENSOR_OEM_BYTE, 131 "TENSOR_OEM_BYTE != ANEURALNETWORKS_TENSOR_OEM_BYTE"); 132static_assert(static_cast<int32_t>(OperandType::TENSOR_FLOAT32) == ANEURALNETWORKS_TENSOR_FLOAT32, 133 "TENSOR_FLOAT32 != ANEURALNETWORKS_TENSOR_FLOAT32"); 134static_assert(static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) == 135 ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, 136 "TENSOR_QUANT8_ASYMM != ANEURALNETWORKS_TENSOR_QUANT8_ASYMM"); 137 138static_assert(static_cast<int32_t>(OperationType::ADD) == ANEURALNETWORKS_ADD, 139 "OperationType::ADD != ANEURALNETWORKS_ADD"); 140static_assert(static_cast<int32_t>(OperationType::AVERAGE_POOL_2D) == 141 ANEURALNETWORKS_AVERAGE_POOL_2D, 142 "OperationType::AVERAGE_POOL_2D != ANEURALNETWORKS_AVERAGE_POOL_2D"); 143static_assert(static_cast<int32_t>(OperationType::CONV_2D) == ANEURALNETWORKS_CONV_2D, 144 "OperationType::CONV_2D != ANEURALNETWORKS_CONV_2D"); 145static_assert(static_cast<int32_t>(OperationType::DEPTHWISE_CONV_2D) == 146 ANEURALNETWORKS_DEPTHWISE_CONV_2D, 147 "OperationType::DEPTHWISE_CONV_2D != ANEURALNETWORKS_DEPTHWISE_CONV_2D"); 148static_assert(static_cast<int32_t>(OperationType::DEPTH_TO_SPACE) == 149 ANEURALNETWORKS_DEPTH_TO_SPACE, 150 "OperationType::DEPTH_TO_SPACE != ANEURALNETWORKS_DEPTH_TO_SPACE"); 151static_assert(static_cast<int32_t>(OperationType::DEQUANTIZE) == ANEURALNETWORKS_DEQUANTIZE, 152 "OperationType::DEQUANTIZE != ANEURALNETWORKS_DEQUANTIZE"); 153static_assert(static_cast<int32_t>(OperationType::EMBEDDING_LOOKUP) == 154 ANEURALNETWORKS_EMBEDDING_LOOKUP, 155 "OperationType::EMBEDDING_LOOKUP != ANEURALNETWORKS_EMBEDDING_LOOKUP"); 156static_assert(static_cast<int32_t>(OperationType::FLOOR) == ANEURALNETWORKS_FLOOR, 157 "OperationType::FLOOR != ANEURALNETWORKS_FLOOR"); 158static_assert(static_cast<int32_t>(OperationType::FULLY_CONNECTED) == 159 ANEURALNETWORKS_FULLY_CONNECTED, 160 "OperationType::FULLY_CONNECTED != ANEURALNETWORKS_FULLY_CONNECTED"); 161static_assert(static_cast<int32_t>(OperationType::HASHTABLE_LOOKUP) == 162 ANEURALNETWORKS_HASHTABLE_LOOKUP, 163 "OperationType::HASHTABLE_LOOKUP != ANEURALNETWORKS_HASHTABLE_LOOKUP"); 164static_assert(static_cast<int32_t>(OperationType::L2_NORMALIZATION) == 165 ANEURALNETWORKS_L2_NORMALIZATION, 166 "OperationType::L2_NORMALIZATION != ANEURALNETWORKS_L2_NORMALIZATION"); 167static_assert(static_cast<int32_t>(OperationType::L2_POOL_2D) == ANEURALNETWORKS_L2_POOL_2D, 168 "OperationType::L2_POOL_2D != ANEURALNETWORKS_L2_POOL_2D"); 169static_assert(static_cast<int32_t>(OperationType::LOCAL_RESPONSE_NORMALIZATION) == 170 ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION, 171 "OperationType::LOCAL_RESPONSE_NORMALIZATION != " 172 "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION"); 173static_assert(static_cast<int32_t>(OperationType::LOGISTIC) == ANEURALNETWORKS_LOGISTIC, 174 "OperationType::LOGISTIC != ANEURALNETWORKS_LOGISTIC"); 175static_assert(static_cast<int32_t>(OperationType::LSH_PROJECTION) == 176 ANEURALNETWORKS_LSH_PROJECTION, 177 "OperationType::LSH_PROJECTION != ANEURALNETWORKS_LSH_PROJECTION"); 178static_assert(static_cast<int32_t>(OperationType::LSTM) == ANEURALNETWORKS_LSTM, 179 "OperationType::LSTM != ANEURALNETWORKS_LSTM"); 180static_assert(static_cast<int32_t>(OperationType::MAX_POOL_2D) == ANEURALNETWORKS_MAX_POOL_2D, 181 "OperationType::MAX_POOL_2D != ANEURALNETWORKS_MAX_POOL_2D"); 182static_assert(static_cast<int32_t>(OperationType::MUL) == ANEURALNETWORKS_MUL, 183 "OperationType::MUL != ANEURALNETWORKS_MUL"); 184static_assert(static_cast<int32_t>(OperationType::RELU) == ANEURALNETWORKS_RELU, 185 "OperationType::RELU != ANEURALNETWORKS_RELU"); 186static_assert(static_cast<int32_t>(OperationType::RELU1) == ANEURALNETWORKS_RELU1, 187 "OperationType::RELU1 != ANEURALNETWORKS_RELU1"); 188static_assert(static_cast<int32_t>(OperationType::RELU6) == ANEURALNETWORKS_RELU6, 189 "OperationType::RELU6 != ANEURALNETWORKS_RELU6"); 190static_assert(static_cast<int32_t>(OperationType::RESHAPE) == ANEURALNETWORKS_RESHAPE, 191 "OperationType::RESHAPE != ANEURALNETWORKS_RESHAPE"); 192static_assert(static_cast<int32_t>(OperationType::RESIZE_BILINEAR) == 193 ANEURALNETWORKS_RESIZE_BILINEAR, 194 "OperationType::RESIZE_BILINEAR != ANEURALNETWORKS_RESIZE_BILINEAR"); 195static_assert(static_cast<int32_t>(OperationType::RNN) == ANEURALNETWORKS_RNN, 196 "OperationType::RNN != ANEURALNETWORKS_RNN"); 197static_assert(static_cast<int32_t>(OperationType::SOFTMAX) == ANEURALNETWORKS_SOFTMAX, 198 "OperationType::SOFTMAX != ANEURALNETWORKS_SOFTMAX"); 199static_assert(static_cast<int32_t>(OperationType::SPACE_TO_DEPTH) == 200 ANEURALNETWORKS_SPACE_TO_DEPTH, 201 "OperationType::SPACE_TO_DEPTH != ANEURALNETWORKS_SPACE_TO_DEPTH"); 202static_assert(static_cast<int32_t>(OperationType::SVDF) == ANEURALNETWORKS_SVDF, 203 "OperationType::SVDF != ANEURALNETWORKS_SVDF"); 204static_assert(static_cast<int32_t>(OperationType::TANH) == ANEURALNETWORKS_TANH, 205 "OperationType::TANH != ANEURALNETWORKS_TANH"); 206 207static_assert(static_cast<int32_t>(FusedActivationFunc::NONE) == ANEURALNETWORKS_FUSED_NONE, 208 "FusedActivationFunc::NONE != ANEURALNETWORKS_FUSED_NONE"); 209static_assert(static_cast<int32_t>(FusedActivationFunc::RELU) == ANEURALNETWORKS_FUSED_RELU, 210 "FusedActivationFunc::RELU != ANEURALNETWORKS_FUSED_RELU"); 211static_assert(static_cast<int32_t>(FusedActivationFunc::RELU1) == ANEURALNETWORKS_FUSED_RELU1, 212 "FusedActivationFunc::RELU1 != ANEURALNETWORKS_FUSED_RELU1"); 213static_assert(static_cast<int32_t>(FusedActivationFunc::RELU6) == ANEURALNETWORKS_FUSED_RELU6, 214 "FusedActivationFunc::RELU6 != ANEURALNETWORKS_FUSED_RELU6"); 215 216using android::sp; 217using namespace android::nn; 218 219int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset, 220 ANeuralNetworksMemory** memory) { 221 *memory = nullptr; 222 std::unique_ptr<MemoryFd> m = std::make_unique<MemoryFd>(); 223 if (m == nullptr) { 224 return ANEURALNETWORKS_OUT_OF_MEMORY; 225 } 226 int n = m->set(size, prot, fd, offset); 227 if (n != ANEURALNETWORKS_NO_ERROR) { 228 return n; 229 } 230 *memory = reinterpret_cast<ANeuralNetworksMemory*>(m.release()); 231 return ANEURALNETWORKS_NO_ERROR; 232} 233 234void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) { 235 // No validation. Free of nullptr is valid. 236 Memory* m = reinterpret_cast<Memory*>(memory); 237 delete m; 238} 239 240int ANeuralNetworksModel_create(ANeuralNetworksModel** model) { 241 initVLogMask(); 242 if (!model) { 243 LOG(ERROR) << "ANeuralNetworksModel_create passed a nullptr"; 244 return ANEURALNETWORKS_UNEXPECTED_NULL; 245 } 246 ModelBuilder* m = new ModelBuilder(); 247 if (m == nullptr) { 248 *model = nullptr; 249 return ANEURALNETWORKS_OUT_OF_MEMORY; 250 } 251 *model = reinterpret_cast<ANeuralNetworksModel*>(m); 252 return ANEURALNETWORKS_NO_ERROR; 253} 254 255void ANeuralNetworksModel_free(ANeuralNetworksModel* model) { 256 // No validation. Free of nullptr is valid. 257 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 258 delete m; 259} 260 261int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) { 262 if (!model) { 263 LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr"; 264 return ANEURALNETWORKS_UNEXPECTED_NULL; 265 } 266 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 267 return m->finish(); 268} 269 270int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model, 271 const ANeuralNetworksOperandType* type) { 272 if (!model || !type) { 273 LOG(ERROR) << "ANeuralNetworksModel_addOperand passed a nullptr"; 274 return ANEURALNETWORKS_UNEXPECTED_NULL; 275 } 276 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 277 return m->addOperand(*type); 278} 279 280int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, int32_t index, 281 const void* buffer, size_t length) { 282 if (!model || !buffer) { 283 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr"; 284 return ANEURALNETWORKS_UNEXPECTED_NULL; 285 } 286 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 287 return m->setOperandValue(index, buffer, length); 288} 289 290int ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index, 291 const ANeuralNetworksMemory* memory, 292 size_t offset, size_t length) { 293 if (!model || !memory) { 294 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr"; 295 return ANEURALNETWORKS_UNEXPECTED_NULL; 296 } 297 const Memory* mem = reinterpret_cast<const Memory*>(memory); 298 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 299 return m->setOperandValueFromMemory(index, mem, offset, length); 300} 301 302int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model, 303 ANeuralNetworksOperationType type, uint32_t inputCount, 304 const uint32_t* inputs, uint32_t outputCount, 305 const uint32_t* outputs) { 306 if (!model || !inputs || !outputs) { 307 LOG(ERROR) << "ANeuralNetworksModel_addOperation passed a nullptr"; 308 return ANEURALNETWORKS_UNEXPECTED_NULL; 309 } 310 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 311 return m->addOperation(type, inputCount, inputs, outputCount, outputs); 312} 313 314int ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount, 315 const uint32_t* inputs, uint32_t outputCount, 316 const uint32_t* outputs) { 317 if (!model || !inputs || !outputs) { 318 LOG(ERROR) << ("ANeuralNetworksModel_identifyInputsAndOutputs passed a nullptr"); 319 return ANEURALNETWORKS_UNEXPECTED_NULL; 320 } 321 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 322 return m->identifyInputsAndOutputs(inputCount, inputs, outputCount, outputs); 323} 324 325int ANeuralNetworksCompilation_create(ANeuralNetworksModel* model, 326 ANeuralNetworksCompilation** compilation) { 327 if (!model || !compilation) { 328 LOG(ERROR) << "ANeuralNetworksCompilation_create passed a nullptr"; 329 return ANEURALNETWORKS_UNEXPECTED_NULL; 330 } 331 332 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); 333 CompilationBuilder* c = nullptr; 334 int result = m->createCompilation(&c); 335 *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c); 336 return result; 337} 338 339void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) { 340 // No validation. Free of nullptr is valid. 341 // TODO specification says that a compilation-in-flight can be deleted 342 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation); 343 delete c; 344} 345 346int ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation* compilation, 347 int32_t preference) { 348 if (!compilation) { 349 LOG(ERROR) << "ANeuralNetworksCompilation_setPreference passed a nullptr"; 350 return ANEURALNETWORKS_UNEXPECTED_NULL; 351 } 352 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation); 353 return c->setPreference(preference); 354} 355 356int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation) { 357 if (!compilation) { 358 LOG(ERROR) << "ANeuralNetworksCompilation_finish passed a nullptr"; 359 return ANEURALNETWORKS_UNEXPECTED_NULL; 360 } 361 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation); 362 return c->finish(); 363} 364 365int ANeuralNetworksExecution_create(ANeuralNetworksCompilation* compilation, 366 ANeuralNetworksExecution** execution) { 367 if (!compilation || !execution) { 368 LOG(ERROR) << "ANeuralNetworksExecution_create passed a nullptr"; 369 return ANEURALNETWORKS_UNEXPECTED_NULL; 370 } 371 372 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation); 373 ExecutionBuilder* r = nullptr; 374 int result = c->createExecution(&r); 375 *execution = reinterpret_cast<ANeuralNetworksExecution*>(r); 376 return result; 377} 378 379void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) { 380 // TODO specification says that an execution-in-flight can be deleted 381 // No validation. Free of nullptr is valid. 382 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution); 383 delete r; 384} 385 386int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution* execution, int32_t index, 387 const ANeuralNetworksOperandType* type, const void* buffer, 388 size_t length) { 389 // TODO: For a non-optional input, also verify that buffer is not null. 390 if (!execution) { 391 LOG(ERROR) << "ANeuralNetworksExecution_setInput passed a nullptr"; 392 return ANEURALNETWORKS_UNEXPECTED_NULL; 393 } 394 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution); 395 return r->setInput(index, type, buffer, length); 396} 397 398int ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index, 399 const ANeuralNetworksOperandType* type, 400 const ANeuralNetworksMemory* memory, size_t offset, 401 size_t length) { 402 if (!execution || !memory) { 403 LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory passed a nullptr"; 404 return ANEURALNETWORKS_UNEXPECTED_NULL; 405 } 406 407 const Memory* m = reinterpret_cast<const Memory*>(memory); 408 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution); 409 return r->setInputFromMemory(index, type, m, offset, length); 410} 411 412int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution* execution, int32_t index, 413 const ANeuralNetworksOperandType* type, void* buffer, 414 size_t length) { 415 if (!execution || !buffer) { 416 LOG(ERROR) << "ANeuralNetworksExecution_setOutput passed a nullptr"; 417 return ANEURALNETWORKS_UNEXPECTED_NULL; 418 } 419 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution); 420 return r->setOutput(index, type, buffer, length); 421} 422 423int ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index, 424 const ANeuralNetworksOperandType* type, 425 const ANeuralNetworksMemory* memory, size_t offset, 426 size_t length) { 427 if (!execution || !memory) { 428 LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory passed a nullptr"; 429 return ANEURALNETWORKS_UNEXPECTED_NULL; 430 } 431 432 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution); 433 const Memory* m = reinterpret_cast<const Memory*>(memory); 434 return r->setOutputFromMemory(index, type, m, offset, length); 435} 436 437int ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution* execution, 438 ANeuralNetworksEvent** event) { 439 if (!execution || !event) { 440 LOG(ERROR) << "ANeuralNetworksExecution_startCompute passed a nullptr"; 441 return ANEURALNETWORKS_UNEXPECTED_NULL; 442 } 443 // TODO validate the rest 444 445 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution); 446 447 // Dynamically allocate an sp to wrap an ExecutionCallback, seen in the NN 448 // API as an abstract event object. The sp<ExecutionCallback> object is 449 // returned when the execution has been successfully launched, otherwise a 450 // nullptr is returned. The sp is used for ref-counting purposes. Without 451 // it, the HIDL service could attempt to communicate with a dead callback 452 // object. 453 std::unique_ptr<sp<ExecutionCallback>> e = std::make_unique<sp<ExecutionCallback>>(); 454 *event = nullptr; 455 456 int n = r->startCompute(e.get()); 457 if (n != ANEURALNETWORKS_NO_ERROR) { 458 return n; 459 } 460 *event = reinterpret_cast<ANeuralNetworksEvent*>(e.release()); 461 return ANEURALNETWORKS_NO_ERROR; 462} 463 464int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) { 465 if (event == nullptr) { 466 LOG(ERROR) << "ANeuralNetworksEvent_wait passed a nullptr"; 467 return ANEURALNETWORKS_UNEXPECTED_NULL; 468 } 469 470 sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event); 471 (*e)->wait(); 472 return ANEURALNETWORKS_NO_ERROR; 473} 474 475void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) { 476 // No validation. Free of nullptr is valid. 477 if (event) { 478 sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event); 479 (*e)->wait(); 480 delete e; 481 } 482} 483