1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <cmath> 16#include <memory> 17#include <string> 18#include <unordered_map> 19#include <vector> 20 21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 22#include "tensorflow/contrib/lite/toco/model.h" 23#include "tensorflow/contrib/lite/toco/tooling_util.h" 24#include "tensorflow/core/platform/logging.h" 25 26namespace toco { 27 28namespace { 29 30std::vector<std::unique_ptr<Operator>>::iterator FindOperator( 31 Model* model, const Operator* op) { 32 auto it = model->operators.begin(); 33 for (; it != model->operators.end(); ++it) { 34 if (it->get() == op) { 35 break; 36 } 37 } 38 return it; 39} 40} // namespace 41 42bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { 43 const auto div_it = model->operators.begin() + op_index; 44 const auto* div_or_mul_op = div_it->get(); 45 OperatorType expected_op_type_producing_div_or_mul_input; 46 if (div_or_mul_op->type == OperatorType::kDiv) { 47 expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; 48 } else if (div_or_mul_op->type == OperatorType::kMul) { 49 expected_op_type_producing_div_or_mul_input = 50 OperatorType::kTensorFlowRsqrt; 51 } else { 52 return false; 53 } 54 CHECK_EQ(div_or_mul_op->inputs.size(), 2); 55 Operator* op_producing_div_or_mul_input[2] = { 56 GetOpWithOutput(*model, div_or_mul_op->inputs[0]), 57 GetOpWithOutput(*model, div_or_mul_op->inputs[1]), 58 }; 59 if (!op_producing_div_or_mul_input[1] || 60 op_producing_div_or_mul_input[1]->type != 61 expected_op_type_producing_div_or_mul_input) { 62 return false; 63 } 64 Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; 65 CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); 66 Operator* op_producing_sqrt_or_rsqrt_input = 67 GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); 68 if (!op_producing_sqrt_or_rsqrt_input) { 69 return false; 70 } 71 72 // There may be an Add or a Maximum here, adding or clamping to a "small" 73 // constant scalar. 74 // Reported bug: b/29395854 75 Operator* add_op = nullptr; 76 Operator* op_producing_add_input = nullptr; 77 if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || 78 op_producing_sqrt_or_rsqrt_input->type == 79 OperatorType::kTensorFlowMaximum) { 80 add_op = op_producing_sqrt_or_rsqrt_input; 81 bool add_can_be_removed = false; 82 CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); 83 for (int i = 0; i < 2; i++) { 84 const auto& input_array = 85 model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]); 86 if (!input_array.buffer) { 87 continue; 88 } 89 if (input_array.buffer->type != ArrayDataType::kFloat) { 90 continue; 91 } 92 if (RequiredBufferSizeForShape(input_array.shape()) != 1) { 93 continue; 94 } 95 const auto& input_float_data = 96 input_array.GetBuffer<ArrayDataType::kFloat>().data; 97 if (std::abs(input_float_data[0]) > 1e-3f) { 98 continue; 99 } 100 add_can_be_removed = true; 101 op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]); 102 break; 103 } 104 if (!add_can_be_removed) { 105 AddMessageF( 106 "Giving up trying to identify L2Normalization subgraph " 107 " because the operator producing the input to the square root, %s," 108 ", does not match the expected pattern", 109 LogName(*op_producing_sqrt_or_rsqrt_input)); 110 return false; 111 } 112 } 113 114 Operator* sum_op = 115 add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; 116 if (sum_op->type != OperatorType::kTensorFlowSum) { 117 AddMessageF( 118 "Giving up trying to identify L2Normalization subgraph: " 119 "expected Sum op, got %s", 120 LogName(*sum_op)); 121 return false; 122 } 123 124 Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); 125 if (square_op->type != OperatorType::kTensorFlowSquare) { 126 AddMessageF( 127 "Giving up trying to identify L2Normalization subgraph: " 128 "expected Square op, got %s", 129 LogName(*square_op)); 130 return false; 131 } 132 133 CHECK_EQ(square_op->inputs.size(), 1); 134 135 if (square_op->inputs[0] != div_or_mul_op->inputs[0]) { 136 AddMessageF( 137 "Giving up trying to identify L2Normalization subgraph: %s does not " 138 "take the same input as the Mul/Div node", 139 LogName(*square_op)); 140 return false; 141 } 142 143 // Create and emplace the new L2Normalization 144 auto* l2norm_op = new L2NormalizationOperator; 145 l2norm_op->inputs = {div_or_mul_op->inputs[0]}; 146 l2norm_op->outputs = div_or_mul_op->outputs; 147 model->operators.emplace(div_it, l2norm_op); 148 149 AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); 150 151 // Erase the subgraph that is now replaced by L2Normalization 152 model->operators.erase(FindOperator(model, square_op)); 153 model->EraseArray(sum_op->inputs[0]); 154 if (sum_op->inputs.size() > 1) { 155 model->EraseArray(sum_op->inputs[1]); 156 } 157 model->operators.erase(FindOperator(model, sum_op)); 158 if (add_op) { 159 model->EraseArray(add_op->inputs[0]); 160 model->EraseArray(add_op->inputs[1]); 161 model->operators.erase(FindOperator(model, add_op)); 162 } 163 model->EraseArray(sqrt_or_rsqrt_op->inputs[0]); 164 model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); 165 model->EraseArray(div_or_mul_op->inputs[1]); 166 model->operators.erase(FindOperator(model, div_or_mul_op)); 167 return true; 168} 169 170} // namespace toco 171