1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#ifndef FRAMEWORKS_ML_NN_LSTMCELL_H 18#define FRAMEWORKS_ML_NN_LSTMCELL_H 19 20#include "ActivationFunctor.h" 21 22#include <algorithm> 23#include <cmath> 24 25namespace android { 26namespace hardware { 27namespace neuralnetworks { 28namespace V1_0 { 29struct Operation; 30} 31} // namespace neuralnetworks 32} // namespace hardware 33} // namespace android 34 35namespace android { 36namespace nn { 37 38struct LSTMParams { 39 ActivationFn activation_; 40 float cell_clip_; 41 float proj_clip_; 42}; 43 44struct RunTimeOperandInfo; 45struct Shape; 46 47class LSTMCell { 48 public: 49 LSTMCell(const android::hardware::neuralnetworks::V1_0::Operation &operation, 50 std::vector<RunTimeOperandInfo> &operands); 51 52 static bool Prepare(const android::hardware::neuralnetworks::V1_0::Operation &operation, 53 std::vector<RunTimeOperandInfo> &operands, 54 Shape *scratchShape, 55 Shape *outputStateShape, 56 Shape *cellStateShape, 57 Shape *outputShape); 58 bool Eval(); 59 60 // Input Tensors of size {n_batch, n_input} 61 static constexpr int kInputTensor = 0; 62 63 // Input weight tensors of size: {n_cell, n_input} 64 static constexpr int kInputToInputWeightsTensor = 1; // Optional 65 static constexpr int kInputToForgetWeightsTensor = 2; 66 static constexpr int kInputToCellWeightsTensor = 3; 67 static constexpr int kInputToOutputWeightsTensor = 4; 68 69 // Recurrent weight tensors of size {n_cell, n_output} 70 static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 71 static constexpr int kRecurrentToForgetWeightsTensor = 6; 72 static constexpr int kRecurrentToCellWeightsTensor = 7; 73 static constexpr int kRecurrentToOutputWeightsTensor = 8; 74 75 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 76 static constexpr int kCellToInputWeightsTensor = 9; // Optional 77 static constexpr int kCellToForgetWeightsTensor = 10; // Optional 78 static constexpr int kCellToOutputWeightsTensor = 11; // Optional 79 80 // Gates bias tensors of size {n_cell} 81 static constexpr int kInputGateBiasTensor = 12; // Optional 82 static constexpr int kForgetGateBiasTensor = 13; 83 static constexpr int kCellGateBiasTensor = 14; 84 static constexpr int kOutputGateBiasTensor = 15; 85 86 // Projection weight tensor of size {n_output, n_cell} 87 static constexpr int kProjectionWeightsTensor = 16; // Optional 88 // Projection bias tensor of size {n_output} 89 static constexpr int kProjectionBiasTensor = 17; // Optional 90 91 static constexpr int kOutputStateInTensor = 18; 92 static constexpr int kCellStateInTensor = 19; 93 94 static constexpr int kActivationParam = 20; 95 static constexpr int kCellClipParam = 21; 96 static constexpr int kProjClipParam = 22; 97 98 // Output tensors. 99 static constexpr int kScratchBufferTensor = 0; 100 static constexpr int kOutputStateOutTensor = 1; 101 static constexpr int kCellStateOutTensor = 2; 102 static constexpr int kOutputTensor = 3; 103 104 private: 105 static bool CheckInputTensorDimensions( 106 const android::hardware::neuralnetworks::V1_0::Operation &operation, 107 std::vector<RunTimeOperandInfo> &operands, uint32_t n_input, 108 uint32_t n_output, uint32_t n_cell); 109 LSTMParams params_; 110 111 const RunTimeOperandInfo *input_; 112 113 const RunTimeOperandInfo *input_to_input_weights_; 114 const RunTimeOperandInfo *input_to_forget_weights_; 115 const RunTimeOperandInfo *input_to_cell_weights_; 116 const RunTimeOperandInfo *input_to_output_weights_; 117 118 const RunTimeOperandInfo *recurrent_to_input_weights_; 119 const RunTimeOperandInfo *recurrent_to_forget_weights_; 120 const RunTimeOperandInfo *recurrent_to_cell_weights_; 121 const RunTimeOperandInfo *recurrent_to_output_weights_; 122 123 const RunTimeOperandInfo *cell_to_input_weights_; 124 const RunTimeOperandInfo *cell_to_forget_weights_; 125 const RunTimeOperandInfo *cell_to_output_weights_; 126 127 const RunTimeOperandInfo *input_gate_bias_; 128 const RunTimeOperandInfo *forget_gate_bias_; 129 const RunTimeOperandInfo *cell_bias_; 130 const RunTimeOperandInfo *output_gate_bias_; 131 132 const RunTimeOperandInfo *projection_weights_; 133 const RunTimeOperandInfo *projection_bias_; 134 135 const RunTimeOperandInfo *output_state_in_; 136 const RunTimeOperandInfo *cell_state_in_; 137 138 RunTimeOperandInfo *output_state_out_; 139 RunTimeOperandInfo *cell_state_out_; 140 RunTimeOperandInfo *output_; 141 142 RunTimeOperandInfo *scratch_buffer_; 143}; 144 145} // namespace nn 146} // namespace android 147 148#endif // FRAMEWORKS_ML_NN_LSTMCELL_H 149