196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet/*
296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Copyright (C) 2017 The Android Open Source Project
396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *
496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Licensed under the Apache License, Version 2.0 (the "License");
596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * you may not use this file except in compliance with the License.
696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * You may obtain a copy of the License at
796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *
896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *      http://www.apache.org/licenses/LICENSE-2.0
996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *
1096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Unless required by applicable law or agreed to in writing, software
1196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * distributed under the License is distributed on an "AS IS" BASIS,
1296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * See the License for the specific language governing permissions and
1496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * limitations under the License.
1596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet */
1696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
1796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet// Provides C++ classes to more easily use the Neural Networks API.
1896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
1996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
2096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
2196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include "NeuralNetworks.h"
2396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang#include <math.h>
2596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include <vector>
2696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace android {
2896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace nn {
2996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace wrapper {
3096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
3196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletenum class Type {
3296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    FLOAT32 = ANEURALNETWORKS_FLOAT32,
3396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    INT32 = ANEURALNETWORKS_INT32,
3496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    UINT32 = ANEURALNETWORKS_UINT32,
3596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
362150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet    TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
3727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
3896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet};
3996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
4096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletenum class ExecutePreference {
4196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
4296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
4396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
4496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet};
4596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
4696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletenum class Result {
4796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    NO_ERROR = ANEURALNETWORKS_NO_ERROR,
4896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
4996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
5096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
5196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    BAD_DATA = ANEURALNETWORKS_BAD_DATA,
5296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet};
5396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
5496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletstruct OperandType {
5596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    ANeuralNetworksOperandType operandType;
5666d56404cdfab9ab8aa79d4bda83be3832a3aff9Miao Wang    // int32_t type;
5796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    std::vector<uint32_t> dimensions;
5896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
5966d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet    OperandType(Type type, const std::vector<uint32_t>& d, float scale = 0.0f,
6066d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet                int32_t zeroPoint = 0)
6166d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet        : dimensions(d) {
6266d56404cdfab9ab8aa79d4bda83be3832a3aff9Miao Wang        operandType.type = static_cast<int32_t>(type);
6345bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang        operandType.scale = scale;
6445bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang        operandType.zeroPoint = zeroPoint;
6527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
66d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet        operandType.dimensionCount = static_cast<uint32_t>(dimensions.size());
67d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet        operandType.dimensions = dimensions.data();
6896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
6996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet};
7096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
718b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletclass Memory {
728b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletpublic:
7342dc6a6cd68877cd85e3bc475b41bda0fd946c41Miao Wang    Memory(size_t size, int protect, int fd, size_t offset) {
7442dc6a6cd68877cd85e3bc475b41bda0fd946c41Miao Wang        mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
75d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet                 ANEURALNETWORKS_NO_ERROR;
76105807d963d969197fe78185ed588bfad3dc0ea5Miao Wang    }
77105807d963d969197fe78185ed588bfad3dc0ea5Miao Wang
788b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    ~Memory() { ANeuralNetworksMemory_free(mMemory); }
792150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet
807b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // Disallow copy semantics to ensure the runtime object can only be freed
817b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // once. Copy semantics could be enabled if some sort of reference counting
827b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // or deep-copy system for runtime objects is added later.
837b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    Memory(const Memory&) = delete;
847b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    Memory& operator=(const Memory&) = delete;
857b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler
867b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // Move semantics to remove access to the runtime object from the wrapper
877b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // object that is being moved. This ensures the runtime object will be
887b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // freed only once.
89d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet    Memory(Memory&& other) { *this = std::move(other); }
907b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    Memory& operator=(Memory&& other) {
917b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        if (this != &other) {
927b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            mMemory = other.mMemory;
937b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            mValid = other.mValid;
947b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            other.mMemory = nullptr;
957b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            other.mValid = false;
967b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        }
977b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        return *this;
987b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    }
997b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler
1008b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    ANeuralNetworksMemory* get() const { return mMemory; }
1018b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    bool isValid() const { return mValid; }
1028b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet
1038b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouilletprivate:
1048b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    ANeuralNetworksMemory* mMemory = nullptr;
1058b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    bool mValid = true;
1068b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet};
1078b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet
10896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletclass Model {
10996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletpublic:
11096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    Model() {
11196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        // TODO handle the value returned by this call
11296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        ANeuralNetworksModel_create(&mModel);
11396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
11496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    ~Model() { ANeuralNetworksModel_free(mModel); }
11596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
1167b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // Disallow copy semantics to ensure the runtime object can only be freed
1177b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // once. Copy semantics could be enabled if some sort of reference counting
1187b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // or deep-copy system for runtime objects is added later.
1197b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    Model(const Model&) = delete;
1207b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    Model& operator=(const Model&) = delete;
1217b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler
1227b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // Move semantics to remove access to the runtime object from the wrapper
1237b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // object that is being moved. This ensures the runtime object will be
1247b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // freed only once.
125d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet    Model(Model&& other) { *this = std::move(other); }
1267b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    Model& operator=(Model&& other) {
1277b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        if (this != &other) {
1287b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            mModel = other.mModel;
1297b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            mNextOperandId = other.mNextOperandId;
1307b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            mValid = other.mValid;
1317b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            other.mModel = nullptr;
1327b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            other.mNextOperandId = 0;
1337b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler            other.mValid = false;
1347b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        }
1357b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        return *this;
1367b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    }
1377b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler
13866d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet    Result finish() { return static_cast<Result>(ANeuralNetworksModel_finish(mModel)); }
139544739620cd7f37d40524d2407c92042e485c73fDavid Gross
14096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    uint32_t addOperand(const OperandType* type) {
14196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
14296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            ANEURALNETWORKS_NO_ERROR) {
14396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            mValid = false;
14496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
14596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        return mNextOperandId++;
14696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
14796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
14896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    void setOperandValue(uint32_t index, const void* buffer, size_t length) {
14996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
15096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            ANEURALNETWORKS_NO_ERROR) {
15196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            mValid = false;
15296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
15396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
15496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
1558b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
1568b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet                                   size_t length) {
1578b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet        if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
1588b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet                                                           length) != ANEURALNETWORKS_NO_ERROR) {
1598b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet            mValid = false;
1608b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet        }
1618b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    }
1628b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet
16396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
16496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet                      const std::vector<uint32_t>& outputs) {
165d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet        if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
166d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet                                              inputs.data(), static_cast<uint32_t>(outputs.size()),
167d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet                                              outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
16896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            mValid = false;
16996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
17096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
17166d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet    void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
17266d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet                                  const std::vector<uint32_t>& outputs) {
17366d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet        if (ANeuralNetworksModel_identifyInputsAndOutputs(
17466d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet                        mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
17566d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet                        static_cast<uint32_t>(outputs.size()),
17666d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet                        outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
17796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            mValid = false;
17896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
17996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
18096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    ANeuralNetworksModel* getHandle() const { return mModel; }
18196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    bool isValid() const { return mValid; }
18296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
18396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletprivate:
18496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    ANeuralNetworksModel* mModel = nullptr;
18596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    // We keep track of the operand ID as a convenience to the caller.
18696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    uint32_t mNextOperandId = 0;
18796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    bool mValid = true;
18896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet};
18996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
190425b2594c76e934dfdbc93209253e3c189571149David Grossclass Event {
191425b2594c76e934dfdbc93209253e3c189571149David Grosspublic:
192e3178825b8686f3300a895572691a2e8c1f0676bDavid Gross    Event() {}
193425b2594c76e934dfdbc93209253e3c189571149David Gross    ~Event() { ANeuralNetworksEvent_free(mEvent); }
194425b2594c76e934dfdbc93209253e3c189571149David Gross
195425b2594c76e934dfdbc93209253e3c189571149David Gross    // Disallow copy semantics to ensure the runtime object can only be freed
196425b2594c76e934dfdbc93209253e3c189571149David Gross    // once. Copy semantics could be enabled if some sort of reference counting
197425b2594c76e934dfdbc93209253e3c189571149David Gross    // or deep-copy system for runtime objects is added later.
198425b2594c76e934dfdbc93209253e3c189571149David Gross    Event(const Event&) = delete;
199425b2594c76e934dfdbc93209253e3c189571149David Gross    Event& operator=(const Event&) = delete;
200425b2594c76e934dfdbc93209253e3c189571149David Gross
201425b2594c76e934dfdbc93209253e3c189571149David Gross    // Move semantics to remove access to the runtime object from the wrapper
202425b2594c76e934dfdbc93209253e3c189571149David Gross    // object that is being moved. This ensures the runtime object will be
203425b2594c76e934dfdbc93209253e3c189571149David Gross    // freed only once.
20466d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet    Event(Event&& other) { *this = std::move(other); }
205425b2594c76e934dfdbc93209253e3c189571149David Gross    Event& operator=(Event&& other) {
206425b2594c76e934dfdbc93209253e3c189571149David Gross        if (this != &other) {
207425b2594c76e934dfdbc93209253e3c189571149David Gross            mEvent = other.mEvent;
208425b2594c76e934dfdbc93209253e3c189571149David Gross            other.mEvent = nullptr;
209425b2594c76e934dfdbc93209253e3c189571149David Gross        }
210425b2594c76e934dfdbc93209253e3c189571149David Gross        return *this;
211425b2594c76e934dfdbc93209253e3c189571149David Gross    }
212425b2594c76e934dfdbc93209253e3c189571149David Gross
213425b2594c76e934dfdbc93209253e3c189571149David Gross    Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
214425b2594c76e934dfdbc93209253e3c189571149David Gross
215e3178825b8686f3300a895572691a2e8c1f0676bDavid Gross    // Only for use by Execution
216425b2594c76e934dfdbc93209253e3c189571149David Gross    void set(ANeuralNetworksEvent* newEvent) {
217425b2594c76e934dfdbc93209253e3c189571149David Gross        ANeuralNetworksEvent_free(mEvent);
218425b2594c76e934dfdbc93209253e3c189571149David Gross        mEvent = newEvent;
219425b2594c76e934dfdbc93209253e3c189571149David Gross    }
220425b2594c76e934dfdbc93209253e3c189571149David Gross
221425b2594c76e934dfdbc93209253e3c189571149David Grossprivate:
222425b2594c76e934dfdbc93209253e3c189571149David Gross    ANeuralNetworksEvent* mEvent = nullptr;
223425b2594c76e934dfdbc93209253e3c189571149David Gross};
224425b2594c76e934dfdbc93209253e3c189571149David Gross
22583e24dc4706a5b7089881a55daf05b3924fab3b7David Grossclass Compilation {
22683e24dc4706a5b7089881a55daf05b3924fab3b7David Grosspublic:
22783e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    Compilation(const Model* model) {
22883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
22983e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        if (result != 0) {
23083e24dc4706a5b7089881a55daf05b3924fab3b7David Gross            // TODO Handle the error
23183e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        }
23283e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    }
23383e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
23483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
23583e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
23683e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    Compilation(const Compilation&) = delete;
237d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet    Compilation& operator=(const Compilation&) = delete;
23883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
239d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet    Compilation(Compilation&& other) { *this = std::move(other); }
24083e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    Compilation& operator=(Compilation&& other) {
24183e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        if (this != &other) {
24283e24dc4706a5b7089881a55daf05b3924fab3b7David Gross            mCompilation = other.mCompilation;
24383e24dc4706a5b7089881a55daf05b3924fab3b7David Gross            other.mCompilation = nullptr;
24483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        }
24583e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        return *this;
24683e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    }
24783e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
24883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    Result setPreference(ExecutePreference preference) {
24983e24dc4706a5b7089881a55daf05b3924fab3b7David Gross        return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
25066d56404cdfab9ab8aa79d4bda83be3832a3aff9Miao Wang                    mCompilation, static_cast<int32_t>(preference)));
25183e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    }
25283e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
25366d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet    Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
25483e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
25583e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
25683e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
25783e24dc4706a5b7089881a55daf05b3924fab3b7David Grossprivate:
25883e24dc4706a5b7089881a55daf05b3924fab3b7David Gross    ANeuralNetworksCompilation* mCompilation = nullptr;
25983e24dc4706a5b7089881a55daf05b3924fab3b7David Gross};
26083e24dc4706a5b7089881a55daf05b3924fab3b7David Gross
2613ced3cfd5b8f22b632c35f24e585c4847383b195David Grossclass Execution {
26296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletpublic:
2633ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    Execution(const Compilation* compilation) {
2643ced3cfd5b8f22b632c35f24e585c4847383b195David Gross        int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
26596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        if (result != 0) {
26696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            // TODO Handle the error
26796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
26896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
26996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2703ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    ~Execution() { ANeuralNetworksExecution_free(mExecution); }
27196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2727b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // Disallow copy semantics to ensure the runtime object can only be freed
2737b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // once. Copy semantics could be enabled if some sort of reference counting
2747b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // or deep-copy system for runtime objects is added later.
2753ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    Execution(const Execution&) = delete;
2763ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    Execution& operator=(const Execution&) = delete;
2777b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler
2787b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // Move semantics to remove access to the runtime object from the wrapper
2797b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // object that is being moved. This ensures the runtime object will be
2807b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    // freed only once.
281d2d0c031c43e8e5aafc75e8a652d79bcc2aaca99Jean-Luc Brouillet    Execution(Execution&& other) { *this = std::move(other); }
2823ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    Execution& operator=(Execution&& other) {
2837b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        if (this != &other) {
2843ced3cfd5b8f22b632c35f24e585c4847383b195David Gross            mExecution = other.mExecution;
2853ced3cfd5b8f22b632c35f24e585c4847383b195David Gross            other.mExecution = nullptr;
2867b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        }
2877b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler        return *this;
2887b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler    }
2897b87fec6bb919d16a8cc2820d470733a2776e8faMichael Butler
29096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    Result setInput(uint32_t index, const void* buffer, size_t length,
29196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet                    const ANeuralNetworksOperandType* type = nullptr) {
29296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        return static_cast<Result>(
2933ced3cfd5b8f22b632c35f24e585c4847383b195David Gross                    ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
29496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
29596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2968b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
2978b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet                              uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
2983ced3cfd5b8f22b632c35f24e585c4847383b195David Gross        return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
2993ced3cfd5b8f22b632c35f24e585c4847383b195David Gross                    mExecution, index, type, memory->get(), offset, length));
30096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
30196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
30296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    Result setOutput(uint32_t index, void* buffer, size_t length,
30396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet                     const ANeuralNetworksOperandType* type = nullptr) {
30496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        return static_cast<Result>(
3053ced3cfd5b8f22b632c35f24e585c4847383b195David Gross                    ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
30696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
30796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
3088b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet    Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
3098b99bb1d98a42b67ba1c00e12c7abb3708cf7c05Jean-Luc Brouillet                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
3103ced3cfd5b8f22b632c35f24e585c4847383b195David Gross        return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
3113ced3cfd5b8f22b632c35f24e585c4847383b195David Gross                    mExecution, index, type, memory->get(), offset, length));
31296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
31396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
314425b2594c76e934dfdbc93209253e3c189571149David Gross    Result startCompute(Event* event) {
315425b2594c76e934dfdbc93209253e3c189571149David Gross        ANeuralNetworksEvent* ev = nullptr;
316425b2594c76e934dfdbc93209253e3c189571149David Gross        Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
317425b2594c76e934dfdbc93209253e3c189571149David Gross        event->set(ev);
31896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        return result;
31996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
32096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
32196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    Result compute() {
322425b2594c76e934dfdbc93209253e3c189571149David Gross        ANeuralNetworksEvent* event = nullptr;
323425b2594c76e934dfdbc93209253e3c189571149David Gross        Result result =
32466d5cb6e3a90aefc8d545f6369080ab88de9d667Jean-Luc Brouillet                    static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
32596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        if (result != Result::NO_ERROR) {
32696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            return result;
32796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
32896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        // TODO how to manage the lifetime of events when multiple waiters is not
32996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        // clear.
330425b2594c76e934dfdbc93209253e3c189571149David Gross        result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
331425b2594c76e934dfdbc93209253e3c189571149David Gross        ANeuralNetworksEvent_free(event);
332425b2594c76e934dfdbc93209253e3c189571149David Gross        return result;
33396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
33496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
33596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletprivate:
3363ced3cfd5b8f22b632c35f24e585c4847383b195David Gross    ANeuralNetworksExecution* mExecution = nullptr;
33796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet};
33896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
3392150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet}  // namespace wrapper
3402150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet}  // namespace nn
3412150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet}  // namespace android
34296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
3432150f1d186b2854fb5aa609594be12a667f845f0Jean-Luc Brouillet#endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
344