10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<std::unique_ptr<Operator>>::iterator FindOperator(
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Model* model, const Operator& op) {
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto it = model->operators.begin();
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (; it != model->operators.end(); ++it) {
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (it->get() == &op) {
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      break;
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return it;
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool GetStateArrayForBackEdge(const Model& model,
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                              const string& back_edge_source_array,
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                              string* state_array = nullptr) {
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (const auto& rnn_state : model.flags.rnn_states()) {
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (back_edge_source_array == rnn_state.back_edge_source_array()) {
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      // Found LSTM cell output
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (state_array) {
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        *state_array = rnn_state.state_array();
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      return true;
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return false;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns true if the given operator has exactly 1 input, and is connected to
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// the given op_type.
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// We use kNone to indicate an input unattached to an operator output. Usually
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// these are the static input arrays.
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MatchOperatorInputs(const Operator& op, const Model& model,
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                         OperatorType op_type, Operator** connected_op) {
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check for required number of inputs
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (op.inputs.size() != 1) {
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check if first input is disconnected/connected to an operator
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* x = GetOpWithOutput(model, op.inputs[0]);
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((op_type == OperatorType::kNone) && (x != nullptr)) {
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((op_type != OperatorType::kNone) && (x == nullptr)) {
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check that first operator, if connected, is of correct type
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((x != nullptr) && (x->type != op_type)) {
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Successfully matched. Optionally return matching input operators.
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (connected_op) {
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    *connected_op = x;
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns true if the given operator has exactly 2 inputs, which are connected
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// to the given op_types.
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// We use kNone to indicate an input unattached to an operator output. Usually
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// these are the static input arrays.
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MatchOperatorInputs(const Operator& op, const Model& model,
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                         OperatorType a_op_type, Operator** a_op,
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                         OperatorType b_op_type, Operator** b_op) {
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check for required number of inputs
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (op.inputs.size() != 2) {
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check if first input is disconnected/connected to an operator
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* x = GetOpWithOutput(model, op.inputs[0]);
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check that first operator, if connected, is of correct type
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((x != nullptr) && (x->type != a_op_type)) {
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check if second input is disconnected/connected to an operator
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* y = GetOpWithOutput(model, op.inputs[1]);
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check that second operator, if connected, is of correct type
1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((y != nullptr) && (y->type != b_op_type)) {
1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Successfully matched. Optionally return matching input operators.
1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (a_op != nullptr) {
1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    *a_op = x;
1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (b_op != nullptr) {
1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    *b_op = y;
1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns true if the given operator has exactly 3 inputs, which are connected
1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// to the given op_types.
1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// We use kNone to indicate an input unattached to an operator output. Usually
1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// these are the static input arrays.
1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MatchOperatorInputs(const Operator& op, const Model& model,
1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                         OperatorType a_op_type, Operator** a_op,
1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                         OperatorType b_op_type, Operator** b_op,
1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                         OperatorType c_op_type, Operator** c_op) {
1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check for required number of inputs
1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (op.inputs.size() != 3) {
1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check if first input is disconnected/connected to an operator
1500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* x = GetOpWithOutput(model, op.inputs[0]);
1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
1550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check that first operator, if connected, is of correct type
1590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((x != nullptr) && (x->type != a_op_type)) {
1600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check if second input is disconnected/connected to an operator
1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* y = GetOpWithOutput(model, op.inputs[1]);
1650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check that second operator, if connected, is of correct type
1730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((y != nullptr) && (y->type != b_op_type)) {
1740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check if third input is disconnected/connected to an operator
1780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* z = GetOpWithOutput(model, op.inputs[2]);
1790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
1800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
1830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Check that third operator, if connected, is of correct type
1870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if ((z != nullptr) && (z->type != c_op_type)) {
1880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Successfully matched. Optionally return matching input operators.
1920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (a_op != nullptr) {
1930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    *a_op = x;
1940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (b_op != nullptr) {
1960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    *b_op = y;
1970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (c_op != nullptr) {
1990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    *c_op = z;
2000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
2020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
2030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
2050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
2070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This LSTM cell identification method is not invariant to commutation of
2080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // commutative operator inputs. For example, if input[0] and input[1] of the
2090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // final output multiplication were swapped, this method would not identify it
2100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // as an LSTM cell. This is OK in most cases, because
2110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way.
2120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Final output multiply
2140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto op_it = model->operators.begin() + op_index;
2150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* final_output_mul = op_it->get();
2160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (final_output_mul->type != OperatorType::kMul) {
2170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator *state_output_tanh, *fc_output_sig;
2200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
2210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_output_tanh, OperatorType::kLogistic,
2220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &fc_output_sig)) {
2230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State output TanH
2270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // (We don't count an operator as ID'd until we verify it has the correct
2280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // operator types feeding into it.)
2290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* state_combine_add;
2300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
2310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_combine_add)) {
2320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  string prev_state;
2350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0],
2360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                &prev_state)) {
2370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State forget & remember addition
2410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator *state_forget_mul, *state_remember_mul;
2420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
2430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_forget_mul, OperatorType::kMul,
2440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_remember_mul)) {
2450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (state_forget_mul->inputs[0] != prev_state) {
2480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State forget gate
2520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* state_forget_sig;
2530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
2540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           nullptr, OperatorType::kLogistic,
2550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_forget_sig)) {
2560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State remember gate
2600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator *state_remember_sig, *state_info_tanh;
2610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
2620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_remember_sig, OperatorType::kTanh,
2630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &state_info_tanh)) {
2640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
2670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State remember "information" activation function
2680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* fc_output_split;
2690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_info_tanh, *model,
2700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           OperatorType::kTensorFlowSplit, &fc_output_split)) {
2710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State remember gate activation function
2740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* tmp;
2750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_remember_sig, *model,
2760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           OperatorType::kTensorFlowSplit, &tmp) ||
2770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      (tmp != fc_output_split)) {
2780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // State forget gate activation function
2810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*state_forget_sig, *model,
2820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           OperatorType::kTensorFlowSplit, &tmp) ||
2830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      (tmp != fc_output_split)) {
2840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Fully connected output activation function
2870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*fc_output_sig, *model,
2880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           OperatorType::kTensorFlowSplit, &tmp) ||
2890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      (tmp != fc_output_split)) {
2900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Fully connected output split
2930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* fully_connected;
2940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
2950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           nullptr, OperatorType::kFullyConnected,
2960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           &fully_connected)) {
2970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
2980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
2990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Fully connected op
3010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* concat_inputs;
3020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!MatchOperatorInputs(*fully_connected, *model,
3030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           OperatorType::kConcatenation, &concat_inputs,
3040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           OperatorType::kNone, nullptr, OperatorType::kNone,
3050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                           nullptr)) {
3060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
3070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
3080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Emplace a new LSTM cell operator
3100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* lstm_cell_op = new LstmCellOperator;
3110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
3120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0];
3130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] =
3140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      concat_inputs->inputs[1];
3150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] =
3160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      fully_connected->inputs[1];
3170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] =
3180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      fully_connected->inputs[2];
3190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state;
3200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
3210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
3220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      state_output_tanh->inputs[0];
3230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
3240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      final_output_mul->outputs[0];
3250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.emplace(op_it, lstm_cell_op);
3260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Creating %s replacing equivalent subgraph",
3270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              LogName(*lstm_cell_op));
3280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Create temp arrays used internally during runtime.
3300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string base_name(FindLongestCommonPrefix(
3310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT],
3320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
3330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string& concat_temp_array_name =
3340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      AvailableArrayName(*model, base_name + "concat_temp");
3350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->GetOrCreateArray(concat_temp_array_name);
3360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
3370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string& activ_temp_array_name =
3380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      AvailableArrayName(*model, base_name + "activ_temp");
3390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->GetOrCreateArray(activ_temp_array_name);
3400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name;
3410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Created temp outputs %s and %s on operator %s",
3420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              concat_temp_array_name, activ_temp_array_name,
3430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              LogName(*lstm_cell_op));
3440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Delete arrays and operators replaced by the LSTM cell operator. Order is
3460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // important - DeleteArrayIfUnused() only succeeds if dependent operators
3470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // have been removed first. Start at the output and work towards the input.
3480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *final_output_mul));
3490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(state_output_tanh->outputs[0], model);
3500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(fc_output_sig->outputs[0], model);
3510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_output_tanh));
3520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *fc_output_sig));
3530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_combine_add));
3540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(state_forget_mul->outputs[0], model);
3550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(state_remember_mul->outputs[0], model);
3560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_forget_mul));
3570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_remember_mul));
3580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(state_forget_sig->outputs[0], model);
3590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(state_info_tanh->outputs[0], model);
3600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(state_remember_sig->outputs[0], model);
3610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_forget_sig));
3620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_info_tanh));
3630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *state_remember_sig));
3640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(fc_output_split->outputs[0], model);
3650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(fc_output_split->outputs[1], model);
3660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(fc_output_split->outputs[2], model);
3670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(fc_output_split->outputs[3], model);
3680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  string dims_array = fc_output_split->inputs[0];
3690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *fc_output_split));
3700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(dims_array, model);
3710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(fully_connected->outputs[0], model);
3720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *fully_connected));
3730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DeleteArrayIfUnused(concat_inputs->outputs[0], model);
3740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, *concat_inputs));
3750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
3760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
3770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
3780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
379