10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string> 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector> 180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h" 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco { 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace { 260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<std::unique_ptr<Operator>>::iterator FindOperator( 280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Model* model, const Operator& op) { 290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto it = model->operators.begin(); 300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (; it != model->operators.end(); ++it) { 310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (it->get() == &op) { 320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle break; 330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return it; 360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool GetStateArrayForBackEdge(const Model& model, 390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string& back_edge_source_array, 400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle string* state_array = nullptr) { 410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (const auto& rnn_state : model.flags.rnn_states()) { 420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (back_edge_source_array == rnn_state.back_edge_source_array()) { 430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Found LSTM cell output 440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (state_array) { 450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *state_array = rnn_state.state_array(); 460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns true if the given operator has exactly 1 input, and is connected to 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// the given op_type. 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// We use kNone to indicate an input unattached to an operator output. Usually 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// these are the static input arrays. 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MatchOperatorInputs(const Operator& op, const Model& model, 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType op_type, Operator** connected_op) { 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check for required number of inputs 600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (op.inputs.size() != 1) { 610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check if first input is disconnected/connected to an operator 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* x = GetOpWithOutput(model, op.inputs[0]); 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((op_type == OperatorType::kNone) && (x != nullptr)) { 670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((op_type != OperatorType::kNone) && (x == nullptr)) { 700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check that first operator, if connected, is of correct type 740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((x != nullptr) && (x->type != op_type)) { 750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Successfully matched. Optionally return matching input operators. 790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (connected_op) { 800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *connected_op = x; 810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns true if the given operator has exactly 2 inputs, which are connected 870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// to the given op_types. 880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// We use kNone to indicate an input unattached to an operator output. Usually 890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// these are the static input arrays. 900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MatchOperatorInputs(const Operator& op, const Model& model, 910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType a_op_type, Operator** a_op, 920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType b_op_type, Operator** b_op) { 930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check for required number of inputs 940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (op.inputs.size() != 2) { 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check if first input is disconnected/connected to an operator 990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* x = GetOpWithOutput(model, op.inputs[0]); 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { 1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { 1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check that first operator, if connected, is of correct type 1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((x != nullptr) && (x->type != a_op_type)) { 1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check if second input is disconnected/connected to an operator 1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* y = GetOpWithOutput(model, op.inputs[1]); 1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { 1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { 1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check that second operator, if connected, is of correct type 1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((y != nullptr) && (y->type != b_op_type)) { 1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Successfully matched. Optionally return matching input operators. 1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (a_op != nullptr) { 1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *a_op = x; 1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (b_op != nullptr) { 1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *b_op = y; 1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns true if the given operator has exactly 3 inputs, which are connected 1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// to the given op_types. 1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// We use kNone to indicate an input unattached to an operator output. Usually 1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// these are the static input arrays. 1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MatchOperatorInputs(const Operator& op, const Model& model, 1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType a_op_type, Operator** a_op, 1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType b_op_type, Operator** b_op, 1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType c_op_type, Operator** c_op) { 1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check for required number of inputs 1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (op.inputs.size() != 3) { 1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check if first input is disconnected/connected to an operator 1500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* x = GetOpWithOutput(model, op.inputs[0]); 1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { 1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { 1550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check that first operator, if connected, is of correct type 1590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((x != nullptr) && (x->type != a_op_type)) { 1600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check if second input is disconnected/connected to an operator 1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* y = GetOpWithOutput(model, op.inputs[1]); 1650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { 1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { 1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check that second operator, if connected, is of correct type 1730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((y != nullptr) && (y->type != b_op_type)) { 1740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check if third input is disconnected/connected to an operator 1780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* z = GetOpWithOutput(model, op.inputs[2]); 1790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { 1800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { 1830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Check that third operator, if connected, is of correct type 1870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if ((z != nullptr) && (z->type != c_op_type)) { 1880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Successfully matched. Optionally return matching input operators. 1920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (a_op != nullptr) { 1930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *a_op = x; 1940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (b_op != nullptr) { 1960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *b_op = y; 1970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (c_op != nullptr) { 1990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle *c_op = z; 2000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 2020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 2030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace 2050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { 2070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This LSTM cell identification method is not invariant to commutation of 2080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // commutative operator inputs. For example, if input[0] and input[1] of the 2090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // final output multiplication were swapped, this method would not identify it 2100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // as an LSTM cell. This is OK in most cases, because 2110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way. 2120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Final output multiply 2140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto op_it = model->operators.begin() + op_index; 2150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* final_output_mul = op_it->get(); 2160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (final_output_mul->type != OperatorType::kMul) { 2170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator *state_output_tanh, *fc_output_sig; 2200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, 2210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_output_tanh, OperatorType::kLogistic, 2220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &fc_output_sig)) { 2230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State output TanH 2270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // (We don't count an operator as ID'd until we verify it has the correct 2280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // operator types feeding into it.) 2290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* state_combine_add; 2300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd, 2310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_combine_add)) { 2320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle string prev_state; 2350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], 2360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &prev_state)) { 2370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State forget & remember addition 2410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator *state_forget_mul, *state_remember_mul; 2420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul, 2430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_forget_mul, OperatorType::kMul, 2440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_remember_mul)) { 2450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (state_forget_mul->inputs[0] != prev_state) { 2480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State forget gate 2520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* state_forget_sig; 2530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone, 2540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle nullptr, OperatorType::kLogistic, 2550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_forget_sig)) { 2560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State remember gate 2600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator *state_remember_sig, *state_info_tanh; 2610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic, 2620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_remember_sig, OperatorType::kTanh, 2630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &state_info_tanh)) { 2640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 2670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State remember "information" activation function 2680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* fc_output_split; 2690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_info_tanh, *model, 2700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kTensorFlowSplit, &fc_output_split)) { 2710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State remember gate activation function 2740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* tmp; 2750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_remember_sig, *model, 2760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kTensorFlowSplit, &tmp) || 2770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle (tmp != fc_output_split)) { 2780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // State forget gate activation function 2810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*state_forget_sig, *model, 2820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kTensorFlowSplit, &tmp) || 2830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle (tmp != fc_output_split)) { 2840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Fully connected output activation function 2870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*fc_output_sig, *model, 2880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kTensorFlowSplit, &tmp) || 2890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle (tmp != fc_output_split)) { 2900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Fully connected output split 2930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* fully_connected; 2940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone, 2950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle nullptr, OperatorType::kFullyConnected, 2960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle &fully_connected)) { 2970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 2980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 2990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Fully connected op 3010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* concat_inputs; 3020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!MatchOperatorInputs(*fully_connected, *model, 3030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kConcatenation, &concat_inputs, 3040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kNone, nullptr, OperatorType::kNone, 3050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle nullptr)) { 3060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 3070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 3080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Emplace a new LSTM cell operator 3100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* lstm_cell_op = new LstmCellOperator; 3110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); 3120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0]; 3130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = 3140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle concat_inputs->inputs[1]; 3150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = 3160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle fully_connected->inputs[1]; 3170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = 3180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle fully_connected->inputs[2]; 3190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state; 3200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); 3210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = 3220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle state_output_tanh->inputs[0]; 3230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = 3240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle final_output_mul->outputs[0]; 3250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.emplace(op_it, lstm_cell_op); 3260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Creating %s replacing equivalent subgraph", 3270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*lstm_cell_op)); 3280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Create temp arrays used internally during runtime. 3300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string base_name(FindLongestCommonPrefix( 3310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT], 3320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT])); 3330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string& concat_temp_array_name = 3340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AvailableArrayName(*model, base_name + "concat_temp"); 3350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->GetOrCreateArray(concat_temp_array_name); 3360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; 3370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string& activ_temp_array_name = 3380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AvailableArrayName(*model, base_name + "activ_temp"); 3390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->GetOrCreateArray(activ_temp_array_name); 3400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name; 3410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Created temp outputs %s and %s on operator %s", 3420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle concat_temp_array_name, activ_temp_array_name, 3430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*lstm_cell_op)); 3440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Delete arrays and operators replaced by the LSTM cell operator. Order is 3460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // important - DeleteArrayIfUnused() only succeeds if dependent operators 3470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // have been removed first. Start at the output and work towards the input. 3480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *final_output_mul)); 3490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(state_output_tanh->outputs[0], model); 3500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(fc_output_sig->outputs[0], model); 3510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_output_tanh)); 3520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *fc_output_sig)); 3530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_combine_add)); 3540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(state_forget_mul->outputs[0], model); 3550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(state_remember_mul->outputs[0], model); 3560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_forget_mul)); 3570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_remember_mul)); 3580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(state_forget_sig->outputs[0], model); 3590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(state_info_tanh->outputs[0], model); 3600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(state_remember_sig->outputs[0], model); 3610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_forget_sig)); 3620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_info_tanh)); 3630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *state_remember_sig)); 3640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(fc_output_split->outputs[0], model); 3650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(fc_output_split->outputs[1], model); 3660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(fc_output_split->outputs[2], model); 3670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(fc_output_split->outputs[3], model); 3680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle string dims_array = fc_output_split->inputs[0]; 3690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *fc_output_split)); 3700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(dims_array, model); 3710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(fully_connected->outputs[0], model); 3720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *fully_connected)); 3730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DeleteArrayIfUnused(concat_inputs->outputs[0], model); 3740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, *concat_inputs)); 3750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 3760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 3770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 3780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace toco 379