10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <cmath> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory> 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string> 180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map> 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector> 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h" 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h" 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h" 250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco { 270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace { 290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<std::unique_ptr<Operator>>::iterator FindOperator( 310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Model* model, const Operator* op) { 320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto it = model->operators.begin(); 330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (; it != model->operators.end(); ++it) { 340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (it->get() == op) { 350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle break; 360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return it; 390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace 410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { 430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto div_it = model->operators.begin() + op_index; 440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto* div_or_mul_op = div_it->get(); 450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType expected_op_type_producing_div_or_mul_input; 460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (div_or_mul_op->type == OperatorType::kDiv) { 470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; 480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } else if (div_or_mul_op->type == OperatorType::kMul) { 490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle expected_op_type_producing_div_or_mul_input = 500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kTensorFlowRsqrt; 510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } else { 520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK_EQ(div_or_mul_op->inputs.size(), 2); 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* op_producing_div_or_mul_input[2] = { 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle GetOpWithOutput(*model, div_or_mul_op->inputs[0]), 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle GetOpWithOutput(*model, div_or_mul_op->inputs[1]), 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }; 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!op_producing_div_or_mul_input[1] || 600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle op_producing_div_or_mul_input[1]->type != 610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle expected_op_type_producing_div_or_mul_input) { 620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* op_producing_sqrt_or_rsqrt_input = 670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); 680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!op_producing_sqrt_or_rsqrt_input) { 690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // There may be an Add or a Maximum here, adding or clamping to a "small" 730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // constant scalar. 740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Reported bug: b/29395854 750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* add_op = nullptr; 760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* op_producing_add_input = nullptr; 770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || 780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle op_producing_sqrt_or_rsqrt_input->type == 790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle OperatorType::kTensorFlowMaximum) { 800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_op = op_producing_sqrt_or_rsqrt_input; 810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bool add_can_be_removed = false; 820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); 830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (int i = 0; i < 2; i++) { 840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& input_array = 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]); 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!input_array.buffer) { 870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle continue; 880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (input_array.buffer->type != ArrayDataType::kFloat) { 900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle continue; 910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (RequiredBufferSizeForShape(input_array.shape()) != 1) { 930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle continue; 940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& input_float_data = 960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle input_array.GetBuffer<ArrayDataType::kFloat>().data; 970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (std::abs(input_float_data[0]) > 1e-3f) { 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle continue; 990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_can_be_removed = true; 1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]); 1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle break; 1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!add_can_be_removed) { 1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF( 1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "Giving up trying to identify L2Normalization subgraph " 1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle " because the operator producing the input to the square root, %s," 1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle ", does not match the expected pattern", 1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*op_producing_sqrt_or_rsqrt_input)); 1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* sum_op = 1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; 1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (sum_op->type != OperatorType::kTensorFlowSum) { 1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF( 1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "Giving up trying to identify L2Normalization subgraph: " 1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "expected Sum op, got %s", 1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*sum_op)); 1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); 1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (square_op->type != OperatorType::kTensorFlowSquare) { 1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF( 1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "Giving up trying to identify L2Normalization subgraph: " 1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "expected Square op, got %s", 1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*square_op)); 1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK_EQ(square_op->inputs.size(), 1); 1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (square_op->inputs[0] != div_or_mul_op->inputs[0]) { 1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF( 1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "Giving up trying to identify L2Normalization subgraph: %s does not " 1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "take the same input as the Mul/Div node", 1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*square_op)); 1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Create and emplace the new L2Normalization 1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* l2norm_op = new L2NormalizationOperator; 1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle l2norm_op->inputs = {div_or_mul_op->inputs[0]}; 1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle l2norm_op->outputs = div_or_mul_op->outputs; 1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.emplace(div_it, l2norm_op); 1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); 1500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Erase the subgraph that is now replaced by L2Normalization 1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, square_op)); 153ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(sum_op->inputs[0]); 1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (sum_op->inputs.size() > 1) { 155ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(sum_op->inputs[1]); 1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, sum_op)); 1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (add_op) { 159ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(add_op->inputs[0]); 160ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(add_op->inputs[1]); 1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, add_op)); 1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 163ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(sqrt_or_rsqrt_op->inputs[0]); 1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); 165ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(div_or_mul_op->inputs[1]); 1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(FindOperator(model, div_or_mul_op)); 1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace toco 171