1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <unistd.h>
16#include <cassert>
17#include <cmath>
18#include <cstdlib>
19#include <cstdio>
20#include <iostream>
21#include <limits>
22
23#include "tensorflow/contrib/lite/builtin_op_data.h"
24#include "tensorflow/contrib/lite/context.h"
25#include "tensorflow/contrib/lite/kernels/activation_functor.h"
26#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
27#include "tensorflow/contrib/lite/kernels/op_macros.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace bidirectional_sequence_rnn {
33
34constexpr int kInputTensor = 0;
35// Forward and backward cell tensors.
36constexpr int kFwWeightsTensor = 1;
37constexpr int kFwRecurrentWeightsTensor = 2;
38constexpr int kFwBiasTensor = 3;
39constexpr int kBwWeightsTensor = 4;
40constexpr int kBwRecurrentWeightsTensor = 5;
41constexpr int kBwBiasTensor = 6;
42// State and output tensors.
43constexpr int kFwHiddenStateTensor = 0;
44constexpr int kFwOutputTensor = 1;
45constexpr int kBwHiddenStateTensor = 2;
46constexpr int kBwOutputTensor = 3;
47
48TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
49  // Check we have all the inputs and outputs we need.
50  TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
51  TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
52
53  TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
54  TfLiteTensor* fw_input_weights =
55      &context->tensors[node->inputs->data[kFwWeightsTensor]];
56  TfLiteTensor* fw_recurrent_weights =
57      &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
58  TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
59  TfLiteTensor* bw_input_weights =
60      &context->tensors[node->inputs->data[kBwWeightsTensor]];
61  TfLiteTensor* bw_recurrent_weights =
62      &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
63  TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
64
65  // Check all the parameters of tensor match within themselves and match the
66  // input configuration.
67  const int batch_size = input->dims->data[0];
68  const int max_time = input->dims->data[1];
69  const int fw_num_units = fw_input_weights->dims->data[0];
70  const int bw_num_units = bw_input_weights->dims->data[0];
71  TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]);
72  TF_LITE_ASSERT_EQ(input->dims->data[2], bw_input_weights->dims->data[1]);
73  TF_LITE_ASSERT_EQ(fw_input_weights->dims->data[0], fw_bias->dims->data[0]);
74  TF_LITE_ASSERT_EQ(bw_input_weights->dims->data[0], bw_bias->dims->data[0]);
75  TF_LITE_ASSERT_EQ(fw_recurrent_weights->dims->data[0],
76                    fw_bias->dims->data[0]);
77  TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
78                    bw_bias->dims->data[0]);
79
80  TfLiteTensor* fw_output =
81      &context->tensors[node->outputs->data[kFwOutputTensor]];
82  TfLiteTensor* bw_output =
83      &context->tensors[node->outputs->data[kBwOutputTensor]];
84
85  // Resize hidden states.
86  TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
87  fw_hidden_state_size_array->data[0] = batch_size;
88  fw_hidden_state_size_array->data[1] = fw_num_units;
89  TfLiteTensor* fw_hidden_state =
90      &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
91  TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
92                                                   fw_hidden_state_size_array));
93
94  TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
95  bw_hidden_state_size_array->data[0] = batch_size;
96  bw_hidden_state_size_array->data[1] = fw_num_units;
97  TfLiteTensor* bw_hidden_state =
98      &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
99  TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
100                                                   bw_hidden_state_size_array));
101
102  // Mark hidden states as a persistent tensor.
103  fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
104  bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
105
106  // Resize outputs.
107  TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
108  fw_output_size_array->data[0] = batch_size;
109  fw_output_size_array->data[1] = max_time;
110  fw_output_size_array->data[2] = fw_num_units;
111  TF_LITE_ENSURE_OK(
112      context, context->ResizeTensor(context, fw_output, fw_output_size_array));
113  TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
114  bw_output_size_array->data[0] = batch_size;
115  bw_output_size_array->data[1] = max_time;
116  bw_output_size_array->data[2] = bw_num_units;
117  TF_LITE_ENSURE_OK(
118      context, context->ResizeTensor(context, bw_output, bw_output_size_array));
119
120  return kTfLiteOk;
121}
122
123TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
124  auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
125
126  TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
127  TfLiteTensor* fw_input_weights =
128      &context->tensors[node->inputs->data[kFwWeightsTensor]];
129  TfLiteTensor* fw_recurrent_weights =
130      &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
131  TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
132  TfLiteTensor* fw_hidden_state =
133      &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
134  TfLiteTensor* fw_output =
135      &context->tensors[node->outputs->data[kFwOutputTensor]];
136
137  TfLiteTensor* bw_input_weights =
138      &context->tensors[node->inputs->data[kBwWeightsTensor]];
139  TfLiteTensor* bw_recurrent_weights =
140      &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
141  TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
142  TfLiteTensor* bw_hidden_state =
143      &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
144  TfLiteTensor* bw_output =
145      &context->tensors[node->outputs->data[kBwOutputTensor]];
146
147  const int batch_size = input->dims->data[0];
148  const int max_time = input->dims->data[1];
149  const int input_size = input->dims->data[2];
150
151  const int fw_num_units = fw_input_weights->dims->data[0];
152  const float* fw_bias_ptr = fw_bias->data.f;
153  const float* fw_input_weights_ptr = fw_input_weights->data.f;
154  const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f;
155
156  const int bw_num_units = bw_input_weights->dims->data[0];
157  const float* bw_bias_ptr = bw_bias->data.f;
158  const float* bw_input_weights_ptr = bw_input_weights->data.f;
159  const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
160
161  for (int b = 0; b < batch_size; b++) {
162    // Forward cell.
163    float* fw_hidden_state_ptr_batch =
164        fw_hidden_state->data.f + b * fw_num_units;
165    for (int s = 0; s < max_time; s++) {
166      const float* input_ptr_batch =
167          input->data.f + b * input_size * max_time + s * input_size;
168      float* output_ptr_batch =
169          fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
170
171      kernel_utils::RnnBatchStep(
172          input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
173          fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
174          params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
175    }
176    // Backward cell.
177    float* bw_hidden_state_ptr_batch =
178        bw_hidden_state->data.f + b * bw_num_units;
179    for (int s = max_time - 1; s >= 0; s--) {
180      const float* input_ptr_batch =
181          input->data.f + b * input_size * max_time + s * input_size;
182      float* output_ptr_batch =
183          bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
184
185      kernel_utils::RnnBatchStep(
186          input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
187          bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
188          params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
189    }
190  }
191  return kTfLiteOk;
192}
193
194}  // namespace bidirectional_sequence_rnn
195
196TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
197  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
198                                 bidirectional_sequence_rnn::Prepare,
199                                 bidirectional_sequence_rnn::Eval};
200  return &r;
201}
202
203}  // namespace builtin
204}  // namespace ops
205}  // namespace tflite
206