1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <unistd.h> 16#include <cassert> 17#include <cmath> 18#include <cstdio> 19#include <cstdlib> 20#include <iostream> 21#include <limits> 22 23#include "tensorflow/contrib/lite/builtin_op_data.h" 24#include "tensorflow/contrib/lite/context.h" 25#include "tensorflow/contrib/lite/kernels/activation_functor.h" 26#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" 27#include "tensorflow/contrib/lite/kernels/op_macros.h" 28 29namespace tflite { 30namespace ops { 31namespace builtin { 32namespace rnn { 33 34constexpr int kInputTensor = 0; 35constexpr int kWeightsTensor = 1; 36constexpr int kRecurrentWeightsTensor = 2; 37constexpr int kBiasTensor = 3; 38constexpr int KHiddenStateTensor = 0; 39constexpr int kOutputTensor = 1; 40 41TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 42 // Check we have all the inputs and outputs we need. 43 TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); 44 TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); 45 46 TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; 47 TfLiteTensor* input_weights = 48 &context->tensors[node->inputs->data[kWeightsTensor]]; 49 TfLiteTensor* recurrent_weights = 50 &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; 51 TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; 52 53 // Check all the parameters of tensor match within themselves and match the 54 // input configuration. 55 const int batch_size = input->dims->data[0]; 56 const int num_units = input_weights->dims->data[0]; 57 TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]); 58 TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); 59 TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); 60 TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); 61 62 TfLiteTensor* hidden_state = 63 &context->tensors[node->outputs->data[KHiddenStateTensor]]; 64 TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; 65 66 // Resize state. 67 TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); 68 hidden_state_size_array->data[0] = batch_size; 69 hidden_state_size_array->data[1] = num_units; 70 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, 71 hidden_state_size_array)); 72 73 // Mark hidden state as a persistent tensor. 74 hidden_state->allocation_type = kTfLiteArenaRwPersistent; 75 76 // Resize output. 77 TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); 78 output_size_array->data[0] = batch_size; 79 output_size_array->data[1] = num_units; 80 TF_LITE_ENSURE_OK(context, 81 context->ResizeTensor(context, output, output_size_array)); 82 83 return kTfLiteOk; 84} 85 86TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 87 auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); 88 89 TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; 90 TfLiteTensor* input_weights = 91 &context->tensors[node->inputs->data[kWeightsTensor]]; 92 TfLiteTensor* recurrent_weights = 93 &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; 94 TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; 95 TfLiteTensor* hidden_state = 96 &context->tensors[node->outputs->data[KHiddenStateTensor]]; 97 TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; 98 99 // Initialize the pointer bias. 100 const float* bias_ptr = bias->data.f; 101 102 const int batch_size = input->dims->data[0]; 103 const int num_units = input_weights->dims->data[0]; 104 const int input_size = input->dims->data[1]; 105 106 // Initialize the pointer to hidden state. 107 float* hidden_state_ptr_batch = hidden_state->data.f; 108 // Initialize the pointer to input and output. 109 const float* input_ptr_batch = input->data.f; 110 float* output_ptr_batch = output->data.f; 111 // Initialize input_weights and recurrent_weights. 112 const float* input_weights_ptr = input_weights->data.f; 113 const float* recurrent_weights_ptr = recurrent_weights->data.f; 114 115 kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, 116 recurrent_weights_ptr, bias_ptr, input_size, 117 num_units, batch_size, params->activation, 118 hidden_state_ptr_batch, output_ptr_batch); 119 return kTfLiteOk; 120} 121 122} // namespace rnn 123 124TfLiteRegistration* Register_RNN() { 125 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, 126 rnn::Prepare, rnn::Eval}; 127 return &r; 128} 129 130} // namespace builtin 131} // namespace ops 132} // namespace tflite 133