1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef FRAMEWORKS_ML_NN_RNN_H
18#define FRAMEWORKS_ML_NN_RNN_H
19
20#include "ActivationFunctor.h"
21
22namespace android {
23namespace hardware {
24namespace neuralnetworks {
25namespace V1_0 {
26struct Operation;
27}
28}  // namespace neuralnetworks
29}  // namespace hardware
30}  // namespace android
31
32namespace android {
33namespace nn {
34
35struct RunTimeOperandInfo;
36struct Shape;
37
38class RNN {
39 public:
40  RNN(const android::hardware::neuralnetworks::V1_0::Operation &operation,
41      std::vector<RunTimeOperandInfo> &operands);
42
43  static bool Prepare(const android::hardware::neuralnetworks::V1_0::Operation &operation,
44                      std::vector<RunTimeOperandInfo> &operands,
45                      Shape *hiddenStateShape,
46                      Shape *outputShape);
47  bool Eval();
48
49  static constexpr int kInputTensor = 0;
50  static constexpr int kWeightsTensor = 1;  // Optional
51  static constexpr int kRecurrentWeightsTensor = 2;
52  static constexpr int kBiasTensor = 3;
53  static constexpr int kHiddenStateInTensor = 4;
54  static constexpr int kActivationParam = 5;
55
56  static constexpr int kHiddenStateOutTensor = 0;
57  static constexpr int kOutputTensor = 1;
58
59 private:
60  ActivationFn activation_;
61
62  const RunTimeOperandInfo *input_;
63  const RunTimeOperandInfo *weights_;
64  const RunTimeOperandInfo *recurrent_weights_;
65  const RunTimeOperandInfo *bias_;
66  const RunTimeOperandInfo *hidden_state_in_;
67
68  RunTimeOperandInfo *hidden_state_out_;
69  RunTimeOperandInfo *output_;
70};
71
72}  // namespace nn
73}  // namespace android
74
75#endif  // FRAMEWORKS_ML_NN_RNN_H
76