1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "SVDF.h"
18
19#include "CpuExecutor.h"
20#include "HalInterfaces.h"
21
22namespace android {
23namespace nn {
24
25namespace {
26
27template <typename T>
28inline T *GetBuffer(RunTimeOperandInfo* operand) {
29  return reinterpret_cast<T*>(operand->buffer);
30}
31
32template <typename T>
33inline const T *GetBuffer(const RunTimeOperandInfo* operand) {
34  return reinterpret_cast<const T*>(operand->buffer);
35}
36
37}
38
39SVDF::SVDF(const Operation& operation,
40           std::vector<RunTimeOperandInfo>& operands) {
41    input_ = GetInput(operation, operands, kInputTensor);
42    weights_feature_ = GetInput(operation, operands, kWeightsFeatureTensor);
43    weights_time_ = GetInput(operation, operands, kWeightsTimeTensor);
44    bias_ = GetInput(operation, operands, kBiasTensor);
45    state_in_ = GetInput(operation, operands, kStateInTensor);
46
47    params_.rank_ = getScalarData<int>(*GetInput(operation, operands, kRankParam));
48    params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int>(
49        *GetInput(operation, operands, kActivationParam)));
50
51    state_out_ = GetOutput(operation, operands, kStateOutTensor);
52    output_ = GetOutput(operation, operands, kOutputTensor);
53}
54
55bool SVDF::Prepare(const Operation &operation,
56                   std::vector<RunTimeOperandInfo> &operands,
57                   Shape *stateShape,
58                   Shape *outputShape) {
59  // Check we have all the inputs and outputs we need.
60  const int num_inputs = NumInputsWithValues(operation, operands);
61
62  NN_CHECK(num_inputs == 6 || num_inputs == 7);
63  NN_CHECK_EQ(NumOutputs(operation), 2);
64
65  const RunTimeOperandInfo *input =
66      GetInput(operation, operands, SVDF::kInputTensor);
67  const RunTimeOperandInfo *weights_feature =
68      GetInput(operation, operands, SVDF::kWeightsFeatureTensor);
69  const RunTimeOperandInfo *weights_time =
70      GetInput(operation, operands, SVDF::kWeightsTimeTensor);
71
72  // Check all the parameters of tensor match within themselves and match the
73  // input configuration.
74  const int rank = getScalarData<int>(*GetInput(operation, operands, kRankParam));
75  const uint32_t batch_size = SizeOfDimension(input, 0);
76  const uint32_t num_filters = SizeOfDimension(weights_feature, 0);
77  NN_CHECK_EQ(num_filters % rank, 0);
78  const uint32_t num_units = num_filters / rank;
79  const uint32_t memory_size = SizeOfDimension(weights_time, 1);
80  NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(weights_feature, 1));
81  NN_CHECK_EQ(SizeOfDimension(weights_time, 0), num_filters);
82
83  const RunTimeOperandInfo *bias =
84      GetInput(operation, operands, kBiasTensor);
85  if (!IsNullInput(bias)) {
86    NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units);
87  }
88
89  // Resize state.
90  const Shape &inputShape = input->shape();
91  stateShape->type = inputShape.type;
92  stateShape->dimensions = { batch_size, memory_size * num_filters };
93  stateShape->offset = inputShape.offset;
94  stateShape->scale = inputShape.scale;
95
96  // Resize output.
97  outputShape->type = inputShape.type;
98  outputShape->dimensions = { batch_size, num_units };
99  outputShape->offset = inputShape.offset;
100  outputShape->scale = inputShape.scale;
101
102  return true;
103}
104
105bool SVDF::Eval() {
106    const int rank = params_.rank_;
107    const int batch_size = SizeOfDimension(input_, 0);
108    const int input_size = SizeOfDimension(input_, 1);
109    const int num_filters = SizeOfDimension(weights_feature_, 0);
110    const int num_units = num_filters / rank;
111    const int memory_size = SizeOfDimension(weights_time_, 1);
112
113    memcpy(GetBuffer<float>(state_out_), GetBuffer<float>(state_in_),
114           sizeof(float) * batch_size * memory_size * num_filters);
115    // Compute conv1d(inputs, weights_feature).
116    for (int b = 0; b < batch_size; b++) {
117        float* state_ptr_batch = GetBuffer<float>(state_out_) + b * memory_size * num_filters;
118        for (int c = 0; c < num_filters; c++) {
119            float* state_ptr = state_ptr_batch + c * memory_size;
120            state_ptr[memory_size - 1] = 0.0;
121        }
122    }
123    // The state left most column is used to save current cycle activation. This
124    // is achieved by starting at state->data.f[memory_size - 1] and having the
125    // stride equal to memory_size.
126    tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
127        GetBuffer<float>(weights_feature_), num_filters, input_size,
128        GetBuffer<float>(input_),  batch_size,
129        &GetBuffer<float>(state_out_)[memory_size - 1], memory_size);
130
131    // Compute matmul(state, weights_time).
132    // The right most column is used to save temporary output (with the size of
133    // num_filters). This is achieved by starting at state->data.f and having the
134    // stride equal to memory_size.
135    float scratch[batch_size * num_filters];
136    for (int b = 0; b < batch_size; b++) {
137        float* state_out_ptr_batch =
138            GetBuffer<float>(state_out_) + b * memory_size * num_filters;
139        float* scratch_ptr_batch = scratch + b * num_filters;
140        tflite::tensor_utils::BatchVectorBatchVectorDotProduct(
141            GetBuffer<float>(weights_time_), state_out_ptr_batch, memory_size, num_filters,
142            scratch_ptr_batch, /*result_stride=*/1);
143    }
144
145    // Initialize output with bias if provided.
146    if (!IsNullInput(bias_)) {
147        tflite::tensor_utils::VectorBatchVectorAssign(
148            GetBuffer<float>(bias_), num_units, batch_size,
149            GetBuffer<float>(output_));
150    } else {
151        tflite::tensor_utils::ZeroVector(
152            GetBuffer<float>(output_), batch_size * num_units);
153    }
154
155    // Reduction sum
156    for (int b = 0; b < batch_size; b++) {
157        float* output_ptr_batch = GetBuffer<float>(output_) + b * num_units;
158        float* scratch_ptr_batch = scratch + b * num_filters;
159        tflite::tensor_utils::ReductionSumVector(
160            scratch_ptr_batch, output_ptr_batch, num_units, rank);
161    }
162
163    // Apply activation.
164    for (int b = 0; b < batch_size; b++) {
165        float* output_ptr_batch = GetBuffer<float>(output_) + b * num_units;
166        tflite::tensor_utils::ApplyActivationToVector(
167            output_ptr_batch, num_units,
168            params_.activation_, output_ptr_batch);
169    }
170
171    // Right shift the state.
172    for (int b = 0; b < batch_size; b++) {
173        float* state_out_ptr_batch =
174            GetBuffer<float>(state_out_) + b * memory_size * num_filters;
175        for (int f = 0; f < num_filters; f++) {
176            tflite::tensor_utils::VectorShiftLeft(state_out_ptr_batch, memory_size,
177                                          /*shift_value=*/0.0);
178            state_out_ptr_batch += memory_size;
179        }
180    }
181    return true;
182}
183
184}  // namespace nn
185}  // namespace android
186