1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef FRAMEWORKS_ML_NN_SVDF_H
18#define FRAMEWORKS_ML_NN_SVDF_H
19
20#include "ActivationFunctor.h"
21
22#include <algorithm>
23#include <cmath>
24
25namespace android {
26namespace hardware {
27namespace neuralnetworks {
28namespace V1_0 {
29struct Operation;
30}
31}  // namespace neuralnetworks
32}  // namespace hardware
33}  // namespace android
34
35namespace android {
36namespace nn {
37
38struct SVDFParams {
39    int rank_;
40    ActivationFn activation_;
41};
42
43struct RunTimeOperandInfo;
44struct Shape;
45
46class SVDF {
47public:
48    SVDF(const android::hardware::neuralnetworks::V1_0::Operation &operation,
49         std::vector<RunTimeOperandInfo>& operands);
50
51    static bool Prepare(
52        const hardware::neuralnetworks::V1_0::Operation &operation,
53        std::vector<RunTimeOperandInfo> &operands, Shape *stateShape,
54        Shape *outputShape);
55    bool Eval();
56
57    static constexpr int kInputTensor = 0;
58    static constexpr int kWeightsFeatureTensor = 1;
59    static constexpr int kWeightsTimeTensor = 2;
60    static constexpr int kBiasTensor = 3;  // Optional
61    static constexpr int kStateInTensor = 4;
62    static constexpr int kRankParam = 5;
63    static constexpr int kActivationParam = 6;
64
65    static constexpr int kStateOutTensor = 0;
66    static constexpr int kOutputTensor = 1;
67
68private:
69    SVDFParams params_;
70
71    const RunTimeOperandInfo *input_;
72    const RunTimeOperandInfo *weights_feature_;
73    const RunTimeOperandInfo *weights_time_;
74    const RunTimeOperandInfo *bias_;
75    const RunTimeOperandInfo *state_in_;
76
77    RunTimeOperandInfo *state_out_;
78    RunTimeOperandInfo *output_;
79};
80
81}  // namespace nn
82}  // namespace android
83
84#endif
85