1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "LSTM.h" 18 19#include "CpuExecutor.h" 20#include "HalInterfaces.h" 21 22namespace android { 23namespace nn { 24 25namespace { 26 27template <typename T> 28inline T *GetBuffer(RunTimeOperandInfo* operand) { 29 return reinterpret_cast<T*>(operand->buffer); 30} 31 32template <typename T> 33inline const T *GetBuffer(const RunTimeOperandInfo* operand) { 34 return reinterpret_cast<const T*>(operand->buffer); 35} 36 37} // anonymous namespace 38 39LSTMCell::LSTMCell(const Operation& operation, 40 std::vector<RunTimeOperandInfo>& operands) { 41 input_ = GetInput(operation, operands, kInputTensor); 42 43 input_to_input_weights_ = GetInput(operation, operands, kInputToInputWeightsTensor); // optional 44 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor); 45 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor); 46 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor); 47 48 recurrent_to_input_weights_ = 49 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional 50 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor); 51 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor); 52 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor); 53 54 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional 55 cell_to_forget_weights_ = GetInput(operation, operands, kCellToForgetWeightsTensor); // optional 56 cell_to_output_weights_ = GetInput(operation, operands, kCellToOutputWeightsTensor); // optional 57 58 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor); 59 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor); 60 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor); 61 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor); 62 63 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional 64 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional 65 66 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor); 67 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor); 68 69 params_.activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>( 70 *GetInput(operation, operands, kActivationParam))); 71 params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam)); 72 params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam)); 73 74 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor); 75 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor); 76 output_ = GetOutput(operation, operands, kOutputTensor); 77 78 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor); 79} 80 81bool LSTMCell::CheckInputTensorDimensions( 82 const Operation &operation, std::vector<RunTimeOperandInfo> &operands, 83 uint32_t n_input, uint32_t n_output, uint32_t n_cell) { 84 LSTMParams params = { 85 .activation_ = static_cast<TfLiteFusedActivation>(getScalarData<int32_t>( 86 *GetInput(operation, operands, LSTMCell::kActivationParam))), 87 .cell_clip_ = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kCellClipParam)), 88 .proj_clip_ = getScalarData<float>(*GetInput(operation, operands, LSTMCell::kProjClipParam)) 89 }; 90 91 // Making sure clipping parameters have valid values. 92 // == 0 means no clipping 93 // > 0 means clipping 94 NN_CHECK(params.cell_clip_ >= 0); 95 NN_CHECK(params.proj_clip_ >= 0); 96 97 const RunTimeOperandInfo *input_to_input_weights = 98 GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor); 99 if (!IsNullInput(input_to_input_weights)) { 100 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2); 101 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell); 102 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input); 103 } 104 105 const RunTimeOperandInfo *input_to_forget_weights = 106 GetInput(operation, operands, LSTMCell::kInputToForgetWeightsTensor); 107 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2); 108 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell); 109 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input); 110 111 const RunTimeOperandInfo *input_to_cell_weights = 112 GetInput(operation, operands, LSTMCell::kInputToCellWeightsTensor); 113 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2); 114 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell); 115 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input); 116 117 const RunTimeOperandInfo *recurrent_to_input_weights = 118 GetInput(operation, operands, LSTMCell::kRecurrentToInputWeightsTensor); 119 if (!IsNullInput(recurrent_to_input_weights)) { 120 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2); 121 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell); 122 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output); 123 } 124 125 const RunTimeOperandInfo *recurrent_to_forget_weights = 126 GetInput(operation, operands, LSTMCell::kRecurrentToForgetWeightsTensor); 127 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2); 128 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell); 129 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output); 130 131 const RunTimeOperandInfo *recurrent_to_cell_weights = 132 GetInput(operation, operands, LSTMCell::kRecurrentToCellWeightsTensor); 133 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2); 134 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell); 135 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output); 136 137 // We make sure the input-gate's parameters are either both present (regular 138 // LSTM) or not at all (CIFG-LSTM). 139 const bool cifg_weights_all_or_none = 140 (!IsNullInput(input_to_input_weights) && 141 !IsNullInput(recurrent_to_input_weights)) || 142 (IsNullInput(input_to_input_weights) && 143 IsNullInput(recurrent_to_input_weights)); 144 NN_CHECK(cifg_weights_all_or_none); 145 146 const RunTimeOperandInfo *cell_to_input_weights = 147 GetInput(operation, operands, LSTMCell::kCellToInputWeightsTensor); 148 if (!IsNullInput(cell_to_input_weights)) { 149 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1); 150 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell); 151 } 152 153 const RunTimeOperandInfo *cell_to_forget_weights = 154 GetInput(operation, operands, LSTMCell::kCellToForgetWeightsTensor); 155 if (!IsNullInput(cell_to_forget_weights)) { 156 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1); 157 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell); 158 } 159 160 const RunTimeOperandInfo *cell_to_output_weights = 161 GetInput(operation, operands, LSTMCell::kCellToOutputWeightsTensor); 162 if (!IsNullInput(cell_to_output_weights)) { 163 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1); 164 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell); 165 } 166 167 // Making sure the peephole weights are there all or none. 168 const bool use_cifg = IsNullInput(input_to_input_weights); 169 const bool peephole_weights_all_or_none = 170 ((!IsNullInput(cell_to_input_weights) || use_cifg) && 171 !IsNullInput(cell_to_forget_weights) && 172 !IsNullInput(cell_to_output_weights)) || 173 (IsNullInput(cell_to_input_weights) && 174 IsNullInput(cell_to_forget_weights) && 175 IsNullInput(cell_to_output_weights)); 176 NN_CHECK(peephole_weights_all_or_none); 177 178 // Make sure the input gate bias is present only when not a CIFG-LSTM. 179 const RunTimeOperandInfo* input_gate_bias = 180 GetInput(operation, operands, LSTMCell::kInputGateBiasTensor); 181 if (use_cifg) { 182 NN_CHECK(IsNullInput(input_gate_bias)); 183 } else { 184 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1); 185 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell); 186 } 187 188 const RunTimeOperandInfo *forget_gate_bias = 189 GetInput(operation, operands, LSTMCell::kForgetGateBiasTensor); 190 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1); 191 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell); 192 193 const RunTimeOperandInfo *cell_bias = 194 GetInput(operation, operands, LSTMCell::kCellGateBiasTensor); 195 NN_CHECK_EQ(NumDimensions(cell_bias), 1); 196 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell); 197 198 const RunTimeOperandInfo *output_gate_bias = 199 GetInput(operation, operands, LSTMCell::kOutputGateBiasTensor); 200 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1); 201 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell); 202 203 const RunTimeOperandInfo *projection_weights = 204 GetInput(operation, operands, LSTMCell::kProjectionWeightsTensor); 205 if (!IsNullInput(projection_weights)) { 206 NN_CHECK_EQ(NumDimensions(projection_weights), 2); 207 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output); 208 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell); 209 } 210 211 const RunTimeOperandInfo *projection_bias = 212 GetInput(operation, operands, LSTMCell::kProjectionBiasTensor); 213 if (!IsNullInput(projection_bias)) { 214 NN_CHECK_EQ(NumDimensions(projection_bias), 1); 215 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output); 216 } 217 218 // Making sure the projection tensors are consistent: 219 // 1) If projection weight is not present, then projection bias should not be 220 // present. 221 // 2) If projection weight is present, then projection bias is optional. 222 // TODO: make sure this is correct. 223 const bool projecton_tensors_consistent = 224 (!IsNullInput(projection_weights) || IsNullInput(projection_bias)); 225 NN_CHECK(projecton_tensors_consistent == true); 226 227 return true; 228} 229 230bool LSTMCell::Prepare(const Operation &operation, 231 std::vector<RunTimeOperandInfo> &operands, 232 Shape *scratchShape, 233 Shape *outputStateShape, 234 Shape *cellStateShape, 235 Shape *outputShape) { 236 // Check we have all the inputs and outputs we need. 237 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 && 238 NumInputsWithValues(operation, operands) <= 23); 239 NN_CHECK_EQ(NumOutputs(operation), 4); 240 241 // Inferring batch size, number of outputs and number of cells from the 242 // input tensors. 243 const RunTimeOperandInfo *input = 244 GetInput(operation, operands, LSTMCell::kInputTensor); 245 NN_CHECK(NumDimensions(input) > 1); 246 const uint32_t n_batch = SizeOfDimension(input, 0); 247 const uint32_t n_input = SizeOfDimension(input, 1); 248 249 const RunTimeOperandInfo *input_to_output_weights = 250 GetInput(operation, operands, LSTMCell::kInputToOutputWeightsTensor); 251 const uint32_t n_cell = SizeOfDimension(input_to_output_weights, 0); 252 NN_CHECK_EQ(NumDimensions(input_to_output_weights), 2); 253 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights, 1), n_input); 254 255 const RunTimeOperandInfo *recurrent_to_output_weights = 256 GetInput(operation, operands, LSTMCell::kRecurrentToOutputWeightsTensor); 257 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights), 2); 258 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights, 0), 259 n_cell); 260 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights, 1); 261 262 // Check that input tensor dimensions matches with each other. 263 if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) { 264 return false; 265 } 266 267 // Resize the output and output_state tensors. 268 const Shape &inputShape = input->shape(); 269 270 outputShape->type = inputShape.type; 271 outputShape->dimensions = { n_batch, n_output }; 272 outputShape->offset = inputShape.offset; 273 outputShape->scale = inputShape.scale; 274 275 outputStateShape->type = inputShape.type; 276 outputStateShape->dimensions = { n_batch, n_output }; 277 outputStateShape->offset = inputShape.offset; 278 outputStateShape->scale = inputShape.scale; 279 280 cellStateShape->type = inputShape.type; 281 cellStateShape->dimensions = { n_batch, n_cell }; 282 cellStateShape->offset = inputShape.offset; 283 cellStateShape->scale = inputShape.scale; 284 285 const RunTimeOperandInfo *input_to_input_weights = 286 GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor); 287 const bool use_cifg = IsNullInput(input_to_input_weights); 288 if (use_cifg) { 289 // Reserving space for Cell, Forget, Output gates 290 scratchShape->dimensions = { n_batch, n_cell * 3 }; 291 } else { 292 // Reserving space for Input, Cell, Forget, Output gates 293 scratchShape->dimensions = { n_batch, n_cell * 4 }; 294 } 295 scratchShape->type = inputShape.type; 296 scratchShape->offset = inputShape.offset; 297 scratchShape->scale = inputShape.scale; 298 299 return true; 300} 301 302bool LSTMCell::Eval() { 303 const uint32_t n_batch = input_->shape().dimensions[0]; 304 const uint32_t n_input = input_->shape().dimensions[1]; 305 // n_cell and n_output will be the same size when there is no projection. 306 const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0]; 307 const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1]; 308 309 // Since we have already checked that weights are all there or none, we can 310 // check the existence of only one to the get the condition. 311 const bool use_cifg = (input_to_input_weights_->lifetime == OperandLifeTime::NO_VALUE); 312 const bool use_peephole = (cell_to_output_weights_->lifetime != OperandLifeTime::NO_VALUE); 313 314 // Index the scratch buffers pointers to the global scratch buffer. 315 float* input_gate_scratch = nullptr; 316 float* cell_scratch = nullptr; 317 float* forget_gate_scratch = nullptr; 318 float* output_gate_scratch = nullptr; 319 if (use_cifg) { 320 cell_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer); 321 forget_gate_scratch = cell_scratch + n_cell * n_batch; 322 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch; 323 } else { 324 input_gate_scratch = reinterpret_cast<float*>(scratch_buffer_->buffer); 325 cell_scratch = input_gate_scratch + n_cell * n_batch; 326 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch; 327 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch; 328 } 329 330 // Initialize scratch buffers with bias. 331 if (!use_cifg) { 332 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_), 333 n_cell, n_batch, input_gate_scratch); 334 } 335 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_), 336 n_cell, n_batch, forget_gate_scratch); 337 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_), 338 n_cell, n_batch, cell_scratch); 339 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_), 340 n_cell, n_batch, output_gate_scratch); 341 342 // For each batch and cell: compute input_weight * input. 343 if (!use_cifg) { 344 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 345 GetBuffer<float>(input_to_input_weights_), n_cell, n_input, 346 GetBuffer<float>(input_), n_batch, input_gate_scratch, /*result_stride*/1); 347 } 348 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 349 GetBuffer<float>(input_to_forget_weights_), n_cell, n_input, 350 GetBuffer<float>(input_), n_batch, forget_gate_scratch, /*result_stride*/1); 351 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 352 GetBuffer<float>(input_to_cell_weights_), n_cell, n_input, 353 GetBuffer<float>(input_), n_batch, cell_scratch, /*result_stride*/1); 354 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 355 GetBuffer<float>(input_to_output_weights_), n_cell, n_input, 356 GetBuffer<float>(input_), n_batch, output_gate_scratch, /*result_stride*/1); 357 358 // For each batch and cell: compute recurrent_weight * output_state. 359 if (!use_cifg) { 360 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 361 GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output, 362 GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch, /*result_stride*/1); 363 } 364 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 365 GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output, 366 GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch, /*result_stride*/1); 367 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 368 GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output, 369 GetBuffer<float>(output_state_in_), n_batch, cell_scratch, /*result_stride*/1); 370 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 371 GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output, 372 GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch, /*result_stride*/1); 373 374 // For each batch and cell: update input gate. 375 if (!use_cifg) { 376 if (use_peephole) { 377 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate( 378 GetBuffer<float>(cell_to_input_weights_), n_cell, 379 GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch); 380 } 381 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, 382 n_cell * n_batch, 383 input_gate_scratch); 384 } 385 386 // For each batch and cell: update forget gate. 387 if (use_peephole) { 388 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate( 389 GetBuffer<float>(cell_to_forget_weights_), n_cell, 390 GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch); 391 } 392 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, 393 n_cell * n_batch, 394 forget_gate_scratch); 395 396 // For each batch and cell: update the cell. 397 tflite::tensor_utils::VectorVectorCwiseProduct( 398 forget_gate_scratch, GetBuffer<float>(cell_state_in_), n_batch * n_cell, 399 GetBuffer<float>(cell_state_out_)); 400 tflite::tensor_utils::ApplyActivationToVector( 401 cell_scratch, n_batch * n_cell, 402 params_.activation_, cell_scratch); 403 if (use_cifg) { 404 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, 405 forget_gate_scratch); 406 tflite::tensor_utils::VectorVectorCwiseProductAccumulate( 407 cell_scratch, forget_gate_scratch, n_batch * n_cell, 408 GetBuffer<float>(cell_state_out_)); 409 } else { 410 tflite::tensor_utils::VectorVectorCwiseProductAccumulate( 411 cell_scratch, input_gate_scratch, n_batch * n_cell, 412 GetBuffer<float>(cell_state_out_)); 413 } 414 if (params_.cell_clip_ > 0.0) { 415 tflite::tensor_utils::ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell, 416 params_.cell_clip_, GetBuffer<float>(cell_state_out_)); 417 } 418 419 // For each batch and cell: update the output gate. 420 if (use_peephole) { 421 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate( 422 GetBuffer<float>(cell_to_output_weights_), n_cell, 423 GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch); 424 } 425 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, 426 output_gate_scratch); 427 tflite::tensor_utils::ApplyActivationToVector(GetBuffer<float>(cell_state_out_), 428 n_batch * n_cell, 429 params_.activation_, 430 cell_scratch); 431 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, 432 cell_scratch, n_batch * n_cell, 433 output_gate_scratch); 434 435 // For each batch: update the projection and output_state. 436 const bool use_projection_weight = 437 (projection_weights_->lifetime != OperandLifeTime::NO_VALUE); 438 const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE); 439 if (use_projection_weight) { 440 if (use_projection_bias) { 441 tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(projection_bias_), n_output, 442 n_batch, GetBuffer<float>(output_)); 443 } else { 444 tflite::tensor_utils::ZeroVector(GetBuffer<float>(output_), n_batch * n_output); 445 } 446 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 447 GetBuffer<float>(projection_weights_), n_output, n_cell, 448 output_gate_scratch, n_batch, GetBuffer<float>(output_), 449 /*result_stride*/1); 450 if (params_.proj_clip_ > 0.0) { 451 tflite::tensor_utils::ClipVector(GetBuffer<float>(output_), n_batch * n_output, 452 params_.proj_clip_, GetBuffer<float>(output_)); 453 } 454 } else { 455 tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, 456 GetBuffer<float>(output_)); 457 } 458 tflite::tensor_utils::CopyVector(GetBuffer<float>(output_), n_batch * n_output, 459 GetBuffer<float>(output_state_out_)); 460 461 return true; 462} 463 464} // namespace nn 465} // namespace android 466