1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "RNN.h" 18 19#include "CpuExecutor.h" 20#include "HalInterfaces.h" 21 22namespace android { 23namespace nn { 24 25RNN::RNN(const Operation& operation, 26 std::vector<RunTimeOperandInfo>& operands) { 27 input_ = GetInput(operation, operands, kInputTensor); 28 weights_ = GetInput(operation, operands, kWeightsTensor); 29 recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor); 30 hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor); 31 bias_ = GetInput(operation, operands, kBiasTensor); 32 33 activation_ = static_cast<ActivationFn>( 34 getScalarData<int32_t>(operands[operation.inputs[kActivationParam]])); 35 36 hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor); 37 output_ = GetOutput(operation, operands, kOutputTensor); 38} 39 40bool RNN::Prepare(const Operation &operation, 41 std::vector<RunTimeOperandInfo> &operands, 42 Shape *hiddenStateShape, 43 Shape *outputShape) { 44 // Check we have all the inputs and outputs we need. 45 const int num_inputs = NumInputsWithValues(operation, operands); 46 NN_CHECK(num_inputs == 5 || num_inputs == 6); 47 NN_CHECK_EQ(NumOutputs(operation), 2); 48 49 const RunTimeOperandInfo *input = 50 GetInput(operation, operands, kInputTensor); 51 const RunTimeOperandInfo *input_weights = 52 GetInput(operation, operands, kWeightsTensor); 53 const RunTimeOperandInfo *recurrent_weights = 54 GetInput(operation, operands, kRecurrentWeightsTensor); 55 const RunTimeOperandInfo *bias = 56 GetInput(operation, operands, kBiasTensor); 57 58 // Check all the parameters of tensor match within themselves and match the 59 // input configuration. 60 const uint32_t batch_size = SizeOfDimension(input, 0); 61 const uint32_t num_units = SizeOfDimension(input_weights, 0); 62 NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1)); 63 NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0)); 64 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0)); 65 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0)); 66 67 const Shape &inputShape = input->shape(); 68 69 // Resize state. 70 hiddenStateShape->type = inputShape.type; 71 hiddenStateShape->dimensions = { batch_size, num_units }; 72 73 // Resize output. 74 outputShape->type = inputShape.type; 75 outputShape->dimensions = { batch_size, num_units }; 76 77 return true; 78} 79 80bool RNN::Eval() { 81 const float* bias_ptr = reinterpret_cast<float*>(bias_->buffer); 82 83 const uint32_t batch_size = input_->shape().dimensions[0]; 84 const uint32_t num_units = weights_->shape().dimensions[0]; 85 const uint32_t input_size = input_->shape().dimensions[1]; 86 const uint32_t input_weights_stride = weights_->shape().dimensions[1]; 87 const uint32_t recurrent_weights_stride = 88 recurrent_weights_->shape().dimensions[1]; 89 90 // For each batch 91 for (uint32_t b = 0; b < batch_size; b++) { 92 // Initialize the pointer to input, output and bias. 93 const float* input_ptr_batch = 94 reinterpret_cast<float*>(input_->buffer) + b * input_size; 95 const float* hidden_state_in_ptr_batch = 96 reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units; 97 float* output_ptr_batch = 98 reinterpret_cast<float*>(output_->buffer) + b * num_units; 99 float* hidden_state_out_ptr_batch = 100 reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units; 101 102 // Initialize input_weights and recurrent_weights. 103 const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer); 104 const float* recurrent_weights_ptr = 105 reinterpret_cast<float*>(recurrent_weights_->buffer); 106 107 // Output = bias 108 for (uint32_t o = 0; o < num_units; o++) { 109 output_ptr_batch[o] = bias_ptr[o]; 110 } 111 112 // Output += input * input_weights 113 for (uint32_t o = 0; o < num_units; o++) { 114 for (uint32_t i = 0; i < input_size; i++) { 115 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; 116 } 117 input_weights_ptr += input_weights_stride; 118 } 119 120 // Output += recurrent_weights * hidden_state 121 for (uint32_t o = 0; o < num_units; o++) { 122 for (uint32_t h = 0; h < num_units; h++) { 123 output_ptr_batch[o] += 124 hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h]; 125 } 126 recurrent_weights_ptr += recurrent_weights_stride; 127 } 128 129 // Output = activation(Output) and update hidden_state 130 for (uint32_t o = 0; o < num_units; o++) { 131 output_ptr_batch[o] = 132 (ActivationFunctor(activation_))(output_ptr_batch[o]); 133 hidden_state_out_ptr_batch[o] = output_ptr_batch[o]; 134 } 135 } 136 137 return true; 138} 139 140} // namespace nn 141} // namespace android 142