1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "LSTM.h"
18
19#include "CpuExecutor.h"
20#include "HalInterfaces.h"
21
22namespace android {
23namespace nn {
24
25namespace {
26
27template <typename T>
28inline T *GetBuffer(RunTimeOperandInfo* operand) {
29  return reinterpret_cast<T*>(operand->buffer);
30}
31
32template <typename T>
33inline const T *GetBuffer(const RunTimeOperandInfo* operand) {
34  return reinterpret_cast<const T*>(operand->buffer);
35}
36
37}  // anonymous namespace
38
39LSTMCell::LSTMCell(const Operation& operation,
40                   std::vector<RunTimeOperandInfo>& operands) {
41  input_ = GetInput(operation, operands, kInputTensor);
42
43  input_to_input_weights_ = GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
44  input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
45  input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
46  input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
47
48  recurrent_to_input_weights_ =
49      GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
50  recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
51  recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
52  recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
53
54  cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);    // optional
55  cell_to_forget_weights_ = GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
56  cell_to_output_weights_ = GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
57
58  input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
59  forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
60  cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
61  output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
62
63  projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
64  projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
65
66  output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
67  cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
68
69  params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>(
70      *GetInput(operation, operands, kActivationParam)));
71  params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
72  params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
73
74  output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
75  cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
76  output_ = GetOutput(operation, operands, kOutputTensor);
77
78  scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
79}
80
81bool LSTMCell::CheckInputTensorDimensions(
82    const Operation &operation, std::vector<RunTimeOperandInfo> &operands,
83    uint32_t n_input, uint32_t n_output, uint32_t n_cell) {
84  LSTMParams params = {
85    .activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>(
86        *GetInput(operation, operands, LSTMCell::kActivationParam))),
87    .cell_clip_  = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kCellClipParam)),
88    .proj_clip_  = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kProjClipParam))
89  };
90
91  // Making sure clipping parameters have valid values.
92  // == 0 means no clipping
93  //  > 0 means clipping
94  NN_CHECK(params.cell_clip_ >= 0);
95  NN_CHECK(params.proj_clip_ >= 0);
96
97  const RunTimeOperandInfo *input_to_input_weights =
98      GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
99  if (!IsNullInput(input_to_input_weights)) {
100    NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
101    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
102    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
103  }
104
105  const RunTimeOperandInfo *input_to_forget_weights =
106      GetInput(operation, operands, LSTMCell::kInputToForgetWeightsTensor);
107  NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
108  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
109  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
110
111  const RunTimeOperandInfo *input_to_cell_weights =
112      GetInput(operation, operands, LSTMCell::kInputToCellWeightsTensor);
113  NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
114  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
115  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
116
117  const RunTimeOperandInfo *recurrent_to_input_weights =
118      GetInput(operation, operands, LSTMCell::kRecurrentToInputWeightsTensor);
119  if (!IsNullInput(recurrent_to_input_weights)) {
120    NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
121    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
122    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
123  }
124
125  const RunTimeOperandInfo *recurrent_to_forget_weights =
126      GetInput(operation, operands, LSTMCell::kRecurrentToForgetWeightsTensor);
127  NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
128  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
129  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
130
131  const RunTimeOperandInfo *recurrent_to_cell_weights =
132      GetInput(operation, operands, LSTMCell::kRecurrentToCellWeightsTensor);
133  NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
134  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
135  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
136
137  // We make sure the input-gate's parameters are either both present (regular
138  // LSTM) or not at all (CIFG-LSTM).
139  const bool cifg_weights_all_or_none =
140      (!IsNullInput(input_to_input_weights) &&
141       !IsNullInput(recurrent_to_input_weights)) ||
142      (IsNullInput(input_to_input_weights) &&
143       IsNullInput(recurrent_to_input_weights));
144  NN_CHECK(cifg_weights_all_or_none);
145
146  const RunTimeOperandInfo *cell_to_input_weights =
147      GetInput(operation, operands, LSTMCell::kCellToInputWeightsTensor);
148  if (!IsNullInput(cell_to_input_weights)) {
149    NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
150    NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
151  }
152
153  const RunTimeOperandInfo *cell_to_forget_weights =
154      GetInput(operation, operands, LSTMCell::kCellToForgetWeightsTensor);
155  if (!IsNullInput(cell_to_forget_weights)) {
156    NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
157    NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
158  }
159
160  const RunTimeOperandInfo *cell_to_output_weights =
161      GetInput(operation, operands, LSTMCell::kCellToOutputWeightsTensor);
162  if (!IsNullInput(cell_to_output_weights)) {
163    NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
164    NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
165  }
166
167  // Making sure the peephole weights are there all or none.
168  const bool use_cifg = IsNullInput(input_to_input_weights);
169  const bool peephole_weights_all_or_none =
170      ((!IsNullInput(cell_to_input_weights) || use_cifg) &&
171       !IsNullInput(cell_to_forget_weights) &&
172       !IsNullInput(cell_to_output_weights)) ||
173      (IsNullInput(cell_to_input_weights) &&
174       IsNullInput(cell_to_forget_weights) &&
175       IsNullInput(cell_to_output_weights));
176  NN_CHECK(peephole_weights_all_or_none);
177
178  // Make sure the input gate bias is present only when not a CIFG-LSTM.
179  const RunTimeOperandInfo* input_gate_bias =
180      GetInput(operation, operands, LSTMCell::kInputGateBiasTensor);
181  if (use_cifg) {
182    NN_CHECK(IsNullInput(input_gate_bias));
183  } else {
184    NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
185    NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
186  }
187
188  const RunTimeOperandInfo *forget_gate_bias =
189      GetInput(operation, operands, LSTMCell::kForgetGateBiasTensor);
190  NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
191  NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
192
193  const RunTimeOperandInfo *cell_bias =
194      GetInput(operation, operands, LSTMCell::kCellGateBiasTensor);
195  NN_CHECK_EQ(NumDimensions(cell_bias), 1);
196  NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
197
198  const RunTimeOperandInfo *output_gate_bias =
199      GetInput(operation, operands, LSTMCell::kOutputGateBiasTensor);
200  NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
201  NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
202
203  const RunTimeOperandInfo *projection_weights =
204      GetInput(operation, operands, LSTMCell::kProjectionWeightsTensor);
205  if (!IsNullInput(projection_weights)) {
206    NN_CHECK_EQ(NumDimensions(projection_weights), 2);
207    NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
208    NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
209  }
210
211  const RunTimeOperandInfo *projection_bias =
212      GetInput(operation, operands, LSTMCell::kProjectionBiasTensor);
213  if (!IsNullInput(projection_bias)) {
214    NN_CHECK_EQ(NumDimensions(projection_bias), 1);
215    NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
216  }
217
218  // Making sure the projection tensors are consistent:
219  // 1) If projection weight is not present, then projection bias should not be
220  // present.
221  // 2) If projection weight is present, then projection bias is optional.
222  // TODO: make sure this is correct.
223  const bool projecton_tensors_consistent =
224      (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
225  NN_CHECK(projecton_tensors_consistent == true);
226
227  return true;
228}
229
230bool LSTMCell::Prepare(const Operation &operation,
231                       std::vector<RunTimeOperandInfo> &operands,
232                       Shape *scratchShape,
233                       Shape *outputStateShape,
234                       Shape *cellStateShape,
235                       Shape *outputShape) {
236  // Check we have all the inputs and outputs we need.
237  NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
238           NumInputsWithValues(operation, operands) <= 23);
239  NN_CHECK_EQ(NumOutputs(operation), 4);
240
241  // Inferring batch size, number of outputs and number of cells from the
242  // input tensors.
243  const RunTimeOperandInfo *input =
244      GetInput(operation, operands, LSTMCell::kInputTensor);
245  NN_CHECK(NumDimensions(input) > 1);
246  const uint32_t n_batch = SizeOfDimension(input, 0);
247  const uint32_t n_input = SizeOfDimension(input, 1);
248
249  const RunTimeOperandInfo *input_to_output_weights =
250      GetInput(operation, operands, LSTMCell::kInputToOutputWeightsTensor);
251  const uint32_t n_cell = SizeOfDimension(input_to_output_weights, 0);
252  NN_CHECK_EQ(NumDimensions(input_to_output_weights), 2);
253  NN_CHECK_EQ(SizeOfDimension(input_to_output_weights, 1), n_input);
254
255  const RunTimeOperandInfo *recurrent_to_output_weights =
256      GetInput(operation, operands, LSTMCell::kRecurrentToOutputWeightsTensor);
257  NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights), 2);
258  NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights, 0),
259                    n_cell);
260  const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights, 1);
261
262  // Check that input tensor dimensions matches with each other.
263  if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
264    return false;
265  }
266
267  // Resize the output and output_state tensors.
268  const Shape &inputShape = input->shape();
269
270  outputShape->type = inputShape.type;
271  outputShape->dimensions = { n_batch, n_output };
272  outputShape->offset = inputShape.offset;
273  outputShape->scale = inputShape.scale;
274
275  outputStateShape->type = inputShape.type;
276  outputStateShape->dimensions = { n_batch, n_output };
277  outputStateShape->offset = inputShape.offset;
278  outputStateShape->scale = inputShape.scale;
279
280  cellStateShape->type = inputShape.type;
281  cellStateShape->dimensions = { n_batch, n_cell };
282  cellStateShape->offset = inputShape.offset;
283  cellStateShape->scale = inputShape.scale;
284
285  const RunTimeOperandInfo *input_to_input_weights =
286      GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
287  const bool use_cifg = IsNullInput(input_to_input_weights);
288  if (use_cifg) {
289    // Reserving space for Cell, Forget, Output gates
290    scratchShape->dimensions = { n_batch, n_cell * 3 };
291  } else {
292    // Reserving space for Input, Cell, Forget, Output gates
293    scratchShape->dimensions = { n_batch, n_cell * 4 };
294  }
295  scratchShape->type = inputShape.type;
296  scratchShape->offset = inputShape.offset;
297  scratchShape->scale = inputShape.scale;
298
299  return true;
300}
301
302bool LSTMCell::Eval() {
303  const uint32_t n_batch = input_->shape().dimensions[0];
304  const uint32_t n_input = input_->shape().dimensions[1];
305  // n_cell and n_output will be the same size when there is no projection.
306  const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0];
307  const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1];
308
309  // Since we have already checked that weights are all there or none, we can
310  // check the existence of only one to the get the condition.
311  const bool use_cifg = (input_to_input_weights_->lifetime == OperandLifeTime::NO_VALUE);
312  const bool use_peephole = (cell_to_output_weights_->lifetime != OperandLifeTime::NO_VALUE);
313
314  // Index the scratch buffers pointers to the global scratch buffer.
315  float* input_gate_scratch = nullptr;
316  float* cell_scratch = nullptr;
317  float* forget_gate_scratch = nullptr;
318  float* output_gate_scratch = nullptr;
319  if (use_cifg) {
320    cell_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer);
321    forget_gate_scratch = cell_scratch + n_cell * n_batch;
322    output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
323  } else {
324    input_gate_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer);
325    cell_scratch = input_gate_scratch + n_cell * n_batch;
326    forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
327    output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
328  }
329
330  // Initialize scratch buffers with bias.
331  if (!use_cifg) {
332    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
333                                                  n_cell, n_batch, input_gate_scratch);
334  }
335  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
336                                                n_cell, n_batch, forget_gate_scratch);
337  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
338                                                n_cell, n_batch, cell_scratch);
339  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
340                                                n_cell, n_batch, output_gate_scratch);
341
342  // For each batch and cell: compute input_weight * input.
343  if (!use_cifg) {
344    tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
345        GetBuffer<float>(input_to_input_weights_), n_cell, n_input,
346        GetBuffer<float>(input_), n_batch, input_gate_scratch, /*result_stride*/1);
347  }
348  tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
349      GetBuffer<float>(input_to_forget_weights_), n_cell, n_input,
350      GetBuffer<float>(input_), n_batch, forget_gate_scratch, /*result_stride*/1);
351  tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
352      GetBuffer<float>(input_to_cell_weights_), n_cell, n_input,
353      GetBuffer<float>(input_), n_batch, cell_scratch, /*result_stride*/1);
354  tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
355      GetBuffer<float>(input_to_output_weights_), n_cell, n_input,
356      GetBuffer<float>(input_), n_batch, output_gate_scratch, /*result_stride*/1);
357
358  // For each batch and cell: compute recurrent_weight * output_state.
359  if (!use_cifg) {
360    tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
361        GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
362        GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch, /*result_stride*/1);
363  }
364  tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
365      GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
366      GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch, /*result_stride*/1);
367  tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
368      GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
369      GetBuffer<float>(output_state_in_), n_batch, cell_scratch, /*result_stride*/1);
370  tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
371      GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
372      GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch, /*result_stride*/1);
373
374  // For each batch and cell: update input gate.
375  if (!use_cifg) {
376    if (use_peephole) {
377      tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
378          GetBuffer<float>(cell_to_input_weights_), n_cell,
379          GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
380    }
381    tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch,
382                                               n_cell * n_batch,
383                                               input_gate_scratch);
384  }
385
386  // For each batch and cell: update forget gate.
387  if (use_peephole) {
388    tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
389        GetBuffer<float>(cell_to_forget_weights_), n_cell,
390        GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
391  }
392  tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch,
393                                             n_cell * n_batch,
394                                             forget_gate_scratch);
395
396  // For each batch and cell: update the cell.
397  tflite::tensor_utils::VectorVectorCwiseProduct(
398      forget_gate_scratch, GetBuffer<float>(cell_state_in_), n_batch * n_cell,
399      GetBuffer<float>(cell_state_out_));
400  tflite::tensor_utils::ApplyActivationToVector(
401      cell_scratch, n_batch * n_cell,
402      params_.activation_, cell_scratch);
403  if (use_cifg) {
404    tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
405                                     forget_gate_scratch);
406    tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
407        cell_scratch, forget_gate_scratch, n_batch * n_cell,
408        GetBuffer<float>(cell_state_out_));
409  } else {
410    tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
411        cell_scratch, input_gate_scratch, n_batch * n_cell,
412        GetBuffer<float>(cell_state_out_));
413  }
414  if (params_.cell_clip_ > 0.0) {
415    tflite::tensor_utils::ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
416                                     params_.cell_clip_, GetBuffer<float>(cell_state_out_));
417  }
418
419  // For each batch and cell: update the output gate.
420  if (use_peephole) {
421    tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
422        GetBuffer<float>(cell_to_output_weights_), n_cell,
423        GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
424  }
425  tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
426                                             output_gate_scratch);
427  tflite::tensor_utils::ApplyActivationToVector(GetBuffer<float>(cell_state_out_),
428                                                n_batch * n_cell,
429                                                params_.activation_,
430                                                cell_scratch);
431  tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch,
432                                                 cell_scratch, n_batch * n_cell,
433                                                 output_gate_scratch);
434
435  // For each batch: update the projection and output_state.
436  const bool use_projection_weight =
437          (projection_weights_->lifetime != OperandLifeTime::NO_VALUE);
438  const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE);
439  if (use_projection_weight) {
440    if (use_projection_bias) {
441      tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(projection_bias_), n_output,
442                                                    n_batch, GetBuffer<float>(output_));
443    } else {
444      tflite::tensor_utils::ZeroVector(GetBuffer<float>(output_), n_batch * n_output);
445    }
446    tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
447        GetBuffer<float>(projection_weights_), n_output, n_cell,
448        output_gate_scratch, n_batch, GetBuffer<float>(output_),
449        /*result_stride*/1);
450    if (params_.proj_clip_ > 0.0) {
451      tflite::tensor_utils::ClipVector(GetBuffer<float>(output_), n_batch * n_output,
452                               params_.proj_clip_, GetBuffer<float>(output_));
453    }
454  } else {
455    tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
456                             GetBuffer<float>(output_));
457  }
458  tflite::tensor_utils::CopyVector(GetBuffer<float>(output_), n_batch * n_output,
459                           GetBuffer<float>(output_state_out_));
460
461  return true;
462}
463
464}  // namespace nn
465}  // namespace android
466