191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet/*
291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * Copyright (C) 2017 The Android Open Source Project
391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet *
491e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * Licensed under the Apache License, Version 2.0 (the "License");
591e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * you may not use this file except in compliance with the License.
691e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * You may obtain a copy of the License at
791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet *
891e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet *      http://www.apache.org/licenses/LICENSE-2.0
991e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet *
1091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * Unless required by applicable law or agreed to in writing, software
1191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * distributed under the License is distributed on an "AS IS" BASIS,
1291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * See the License for the specific language governing permissions and
1491e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet * limitations under the License.
1591e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet */
1691e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
1791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet// Classes used to plan how to execute a model across multiple devices.
1891e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
1991e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#ifndef ANDROID_ML_NN_RUNTIME_EXECUTION_PLAN_H
2091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#define ANDROID_ML_NN_RUNTIME_EXECUTION_PLAN_H
2191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
2291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#include "HalInterfaces.h"
2391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#include "Memory.h"
2457167f7ec8bfe682139a9a4d60cd8aa913899441Michael Butler#include "ModelBuilder.h"
2591e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#include "NeuralNetworks.h"
2691e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#include "Utils.h"
2791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
288913ae3283de7752aed108c1b26aef1adacb049fDavid Gross#include <set>
298913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
3091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletnamespace android {
3191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletnamespace nn {
3291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
3391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass CompilationBuilder;
3491e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass Device;
35b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Grossclass ExecutionBuilder;
368913ae3283de7752aed108c1b26aef1adacb049fDavid Grossclass ExecutionPlan;
3791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass Memory;
38b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Grossclass StepExecutor;
3991e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
4091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass ExecutionStep {
41c4c264098a728268ad28084ea6e0263d9c1d7868David Grosspublic:
428913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    typedef std::vector<std::pair<uint32_t, uint32_t>> RemapVectorType;
43b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    typedef std::set<std::pair<uint32_t, uint32_t>> SubModelOutputSetType;
448913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
45f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    enum OperandKind { INPUT, OUTPUT };
46f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross
478913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    ExecutionStep(ExecutionPlan* plan,
488913ae3283de7752aed108c1b26aef1adacb049fDavid Gross                  uint32_t stepIndex,
498913ae3283de7752aed108c1b26aef1adacb049fDavid Gross                  std::shared_ptr<Device> device);
5091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet    int addOperation(int operationIndex, const ModelBuilder& fromModel);
5191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet    int addOperand(uint32_t fromOperandIndex, uint32_t* toOperandIndex,
52f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross                   const ModelBuilder& fromModel, OperandKind kind);
5391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
5496811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    // Each container entry is of the form (fromModel index, subModel index)
5596811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    const RemapVectorType& getModelInputs() const {
5696811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        return mModelInputs;
5796811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    }
5896811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    const RemapVectorType& getModelOutputs() const {
5996811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        return mModelOutputs;
6096811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    }
61c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    const RemapVectorType& getTempsAsSubModelInputs() const {
62c4c264098a728268ad28084ea6e0263d9c1d7868David Gross        return mTempsAsSubModelInputs;
638913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    }
64c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    const SubModelOutputSetType& getTempsAsSubModelOutputs() const {
65c4c264098a728268ad28084ea6e0263d9c1d7868David Gross        return mTempsAsSubModelOutputs;
66c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    }
67c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    const RemapVectorType& getOutputsAsSubModelInputs() const {
68c4c264098a728268ad28084ea6e0263d9c1d7868David Gross        return mOutputsAsSubModelInputs;
69c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    }
70c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    const std::vector<uint32_t>& getOutputsAsSubModelInputsIndexToFromModel() const {
71c4c264098a728268ad28084ea6e0263d9c1d7868David Gross        return mOutputsAsSubModelInputsIndexToFromModel;
7296811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    }
73b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross
74c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    void recordTempAsSubModelOutput(uint32_t fromModelIndex) {
758913ae3283de7752aed108c1b26aef1adacb049fDavid Gross        const auto it = mOperandMap.find(fromModelIndex);
768913ae3283de7752aed108c1b26aef1adacb049fDavid Gross        nnAssert(it != mOperandMap.end());
77c4c264098a728268ad28084ea6e0263d9c1d7868David Gross        mTempsAsSubModelOutputs.insert(std::make_pair(fromModelIndex, it->second));
788913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    }
798913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
80b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // If this step has a submodel output of unknown size, sets
81b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // *hasOutputOfUnknownSize to true; otherwise, leaves it
82b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // unchanged.
831e9666208595bc251a8958155b1e41eca90b69dbMichael Butler    int finishSubModel(const ModelBuilder* fromModel, bool* hasOutputOfUnknownSize,
841e9666208595bc251a8958155b1e41eca90b69dbMichael Butler                       int32_t executionPreference);
85891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross
8657167f7ec8bfe682139a9a4d60cd8aa913899441Michael Butler    const ModelBuilder* getSubModel() const { return &mSubModel; }
87891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    std::shared_ptr<Device> getDevice() const { return mDevice; }
88891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross
89891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // only available after calling finishSubModel()
90891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    sp<IPreparedModel> getPreparedSubModel() const { return mPreparedSubModel; }
91891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross
92891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // Map inputs and outputs from ExecutionBuilder to StepExecutor.
93891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    void mapInputsAndOutputs(std::shared_ptr<StepExecutor> stepExecutor) const;
948913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
958913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    void dump() const;
96c4c264098a728268ad28084ea6e0263d9c1d7868David Gross
9791e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletprivate:
98c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    void logSubModel() const;
99c4c264098a728268ad28084ea6e0263d9c1d7868David Gross
100f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    // TODO: Some of the data is working state information that
101f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    // shouldn't be needed after we've constructed but not executed
102f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    // the step.
103f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross
1048913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    ExecutionPlan* mPlan;
1058913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    uint32_t mIndex;  // index of step within plan
10657167f7ec8bfe682139a9a4d60cd8aa913899441Michael Butler    ModelBuilder mSubModel;
1078913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    std::shared_ptr<Device> mDevice;  // nullptr signifies CPU
1081f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    sp<IPreparedModel> mPreparedSubModel;  // not used for CPU
109f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross
110f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    // Inputs of original model that are also inputs of this submodel:
111f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    //     (fromModel index, subModel index)
1128913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    RemapVectorType mModelInputs;
113f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    // Outputs of original model that are also outputs of this submodel:
114f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    //     (fromModel index, subModel index)
1158913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    RemapVectorType mModelOutputs;
116f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    // Temporaries of original model that are inputs of this submodel:
117f4e1c640547a44c7a37209e81ee5f3831b7d0fdcDavid Gross    //     (fromModel index, subModel index)
118c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    RemapVectorType mTempsAsSubModelInputs;
1198913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    // Temporaries of original model that are outputs of this submodel:
1208913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    //     (fromModel index, subModel index)
121c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    SubModelOutputSetType mTempsAsSubModelOutputs;
122c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // Outputs of original model that are inputs of this submodel:
123c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //     (fromModel index, subModel index)
124c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    RemapVectorType mOutputsAsSubModelInputs;
12591e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet    // Converts operand indexes from the main model to the submodel.
12691e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet    std::unordered_map<uint32_t, uint32_t> mOperandMap;
127891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // Converts input indexes from the submodel to the main model
128891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // (these are input indexes, not operand indexes).  This vector
129891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // only describes inputs of the submodel that are also inputs of
130c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // the main model -- that is, mModelInputs but not mTempsAsSubModelInputs.
131891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    std::vector<uint32_t> mInputIndexSubModelToFromModel;
132891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // Converts output indexes from the submodel to the main model
133891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // (these are output indexes, not operand indexes).  This vector
134891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    // only describes outputs of the submodel that are also outputs of
135c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // the main model -- that is, mModelOutputs but not mTempsAsSubModelOutputs.
136891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross    std::vector<uint32_t> mOutputIndexSubModelToFromModel;
137c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // Converts indexes into mOutputsAsSubModelInputs to indexes into
138c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // main model outputs (these are input and output indexes, not
139c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // operand indexes).  To be specific, if the main model outputs
140c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // are mainModelOutputs,
141c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //
142c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //     mOutputsAsSubModelInputsIndexToFromModel.size() ==
143c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //     mOutputsAsSubModelInputs.size()
144c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //
145c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    // and when (0 <= i < mOutputsAsSubModelInputs.size()),
146c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //
147c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //     mainModelOutputs[mOutputsAsSubModelInputsIndexToFromModel[i]] ==
148c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    //     mOutputsAsSubModelInputs[i].first
149c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    std::vector<uint32_t> mOutputsAsSubModelInputsIndexToFromModel;
15091e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet};
15191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
15291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouilletclass ExecutionPlan {
1538913ae3283de7752aed108c1b26aef1adacb049fDavid Grosspublic:
1541f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    ExecutionPlan(const ExecutionPlan&) = delete;
1551f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    ExecutionPlan& operator=(const ExecutionPlan&) = delete;
1568913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
1571f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    ExecutionPlan() { }
1581f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    ~ExecutionPlan() { delete mBody; }
1591f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
160b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // Controller is part of the interface to a mechanism for
161b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // performing an execution in N steps.
162b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    //
163b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // Usage pattern:
164b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // - Instantiate Controller with ExecutionPlan::makeController().
165b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // - Call ExecutionPlan::next() on Controller N+1 times.  The first N times,
166b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    //   *executor is set to point to a new StepExecutor corresponding
167b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    //   to that step.  The N+1st time, *executor is set to nullptr,
168b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    //   signifying there are no more steps.
169b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    // - If ExecutionPlan::next() returns anything other than ANEURALNETWORKS_NO_ERROR,
170b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    //   a problem has occurred.
171b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    class Controller {
172b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross        friend class ExecutionPlan;
173b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    private:
174a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross        Controller(const Controller&) = delete;
175a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross        Controller& operator=(const Controller&) = delete;
176a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross
17796811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        // Map from the operand index of a TEMPORARY in the original
1788fb14e90ceb360adfbac0f708d27161b7c5b7fc5David Gross        // model to an offset into mTemporaries used to represent that
1798fb14e90ceb360adfbac0f708d27161b7c5b7fc5David Gross        // TEMPORARY as an inter-partition input or output.
1808fb14e90ceb360adfbac0f708d27161b7c5b7fc5David Gross        typedef std::map<uint32_t, uint32_t> SubModelInputsAndOutputsType;
18196811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross
182b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross        static const size_t kBadStepIndex = ~size_t(0);
183b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross
18496811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        Controller(const ExecutionPlan* plan, const ExecutionBuilder* executionBuilder,
1858fb14e90ceb360adfbac0f708d27161b7c5b7fc5David Gross                   std::shared_ptr<const SubModelInputsAndOutputsType> subModelInputsAndOutputs,
1868fb14e90ceb360adfbac0f708d27161b7c5b7fc5David Gross                   uint32_t totalSizeOfTemporaries);
187b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross
188a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross        const ExecutionPlan* mPlan;
189a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross        const ExecutionBuilder* mExecutionBuilder;
19096811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        std::shared_ptr<const SubModelInputsAndOutputsType> mSubModelInputsAndOutputs;  // may be nullptr
1918fb14e90ceb360adfbac0f708d27161b7c5b7fc5David Gross        Memory mTemporaries;
192a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross        size_t mNextStepIndex;
193b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross    };
194b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross
195a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross    std::shared_ptr<Controller> makeController(const ExecutionBuilder* executionBuilder) const;
196b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross
197a2a03635c8f215cb75be68ff1939bf4dec285ef8David Gross    int next(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const;
198b26049114bc4c64e6bea3a5d5d129fcaec8e69b6David Gross
1995e8feed5e8a07bab1ec395e5a01bb8900db00cecDavid Gross    // Create the same executor as the last one created by next().
2005e8feed5e8a07bab1ec395e5a01bb8900db00cecDavid Gross    int fallback(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const;
2015e8feed5e8a07bab1ec395e5a01bb8900db00cecDavid Gross
2021f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    std::shared_ptr<ExecutionStep> createNewStep(const std::shared_ptr<Device> device);
2031f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2041f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    void becomeSingleStep(const std::shared_ptr<Device> device,
2051f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross                          const ModelBuilder* model);
2061f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2071e9666208595bc251a8958155b1e41eca90b69dbMichael Butler    int finish(const ModelBuilder* fromModel, int32_t executionPreference);
2088913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
2098913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    void recordTemporaryDef(uint32_t fromModelIndex, uint32_t stepIndex) {
2101f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        auto& temporaryToDefiningStep = compound()->mTemporaryToDefiningStep;
2111f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        nnAssert(temporaryToDefiningStep.count(fromModelIndex) == 0);
2121f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        temporaryToDefiningStep.insert(std::make_pair(fromModelIndex, stepIndex));
21391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet    }
2148913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
2158913ae3283de7752aed108c1b26aef1adacb049fDavid Gross    void dump() const;
2168913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
217def0a14aa77689f12120cfb4f136eea659038cc0David Gross    // These functions are solely intended for use by unit tests of
218def0a14aa77689f12120cfb4f136eea659038cc0David Gross    // the partitioning algorithm.
219def0a14aa77689f12120cfb4f136eea659038cc0David Gross    enum class Kind { ERROR, EMPTY, SIMPLE, COMPOUND };
220def0a14aa77689f12120cfb4f136eea659038cc0David Gross    Kind forTest_getKind() const;
221def0a14aa77689f12120cfb4f136eea659038cc0David Gross    std::shared_ptr<const Device> forTest_simpleGetDevice() const;
222def0a14aa77689f12120cfb4f136eea659038cc0David Gross    const std::vector<std::shared_ptr<ExecutionStep>>& forTest_compoundGetSteps() const;
223bb255b6e87dc343eb90dec998be1cf153106ab65Mika Raento    bool forTest_hasSubModelOutputsOfUnknownSize() const;
224def0a14aa77689f12120cfb4f136eea659038cc0David Gross
2258913ae3283de7752aed108c1b26aef1adacb049fDavid Grossprivate:
226c4c264098a728268ad28084ea6e0263d9c1d7868David Gross    void findTempsAsSubModelOutputs();
2278913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
2281f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    struct Body {
2291f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        virtual ~Body() {}
2301f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        virtual void dump() const = 0;
2311e9666208595bc251a8958155b1e41eca90b69dbMichael Butler        virtual int finish(const ModelBuilder* fromModel, int32_t executionPreference) = 0;
232bb255b6e87dc343eb90dec998be1cf153106ab65Mika Raento        virtual bool hasSubModelOutputsOfUnknownSize() const = 0;
233891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross        bool mSuccessfulFinish = false;
2341f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    };
2351f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2361f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    struct SimpleBody : Body {
2371f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        SimpleBody(std::shared_ptr<Device> device, const ModelBuilder* model) :
2381f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross                mDevice(device), mModel(model) {}
2398913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
2401f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        void dump() const override;
2411e9666208595bc251a8958155b1e41eca90b69dbMichael Butler        int finish(const ModelBuilder* fromModel, int32_t executionPreference) override;
242bb255b6e87dc343eb90dec998be1cf153106ab65Mika Raento        virtual bool hasSubModelOutputsOfUnknownSize() const override { return false; }
2438913ae3283de7752aed108c1b26aef1adacb049fDavid Gross
2441f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        std::shared_ptr<Device> mDevice;  // nullptr signifies CPU
2451f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        const ModelBuilder* mModel;
2461f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        sp<IPreparedModel> mPreparedModel;  // not used for CPU
2471f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    };
2481f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2491f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    struct CompoundBody : Body {
2501f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        void dump() const override;
2511e9666208595bc251a8958155b1e41eca90b69dbMichael Butler        int finish(const ModelBuilder* fromModel, int32_t executionPreference) override;
252bb255b6e87dc343eb90dec998be1cf153106ab65Mika Raento        virtual bool hasSubModelOutputsOfUnknownSize() const override {
253bb255b6e87dc343eb90dec998be1cf153106ab65Mika Raento            return mHasSubModelOutputOfUnknownSize;
254bb255b6e87dc343eb90dec998be1cf153106ab65Mika Raento        }
2551f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2561f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        // TODO: Some of the data is working state information that
2571f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        // shouldn't be needed after we've constructed but not
2581f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        // executed the plan.
2591f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2601f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        std::vector<std::shared_ptr<ExecutionStep>> mSteps;
2611f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2621f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        // Map from original operand index to defining step index.
2631f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        // Used for all (and only) TEMPORARY_VARIABLEs.
2641f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        std::unordered_map<uint32_t, uint32_t> mTemporaryToDefiningStep;
2651f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
266891b10f7048c62a37a74c4b570be220089dfd55eDavid Gross        bool mHasSubModelOutputOfUnknownSize = false;
2671f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    private:
268c4c264098a728268ad28084ea6e0263d9c1d7868David Gross        void findTempsAsSubModelOutputs();
2691f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    };
2701f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross
2711f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    enum { EMPTY, SIMPLE, COMPOUND } mState = EMPTY;
2721f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    Body* mBody = nullptr;
2731f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    CompoundBody* compound() {
2741f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        nnAssert(mState == COMPOUND);
2751f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross        return static_cast<CompoundBody*>(mBody);
2761f4381539b7e89c42336ee7cd1addb9a4c317b34David Gross    }
27796811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    const CompoundBody* compound() const {
27896811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        nnAssert(mState == COMPOUND);
27996811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross        return static_cast<const CompoundBody*>(mBody);
28096811e2b1347889a25bd9686f47ca3cbf061fb1bDavid Gross    }
28191e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet};
28291e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
28391e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet}  // namespace nn
28491e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet}  // namespace android
28591e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet
28691e8417c4c395e3922d12abfd956b93b71121976Jean-Luc Brouillet#endif  // ANDROID_ML_NN_RUNTIME_EXECUTION_PLAN_H
287