154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 354da0e3ca924a5040e88a1c067f9f6760a14b20bsamebLicensed under the Apache License, Version 2.0 (the "License"); 454da0e3ca924a5040e88a1c067f9f6760a14b20bsamebyou may not use this file except in compliance with the License. 554da0e3ca924a5040e88a1c067f9f6760a14b20bsamebYou may obtain a copy of the License at 654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb http://www.apache.org/licenses/LICENSE-2.0 854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 954da0e3ca924a5040e88a1c067f9f6760a14b20bsamebUnless required by applicable law or agreed to in writing, software 1054da0e3ca924a5040e88a1c067f9f6760a14b20bsamebdistributed under the License is distributed on an "AS IS" BASIS, 1154da0e3ca924a5040e88a1c067f9f6760a14b20bsamebWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1254da0e3ca924a5040e88a1c067f9f6760a14b20bsamebSee the License for the specific language governing permissions and 1354da0e3ca924a5040e88a1c067f9f6760a14b20bsameblimitations under the License. 1454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb==============================================================================*/ 1554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <memory> 1654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <string> 1754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <unordered_map> 1854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <vector> 1954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 2054da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 2154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/model.h" 2254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/runtime/types.h" 2354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/tooling_util.h" 2454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/core/platform/logging.h" 2554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 2654da0e3ca924a5040e88a1c067f9f6760a14b20bsamebnamespace toco { 2754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 2854da0e3ca924a5040e88a1c067f9f6760a14b20bsamebnamespace { 2954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 3054da0e3ca924a5040e88a1c067f9f6760a14b20bsamebvoid FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, 3154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const Operator* add_or_sub_op, 3254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb int index_of_constant_input) { 3354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb CHECK(add_or_sub_op->type == OperatorType::kAdd || 3454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb add_or_sub_op->type == OperatorType::kSub); 3554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); 3654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb if (preceding_op->inputs.size() < 3) { 3754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb LOG(FATAL) << "Missing bias parameter"; 3854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb } 3954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb auto& bias = model->GetArray(preceding_op->inputs[2]); 4054da0e3ca924a5040e88a1c067f9f6760a14b20bsameb bias.minmax = nullptr; 4154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const auto& operand = 4254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); 4354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 4454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const Shape& bias_shape = bias.shape(); 4554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const Shape& operand_shape = operand.shape(); 4654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); 4754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb float* const bias_data = bias_buffer.data.data(); 4854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>(); 4954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const float* const operand_data = operand_buffer.data.data(); 5054da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 5154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb // TODO(b/62904716): Bias array should become 1-D when padding removed. 5254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); 5354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); 5454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 5554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; 5654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb 5754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb const OpType optype = (add_or_sub_op->type == OperatorType::kAdd) 5854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb ? OpType::BiasPlusOperand 5954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb : (index_of_constant_input == 1) 60 ? OpType::BiasMinusOperand 61 : OpType::OperandMinusBias; 62 63 for (int i = 0; i < depth; i++) { 64 float& bias_val = bias_data[i]; 65 const float operand_val = operand_data[i]; 66 if (optype == OpType::BiasPlusOperand) { 67 bias_val += operand_val; 68 } else if (optype == OpType::BiasMinusOperand) { 69 bias_val -= operand_val; 70 } else if (optype == OpType::OperandMinusBias) { 71 bias_val = operand_val - bias_val; 72 } else { 73 LOG(FATAL) << "Should not get here."; 74 } 75 } 76} 77 78void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, 79 const Operator* mul_or_div_op, 80 int index_of_constant_input) { 81 CHECK(mul_or_div_op->type == OperatorType::kMul || 82 mul_or_div_op->type == OperatorType::kDiv); 83 CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); 84 // If the op is a division, the constant input should be the right hand side. 85 // This should have been checked before this point. 86 CHECK(mul_or_div_op->type != OperatorType::kDiv || 87 index_of_constant_input == 1); 88 if (preceding_op->inputs.size() < 3) { 89 LOG(FATAL) << "Missing bias parameter"; 90 } 91 const auto& weights_name = preceding_op->inputs[1]; 92 const auto& bias_name = preceding_op->inputs[2]; 93 auto& weights = model->GetArray(weights_name); 94 DropMinMax(model, weights_name); 95 auto& bias = model->GetArray(bias_name); 96 DropMinMax(model, bias_name); 97 const auto& operand = 98 model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); 99 100 const Shape& weights_shape = weights.shape(); 101 const Shape& bias_shape = bias.shape(); 102 const Shape& operand_shape = operand.shape(); 103 auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>(); 104 float* const weights_data = weights_buffer.data.data(); 105 auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); 106 float* const bias_data = bias_buffer.data.data(); 107 const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>(); 108 const float* const operand_data = operand_buffer.data.data(); 109 110 // We support broadcasting the operand along the depth dimension, 111 // when the operand's depth is 1. 112 int operand_channel_increment = 0; 113 if (operand_shape.dimensions_count() >= 1 && 114 operand_shape.dims(operand_shape.dimensions_count() - 1) == 115 bias_shape.dims(bias_shape.dimensions_count() - 1)) { 116 operand_channel_increment = 1; 117 } else if (operand_shape.dimensions_count() == 0 || 118 operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { 119 operand_channel_increment = 0; 120 } else { 121 LOG(FATAL) << "Operand shape mismatch."; 122 } 123 124 int output_depth; 125 126 if (preceding_op->type == OperatorType::kConv || 127 preceding_op->type == OperatorType::kFullyConnected) { 128 output_depth = weights_shape.dims(0); 129 } else if (preceding_op->type == OperatorType::kDepthwiseConv) { 130 output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1); 131 } else { 132 LOG(FATAL) << "Should not get here"; 133 } 134 135 const int weights_size = RequiredBufferSizeForShape(weights_shape); 136 const int weights_per_depth = weights_size / output_depth; 137 CHECK_EQ(weights_size, weights_per_depth * output_depth); 138 139 int operand_channel = 0; 140 for (int c = 0; c < output_depth; c++) { 141 if (mul_or_div_op->type == OperatorType::kMul) { 142 bias_data[c] *= operand_data[operand_channel]; 143 } else if (mul_or_div_op->type == OperatorType::kDiv) { 144 bias_data[c] /= operand_data[operand_channel]; 145 } else { 146 LOG(FATAL) << "Should not get here"; 147 } 148 if (preceding_op->type == OperatorType::kConv || 149 preceding_op->type == OperatorType::kFullyConnected) { 150 for (int i = 0; i < weights_per_depth; i++) { 151 if (mul_or_div_op->type == OperatorType::kMul) { 152 weights_data[c * weights_per_depth + i] *= 153 operand_data[operand_channel]; 154 } else if (mul_or_div_op->type == OperatorType::kDiv) { 155 weights_data[c * weights_per_depth + i] /= 156 operand_data[operand_channel]; 157 } else { 158 LOG(FATAL) << "Should not get here"; 159 } 160 } 161 } else if (preceding_op->type == OperatorType::kDepthwiseConv) { 162 for (int k = 0; k < weights_per_depth; k++) { 163 if (mul_or_div_op->type == OperatorType::kMul) { 164 weights_data[k * output_depth + c] *= operand_data[operand_channel]; 165 } else if (mul_or_div_op->type == OperatorType::kDiv) { 166 weights_data[k * output_depth + c] /= operand_data[operand_channel]; 167 } else { 168 LOG(FATAL) << "Should not get here"; 169 } 170 } 171 } else { 172 LOG(FATAL) << "Should not get here"; 173 } 174 operand_channel += operand_channel_increment; 175 } 176} 177} // namespace 178 179bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { 180 const auto binary_it = model->operators.begin() + op_index; 181 const auto* binary_op = binary_it->get(); 182 if (binary_op->type != OperatorType::kAdd && 183 binary_op->type != OperatorType::kMul && 184 binary_op->type != OperatorType::kSub && 185 binary_op->type != OperatorType::kDiv) { 186 return false; 187 } 188 189 CHECK_EQ(binary_op->inputs.size(), 2); 190 191 // We only can fuse an binary when the two operands break down as follows: 192 // 1. One operand is the (variable) output of a typical affine (linear plus 193 // bias) 194 // op of a finite list of possible types: at the moment Conv, 195 // DepthwiseConv and 196 // FullyConnected are supported. 197 // 2. The other operand is a constant param array. 198 const bool is_input_constant[2] = { 199 IsConstantParameterArray(*model, binary_op->inputs[0]), 200 IsConstantParameterArray(*model, binary_op->inputs[1]), 201 }; 202 if (!is_input_constant[0] && !is_input_constant[1]) { 203 // Neither input is constant, so nothing we can fuse into a constant. 204 return false; 205 } 206 if (is_input_constant[0] && is_input_constant[1]) { 207 // Both inputs are constants. That's a job for constants 208 // propagation, not for us to handle here. 209 return false; 210 } 211 const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 212 const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 213 CHECK(is_input_constant[index_of_constant_input]); 214 CHECK(!is_input_constant[index_of_variable_input]); 215 216 // For division, we can only fuse if the denominator is constant. 217 if (binary_op->type == OperatorType::kDiv) { 218 if (index_of_constant_input != 1) { 219 AddMessageF("Not fusing %s because the denominator is not constant", 220 LogName(*binary_op)); 221 return false; 222 } 223 } 224 225 Operator* preceding_op = 226 GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]); 227 if (!preceding_op) { 228 AddMessageF("Not fusing %s because it is not the output of another op", 229 LogName(*binary_op)); 230 return false; 231 } 232 233 for (const string& output_array : model->flags.output_arrays()) { 234 if (preceding_op->outputs[0] == output_array) { 235 return false; 236 } 237 } 238 239 if (preceding_op->type != OperatorType::kConv && 240 preceding_op->type != OperatorType::kFullyConnected && 241 preceding_op->type != OperatorType::kDepthwiseConv) { 242 AddMessageF( 243 "Not fusing %s because the preceding %s is not of one of the supported " 244 "types", 245 LogName(*binary_op), LogName(*preceding_op)); 246 return false; 247 } 248 249 if (preceding_op->fused_activation_function != 250 FusedActivationFunctionType::kNone) { 251 AddMessageF( 252 "Not fusing %s because the preceding %s has a fused activation " 253 "function", 254 LogName(*binary_op), LogName(*preceding_op)); 255 return false; 256 } 257 258 if (preceding_op->inputs.size() < 3) { 259 AddMessageF( 260 "Not fusing %s because the preceding %s does not have a bias vector", 261 LogName(*binary_op), LogName(*preceding_op)); 262 return false; 263 } 264 265 const auto& weights = model->GetArray(preceding_op->inputs[1]); 266 const auto& bias = model->GetArray(preceding_op->inputs[2]); 267 if (binary_op->type == OperatorType::kAdd || 268 binary_op->type == OperatorType::kSub) { 269 if (!bias.buffer) { 270 AddMessageF( 271 "Not fusing %s because the preceding %s has a non-constant bias " 272 "array", 273 LogName(*binary_op), LogName(*preceding_op)); 274 return false; 275 } 276 } else { 277 if (!weights.buffer || !bias.buffer) { 278 AddMessageF( 279 "Not fusing %s because the preceding %s has non-constant weights or " 280 "bias arrays", 281 LogName(*binary_op), LogName(*preceding_op)); 282 return false; 283 } 284 } 285 286 int count_ops_consuming_output = 287 CountOpsWithInput(*model, preceding_op->outputs[0]); 288 DCHECK_GE(count_ops_consuming_output, 1); 289 if (count_ops_consuming_output > 1) { 290 AddMessageF( 291 "Not fusing %s because the output of the preceding %s is consumed by " 292 "another op", 293 LogName(*binary_op), LogName(*preceding_op)); 294 return false; 295 } 296 297 AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), 298 LogName(*preceding_op)); 299 300 if (binary_op->type == OperatorType::kAdd || 301 binary_op->type == OperatorType::kSub) { 302 FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op, 303 index_of_constant_input); 304 } else if (binary_op->type == OperatorType::kMul || 305 binary_op->type == OperatorType::kDiv) { 306 FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op, 307 index_of_constant_input); 308 } else { 309 LOG(FATAL) << "should not get here"; 310 } 311 312 model->EraseArray(preceding_op->outputs[0]); 313 preceding_op->outputs[0] = binary_op->outputs[0]; 314 preceding_op->fused_activation_function = 315 binary_op->fused_activation_function; 316 const auto& old_constant_param_name = 317 binary_op->inputs[index_of_constant_input]; 318 CHECK(IsConstantParameterArray(*model, old_constant_param_name)); 319 if (CountOpsWithInput(*model, old_constant_param_name) == 1) { 320 model->EraseArray(old_constant_param_name); 321 } 322 model->operators.erase(binary_it); 323 return true; 324} 325 326} // namespace toco 327