1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <unistd.h> 16#include <cassert> 17#include <cmath> 18#include <cstdlib> 19#include <cstdio> 20#include <iostream> 21#include <limits> 22 23#include "tensorflow/contrib/lite/builtin_op_data.h" 24#include "tensorflow/contrib/lite/context.h" 25#include "tensorflow/contrib/lite/kernels/activation_functor.h" 26#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" 27#include "tensorflow/contrib/lite/kernels/op_macros.h" 28 29namespace tflite { 30namespace ops { 31namespace builtin { 32namespace bidirectional_sequence_rnn { 33 34constexpr int kInputTensor = 0; 35// Forward and backward cell tensors. 36constexpr int kFwWeightsTensor = 1; 37constexpr int kFwRecurrentWeightsTensor = 2; 38constexpr int kFwBiasTensor = 3; 39constexpr int kBwWeightsTensor = 4; 40constexpr int kBwRecurrentWeightsTensor = 5; 41constexpr int kBwBiasTensor = 6; 42// State and output tensors. 43constexpr int kFwHiddenStateTensor = 0; 44constexpr int kFwOutputTensor = 1; 45constexpr int kBwHiddenStateTensor = 2; 46constexpr int kBwOutputTensor = 3; 47 48TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 49 // Check we have all the inputs and outputs we need. 50 TF_LITE_ENSURE_EQ(context, node->inputs->size, 7); 51 TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); 52 53 TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; 54 TfLiteTensor* fw_input_weights = 55 &context->tensors[node->inputs->data[kFwWeightsTensor]]; 56 TfLiteTensor* fw_recurrent_weights = 57 &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; 58 TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; 59 TfLiteTensor* bw_input_weights = 60 &context->tensors[node->inputs->data[kBwWeightsTensor]]; 61 TfLiteTensor* bw_recurrent_weights = 62 &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; 63 TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; 64 65 // Check all the parameters of tensor match within themselves and match the 66 // input configuration. 67 const int batch_size = input->dims->data[0]; 68 const int max_time = input->dims->data[1]; 69 const int fw_num_units = fw_input_weights->dims->data[0]; 70 const int bw_num_units = bw_input_weights->dims->data[0]; 71 TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]); 72 TF_LITE_ASSERT_EQ(input->dims->data[2], bw_input_weights->dims->data[1]); 73 TF_LITE_ASSERT_EQ(fw_input_weights->dims->data[0], fw_bias->dims->data[0]); 74 TF_LITE_ASSERT_EQ(bw_input_weights->dims->data[0], bw_bias->dims->data[0]); 75 TF_LITE_ASSERT_EQ(fw_recurrent_weights->dims->data[0], 76 fw_bias->dims->data[0]); 77 TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1], 78 bw_bias->dims->data[0]); 79 80 TfLiteTensor* fw_output = 81 &context->tensors[node->outputs->data[kFwOutputTensor]]; 82 TfLiteTensor* bw_output = 83 &context->tensors[node->outputs->data[kBwOutputTensor]]; 84 85 // Resize hidden states. 86 TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2); 87 fw_hidden_state_size_array->data[0] = batch_size; 88 fw_hidden_state_size_array->data[1] = fw_num_units; 89 TfLiteTensor* fw_hidden_state = 90 &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; 91 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state, 92 fw_hidden_state_size_array)); 93 94 TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2); 95 bw_hidden_state_size_array->data[0] = batch_size; 96 bw_hidden_state_size_array->data[1] = fw_num_units; 97 TfLiteTensor* bw_hidden_state = 98 &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; 99 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state, 100 bw_hidden_state_size_array)); 101 102 // Mark hidden states as a persistent tensor. 103 fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; 104 bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; 105 106 // Resize outputs. 107 TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); 108 fw_output_size_array->data[0] = batch_size; 109 fw_output_size_array->data[1] = max_time; 110 fw_output_size_array->data[2] = fw_num_units; 111 TF_LITE_ENSURE_OK( 112 context, context->ResizeTensor(context, fw_output, fw_output_size_array)); 113 TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); 114 bw_output_size_array->data[0] = batch_size; 115 bw_output_size_array->data[1] = max_time; 116 bw_output_size_array->data[2] = bw_num_units; 117 TF_LITE_ENSURE_OK( 118 context, context->ResizeTensor(context, bw_output, bw_output_size_array)); 119 120 return kTfLiteOk; 121} 122 123TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 124 auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); 125 126 TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; 127 TfLiteTensor* fw_input_weights = 128 &context->tensors[node->inputs->data[kFwWeightsTensor]]; 129 TfLiteTensor* fw_recurrent_weights = 130 &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; 131 TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; 132 TfLiteTensor* fw_hidden_state = 133 &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; 134 TfLiteTensor* fw_output = 135 &context->tensors[node->outputs->data[kFwOutputTensor]]; 136 137 TfLiteTensor* bw_input_weights = 138 &context->tensors[node->inputs->data[kBwWeightsTensor]]; 139 TfLiteTensor* bw_recurrent_weights = 140 &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; 141 TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; 142 TfLiteTensor* bw_hidden_state = 143 &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; 144 TfLiteTensor* bw_output = 145 &context->tensors[node->outputs->data[kBwOutputTensor]]; 146 147 const int batch_size = input->dims->data[0]; 148 const int max_time = input->dims->data[1]; 149 const int input_size = input->dims->data[2]; 150 151 const int fw_num_units = fw_input_weights->dims->data[0]; 152 const float* fw_bias_ptr = fw_bias->data.f; 153 const float* fw_input_weights_ptr = fw_input_weights->data.f; 154 const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f; 155 156 const int bw_num_units = bw_input_weights->dims->data[0]; 157 const float* bw_bias_ptr = bw_bias->data.f; 158 const float* bw_input_weights_ptr = bw_input_weights->data.f; 159 const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f; 160 161 for (int b = 0; b < batch_size; b++) { 162 // Forward cell. 163 float* fw_hidden_state_ptr_batch = 164 fw_hidden_state->data.f + b * fw_num_units; 165 for (int s = 0; s < max_time; s++) { 166 const float* input_ptr_batch = 167 input->data.f + b * input_size * max_time + s * input_size; 168 float* output_ptr_batch = 169 fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; 170 171 kernel_utils::RnnBatchStep( 172 input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, 173 fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1, 174 params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); 175 } 176 // Backward cell. 177 float* bw_hidden_state_ptr_batch = 178 bw_hidden_state->data.f + b * bw_num_units; 179 for (int s = max_time - 1; s >= 0; s--) { 180 const float* input_ptr_batch = 181 input->data.f + b * input_size * max_time + s * input_size; 182 float* output_ptr_batch = 183 bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; 184 185 kernel_utils::RnnBatchStep( 186 input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, 187 bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1, 188 params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); 189 } 190 } 191 return kTfLiteOk; 192} 193 194} // namespace bidirectional_sequence_rnn 195 196TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { 197 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, 198 bidirectional_sequence_rnn::Prepare, 199 bidirectional_sequence_rnn::Eval}; 200 return &r; 201} 202 203} // namespace builtin 204} // namespace ops 205} // namespace tflite 206