196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet/* 296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Copyright (C) 2017 The Android Open Source Project 396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * 496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Licensed under the Apache License, Version 2.0 (the "License"); 596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * you may not use this file except in compliance with the License. 696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * You may obtain a copy of the License at 796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * 896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * http://www.apache.org/licenses/LICENSE-2.0 996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * 1096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Unless required by applicable law or agreed to in writing, software 1196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * distributed under the License is distributed on an "AS IS" BASIS, 1296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * See the License for the specific language governing permissions and 1496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * limitations under the License. 1596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet */ 1696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 1796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet// Provides C++ classes to more easily use the Neural Networks API. 1896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 1996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 2096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 2196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include "NeuralNetworks.h" 2396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang#include <math.h> 2596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include <vector> 2696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace android { 2896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace nn { 2996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace wrapper { 3096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 3196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletenum class Type { 3296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet FLOAT32 = ANEURALNETWORKS_FLOAT32, 3396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet INT32 = ANEURALNETWORKS_INT32, 3496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet UINT32 = ANEURALNETWORKS_UINT32, 3596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32, 362150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32, 3727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, 3896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 3996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 4096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletenum class ExecutePreference { 4196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER, 4296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER, 4396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED 4496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 4596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 4696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletenum class Result { 4796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet NO_ERROR = ANEURALNETWORKS_NO_ERROR, 4896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY, 4996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet INCOMPLETE = ANEURALNETWORKS_INCOMPLETE, 5096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL, 5196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet BAD_DATA = ANEURALNETWORKS_BAD_DATA, 5296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 5396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 5496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletstruct OperandType { 5596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ANeuralNetworksOperandType operandType; 5666d56404cdfab9ab8aa79d4bda83be3832a3aff9Miao Wang // int32_t type; 5796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet std::vector<uint32_t> dimensions; 5896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 5966d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet OperandType(Type type, const std::vector<uint32_t>& d, float scale = 0.0f, 6066d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet int32_t zeroPoint = 0) 6166d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet : dimensions(d) { 6266d56404cdfab9ab8aa79d4bda83be3832a3aff9Miao Wang operandType.type = static_cast<int32_t>(type); 6345bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang operandType.scale = scale; 6445bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang operandType.zeroPoint = zeroPoint; 6527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang 66d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet operandType.dimensionCount = static_cast<uint32_t>(dimensions.size()); 67d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet operandType.dimensions = dimensions.data(); 6896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 6996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 7096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 718b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletclass Memory { 728b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletpublic: 7342dc6a6cd68877cd85e3bc475b41bda0fd946c41Miao Wang Memory(size_t size, int protect, int fd, size_t offset) { 7442dc6a6cd68877cd85e3bc475b41bda0fd946c41Miao Wang mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) == 75d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet ANEURALNETWORKS_NO_ERROR; 76105807d963d969197fe78185ed588bfad3dc0ea5Miao Wang } 77105807d963d969197fe78185ed588bfad3dc0ea5Miao Wang 788b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet ~Memory() { ANeuralNetworksMemory_free(mMemory); } 792150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet 807b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // Disallow copy semantics to ensure the runtime object can only be freed 817b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // once. Copy semantics could be enabled if some sort of reference counting 827b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // or deep-copy system for runtime objects is added later. 837b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler Memory(const Memory&) = delete; 847b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler Memory& operator=(const Memory&) = delete; 857b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler 867b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // Move semantics to remove access to the runtime object from the wrapper 877b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // object that is being moved. This ensures the runtime object will be 887b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // freed only once. 89d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet Memory(Memory&& other) { *this = std::move(other); } 907b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler Memory& operator=(Memory&& other) { 917b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler if (this != &other) { 927b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler mMemory = other.mMemory; 937b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler mValid = other.mValid; 947b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler other.mMemory = nullptr; 957b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler other.mValid = false; 967b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler } 977b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler return *this; 987b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler } 997b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler 1008b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet ANeuralNetworksMemory* get() const { return mMemory; } 1018b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet bool isValid() const { return mValid; } 1028b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet 1038b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletprivate: 1048b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet ANeuralNetworksMemory* mMemory = nullptr; 1058b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet bool mValid = true; 1068b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet}; 1078b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet 10896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletclass Model { 10996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletpublic: 11096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet Model() { 11196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // TODO handle the value returned by this call 11296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ANeuralNetworksModel_create(&mModel); 11396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 11496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ~Model() { ANeuralNetworksModel_free(mModel); } 11596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 1167b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // Disallow copy semantics to ensure the runtime object can only be freed 1177b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // once. Copy semantics could be enabled if some sort of reference counting 1187b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // or deep-copy system for runtime objects is added later. 1197b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler Model(const Model&) = delete; 1207b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler Model& operator=(const Model&) = delete; 1217b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler 1227b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // Move semantics to remove access to the runtime object from the wrapper 1237b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // object that is being moved. This ensures the runtime object will be 1247b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // freed only once. 125d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet Model(Model&& other) { *this = std::move(other); } 1267b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler Model& operator=(Model&& other) { 1277b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler if (this != &other) { 1287b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler mModel = other.mModel; 1297b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler mNextOperandId = other.mNextOperandId; 1307b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler mValid = other.mValid; 1317b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler other.mModel = nullptr; 1327b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler other.mNextOperandId = 0; 1337b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler other.mValid = false; 1347b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler } 1357b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler return *this; 1367b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler } 1377b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler 13866d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet Result finish() { return static_cast<Result>(ANeuralNetworksModel_finish(mModel)); } 139544739620cd7f37d40524d2407c92042e485c73fDavid Gross 14096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet uint32_t addOperand(const OperandType* type) { 14196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 14296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ANEURALNETWORKS_NO_ERROR) { 14396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet mValid = false; 14496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 14596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return mNextOperandId++; 14696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 14796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 14896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet void setOperandValue(uint32_t index, const void* buffer, size_t length) { 14996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != 15096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ANEURALNETWORKS_NO_ERROR) { 15196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet mValid = false; 15296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 15396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 15496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 1558b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 1568b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet size_t length) { 1578b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset, 1588b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet length) != ANEURALNETWORKS_NO_ERROR) { 1598b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet mValid = false; 1608b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet } 1618b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet } 1628b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet 16396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, 16496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet const std::vector<uint32_t>& outputs) { 165d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()), 166d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet inputs.data(), static_cast<uint32_t>(outputs.size()), 167d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet outputs.data()) != ANEURALNETWORKS_NO_ERROR) { 16896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet mValid = false; 16996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 17096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 17166d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs, 17266d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet const std::vector<uint32_t>& outputs) { 17366d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet if (ANeuralNetworksModel_identifyInputsAndOutputs( 17466d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet mModel, static_cast<uint32_t>(inputs.size()), inputs.data(), 17566d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet static_cast<uint32_t>(outputs.size()), 17666d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet outputs.data()) != ANEURALNETWORKS_NO_ERROR) { 17796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet mValid = false; 17896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 17996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 18096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ANeuralNetworksModel* getHandle() const { return mModel; } 18196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet bool isValid() const { return mValid; } 18296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 18396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletprivate: 18496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet ANeuralNetworksModel* mModel = nullptr; 18596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // We keep track of the operand ID as a convenience to the caller. 18696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet uint32_t mNextOperandId = 0; 18796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet bool mValid = true; 18896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 18996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 190425b2594c76e934dfdbc93209253e3c189571149David Grossclass Event { 191425b2594c76e934dfdbc93209253e3c189571149David Grosspublic: 192e3178825b8686f3300a895572691a2e8c1f0676bDavid Gross Event() {} 193425b2594c76e934dfdbc93209253e3c189571149David Gross ~Event() { ANeuralNetworksEvent_free(mEvent); } 194425b2594c76e934dfdbc93209253e3c189571149David Gross 195425b2594c76e934dfdbc93209253e3c189571149David Gross // Disallow copy semantics to ensure the runtime object can only be freed 196425b2594c76e934dfdbc93209253e3c189571149David Gross // once. Copy semantics could be enabled if some sort of reference counting 197425b2594c76e934dfdbc93209253e3c189571149David Gross // or deep-copy system for runtime objects is added later. 198425b2594c76e934dfdbc93209253e3c189571149David Gross Event(const Event&) = delete; 199425b2594c76e934dfdbc93209253e3c189571149David Gross Event& operator=(const Event&) = delete; 200425b2594c76e934dfdbc93209253e3c189571149David Gross 201425b2594c76e934dfdbc93209253e3c189571149David Gross // Move semantics to remove access to the runtime object from the wrapper 202425b2594c76e934dfdbc93209253e3c189571149David Gross // object that is being moved. This ensures the runtime object will be 203425b2594c76e934dfdbc93209253e3c189571149David Gross // freed only once. 20466d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet Event(Event&& other) { *this = std::move(other); } 205425b2594c76e934dfdbc93209253e3c189571149David Gross Event& operator=(Event&& other) { 206425b2594c76e934dfdbc93209253e3c189571149David Gross if (this != &other) { 207425b2594c76e934dfdbc93209253e3c189571149David Gross mEvent = other.mEvent; 208425b2594c76e934dfdbc93209253e3c189571149David Gross other.mEvent = nullptr; 209425b2594c76e934dfdbc93209253e3c189571149David Gross } 210425b2594c76e934dfdbc93209253e3c189571149David Gross return *this; 211425b2594c76e934dfdbc93209253e3c189571149David Gross } 212425b2594c76e934dfdbc93209253e3c189571149David Gross 213425b2594c76e934dfdbc93209253e3c189571149David Gross Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); } 214425b2594c76e934dfdbc93209253e3c189571149David Gross 215e3178825b8686f3300a895572691a2e8c1f0676bDavid Gross // Only for use by Execution 216425b2594c76e934dfdbc93209253e3c189571149David Gross void set(ANeuralNetworksEvent* newEvent) { 217425b2594c76e934dfdbc93209253e3c189571149David Gross ANeuralNetworksEvent_free(mEvent); 218425b2594c76e934dfdbc93209253e3c189571149David Gross mEvent = newEvent; 219425b2594c76e934dfdbc93209253e3c189571149David Gross } 220425b2594c76e934dfdbc93209253e3c189571149David Gross 221425b2594c76e934dfdbc93209253e3c189571149David Grossprivate: 222425b2594c76e934dfdbc93209253e3c189571149David Gross ANeuralNetworksEvent* mEvent = nullptr; 223425b2594c76e934dfdbc93209253e3c189571149David Gross}; 224425b2594c76e934dfdbc93209253e3c189571149David Gross 22583e24dc4706a5b7089881a55daf05b3924fab3b7David Grossclass Compilation { 22683e24dc4706a5b7089881a55daf05b3924fab3b7David Grosspublic: 22783e24dc4706a5b7089881a55daf05b3924fab3b7David Gross Compilation(const Model* model) { 22883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation); 22983e24dc4706a5b7089881a55daf05b3924fab3b7David Gross if (result != 0) { 23083e24dc4706a5b7089881a55daf05b3924fab3b7David Gross // TODO Handle the error 23183e24dc4706a5b7089881a55daf05b3924fab3b7David Gross } 23283e24dc4706a5b7089881a55daf05b3924fab3b7David Gross } 23383e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 23483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); } 23583e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 23683e24dc4706a5b7089881a55daf05b3924fab3b7David Gross Compilation(const Compilation&) = delete; 237d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet Compilation& operator=(const Compilation&) = delete; 23883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 239d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet Compilation(Compilation&& other) { *this = std::move(other); } 24083e24dc4706a5b7089881a55daf05b3924fab3b7David Gross Compilation& operator=(Compilation&& other) { 24183e24dc4706a5b7089881a55daf05b3924fab3b7David Gross if (this != &other) { 24283e24dc4706a5b7089881a55daf05b3924fab3b7David Gross mCompilation = other.mCompilation; 24383e24dc4706a5b7089881a55daf05b3924fab3b7David Gross other.mCompilation = nullptr; 24483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross } 24583e24dc4706a5b7089881a55daf05b3924fab3b7David Gross return *this; 24683e24dc4706a5b7089881a55daf05b3924fab3b7David Gross } 24783e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 24883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross Result setPreference(ExecutePreference preference) { 24983e24dc4706a5b7089881a55daf05b3924fab3b7David Gross return static_cast<Result>(ANeuralNetworksCompilation_setPreference( 25066d56404cdfab9ab8aa79d4bda83be3832a3aff9Miao Wang mCompilation, static_cast<int32_t>(preference))); 25183e24dc4706a5b7089881a55daf05b3924fab3b7David Gross } 25283e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 25366d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); } 25483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 25583e24dc4706a5b7089881a55daf05b3924fab3b7David Gross ANeuralNetworksCompilation* getHandle() const { return mCompilation; } 25683e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 25783e24dc4706a5b7089881a55daf05b3924fab3b7David Grossprivate: 25883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross ANeuralNetworksCompilation* mCompilation = nullptr; 25983e24dc4706a5b7089881a55daf05b3924fab3b7David Gross}; 26083e24dc4706a5b7089881a55daf05b3924fab3b7David Gross 2613ced3cfd5b8f22b632c35f24e585c4847383b195David Grossclass Execution { 26296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletpublic: 2633ced3cfd5b8f22b632c35f24e585c4847383b195David Gross Execution(const Compilation* compilation) { 2643ced3cfd5b8f22b632c35f24e585c4847383b195David Gross int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution); 26596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet if (result != 0) { 26696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // TODO Handle the error 26796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 26896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 26996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2703ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ~Execution() { ANeuralNetworksExecution_free(mExecution); } 27196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2727b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // Disallow copy semantics to ensure the runtime object can only be freed 2737b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // once. Copy semantics could be enabled if some sort of reference counting 2747b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // or deep-copy system for runtime objects is added later. 2753ced3cfd5b8f22b632c35f24e585c4847383b195David Gross Execution(const Execution&) = delete; 2763ced3cfd5b8f22b632c35f24e585c4847383b195David Gross Execution& operator=(const Execution&) = delete; 2777b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler 2787b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // Move semantics to remove access to the runtime object from the wrapper 2797b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // object that is being moved. This ensures the runtime object will be 2807b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler // freed only once. 281d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet Execution(Execution&& other) { *this = std::move(other); } 2823ced3cfd5b8f22b632c35f24e585c4847383b195David Gross Execution& operator=(Execution&& other) { 2837b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler if (this != &other) { 2843ced3cfd5b8f22b632c35f24e585c4847383b195David Gross mExecution = other.mExecution; 2853ced3cfd5b8f22b632c35f24e585c4847383b195David Gross other.mExecution = nullptr; 2867b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler } 2877b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler return *this; 2887b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler } 2897b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler 29096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet Result setInput(uint32_t index, const void* buffer, size_t length, 29196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet const ANeuralNetworksOperandType* type = nullptr) { 29296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return static_cast<Result>( 2933ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length)); 29496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 29596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2968b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 2978b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 2983ced3cfd5b8f22b632c35f24e585c4847383b195David Gross return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory( 2993ced3cfd5b8f22b632c35f24e585c4847383b195David Gross mExecution, index, type, memory->get(), offset, length)); 30096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 30196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 30296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet Result setOutput(uint32_t index, void* buffer, size_t length, 30396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet const ANeuralNetworksOperandType* type = nullptr) { 30496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return static_cast<Result>( 3053ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length)); 30696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 30796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 3088b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 3098b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 3103ced3cfd5b8f22b632c35f24e585c4847383b195David Gross return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory( 3113ced3cfd5b8f22b632c35f24e585c4847383b195David Gross mExecution, index, type, memory->get(), offset, length)); 31296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 31396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 314425b2594c76e934dfdbc93209253e3c189571149David Gross Result startCompute(Event* event) { 315425b2594c76e934dfdbc93209253e3c189571149David Gross ANeuralNetworksEvent* ev = nullptr; 316425b2594c76e934dfdbc93209253e3c189571149David Gross Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev)); 317425b2594c76e934dfdbc93209253e3c189571149David Gross event->set(ev); 31896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return result; 31996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 32096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 32196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet Result compute() { 322425b2594c76e934dfdbc93209253e3c189571149David Gross ANeuralNetworksEvent* event = nullptr; 323425b2594c76e934dfdbc93209253e3c189571149David Gross Result result = 32466d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event)); 32596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet if (result != Result::NO_ERROR) { 32696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return result; 32796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 32896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // TODO how to manage the lifetime of events when multiple waiters is not 32996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // clear. 330425b2594c76e934dfdbc93209253e3c189571149David Gross result = static_cast<Result>(ANeuralNetworksEvent_wait(event)); 331425b2594c76e934dfdbc93209253e3c189571149David Gross ANeuralNetworksEvent_free(event); 332425b2594c76e934dfdbc93209253e3c189571149David Gross return result; 33396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 33496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 33596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletprivate: 3363ced3cfd5b8f22b632c35f24e585c4847383b195David Gross ANeuralNetworksExecution* mExecution = nullptr; 33796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 33896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 3392150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet} // namespace wrapper 3402150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet} // namespace nn 3412150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet} // namespace android 34296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 3432150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet#endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 344