1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <unistd.h> 17#include <cassert> 18#include <cmath> 19#include <cstdio> 20#include <cstdlib> 21#include <iostream> 22#include <limits> 23 24#include "tensorflow/contrib/lite/builtin_op_data.h" 25#include "tensorflow/contrib/lite/context.h" 26#include "tensorflow/contrib/lite/kernels/activation_functor.h" 27#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" 28#include "tensorflow/contrib/lite/kernels/kernel_util.h" 29#include "tensorflow/contrib/lite/kernels/op_macros.h" 30 31namespace tflite { 32namespace ops { 33namespace builtin { 34namespace unidirectional_sequence_lstm { 35 36// Input Tensors of size {max_time, n_batch, n_input} 37constexpr int kInputTensor = 0; 38 39// Input weight tensors of size: {n_cell, n_input} 40constexpr int kInputToInputWeightsTensor = 1; // Optional 41constexpr int kInputToForgetWeightsTensor = 2; 42constexpr int kInputToCellWeightsTensor = 3; 43constexpr int kInputToOutputWeightsTensor = 4; 44 45// Recurrent weight tensors of size {n_cell, n_output} 46constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 47constexpr int kRecurrentToForgetWeightsTensor = 6; 48constexpr int kRecurrentToCellWeightsTensor = 7; 49constexpr int kRecurrentToOutputWeightsTensor = 8; 50 51// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 52constexpr int kCellToInputWeightsTensor = 9; // Optional 53constexpr int kCellToForgetWeightsTensor = 10; // Optional 54constexpr int kCellToOutputWeightsTensor = 11; // Optional 55 56// Gates bias tensors of size {n_cell} 57constexpr int kInputGateBiasTensor = 12; // Optional 58constexpr int kForgetGateBiasTensor = 13; 59constexpr int kCellGateBiasTensor = 14; 60constexpr int kOutputGateBiasTensor = 15; 61 62// Projection weight tensor of size {n_output, n_cell} 63constexpr int kProjectionWeightsTensor = 16; // Optional 64// Projection bias tensor of size {n_output} 65constexpr int kProjectionBiasTensor = 17; // Optional 66 67// Output tensors. 68constexpr int kScratchBufferTensor = 0; 69constexpr int kOutputStateTensor = 1; 70constexpr int kCellStateTensor = 2; 71constexpr int kOutputTensor = 3; 72 73// Check that input tensor dimensions matches with each other. 74TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, 75 TfLiteNode* node, int n_input, 76 int n_output, int n_cell) { 77 auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); 78 79 // Making sure clipping parameters have valid values. 80 // == 0 means no clipping 81 // > 0 means clipping 82 TF_LITE_ENSURE(context, params->cell_clip >= 0); 83 TF_LITE_ENSURE(context, params->proj_clip >= 0); 84 85 TfLiteTensor* input_to_input_weights = 86 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 87 if (input_to_input_weights) { 88 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); 89 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); 90 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); 91 } 92 93 TfLiteTensor* input_to_forget_weights = 94 GetInput(context, node, kInputToForgetWeightsTensor); 95 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); 96 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); 97 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); 98 99 TfLiteTensor* input_to_cell_weights = 100 GetInput(context, node, kInputToCellWeightsTensor); 101 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); 102 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); 103 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); 104 105 TfLiteTensor* recurrent_to_input_weights = 106 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); 107 if (recurrent_to_input_weights) { 108 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); 109 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], 110 n_cell); 111 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], 112 n_output); 113 } 114 115 TfLiteTensor* recurrent_to_forget_weights = 116 GetInput(context, node, kRecurrentToForgetWeightsTensor); 117 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); 118 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], 119 n_cell); 120 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], 121 n_output); 122 123 TfLiteTensor* recurrent_to_cell_weights = 124 GetInput(context, node, kRecurrentToCellWeightsTensor); 125 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); 126 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); 127 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], 128 n_output); 129 130 // We make sure the input-gate's parameters are either both present (regular 131 // LSTM) or not at all (CIFG-LSTM). 132 const bool cifg_weights_all_or_none = 133 ((input_to_input_weights != nullptr) && 134 (recurrent_to_input_weights != nullptr)) || 135 ((input_to_input_weights == nullptr) && 136 (recurrent_to_input_weights == nullptr)); 137 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); 138 139 TfLiteTensor* cell_to_input_weights = 140 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); 141 if (cell_to_input_weights) { 142 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); 143 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); 144 } 145 146 TfLiteTensor* cell_to_forget_weights = 147 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); 148 if (cell_to_forget_weights) { 149 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); 150 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); 151 } 152 153 TfLiteTensor* cell_to_output_weights = 154 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); 155 if (cell_to_output_weights) { 156 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); 157 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); 158 } 159 160 // Making sure the peephole weights are there all or none. 161 const bool use_cifg = (input_to_input_weights == nullptr); 162 const bool peephole_weights_all_or_none = 163 ((cell_to_input_weights != nullptr || use_cifg) && 164 (cell_to_forget_weights != nullptr) && 165 (cell_to_output_weights != nullptr)) || 166 ((cell_to_input_weights == nullptr) && 167 (cell_to_forget_weights == nullptr) && 168 (cell_to_output_weights == nullptr)); 169 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); 170 171 // Make sure the input gate bias is present only when not a CIFG-LSTM. 172 TfLiteTensor* input_gate_bias = 173 GetOptionalInputTensor(context, node, kInputGateBiasTensor); 174 if (use_cifg) { 175 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); 176 } else { 177 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); 178 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); 179 } 180 181 TfLiteTensor* forget_gate_bias = 182 GetInput(context, node, kForgetGateBiasTensor); 183 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); 184 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); 185 186 TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); 187 TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); 188 TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); 189 190 TfLiteTensor* output_gate_bias = 191 GetInput(context, node, kOutputGateBiasTensor); 192 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); 193 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); 194 195 TfLiteTensor* projection_weights = 196 GetOptionalInputTensor(context, node, kProjectionWeightsTensor); 197 if (projection_weights) { 198 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); 199 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); 200 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); 201 } 202 203 TfLiteTensor* projection_bias = 204 GetOptionalInputTensor(context, node, kProjectionBiasTensor); 205 if (projection_bias) { 206 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); 207 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); 208 } 209 210 // Making sure the projection tensors are consistent: 211 // 1) If projection weight is not present, then projection bias should not be 212 // present. 213 // 2) If projection weight is present, then projection bias is optional. 214 // TODO(ghodrat): make sure this is correct. 215 const bool projecton_tensors_consistent = 216 ((projection_weights != nullptr) || (projection_bias == nullptr)); 217 TF_LITE_ENSURE(context, projecton_tensors_consistent == true); 218 219 return kTfLiteOk; 220} 221 222// Resize the output, state and scratch tensors based on the sizes of the input 223// tensors. Also check that the size of the input tensors match each other. 224TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 225 // Check we have all the inputs and outputs we need. 226 TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); 227 TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); 228 229 // Inferring batch size, number of outputs and sequence length and 230 // number of cells from the input tensors. 231 TfLiteTensor* input = GetInput(context, node, kInputTensor); 232 TF_LITE_ENSURE(context, input->dims->size > 1); 233 const int max_time = input->dims->data[0]; 234 const int n_batch = input->dims->data[1]; 235 const int n_input = input->dims->data[2]; 236 237 TfLiteTensor* input_to_output_weights = 238 GetInput(context, node, kInputToOutputWeightsTensor); 239 const int n_cell = input_to_output_weights->dims->data[0]; 240 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); 241 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); 242 243 TfLiteTensor* recurrent_to_output_weights = 244 GetInput(context, node, kRecurrentToOutputWeightsTensor); 245 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); 246 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], 247 n_cell); 248 const int n_output = recurrent_to_output_weights->dims->data[1]; 249 250 // Check that input tensor dimensions matches with each other. 251 CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); 252 253 // Get the pointer to output, state and scratch buffer tensors. 254 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 255 TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); 256 TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); 257 // TODO(ghodrat): Modify this as soon as we have a finalized method for 258 // scratch buffers. 259 TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); 260 261 // Resize the output and output_state tensors. 262 TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); 263 output_size->data[0] = max_time; 264 output_size->data[1] = n_batch; 265 output_size->data[2] = n_output; 266 TF_LITE_ENSURE_OK(context, 267 context->ResizeTensor(context, output, output_size)); 268 269 TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); 270 output_state_size->data[0] = n_batch; 271 output_state_size->data[1] = n_output; 272 TF_LITE_ENSURE_OK( 273 context, context->ResizeTensor(context, output_state, output_state_size)); 274 275 // Resize the scratch buffer tensor. 276 TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); 277 cell_size->data[0] = n_batch; 278 cell_size->data[1] = n_cell; 279 TF_LITE_ENSURE_OK(context, 280 context->ResizeTensor(context, cell_state, cell_size)); 281 282 // Mark state tensors as persistent tensors. 283 output_state->allocation_type = kTfLiteArenaRwPersistent; 284 cell_state->allocation_type = kTfLiteArenaRwPersistent; 285 286 TfLiteTensor* input_to_input_weights = 287 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 288 const bool use_cifg = (input_to_input_weights == nullptr); 289 if (use_cifg) { 290 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); 291 scratch_buffer_size->data[0] = n_batch; 292 // Reserving space for Cell, Forget, Output gates 293 scratch_buffer_size->data[1] = n_cell * 3; 294 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, 295 scratch_buffer_size)); 296 } else { 297 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); 298 scratch_buffer_size->data[0] = n_batch; 299 // Reserving space for Input, Cell, Forget, Output gates 300 scratch_buffer_size->data[1] = n_cell * 4; 301 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, 302 scratch_buffer_size)); 303 } 304 return kTfLiteOk; 305} 306 307// The LSTM Op engine. 308TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 309 auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); 310 TfLiteTensor* input = GetInput(context, node, kInputTensor); 311 312 TfLiteTensor* input_to_input_weights = 313 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 314 TfLiteTensor* input_to_forget_weights = 315 GetInput(context, node, kInputToForgetWeightsTensor); 316 TfLiteTensor* input_to_cell_weights = 317 GetInput(context, node, kInputToCellWeightsTensor); 318 TfLiteTensor* input_to_output_weights = 319 GetInput(context, node, kInputToOutputWeightsTensor); 320 321 TfLiteTensor* recurrent_to_input_weights = 322 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); 323 TfLiteTensor* recurrent_to_forget_weights = 324 GetInput(context, node, kRecurrentToForgetWeightsTensor); 325 TfLiteTensor* recurrent_to_cell_weights = 326 GetInput(context, node, kRecurrentToCellWeightsTensor); 327 TfLiteTensor* recurrent_to_output_weights = 328 GetInput(context, node, kRecurrentToOutputWeightsTensor); 329 330 TfLiteTensor* cell_to_input_weights = 331 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); 332 TfLiteTensor* cell_to_forget_weights = 333 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); 334 TfLiteTensor* cell_to_output_weights = 335 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); 336 337 TfLiteTensor* input_gate_bias = 338 GetOptionalInputTensor(context, node, kInputGateBiasTensor); 339 TfLiteTensor* forget_gate_bias = 340 GetInput(context, node, kForgetGateBiasTensor); 341 TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); 342 TfLiteTensor* output_gate_bias = 343 GetInput(context, node, kOutputGateBiasTensor); 344 345 TfLiteTensor* projection_weights = 346 GetOptionalInputTensor(context, node, kProjectionWeightsTensor); 347 TfLiteTensor* projection_bias = 348 GetOptionalInputTensor(context, node, kProjectionBiasTensor); 349 350 TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); 351 TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); 352 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 353 354 const int max_time = input->dims->data[0]; 355 const int n_batch = input->dims->data[1]; 356 const int n_input = input->dims->data[2]; 357 // n_cell and n_output will be the same size when there is no projection. 358 const int n_cell = input_to_output_weights->dims->data[0]; 359 const int n_output = recurrent_to_output_weights->dims->data[1]; 360 361 // Since we have already checked that weights are all there or none, we can 362 // check the existense of only one to the get the condition. 363 const bool use_cifg = (input_to_input_weights == nullptr); 364 const bool use_peephole = (cell_to_output_weights != nullptr); 365 366 // Index the scratch buffers pointers to the global scratch buffer. 367 TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); 368 float* input_gate_scratch = nullptr; 369 float* cell_scratch = nullptr; 370 float* forget_gate_scratch = nullptr; 371 float* output_gate_scratch = nullptr; 372 if (use_cifg) { 373 cell_scratch = scratch_buffer->data.f; 374 forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; 375 output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; 376 } else { 377 input_gate_scratch = scratch_buffer->data.f; 378 cell_scratch = scratch_buffer->data.f + n_cell * n_batch; 379 forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; 380 output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; 381 } 382 383 for (int t = 0; t < max_time; t++) { 384 const float* input_ptr_time = input->data.f + t * n_batch * n_input; 385 // Initialize scratch buffers with bias. 386 if (!use_cifg) { 387 tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, 388 n_batch, input_gate_scratch); 389 } 390 tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, 391 n_batch, forget_gate_scratch); 392 tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, 393 cell_scratch); 394 tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, 395 n_batch, output_gate_scratch); 396 397 // For each batch and cell: compute input_weight * input. 398 if (!use_cifg) { 399 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 400 input_to_input_weights->data.f, n_cell, n_input, input_ptr_time, 401 n_batch, input_gate_scratch, /*result_stride=*/1); 402 } 403 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 404 input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time, 405 n_batch, forget_gate_scratch, /*result_stride=*/1); 406 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 407 input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch, 408 cell_scratch, /*result_stride=*/1); 409 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 410 input_to_output_weights->data.f, n_cell, n_input, input_ptr_time, 411 n_batch, output_gate_scratch, /*result_stride=*/1); 412 413 // For each batch and cell: compute recurrent_weight * output_state. 414 if (!use_cifg) { 415 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 416 recurrent_to_input_weights->data.f, n_cell, n_output, 417 output_state->data.f, n_batch, input_gate_scratch, 418 /*result_stride=*/1); 419 } 420 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 421 recurrent_to_forget_weights->data.f, n_cell, n_output, 422 output_state->data.f, n_batch, forget_gate_scratch, 423 /*result_stride=*/1); 424 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 425 recurrent_to_cell_weights->data.f, n_cell, n_output, 426 output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1); 427 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 428 recurrent_to_output_weights->data.f, n_cell, n_output, 429 output_state->data.f, n_batch, output_gate_scratch, 430 /*result_stride=*/1); 431 432 // For each batch and cell: update input gate. 433 if (!use_cifg) { 434 if (use_peephole) { 435 tensor_utils::VectorBatchVectorCwiseProductAccumulate( 436 cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, 437 input_gate_scratch); 438 } 439 tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, 440 input_gate_scratch); 441 } 442 443 // For each batch and cell: update forget gate. 444 if (use_peephole) { 445 tensor_utils::VectorBatchVectorCwiseProductAccumulate( 446 cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, 447 forget_gate_scratch); 448 } 449 tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, 450 forget_gate_scratch); 451 452 // For each batch and cell: update the cell. 453 tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, 454 cell_state->data.f, n_batch * n_cell, 455 cell_state->data.f); 456 tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, 457 params->activation, cell_scratch); 458 if (use_cifg) { 459 tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, 460 forget_gate_scratch); 461 tensor_utils::VectorVectorCwiseProductAccumulate( 462 cell_scratch, forget_gate_scratch, n_batch * n_cell, 463 cell_state->data.f); 464 } else { 465 tensor_utils::VectorVectorCwiseProductAccumulate( 466 cell_scratch, input_gate_scratch, n_batch * n_cell, 467 cell_state->data.f); 468 } 469 if (params->cell_clip > 0.0) { 470 tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, 471 params->cell_clip, cell_state->data.f); 472 } 473 474 // For each batch and cell: update the output gate. 475 if (use_peephole) { 476 tensor_utils::VectorBatchVectorCwiseProductAccumulate( 477 cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, 478 output_gate_scratch); 479 } 480 tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, 481 output_gate_scratch); 482 tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, 483 params->activation, cell_scratch); 484 tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, 485 n_batch * n_cell, 486 output_gate_scratch); 487 488 // For each batch: update the projection and output_state. 489 const bool use_projection_weight = (projection_weights != nullptr); 490 const bool use_projection_bias = (projection_bias != nullptr); 491 float* output_ptr_time = output->data.f + t * n_batch * n_output; 492 if (use_projection_weight) { 493 if (use_projection_bias) { 494 tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, 495 n_batch, output_ptr_time); 496 } else { 497 tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); 498 } 499 tensor_utils::MatrixBatchVectorMultiplyAccumulate( 500 projection_weights->data.f, n_output, n_cell, output_gate_scratch, 501 n_batch, output_ptr_time, /*result_stride=*/1); 502 if (params->proj_clip > 0.0) { 503 tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, 504 params->proj_clip, output_ptr_time); 505 } 506 } else { 507 tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, 508 output_ptr_time); 509 } 510 tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, 511 output_state->data.f); 512 } 513 return kTfLiteOk; 514} 515 516} // namespace unidirectional_sequence_lstm 517 518TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { 519 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, 520 unidirectional_sequence_lstm::Prepare, 521 unidirectional_sequence_lstm::Eval}; 522 return &r; 523} 524 525} // namespace builtin 526} // namespace ops 527} // namespace tflite 528