1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <unistd.h>
16#include <cassert>
17#include <cmath>
18#include <cstdio>
19#include <cstdlib>
20#include <iostream>
21#include <limits>
22
23#include "tensorflow/contrib/lite/builtin_op_data.h"
24#include "tensorflow/contrib/lite/context.h"
25#include "tensorflow/contrib/lite/kernels/activation_functor.h"
26#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
27#include "tensorflow/contrib/lite/kernels/op_macros.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace rnn {
33
34constexpr int kInputTensor = 0;
35constexpr int kWeightsTensor = 1;
36constexpr int kRecurrentWeightsTensor = 2;
37constexpr int kBiasTensor = 3;
38constexpr int KHiddenStateTensor = 0;
39constexpr int kOutputTensor = 1;
40
41TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42  // Check we have all the inputs and outputs we need.
43  TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
44  TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
45
46  TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
47  TfLiteTensor* input_weights =
48      &context->tensors[node->inputs->data[kWeightsTensor]];
49  TfLiteTensor* recurrent_weights =
50      &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
51  TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
52
53  // Check all the parameters of tensor match within themselves and match the
54  // input configuration.
55  const int batch_size = input->dims->data[0];
56  const int num_units = input_weights->dims->data[0];
57  TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]);
58  TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
59  TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
60  TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
61
62  TfLiteTensor* hidden_state =
63      &context->tensors[node->outputs->data[KHiddenStateTensor]];
64  TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
65
66  // Resize state.
67  TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
68  hidden_state_size_array->data[0] = batch_size;
69  hidden_state_size_array->data[1] = num_units;
70  TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
71                                                   hidden_state_size_array));
72
73  // Mark hidden state as a persistent tensor.
74  hidden_state->allocation_type = kTfLiteArenaRwPersistent;
75
76  // Resize output.
77  TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
78  output_size_array->data[0] = batch_size;
79  output_size_array->data[1] = num_units;
80  TF_LITE_ENSURE_OK(context,
81                    context->ResizeTensor(context, output, output_size_array));
82
83  return kTfLiteOk;
84}
85
86TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
87  auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
88
89  TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
90  TfLiteTensor* input_weights =
91      &context->tensors[node->inputs->data[kWeightsTensor]];
92  TfLiteTensor* recurrent_weights =
93      &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
94  TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
95  TfLiteTensor* hidden_state =
96      &context->tensors[node->outputs->data[KHiddenStateTensor]];
97  TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
98
99  // Initialize the pointer bias.
100  const float* bias_ptr = bias->data.f;
101
102  const int batch_size = input->dims->data[0];
103  const int num_units = input_weights->dims->data[0];
104  const int input_size = input->dims->data[1];
105
106  // Initialize the pointer to hidden state.
107  float* hidden_state_ptr_batch = hidden_state->data.f;
108  // Initialize the pointer to input and output.
109  const float* input_ptr_batch = input->data.f;
110  float* output_ptr_batch = output->data.f;
111  // Initialize input_weights and recurrent_weights.
112  const float* input_weights_ptr = input_weights->data.f;
113  const float* recurrent_weights_ptr = recurrent_weights->data.f;
114
115  kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
116                             recurrent_weights_ptr, bias_ptr, input_size,
117                             num_units, batch_size, params->activation,
118                             hidden_state_ptr_batch, output_ptr_batch);
119  return kTfLiteOk;
120}
121
122}  // namespace rnn
123
124TfLiteRegistration* Register_RNN() {
125  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
126                                 rnn::Prepare, rnn::Eval};
127  return &r;
128}
129
130}  // namespace builtin
131}  // namespace ops
132}  // namespace tflite
133