1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <memory> 16#include <string> 17#include <unordered_map> 18#include <vector> 19 20#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 21#include "tensorflow/contrib/lite/toco/model.h" 22#include "tensorflow/contrib/lite/toco/runtime/types.h" 23#include "tensorflow/contrib/lite/toco/tooling_util.h" 24#include "tensorflow/core/platform/logging.h" 25 26namespace toco { 27 28namespace { 29 30void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, 31 const Operator* add_or_sub_op, 32 int index_of_constant_input) { 33 CHECK(add_or_sub_op->type == OperatorType::kAdd || 34 add_or_sub_op->type == OperatorType::kSub); 35 CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); 36 if (preceding_op->inputs.size() < 3) { 37 LOG(FATAL) << "Missing bias parameter"; 38 } 39 auto& bias = model->GetArray(preceding_op->inputs[2]); 40 bias.minmax = nullptr; 41 const auto& operand = 42 model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); 43 44 const Shape& bias_shape = bias.shape(); 45 const Shape& operand_shape = operand.shape(); 46 auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); 47 float* const bias_data = bias_buffer.data.data(); 48 const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>(); 49 const float* const operand_data = operand_buffer.data.data(); 50 51 // TODO(b/62904716): Bias array should become 1-D when padding removed. 52 const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); 53 CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); 54 55 enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; 56 57 const OpType optype = (add_or_sub_op->type == OperatorType::kAdd) 58 ? OpType::BiasPlusOperand 59 : (index_of_constant_input == 1) 60 ? OpType::BiasMinusOperand 61 : OpType::OperandMinusBias; 62 63 for (int i = 0; i < depth; i++) { 64 float& bias_val = bias_data[i]; 65 const float operand_val = operand_data[i]; 66 if (optype == OpType::BiasPlusOperand) { 67 bias_val += operand_val; 68 } else if (optype == OpType::BiasMinusOperand) { 69 bias_val -= operand_val; 70 } else if (optype == OpType::OperandMinusBias) { 71 bias_val = operand_val - bias_val; 72 } else { 73 LOG(FATAL) << "Should not get here."; 74 } 75 } 76} 77 78void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, 79 const Operator* mul_or_div_op, 80 int index_of_constant_input) { 81 CHECK(mul_or_div_op->type == OperatorType::kMul || 82 mul_or_div_op->type == OperatorType::kDiv); 83 CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); 84 // If the op is a division, the constant input should be the right hand side. 85 // This should have been checked before this point. 86 CHECK(mul_or_div_op->type != OperatorType::kDiv || 87 index_of_constant_input == 1); 88 if (preceding_op->inputs.size() < 3) { 89 LOG(FATAL) << "Missing bias parameter"; 90 } 91 const auto& weights_name = preceding_op->inputs[1]; 92 const auto& bias_name = preceding_op->inputs[2]; 93 auto& weights = model->GetArray(weights_name); 94 DropMinMax(model, weights_name); 95 auto& bias = model->GetArray(bias_name); 96 DropMinMax(model, bias_name); 97 const auto& operand = 98 model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); 99 100 const Shape& weights_shape = weights.shape(); 101 const Shape& bias_shape = bias.shape(); 102 const Shape& operand_shape = operand.shape(); 103 auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>(); 104 float* const weights_data = weights_buffer.data.data(); 105 auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); 106 float* const bias_data = bias_buffer.data.data(); 107 const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>(); 108 const float* const operand_data = operand_buffer.data.data(); 109 110 // We support broadcasting the operand along the depth dimension, 111 // when the operand's depth is 1. 112 int operand_channel_increment = 0; 113 if (operand_shape.dimensions_count() >= 1 && 114 operand_shape.dims(operand_shape.dimensions_count() - 1) == 115 bias_shape.dims(bias_shape.dimensions_count() - 1)) { 116 operand_channel_increment = 1; 117 } else if (operand_shape.dimensions_count() == 0 || 118 operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { 119 operand_channel_increment = 0; 120 } else { 121 LOG(FATAL) << "Operand shape mismatch."; 122 } 123 124 int output_depth; 125 126 if (preceding_op->type == OperatorType::kConv || 127 preceding_op->type == OperatorType::kFullyConnected) { 128 output_depth = weights_shape.dims(0); 129 } else if (preceding_op->type == OperatorType::kDepthwiseConv) { 130 output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1); 131 } else { 132 LOG(FATAL) << "Should not get here"; 133 } 134 135 const int weights_size = RequiredBufferSizeForShape(weights_shape); 136 const int weights_per_depth = weights_size / output_depth; 137 CHECK_EQ(weights_size, weights_per_depth * output_depth); 138 139 int operand_channel = 0; 140 for (int c = 0; c < output_depth; c++) { 141 if (mul_or_div_op->type == OperatorType::kMul) { 142 bias_data[c] *= operand_data[operand_channel]; 143 } else if (mul_or_div_op->type == OperatorType::kDiv) { 144 bias_data[c] /= operand_data[operand_channel]; 145 } else { 146 LOG(FATAL) << "Should not get here"; 147 } 148 if (preceding_op->type == OperatorType::kConv || 149 preceding_op->type == OperatorType::kFullyConnected) { 150 for (int i = 0; i < weights_per_depth; i++) { 151 if (mul_or_div_op->type == OperatorType::kMul) { 152 weights_data[c * weights_per_depth + i] *= 153 operand_data[operand_channel]; 154 } else if (mul_or_div_op->type == OperatorType::kDiv) { 155 weights_data[c * weights_per_depth + i] /= 156 operand_data[operand_channel]; 157 } else { 158 LOG(FATAL) << "Should not get here"; 159 } 160 } 161 } else if (preceding_op->type == OperatorType::kDepthwiseConv) { 162 for (int k = 0; k < weights_per_depth; k++) { 163 if (mul_or_div_op->type == OperatorType::kMul) { 164 weights_data[k * output_depth + c] *= operand_data[operand_channel]; 165 } else if (mul_or_div_op->type == OperatorType::kDiv) { 166 weights_data[k * output_depth + c] /= operand_data[operand_channel]; 167 } else { 168 LOG(FATAL) << "Should not get here"; 169 } 170 } 171 } else { 172 LOG(FATAL) << "Should not get here"; 173 } 174 operand_channel += operand_channel_increment; 175 } 176} 177} // namespace 178 179bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { 180 const auto binary_it = model->operators.begin() + op_index; 181 const auto* binary_op = binary_it->get(); 182 if (binary_op->type != OperatorType::kAdd && 183 binary_op->type != OperatorType::kMul && 184 binary_op->type != OperatorType::kSub && 185 binary_op->type != OperatorType::kDiv) { 186 return false; 187 } 188 189 CHECK_EQ(binary_op->inputs.size(), 2); 190 191 // We only can fuse an binary when the two operands break down as follows: 192 // 1. One operand is the (variable) output of a typical affine (linear plus 193 // bias) 194 // op of a finite list of possible types: at the moment Conv, 195 // DepthwiseConv and 196 // FullyConnected are supported. 197 // 2. The other operand is a constant param array. 198 const bool is_input_constant[2] = { 199 IsConstantParameterArray(*model, binary_op->inputs[0]), 200 IsConstantParameterArray(*model, binary_op->inputs[1]), 201 }; 202 if (!is_input_constant[0] && !is_input_constant[1]) { 203 // Neither input is constant, so nothing we can fuse into a constant. 204 return false; 205 } 206 if (is_input_constant[0] && is_input_constant[1]) { 207 // Both inputs are constants. That's a job for constants 208 // propagation, not for us to handle here. 209 return false; 210 } 211 const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 212 const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 213 CHECK(is_input_constant[index_of_constant_input]); 214 CHECK(!is_input_constant[index_of_variable_input]); 215 216 // For division, we can only fuse if the denominator is constant. 217 if (binary_op->type == OperatorType::kDiv) { 218 if (index_of_constant_input != 1) { 219 AddMessageF("Not fusing %s because the denominator is not constant", 220 LogName(*binary_op)); 221 return false; 222 } 223 } 224 225 Operator* preceding_op = 226 GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]); 227 if (!preceding_op) { 228 AddMessageF("Not fusing %s because it is not the output of another op", 229 LogName(*binary_op)); 230 return false; 231 } 232 233 for (const string& output_array : model->flags.output_arrays()) { 234 if (preceding_op->outputs[0] == output_array) { 235 return false; 236 } 237 } 238 239 if (preceding_op->type != OperatorType::kConv && 240 preceding_op->type != OperatorType::kFullyConnected && 241 preceding_op->type != OperatorType::kDepthwiseConv) { 242 AddMessageF( 243 "Not fusing %s because the preceding %s is not of one of the supported " 244 "types", 245 LogName(*binary_op), LogName(*preceding_op)); 246 return false; 247 } 248 249 if (preceding_op->fused_activation_function != 250 FusedActivationFunctionType::kNone) { 251 AddMessageF( 252 "Not fusing %s because the preceding %s has a fused activation " 253 "function", 254 LogName(*binary_op), LogName(*preceding_op)); 255 return false; 256 } 257 258 if (preceding_op->inputs.size() < 3) { 259 AddMessageF( 260 "Not fusing %s because the preceding %s does not have a bias vector", 261 LogName(*binary_op), LogName(*preceding_op)); 262 return false; 263 } 264 265 const auto& weights = model->GetArray(preceding_op->inputs[1]); 266 const auto& bias = model->GetArray(preceding_op->inputs[2]); 267 if (binary_op->type == OperatorType::kAdd || 268 binary_op->type == OperatorType::kSub) { 269 if (!bias.buffer) { 270 AddMessageF( 271 "Not fusing %s because the preceding %s has a non-constant bias " 272 "array", 273 LogName(*binary_op), LogName(*preceding_op)); 274 return false; 275 } 276 } else { 277 if (!weights.buffer || !bias.buffer) { 278 AddMessageF( 279 "Not fusing %s because the preceding %s has non-constant weights or " 280 "bias arrays", 281 LogName(*binary_op), LogName(*preceding_op)); 282 return false; 283 } 284 } 285 286 int count_ops_consuming_output = 287 CountOpsWithInput(*model, preceding_op->outputs[0]); 288 DCHECK_GE(count_ops_consuming_output, 1); 289 if (count_ops_consuming_output > 1) { 290 AddMessageF( 291 "Not fusing %s because the output of the preceding %s is consumed by " 292 "another op", 293 LogName(*binary_op), LogName(*preceding_op)); 294 return false; 295 } 296 297 AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), 298 LogName(*preceding_op)); 299 300 if (binary_op->type == OperatorType::kAdd || 301 binary_op->type == OperatorType::kSub) { 302 FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op, 303 index_of_constant_input); 304 } else if (binary_op->type == OperatorType::kMul || 305 binary_op->type == OperatorType::kDiv) { 306 FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op, 307 index_of_constant_input); 308 } else { 309 LOG(FATAL) << "should not get here"; 310 } 311 312 model->EraseArray(preceding_op->outputs[0]); 313 preceding_op->outputs[0] = binary_op->outputs[0]; 314 preceding_op->fused_activation_function = 315 binary_op->fused_activation_function; 316 const auto& old_constant_param_name = 317 binary_op->inputs[index_of_constant_input]; 318 CHECK(IsConstantParameterArray(*model, old_constant_param_name)); 319 if (CountOpsWithInput(*model, old_constant_param_name) == 1) { 320 model->EraseArray(old_constant_param_name); 321 } 322 model->operators.erase(binary_it); 323 return true; 324} 325 326} // namespace toco 327