1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <algorithm> 16#include <memory> 17#include <string> 18#include <unordered_map> 19#include <vector> 20 21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 22#include "tensorflow/contrib/lite/toco/model.h" 23#include "tensorflow/contrib/lite/toco/runtime/types.h" 24#include "tensorflow/contrib/lite/toco/tooling_util.h" 25#include "tensorflow/core/platform/logging.h" 26 27namespace toco { 28 29namespace { 30 31void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op, 32 const Operator* add_or_sub_op, 33 int index_of_constant_input) { 34 CHECK(add_or_sub_op->type == OperatorType::kAdd || 35 add_or_sub_op->type == OperatorType::kSub); 36 CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); 37 // If the op is a subtraction, the constant input should be the right hand 38 // side. 39 // This should have been checked before this point. 40 CHECK(add_or_sub_op->type != OperatorType::kSub || 41 index_of_constant_input == 1); 42 if (following_op->inputs.size() < 3) { 43 LOG(FATAL) << "Missing bias parameter"; 44 } 45 const auto& weights = model->GetArray(following_op->inputs[1]); 46 auto& bias = model->GetArray(following_op->inputs[2]); 47 bias.minmax = nullptr; 48 const auto& operand = 49 model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); 50 // We're only supporting the case of a scalar operand. Should have 51 // been checked earlier. 52 CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); 53 54 const float scalar_operand = 55 operand.GetBuffer<ArrayDataType::kFloat>().data[0]; 56 // At this point we reduce the case of subtraction to that of addition 57 // by negating the operand. 58 float add_scalar_operand = 0.f; 59 if (add_or_sub_op->type == OperatorType::kAdd) { 60 add_scalar_operand = scalar_operand; 61 } else if (add_or_sub_op->type == OperatorType::kSub && 62 index_of_constant_input == 1) { 63 add_scalar_operand = -scalar_operand; 64 } else { 65 LOG(FATAL) << "Should not get here"; 66 } 67 // From here on we are fusing an addition. add_or_sub_op->type does not 68 // matter anymore. 69 70 const Shape& weights_shape = weights.shape(); 71 const Shape& bias_shape = bias.shape(); 72 const auto& weights_buffer = weights.GetBuffer<ArrayDataType::kFloat>(); 73 const float* const weights_data = weights_buffer.data.data(); 74 auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); 75 float* const bias_data = bias_buffer.data.data(); 76 77 if (following_op->type == OperatorType::kConv || 78 following_op->type == OperatorType::kFullyConnected) { 79 const int output_depth = weights_shape.dims(0); 80 // TODO(b/62904716): Bias array should become 1-D when padding removed. 81 CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1)); 82 const int weights_size = RequiredBufferSizeForShape(weights_shape); 83 const int weights_per_depth = weights_size / output_depth; 84 CHECK_EQ(weights_size, weights_per_depth * output_depth); 85 86 for (int d = 0; d < output_depth; d++) { 87 float accumulation = 0; 88 for (int i = 0; i < weights_per_depth; i++) { 89 accumulation += 90 add_scalar_operand * weights_data[d * weights_per_depth + i]; 91 } 92 bias_data[d] += accumulation; 93 } 94 } else if (following_op->type == OperatorType::kDepthwiseConv) { 95 const int output_depth = 96 weights_shape.dims(weights_shape.dimensions_count() - 1); 97 const int weights_size = RequiredBufferSizeForShape(weights_shape); 98 const int weights_per_depth = weights_size / output_depth; 99 CHECK_EQ(weights_size, weights_per_depth * output_depth); 100 101 for (int c = 0; c < output_depth; c++) { 102 float accumulation = 0; 103 for (int k = 0; k < weights_per_depth; k++) { 104 accumulation += add_scalar_operand * weights_data[k * output_depth + c]; 105 } 106 bias_data[c] += accumulation; 107 } 108 } else { 109 LOG(FATAL) << "Should not get here."; 110 } 111} 112 113void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, 114 const Operator* mul_or_div_op, 115 int index_of_constant_input) { 116 CHECK(mul_or_div_op->type == OperatorType::kMul || 117 mul_or_div_op->type == OperatorType::kDiv); 118 CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); 119 // If the op is a division, the constant input should be the right hand side. 120 // This should have been checked before this point. 121 CHECK(mul_or_div_op->type != OperatorType::kDiv || 122 index_of_constant_input == 1); 123 const auto& weights_name = following_op->inputs[1]; 124 const auto& bias_name = following_op->inputs[2]; 125 auto& weights = model->GetArray(weights_name); 126 DropMinMax(model, weights_name); 127 DropMinMax(model, bias_name); 128 const auto& operand = 129 model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); 130 // We're only supporting the case of a scalar operand. Should have 131 // been checked earlier. 132 CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); 133 134 const float scalar_operand = 135 operand.GetBuffer<ArrayDataType::kFloat>().data[0]; 136 137 float* weights_data = 138 weights.GetMutableBuffer<ArrayDataType::kFloat>().data.data(); 139 const int weights_size = RequiredBufferSizeForShape(weights.shape()); 140 for (int i = 0; i < weights_size; i++) { 141 if (mul_or_div_op->type == OperatorType::kMul) { 142 weights_data[i] *= scalar_operand; 143 } else if (mul_or_div_op->type == OperatorType::kDiv) { 144 weights_data[i] /= scalar_operand; 145 } else { 146 LOG(FATAL) << "Should not get here"; 147 } 148 } 149} 150 151} // namespace 152 153bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { 154 const auto binary_it = model->operators.begin() + op_index; 155 auto* binary_op = binary_it->get(); 156 if (binary_op->type != OperatorType::kAdd && 157 binary_op->type != OperatorType::kMul && 158 binary_op->type != OperatorType::kSub && 159 binary_op->type != OperatorType::kDiv) { 160 return false; 161 } 162 163 CHECK_EQ(binary_op->inputs.size(), 2); 164 165 // We only can fuse an binary when the two operands break down as follows: 166 // 1. One operand is the (variable) output of a typical affine (linear plus 167 // bias) 168 // op of a finite list of possible types: at the moment Conv, 169 // DepthwiseConv and 170 // FullyConnected are supported. 171 // 2. The other operand is a constant param array. 172 const bool is_input_constant[2] = { 173 IsConstantParameterArray(*model, binary_op->inputs[0]), 174 IsConstantParameterArray(*model, binary_op->inputs[1]), 175 }; 176 if (!is_input_constant[0] && !is_input_constant[1]) { 177 // Neither input is constant, so nothing we can fuse into a constant. 178 return false; 179 } 180 if (is_input_constant[0] && is_input_constant[1]) { 181 // Both inputs are constants. That's a job for constants 182 // propagation, not for us to handle here. 183 return false; 184 } 185 const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 186 const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 187 CHECK(is_input_constant[index_of_constant_input]); 188 CHECK(!is_input_constant[index_of_variable_input]); 189 190 // For division, we can only fuse if the denominator is constant. 191 if (binary_op->type == OperatorType::kDiv) { 192 if (index_of_constant_input != 1) { 193 AddMessageF("Not fusing %s because the denominator is not constant", 194 LogName(*binary_op)); 195 return false; 196 } 197 } 198 199 const auto& operand_shape = 200 model->GetArray(binary_op->inputs[index_of_constant_input]).shape(); 201 for (const auto& dim : operand_shape.dims()) { 202 if (dim > 1) { 203 AddMessageF( 204 "Not fusing %s into the following affine op, because we only know " 205 "how to do so when the constant operand is a scalar", 206 LogName(*binary_op)); 207 return false; 208 } 209 } 210 211 if (binary_op->fused_activation_function != 212 FusedActivationFunctionType::kNone) { 213 AddMessageF("Not fusing %s because it has a fused activation function", 214 LogName(*binary_op)); 215 return false; 216 } 217 218 Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]); 219 220 if (!following_op) { 221 AddMessageF( 222 "Not fusing %s because it is not consumed by exactly one other op", 223 LogName(*binary_op)); 224 return false; 225 } 226 227 if (following_op->type != OperatorType::kConv && 228 following_op->type != OperatorType::kFullyConnected && 229 following_op->type != OperatorType::kDepthwiseConv) { 230 AddMessageF( 231 "Not fusing %s because the following %s is not of one of the supported " 232 "types", 233 LogName(*binary_op), LogName(*following_op)); 234 return false; 235 } 236 237 if (following_op->inputs.size() < 3) { 238 AddMessageF( 239 "Not fusing %s because the following %s does not have a bias vector", 240 LogName(*following_op), LogName(*binary_op)); 241 return false; 242 } 243 244 const auto& weights = model->GetArray(following_op->inputs[1]); 245 const auto& bias = model->GetArray(following_op->inputs[2]); 246 if (!weights.buffer || !bias.buffer) { 247 AddMessageF( 248 "Not fusing %s because the following %s has non-constant weights or " 249 "bias arrays", 250 LogName(*binary_op), LogName(*following_op)); 251 return false; 252 } 253 254 // Try to fuse the binary params into the following op's params 255 if (binary_op->type == OperatorType::kAdd || 256 binary_op->type == OperatorType::kSub) { 257 if (following_op->type == OperatorType::kConv) { 258 if (static_cast<ConvOperator*>(following_op)->padding.type != 259 PaddingType::kValid) { 260 AddMessageF( 261 "Not fusing %s because the following %s does not use VALID padding", 262 LogName(*binary_op), LogName(*following_op)); 263 return false; 264 } 265 } 266 if (following_op->type == OperatorType::kDepthwiseConv) { 267 if (static_cast<DepthwiseConvOperator*>(following_op)->padding.type != 268 PaddingType::kValid) { 269 AddMessageF( 270 "Not fusing %s because the following %s does not use VALID padding", 271 LogName(*binary_op), LogName(*following_op)); 272 return false; 273 } 274 } 275 FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op, 276 index_of_constant_input); 277 } else if (binary_op->type == OperatorType::kMul || 278 binary_op->type == OperatorType::kDiv) { 279 FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op, 280 index_of_constant_input); 281 } else { 282 LOG(FATAL) << "should not get here"; 283 } 284 285 AddMessageF("Fusing %s into the following %s", LogName(*binary_op), 286 LogName(*following_op)); 287 288 model->EraseArray(binary_op->outputs[0]); 289 following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; 290 const auto& old_constant_param_name = 291 binary_op->inputs[index_of_constant_input]; 292 CHECK(IsConstantParameterArray(*model, old_constant_param_name)); 293 if (CountOpsWithInput(*model, old_constant_param_name) == 1) { 294 model->EraseArray(old_constant_param_name); 295 } 296 model->operators.erase(binary_it); 297 return true; 298} 299 300} // namespace toco 301