196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet/* 296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Copyright (C) 2017 The Android Open Source Project 396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * 496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Licensed under the Apache License, Version 2.0 (the "License"); 596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * you may not use this file except in compliance with the License. 696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * You may obtain a copy of the License at 796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * 896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * http://www.apache.org/licenses/LICENSE-2.0 996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * 1096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Unless required by applicable law or agreed to in writing, software 1196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * distributed under the License is distributed on an "AS IS" BASIS, 1296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * See the License for the specific language governing permissions and 1496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * limitations under the License. 1596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet */ 1696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 1796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet// Class used to build a model through a succession of successive calls 1896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet// to the NN API. 1996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 20707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet#ifndef ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 21707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet#define ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 2296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 23707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet#include "HalInterfaces.h" 248b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet#include "Memory.h" 2596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include "NeuralNetworks.h" 2696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include "Utils.h" 2796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 2896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace android { 2996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace nn { 3096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 3183e24dc4706a5b7089881a55daf05b3924fab3b7David Grossclass CompilationBuilder; 3291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass Device; 3391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass ExecutionPlan; 348b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletclass Memory; 3596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 36707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouilletclass ModelBuilder { 3796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletpublic: 3896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // Adds an operand to the model. 3996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet int addOperand(const ANeuralNetworksOperandType& type); 4096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet int setOperandValue(uint32_t index, const void* buffer, size_t length); 418b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet int setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 428b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet size_t length); 4396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 44d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet int addOperation(ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, 45d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet uint32_t outputCount, const uint32_t* outputs); 4666d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet int identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, 4766d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet const uint32_t* outputs); 48084401d6215dca122999261c5ac3718ebf61b29eMichael Butler int relaxComputationFloat32toFloat16(bool allow); 49084401d6215dca122999261c5ac3718ebf61b29eMichael Butler bool isComputationFloat32RelaxedToFloat16() const { return mRelaxComputationFloat32toFloat16; } 5096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 51544739620cd7f37d40524d2407c92042e485c73fDavid Gross int finish(); 52544739620cd7f37d40524d2407c92042e485c73fDavid Gross bool isFinished() const { return mCompletedModel; } 53544739620cd7f37d40524d2407c92042e485c73fDavid Gross 54544739620cd7f37d40524d2407c92042e485c73fDavid Gross int createCompilation(CompilationBuilder** compilation); 55707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet 56707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet void setHidlModel(Model* model) const; 5796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 5896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet uint32_t operandCount() const { 5996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // We don't allow more than uint32_t worth of operands 6096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return static_cast<uint32_t>(mOperands.size()); 6196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 6296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet uint32_t operationCount() const { 6396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // We don't allow more than uint32_t worth of operations 6496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet return static_cast<uint32_t>(mOperations.size()); 6596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet } 66707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet uint32_t inputCount() const { return static_cast<uint32_t>(mInputIndexes.size()); } 67707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet uint32_t outputCount() const { return static_cast<uint32_t>(mOutputIndexes.size()); } 68891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross uint32_t getInputOperandIndex(uint32_t i) const { return mInputIndexes[i]; } 69891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross const Operand& getInputOperand(uint32_t i) const { 70891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross return mOperands[getInputOperandIndex(i)]; 71891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross } 72891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross uint32_t getOutputOperandIndex(uint32_t i) const { return mOutputIndexes[i]; } 73891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross const Operand& getOutputOperand(uint32_t i) const { 74891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross return mOperands[getOutputOperandIndex(i)]; 75891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross } 76707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet const Operand& getOperand(uint32_t index) const { return mOperands[index]; } 7791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet const Operation& getOperation(uint32_t index) const { return mOperations[index]; } 788b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet const MemoryTracker& getMemories() const { return mMemories; } 7991e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet const std::vector<Operation>& getOperations() const { return mOperations; } 8091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet const uint8_t* getPointerToOperandValue(uint32_t offset) const { 811da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet return mSmallOperandValues.data() + offset; 8291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet } 8396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 84def0a14aa77689f12120cfb4f136eea659038cc0David Gross int partitionTheWork(const std::vector<std::shared_ptr<Device>>& devices, 85def0a14aa77689f12120cfb4f136eea659038cc0David Gross uint32_t preference, ExecutionPlan* plan) const; 86def0a14aa77689f12120cfb4f136eea659038cc0David Gross 87def0a14aa77689f12120cfb4f136eea659038cc0David Gross private: 880b9453e41a544f9c780eaa15ad65136ad4662ccbDavid Gross // TODO: move partitionTheWork, findBestDeviceForEachOperation, 890b9453e41a544f9c780eaa15ad65136ad4662ccbDavid Gross // sortIntoRunOrder to CompilationBuilder? 900b9453e41a544f9c780eaa15ad65136ad4662ccbDavid Gross 9191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet int findBestDeviceForEachOperation(uint32_t preference, 9291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet const std::vector<std::shared_ptr<Device>>& devices, 9391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet const size_t deviceCount, 940b9453e41a544f9c780eaa15ad65136ad4662ccbDavid Gross std::vector<int>* bestDeviceForOperation) const; 957d6ac906f000f3afe418e92c0a4ae36b2ea1143eJean-Luc Brouillet PerformanceInfo getPerformanceInfo(const std::shared_ptr<Device> device, 960b9453e41a544f9c780eaa15ad65136ad4662ccbDavid Gross uint32_t operationIndex) const; 9791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet 98853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang // Return true if either mCompleteModel or mInvalidModel is true. 99853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang bool badState(const char* name); 100853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang 10196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // Sorts the operations to be in the correct order for single threaded 10296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // node-at-a-time execution. 10396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet void sortIntoRunOrder(); 1041da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet 1051da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet // Copies the large values to a shared memory, if we have any. 1061da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet int copyLargeValuesToSharedMemory(); 10796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 10896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // The operations of the graph. 109707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet std::vector<Operation> mOperations; 11096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // The description of the operands of the graph. 111707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet std::vector<Operand> mOperands; 112707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet // Specifies where to find the list of indexes identifying 113707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet // the inputs and outputs of the model. The offset is into 114707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet // the mOperandIndexes table. 115707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet std::vector<uint32_t> mInputIndexes; 116707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet std::vector<uint32_t> mOutputIndexes; 117707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet 1188b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet MemoryTracker mMemories; 1198b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet 1201da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet // The value of the small operands that are defined at model 12196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // creation time. 1221da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet std::vector<uint8_t> mSmallOperandValues; 1231da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet 1241da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet struct LargeValue { 1251da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet uint32_t operandIndex; 1261da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet const void* buffer; 1271da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet }; 1281da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet // Operand index and buffer pointer for all the large operand values of this model. 1291da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet std::vector<LargeValue> mLargeOperandValues; 1301da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet // The shared memory region that will contain the large values. 1311da8fed77c5c296afa18f754ec3616e7f02a4cfdJean-Luc Brouillet Memory mLargeValueMemory; 13296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 13365aa556323f4a054f80a75b6c4c721b2a7ed3298David Gross // Once the model has been finished, we should not allow further 13496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet // modifications to the model. 13596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet mutable bool mCompletedModel = false; 136084401d6215dca122999261c5ac3718ebf61b29eMichael Butler 137853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang // Any invalid manipulation of the model will mark the model invalid. 138853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang // No further modifications are allowed to the model. 139853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang mutable bool mInvalidModel = false; 140853931ff8e4df815699c30c2948b5a51aa4a206dMiao Wang 141084401d6215dca122999261c5ac3718ebf61b29eMichael Butler // 'true' indicates TENSOR_FLOAT32 may be calculated with range and/or 142084401d6215dca122999261c5ac3718ebf61b29eMichael Butler // precision as low as that of the IEEE 754 16-bit floating-point format. 143084401d6215dca122999261c5ac3718ebf61b29eMichael Butler // 'false' indicates TENSOR_FLOAT32 must be calculated using at least the 144084401d6215dca122999261c5ac3718ebf61b29eMichael Butler // range and precision of the IEEE 754 32-bit floating-point format. 145084401d6215dca122999261c5ac3718ebf61b29eMichael Butler bool mRelaxComputationFloat32toFloat16 = false; 14696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}; 14796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 148d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet} // namespace nn 149d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet} // namespace android 15096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet 151d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet#endif // ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 152