1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <unistd.h>
17#include <cassert>
18#include <cmath>
19#include <cstdio>
20#include <cstdlib>
21#include <iostream>
22#include <limits>
23
24#include "tensorflow/contrib/lite/builtin_op_data.h"
25#include "tensorflow/contrib/lite/context.h"
26#include "tensorflow/contrib/lite/kernels/activation_functor.h"
27#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
28#include "tensorflow/contrib/lite/kernels/kernel_util.h"
29#include "tensorflow/contrib/lite/kernels/op_macros.h"
30
31namespace tflite {
32namespace ops {
33namespace builtin {
34namespace unidirectional_sequence_lstm {
35
36// Input Tensors of size {max_time, n_batch, n_input}
37constexpr int kInputTensor = 0;
38
39// Input weight tensors of size: {n_cell, n_input}
40constexpr int kInputToInputWeightsTensor = 1;  // Optional
41constexpr int kInputToForgetWeightsTensor = 2;
42constexpr int kInputToCellWeightsTensor = 3;
43constexpr int kInputToOutputWeightsTensor = 4;
44
45// Recurrent weight tensors of size {n_cell, n_output}
46constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
47constexpr int kRecurrentToForgetWeightsTensor = 6;
48constexpr int kRecurrentToCellWeightsTensor = 7;
49constexpr int kRecurrentToOutputWeightsTensor = 8;
50
51// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
52constexpr int kCellToInputWeightsTensor = 9;    // Optional
53constexpr int kCellToForgetWeightsTensor = 10;  // Optional
54constexpr int kCellToOutputWeightsTensor = 11;  // Optional
55
56// Gates bias tensors of size {n_cell}
57constexpr int kInputGateBiasTensor = 12;  // Optional
58constexpr int kForgetGateBiasTensor = 13;
59constexpr int kCellGateBiasTensor = 14;
60constexpr int kOutputGateBiasTensor = 15;
61
62// Projection weight tensor of size {n_output, n_cell}
63constexpr int kProjectionWeightsTensor = 16;  // Optional
64// Projection bias tensor of size {n_output}
65constexpr int kProjectionBiasTensor = 17;  // Optional
66
67// Output tensors.
68constexpr int kScratchBufferTensor = 0;
69constexpr int kOutputStateTensor = 1;
70constexpr int kCellStateTensor = 2;
71constexpr int kOutputTensor = 3;
72
73// Check that input tensor dimensions matches with each other.
74TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
75                                        TfLiteNode* node, int n_input,
76                                        int n_output, int n_cell) {
77  auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
78
79  // Making sure clipping parameters have valid values.
80  // == 0 means no clipping
81  //  > 0 means clipping
82  TF_LITE_ENSURE(context, params->cell_clip >= 0);
83  TF_LITE_ENSURE(context, params->proj_clip >= 0);
84
85  TfLiteTensor* input_to_input_weights =
86      GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
87  if (input_to_input_weights) {
88    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
89    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
90    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
91  }
92
93  TfLiteTensor* input_to_forget_weights =
94      GetInput(context, node, kInputToForgetWeightsTensor);
95  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
96  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
97  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
98
99  TfLiteTensor* input_to_cell_weights =
100      GetInput(context, node, kInputToCellWeightsTensor);
101  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
102  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
103  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
104
105  TfLiteTensor* recurrent_to_input_weights =
106      GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
107  if (recurrent_to_input_weights) {
108    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
109    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
110                      n_cell);
111    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
112                      n_output);
113  }
114
115  TfLiteTensor* recurrent_to_forget_weights =
116      GetInput(context, node, kRecurrentToForgetWeightsTensor);
117  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
118  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
119                    n_cell);
120  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
121                    n_output);
122
123  TfLiteTensor* recurrent_to_cell_weights =
124      GetInput(context, node, kRecurrentToCellWeightsTensor);
125  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
126  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
127  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
128                    n_output);
129
130  // We make sure the input-gate's parameters are either both present (regular
131  // LSTM) or not at all (CIFG-LSTM).
132  const bool cifg_weights_all_or_none =
133      ((input_to_input_weights != nullptr) &&
134       (recurrent_to_input_weights != nullptr)) ||
135      ((input_to_input_weights == nullptr) &&
136       (recurrent_to_input_weights == nullptr));
137  TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
138
139  TfLiteTensor* cell_to_input_weights =
140      GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
141  if (cell_to_input_weights) {
142    TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
143    TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
144  }
145
146  TfLiteTensor* cell_to_forget_weights =
147      GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
148  if (cell_to_forget_weights) {
149    TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
150    TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
151  }
152
153  TfLiteTensor* cell_to_output_weights =
154      GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
155  if (cell_to_output_weights) {
156    TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
157    TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
158  }
159
160  // Making sure the peephole weights are there all or none.
161  const bool use_cifg = (input_to_input_weights == nullptr);
162  const bool peephole_weights_all_or_none =
163      ((cell_to_input_weights != nullptr || use_cifg) &&
164       (cell_to_forget_weights != nullptr) &&
165       (cell_to_output_weights != nullptr)) ||
166      ((cell_to_input_weights == nullptr) &&
167       (cell_to_forget_weights == nullptr) &&
168       (cell_to_output_weights == nullptr));
169  TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
170
171  // Make sure the input gate bias is present only when not a CIFG-LSTM.
172  TfLiteTensor* input_gate_bias =
173      GetOptionalInputTensor(context, node, kInputGateBiasTensor);
174  if (use_cifg) {
175    TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
176  } else {
177    TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
178    TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
179  }
180
181  TfLiteTensor* forget_gate_bias =
182      GetInput(context, node, kForgetGateBiasTensor);
183  TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
184  TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
185
186  TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
187  TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
188  TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
189
190  TfLiteTensor* output_gate_bias =
191      GetInput(context, node, kOutputGateBiasTensor);
192  TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
193  TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
194
195  TfLiteTensor* projection_weights =
196      GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
197  if (projection_weights) {
198    TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
199    TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
200    TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
201  }
202
203  TfLiteTensor* projection_bias =
204      GetOptionalInputTensor(context, node, kProjectionBiasTensor);
205  if (projection_bias) {
206    TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
207    TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
208  }
209
210  // Making sure the projection tensors are consistent:
211  // 1) If projection weight is not present, then projection bias should not be
212  // present.
213  // 2) If projection weight is present, then projection bias is optional.
214  // TODO(ghodrat): make sure this is correct.
215  const bool projecton_tensors_consistent =
216      ((projection_weights != nullptr) || (projection_bias == nullptr));
217  TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
218
219  return kTfLiteOk;
220}
221
222// Resize the output, state and scratch tensors based on the sizes of the input
223// tensors. Also check that the size of the input tensors match each other.
224TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
225  // Check we have all the inputs and outputs we need.
226  TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
227  TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
228
229  // Inferring batch size, number of outputs and sequence length and
230  // number of cells from the input tensors.
231  TfLiteTensor* input = GetInput(context, node, kInputTensor);
232  TF_LITE_ENSURE(context, input->dims->size > 1);
233  const int max_time = input->dims->data[0];
234  const int n_batch = input->dims->data[1];
235  const int n_input = input->dims->data[2];
236
237  TfLiteTensor* input_to_output_weights =
238      GetInput(context, node, kInputToOutputWeightsTensor);
239  const int n_cell = input_to_output_weights->dims->data[0];
240  TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
241  TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
242
243  TfLiteTensor* recurrent_to_output_weights =
244      GetInput(context, node, kRecurrentToOutputWeightsTensor);
245  TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
246  TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
247                    n_cell);
248  const int n_output = recurrent_to_output_weights->dims->data[1];
249
250  // Check that input tensor dimensions matches with each other.
251  CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
252
253  // Get the pointer to output, state and scratch buffer tensors.
254  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
255  TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
256  TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
257  // TODO(ghodrat): Modify this as soon as we have a finalized method for
258  // scratch buffers.
259  TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
260
261  // Resize the output and output_state tensors.
262  TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
263  output_size->data[0] = max_time;
264  output_size->data[1] = n_batch;
265  output_size->data[2] = n_output;
266  TF_LITE_ENSURE_OK(context,
267                    context->ResizeTensor(context, output, output_size));
268
269  TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
270  output_state_size->data[0] = n_batch;
271  output_state_size->data[1] = n_output;
272  TF_LITE_ENSURE_OK(
273      context, context->ResizeTensor(context, output_state, output_state_size));
274
275  // Resize the scratch buffer tensor.
276  TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
277  cell_size->data[0] = n_batch;
278  cell_size->data[1] = n_cell;
279  TF_LITE_ENSURE_OK(context,
280                    context->ResizeTensor(context, cell_state, cell_size));
281
282  // Mark state tensors as persistent tensors.
283  output_state->allocation_type = kTfLiteArenaRwPersistent;
284  cell_state->allocation_type = kTfLiteArenaRwPersistent;
285
286  TfLiteTensor* input_to_input_weights =
287      GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
288  const bool use_cifg = (input_to_input_weights == nullptr);
289  if (use_cifg) {
290    TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
291    scratch_buffer_size->data[0] = n_batch;
292    // Reserving space for Cell, Forget, Output gates
293    scratch_buffer_size->data[1] = n_cell * 3;
294    TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
295                                                     scratch_buffer_size));
296  } else {
297    TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
298    scratch_buffer_size->data[0] = n_batch;
299    // Reserving space for Input, Cell, Forget, Output gates
300    scratch_buffer_size->data[1] = n_cell * 4;
301    TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
302                                                     scratch_buffer_size));
303  }
304  return kTfLiteOk;
305}
306
307// The LSTM Op engine.
308TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
309  auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
310  TfLiteTensor* input = GetInput(context, node, kInputTensor);
311
312  TfLiteTensor* input_to_input_weights =
313      GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
314  TfLiteTensor* input_to_forget_weights =
315      GetInput(context, node, kInputToForgetWeightsTensor);
316  TfLiteTensor* input_to_cell_weights =
317      GetInput(context, node, kInputToCellWeightsTensor);
318  TfLiteTensor* input_to_output_weights =
319      GetInput(context, node, kInputToOutputWeightsTensor);
320
321  TfLiteTensor* recurrent_to_input_weights =
322      GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
323  TfLiteTensor* recurrent_to_forget_weights =
324      GetInput(context, node, kRecurrentToForgetWeightsTensor);
325  TfLiteTensor* recurrent_to_cell_weights =
326      GetInput(context, node, kRecurrentToCellWeightsTensor);
327  TfLiteTensor* recurrent_to_output_weights =
328      GetInput(context, node, kRecurrentToOutputWeightsTensor);
329
330  TfLiteTensor* cell_to_input_weights =
331      GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
332  TfLiteTensor* cell_to_forget_weights =
333      GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
334  TfLiteTensor* cell_to_output_weights =
335      GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
336
337  TfLiteTensor* input_gate_bias =
338      GetOptionalInputTensor(context, node, kInputGateBiasTensor);
339  TfLiteTensor* forget_gate_bias =
340      GetInput(context, node, kForgetGateBiasTensor);
341  TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
342  TfLiteTensor* output_gate_bias =
343      GetInput(context, node, kOutputGateBiasTensor);
344
345  TfLiteTensor* projection_weights =
346      GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
347  TfLiteTensor* projection_bias =
348      GetOptionalInputTensor(context, node, kProjectionBiasTensor);
349
350  TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
351  TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
352  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
353
354  const int max_time = input->dims->data[0];
355  const int n_batch = input->dims->data[1];
356  const int n_input = input->dims->data[2];
357  // n_cell and n_output will be the same size when there is no projection.
358  const int n_cell = input_to_output_weights->dims->data[0];
359  const int n_output = recurrent_to_output_weights->dims->data[1];
360
361  // Since we have already checked that weights are all there or none, we can
362  // check the existense of only one to the get the condition.
363  const bool use_cifg = (input_to_input_weights == nullptr);
364  const bool use_peephole = (cell_to_output_weights != nullptr);
365
366  // Index the scratch buffers pointers to the global scratch buffer.
367  TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
368  float* input_gate_scratch = nullptr;
369  float* cell_scratch = nullptr;
370  float* forget_gate_scratch = nullptr;
371  float* output_gate_scratch = nullptr;
372  if (use_cifg) {
373    cell_scratch = scratch_buffer->data.f;
374    forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
375    output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
376  } else {
377    input_gate_scratch = scratch_buffer->data.f;
378    cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
379    forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
380    output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
381  }
382
383  for (int t = 0; t < max_time; t++) {
384    const float* input_ptr_time = input->data.f + t * n_batch * n_input;
385    // Initialize scratch buffers with bias.
386    if (!use_cifg) {
387      tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell,
388                                            n_batch, input_gate_scratch);
389    }
390    tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell,
391                                          n_batch, forget_gate_scratch);
392    tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch,
393                                          cell_scratch);
394    tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell,
395                                          n_batch, output_gate_scratch);
396
397    // For each batch and cell: compute input_weight * input.
398    if (!use_cifg) {
399      tensor_utils::MatrixBatchVectorMultiplyAccumulate(
400          input_to_input_weights->data.f, n_cell, n_input, input_ptr_time,
401          n_batch, input_gate_scratch, /*result_stride=*/1);
402    }
403    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
404        input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time,
405        n_batch, forget_gate_scratch, /*result_stride=*/1);
406    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
407        input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch,
408        cell_scratch, /*result_stride=*/1);
409    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
410        input_to_output_weights->data.f, n_cell, n_input, input_ptr_time,
411        n_batch, output_gate_scratch, /*result_stride=*/1);
412
413    // For each batch and cell: compute recurrent_weight * output_state.
414    if (!use_cifg) {
415      tensor_utils::MatrixBatchVectorMultiplyAccumulate(
416          recurrent_to_input_weights->data.f, n_cell, n_output,
417          output_state->data.f, n_batch, input_gate_scratch,
418          /*result_stride=*/1);
419    }
420    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
421        recurrent_to_forget_weights->data.f, n_cell, n_output,
422        output_state->data.f, n_batch, forget_gate_scratch,
423        /*result_stride=*/1);
424    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
425        recurrent_to_cell_weights->data.f, n_cell, n_output,
426        output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1);
427    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
428        recurrent_to_output_weights->data.f, n_cell, n_output,
429        output_state->data.f, n_batch, output_gate_scratch,
430        /*result_stride=*/1);
431
432    // For each batch and cell: update input gate.
433    if (!use_cifg) {
434      if (use_peephole) {
435        tensor_utils::VectorBatchVectorCwiseProductAccumulate(
436            cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch,
437            input_gate_scratch);
438      }
439      tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
440                                         input_gate_scratch);
441    }
442
443    // For each batch and cell: update forget gate.
444    if (use_peephole) {
445      tensor_utils::VectorBatchVectorCwiseProductAccumulate(
446          cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch,
447          forget_gate_scratch);
448    }
449    tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
450                                       forget_gate_scratch);
451
452    // For each batch and cell: update the cell.
453    tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
454                                           cell_state->data.f, n_batch * n_cell,
455                                           cell_state->data.f);
456    tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
457                                          params->activation, cell_scratch);
458    if (use_cifg) {
459      tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
460                               forget_gate_scratch);
461      tensor_utils::VectorVectorCwiseProductAccumulate(
462          cell_scratch, forget_gate_scratch, n_batch * n_cell,
463          cell_state->data.f);
464    } else {
465      tensor_utils::VectorVectorCwiseProductAccumulate(
466          cell_scratch, input_gate_scratch, n_batch * n_cell,
467          cell_state->data.f);
468    }
469    if (params->cell_clip > 0.0) {
470      tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell,
471                               params->cell_clip, cell_state->data.f);
472    }
473
474    // For each batch and cell: update the output gate.
475    if (use_peephole) {
476      tensor_utils::VectorBatchVectorCwiseProductAccumulate(
477          cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch,
478          output_gate_scratch);
479    }
480    tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
481                                       output_gate_scratch);
482    tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell,
483                                          params->activation, cell_scratch);
484    tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
485                                           n_batch * n_cell,
486                                           output_gate_scratch);
487
488    // For each batch: update the projection and output_state.
489    const bool use_projection_weight = (projection_weights != nullptr);
490    const bool use_projection_bias = (projection_bias != nullptr);
491    float* output_ptr_time = output->data.f + t * n_batch * n_output;
492    if (use_projection_weight) {
493      if (use_projection_bias) {
494        tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output,
495                                              n_batch, output_ptr_time);
496      } else {
497        tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output);
498      }
499      tensor_utils::MatrixBatchVectorMultiplyAccumulate(
500          projection_weights->data.f, n_output, n_cell, output_gate_scratch,
501          n_batch, output_ptr_time, /*result_stride=*/1);
502      if (params->proj_clip > 0.0) {
503        tensor_utils::ClipVector(output_ptr_time, n_batch * n_output,
504                                 params->proj_clip, output_ptr_time);
505      }
506    } else {
507      tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
508                               output_ptr_time);
509    }
510    tensor_utils::CopyVector(output_ptr_time, n_batch * n_output,
511                             output_state->data.f);
512  }
513  return kTfLiteOk;
514}
515
516}  // namespace unidirectional_sequence_lstm
517
518TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
519  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
520                                 unidirectional_sequence_lstm::Prepare,
521                                 unidirectional_sequence_lstm::Eval};
522  return &r;
523}
524
525}  // namespace builtin
526}  // namespace ops
527}  // namespace tflite
528