1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "RNN.h"
18
19#include "CpuExecutor.h"
20#include "HalInterfaces.h"
21
22namespace android {
23namespace nn {
24
25RNN::RNN(const Operation& operation,
26         std::vector<RunTimeOperandInfo>& operands) {
27  input_ = GetInput(operation, operands, kInputTensor);
28  weights_ = GetInput(operation, operands, kWeightsTensor);
29  recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
30  hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
31  bias_ = GetInput(operation, operands, kBiasTensor);
32
33  activation_ = static_cast<ActivationFn>(
34      getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
35
36  hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
37  output_ = GetOutput(operation, operands, kOutputTensor);
38}
39
40bool RNN::Prepare(const Operation &operation,
41                  std::vector<RunTimeOperandInfo> &operands,
42                  Shape *hiddenStateShape,
43                  Shape *outputShape) {
44  // Check we have all the inputs and outputs we need.
45  const int num_inputs = NumInputsWithValues(operation, operands);
46  NN_CHECK(num_inputs == 5 || num_inputs == 6);
47  NN_CHECK_EQ(NumOutputs(operation), 2);
48
49  const RunTimeOperandInfo *input =
50      GetInput(operation, operands, kInputTensor);
51  const RunTimeOperandInfo *input_weights =
52      GetInput(operation, operands, kWeightsTensor);
53  const RunTimeOperandInfo *recurrent_weights =
54      GetInput(operation, operands, kRecurrentWeightsTensor);
55  const RunTimeOperandInfo *bias =
56      GetInput(operation, operands, kBiasTensor);
57
58  // Check all the parameters of tensor match within themselves and match the
59  // input configuration.
60  const uint32_t batch_size = SizeOfDimension(input, 0);
61  const uint32_t num_units = SizeOfDimension(input_weights, 0);
62  NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
63  NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
64  NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
65  NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
66
67  const Shape &inputShape = input->shape();
68
69  // Resize state.
70  hiddenStateShape->type = inputShape.type;
71  hiddenStateShape->dimensions = { batch_size, num_units };
72
73  // Resize output.
74  outputShape->type = inputShape.type;
75  outputShape->dimensions = { batch_size, num_units };
76
77  return true;
78}
79
80bool RNN::Eval() {
81  const float* bias_ptr = reinterpret_cast<float*>(bias_->buffer);
82
83  const uint32_t batch_size = input_->shape().dimensions[0];
84  const uint32_t num_units = weights_->shape().dimensions[0];
85  const uint32_t input_size = input_->shape().dimensions[1];
86  const uint32_t input_weights_stride = weights_->shape().dimensions[1];
87  const uint32_t recurrent_weights_stride =
88      recurrent_weights_->shape().dimensions[1];
89
90  // For each batch
91  for (uint32_t b = 0; b < batch_size; b++) {
92    // Initialize the pointer to input, output and bias.
93    const float* input_ptr_batch =
94        reinterpret_cast<float*>(input_->buffer) + b * input_size;
95    const float* hidden_state_in_ptr_batch =
96        reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units;
97    float* output_ptr_batch =
98        reinterpret_cast<float*>(output_->buffer) + b * num_units;
99    float* hidden_state_out_ptr_batch =
100        reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units;
101
102    // Initialize input_weights and recurrent_weights.
103    const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer);
104    const float* recurrent_weights_ptr =
105        reinterpret_cast<float*>(recurrent_weights_->buffer);
106
107    // Output = bias
108    for (uint32_t o = 0; o < num_units; o++) {
109      output_ptr_batch[o] = bias_ptr[o];
110    }
111
112    // Output += input * input_weights
113    for (uint32_t o = 0; o < num_units; o++) {
114      for (uint32_t i = 0; i < input_size; i++) {
115        output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
116      }
117      input_weights_ptr += input_weights_stride;
118    }
119
120    // Output += recurrent_weights * hidden_state
121    for (uint32_t o = 0; o < num_units; o++) {
122      for (uint32_t h = 0; h < num_units; h++) {
123        output_ptr_batch[o] +=
124            hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
125      }
126      recurrent_weights_ptr += recurrent_weights_stride;
127    }
128
129    // Output = activation(Output) and update hidden_state
130    for (uint32_t o = 0; o < num_units; o++) {
131      output_ptr_batch[o] =
132          (ActivationFunctor(activation_))(output_ptr_batch[o]);
133      hidden_state_out_ptr_batch[o] = output_ptr_batch[o];
134    }
135  }
136
137  return true;
138}
139
140}  // namespace nn
141}  // namespace android
142