154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
354da0e3ca924a5040e88a1c067f9f6760a14b20bsamebLicensed under the Apache License, Version 2.0 (the "License");
454da0e3ca924a5040e88a1c067f9f6760a14b20bsamebyou may not use this file except in compliance with the License.
554da0e3ca924a5040e88a1c067f9f6760a14b20bsamebYou may obtain a copy of the License at
654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb    http://www.apache.org/licenses/LICENSE-2.0
854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
954da0e3ca924a5040e88a1c067f9f6760a14b20bsamebUnless required by applicable law or agreed to in writing, software
1054da0e3ca924a5040e88a1c067f9f6760a14b20bsamebdistributed under the License is distributed on an "AS IS" BASIS,
1154da0e3ca924a5040e88a1c067f9f6760a14b20bsamebWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1254da0e3ca924a5040e88a1c067f9f6760a14b20bsamebSee the License for the specific language governing permissions and
1354da0e3ca924a5040e88a1c067f9f6760a14b20bsameblimitations under the License.
1454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb==============================================================================*/
1554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <memory>
1654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <string>
1754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <unordered_map>
1854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include <vector>
1954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
2054da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
2154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/model.h"
2254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/runtime/types.h"
2354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/contrib/lite/toco/tooling_util.h"
2454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb#include "tensorflow/core/platform/logging.h"
2554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
2654da0e3ca924a5040e88a1c067f9f6760a14b20bsamebnamespace toco {
2754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
2854da0e3ca924a5040e88a1c067f9f6760a14b20bsamebnamespace {
2954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
3054da0e3ca924a5040e88a1c067f9f6760a14b20bsamebvoid FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
3154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb                                           const Operator* add_or_sub_op,
3254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb                                           int index_of_constant_input) {
3354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  CHECK(add_or_sub_op->type == OperatorType::kAdd ||
3454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb        add_or_sub_op->type == OperatorType::kSub);
3554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
3654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  if (preceding_op->inputs.size() < 3) {
3754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb    LOG(FATAL) << "Missing bias parameter";
3854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  }
3954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  auto& bias = model->GetArray(preceding_op->inputs[2]);
4054da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  bias.minmax = nullptr;
4154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const auto& operand =
4254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb      model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
4354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
4454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const Shape& bias_shape = bias.shape();
4554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const Shape& operand_shape = operand.shape();
4654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
4754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  float* const bias_data = bias_buffer.data.data();
4854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
4954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const float* const operand_data = operand_buffer.data.data();
5054da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
5154da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  // TODO(b/62904716): Bias array should become 1-D when padding removed.
5254da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
5354da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
5454da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
5554da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
5654da0e3ca924a5040e88a1c067f9f6760a14b20bsameb
5754da0e3ca924a5040e88a1c067f9f6760a14b20bsameb  const OpType optype = (add_or_sub_op->type == OperatorType::kAdd)
5854da0e3ca924a5040e88a1c067f9f6760a14b20bsameb                            ? OpType::BiasPlusOperand
5954da0e3ca924a5040e88a1c067f9f6760a14b20bsameb                            : (index_of_constant_input == 1)
60                                  ? OpType::BiasMinusOperand
61                                  : OpType::OperandMinusBias;
62
63  for (int i = 0; i < depth; i++) {
64    float& bias_val = bias_data[i];
65    const float operand_val = operand_data[i];
66    if (optype == OpType::BiasPlusOperand) {
67      bias_val += operand_val;
68    } else if (optype == OpType::BiasMinusOperand) {
69      bias_val -= operand_val;
70    } else if (optype == OpType::OperandMinusBias) {
71      bias_val = operand_val - bias_val;
72    } else {
73      LOG(FATAL) << "Should not get here.";
74    }
75  }
76}
77
78void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
79                                           const Operator* mul_or_div_op,
80                                           int index_of_constant_input) {
81  CHECK(mul_or_div_op->type == OperatorType::kMul ||
82        mul_or_div_op->type == OperatorType::kDiv);
83  CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
84  // If the op is a division, the constant input should be the right hand side.
85  // This should have been checked before this point.
86  CHECK(mul_or_div_op->type != OperatorType::kDiv ||
87        index_of_constant_input == 1);
88  if (preceding_op->inputs.size() < 3) {
89    LOG(FATAL) << "Missing bias parameter";
90  }
91  const auto& weights_name = preceding_op->inputs[1];
92  const auto& bias_name = preceding_op->inputs[2];
93  auto& weights = model->GetArray(weights_name);
94  DropMinMax(model, weights_name);
95  auto& bias = model->GetArray(bias_name);
96  DropMinMax(model, bias_name);
97  const auto& operand =
98      model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
99
100  const Shape& weights_shape = weights.shape();
101  const Shape& bias_shape = bias.shape();
102  const Shape& operand_shape = operand.shape();
103  auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>();
104  float* const weights_data = weights_buffer.data.data();
105  auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
106  float* const bias_data = bias_buffer.data.data();
107  const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
108  const float* const operand_data = operand_buffer.data.data();
109
110  // We support broadcasting the operand along the depth dimension,
111  // when the operand's depth is 1.
112  int operand_channel_increment = 0;
113  if (operand_shape.dimensions_count() >= 1 &&
114      operand_shape.dims(operand_shape.dimensions_count() - 1) ==
115          bias_shape.dims(bias_shape.dimensions_count() - 1)) {
116    operand_channel_increment = 1;
117  } else if (operand_shape.dimensions_count() == 0 ||
118             operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
119    operand_channel_increment = 0;
120  } else {
121    LOG(FATAL) << "Operand shape mismatch.";
122  }
123
124  int output_depth;
125
126  if (preceding_op->type == OperatorType::kConv ||
127      preceding_op->type == OperatorType::kFullyConnected) {
128    output_depth = weights_shape.dims(0);
129  } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
130    output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1);
131  } else {
132    LOG(FATAL) << "Should not get here";
133  }
134
135  const int weights_size = RequiredBufferSizeForShape(weights_shape);
136  const int weights_per_depth = weights_size / output_depth;
137  CHECK_EQ(weights_size, weights_per_depth * output_depth);
138
139  int operand_channel = 0;
140  for (int c = 0; c < output_depth; c++) {
141    if (mul_or_div_op->type == OperatorType::kMul) {
142      bias_data[c] *= operand_data[operand_channel];
143    } else if (mul_or_div_op->type == OperatorType::kDiv) {
144      bias_data[c] /= operand_data[operand_channel];
145    } else {
146      LOG(FATAL) << "Should not get here";
147    }
148    if (preceding_op->type == OperatorType::kConv ||
149        preceding_op->type == OperatorType::kFullyConnected) {
150      for (int i = 0; i < weights_per_depth; i++) {
151        if (mul_or_div_op->type == OperatorType::kMul) {
152          weights_data[c * weights_per_depth + i] *=
153              operand_data[operand_channel];
154        } else if (mul_or_div_op->type == OperatorType::kDiv) {
155          weights_data[c * weights_per_depth + i] /=
156              operand_data[operand_channel];
157        } else {
158          LOG(FATAL) << "Should not get here";
159        }
160      }
161    } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
162      for (int k = 0; k < weights_per_depth; k++) {
163        if (mul_or_div_op->type == OperatorType::kMul) {
164          weights_data[k * output_depth + c] *= operand_data[operand_channel];
165        } else if (mul_or_div_op->type == OperatorType::kDiv) {
166          weights_data[k * output_depth + c] /= operand_data[operand_channel];
167        } else {
168          LOG(FATAL) << "Should not get here";
169        }
170      }
171    } else {
172      LOG(FATAL) << "Should not get here";
173    }
174    operand_channel += operand_channel_increment;
175  }
176}
177}  // namespace
178
179bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
180  const auto binary_it = model->operators.begin() + op_index;
181  const auto* binary_op = binary_it->get();
182  if (binary_op->type != OperatorType::kAdd &&
183      binary_op->type != OperatorType::kMul &&
184      binary_op->type != OperatorType::kSub &&
185      binary_op->type != OperatorType::kDiv) {
186    return false;
187  }
188
189  CHECK_EQ(binary_op->inputs.size(), 2);
190
191  // We only can fuse an binary when the two operands break down as follows:
192  //   1. One operand is the (variable) output of a typical affine (linear plus
193  //   bias)
194  //      op of a finite list of possible types: at the moment Conv,
195  //      DepthwiseConv and
196  //      FullyConnected are supported.
197  //   2. The other operand is a constant param array.
198  const bool is_input_constant[2] = {
199      IsConstantParameterArray(*model, binary_op->inputs[0]),
200      IsConstantParameterArray(*model, binary_op->inputs[1]),
201  };
202  if (!is_input_constant[0] && !is_input_constant[1]) {
203    // Neither input is constant, so nothing we can fuse into a constant.
204    return false;
205  }
206  if (is_input_constant[0] && is_input_constant[1]) {
207    // Both inputs are constants. That's a job for constants
208    // propagation, not for us to handle here.
209    return false;
210  }
211  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
212  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
213  CHECK(is_input_constant[index_of_constant_input]);
214  CHECK(!is_input_constant[index_of_variable_input]);
215
216  // For division, we can only fuse if the denominator is constant.
217  if (binary_op->type == OperatorType::kDiv) {
218    if (index_of_constant_input != 1) {
219      AddMessageF("Not fusing %s because the denominator is not constant",
220                  LogName(*binary_op));
221      return false;
222    }
223  }
224
225  Operator* preceding_op =
226      GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]);
227  if (!preceding_op) {
228    AddMessageF("Not fusing %s because it is not the output of another op",
229                LogName(*binary_op));
230    return false;
231  }
232
233  for (const string& output_array : model->flags.output_arrays()) {
234    if (preceding_op->outputs[0] == output_array) {
235      return false;
236    }
237  }
238
239  if (preceding_op->type != OperatorType::kConv &&
240      preceding_op->type != OperatorType::kFullyConnected &&
241      preceding_op->type != OperatorType::kDepthwiseConv) {
242    AddMessageF(
243        "Not fusing %s because the preceding %s is not of one of the supported "
244        "types",
245        LogName(*binary_op), LogName(*preceding_op));
246    return false;
247  }
248
249  if (preceding_op->fused_activation_function !=
250      FusedActivationFunctionType::kNone) {
251    AddMessageF(
252        "Not fusing %s because the preceding %s has a fused activation "
253        "function",
254        LogName(*binary_op), LogName(*preceding_op));
255    return false;
256  }
257
258  if (preceding_op->inputs.size() < 3) {
259    AddMessageF(
260        "Not fusing %s because the preceding %s does not have a bias vector",
261        LogName(*binary_op), LogName(*preceding_op));
262    return false;
263  }
264
265  const auto& weights = model->GetArray(preceding_op->inputs[1]);
266  const auto& bias = model->GetArray(preceding_op->inputs[2]);
267  if (binary_op->type == OperatorType::kAdd ||
268      binary_op->type == OperatorType::kSub) {
269    if (!bias.buffer) {
270      AddMessageF(
271          "Not fusing %s because the preceding %s has a non-constant bias "
272          "array",
273          LogName(*binary_op), LogName(*preceding_op));
274      return false;
275    }
276  } else {
277    if (!weights.buffer || !bias.buffer) {
278      AddMessageF(
279          "Not fusing %s because the preceding %s has non-constant weights or "
280          "bias arrays",
281          LogName(*binary_op), LogName(*preceding_op));
282      return false;
283    }
284  }
285
286  int count_ops_consuming_output =
287      CountOpsWithInput(*model, preceding_op->outputs[0]);
288  DCHECK_GE(count_ops_consuming_output, 1);
289  if (count_ops_consuming_output > 1) {
290    AddMessageF(
291        "Not fusing %s because the output of the preceding %s is consumed by "
292        "another op",
293        LogName(*binary_op), LogName(*preceding_op));
294    return false;
295  }
296
297  AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
298              LogName(*preceding_op));
299
300  if (binary_op->type == OperatorType::kAdd ||
301      binary_op->type == OperatorType::kSub) {
302    FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op,
303                                          index_of_constant_input);
304  } else if (binary_op->type == OperatorType::kMul ||
305             binary_op->type == OperatorType::kDiv) {
306    FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op,
307                                          index_of_constant_input);
308  } else {
309    LOG(FATAL) << "should not get here";
310  }
311
312  model->EraseArray(preceding_op->outputs[0]);
313  preceding_op->outputs[0] = binary_op->outputs[0];
314  preceding_op->fused_activation_function =
315      binary_op->fused_activation_function;
316  const auto& old_constant_param_name =
317      binary_op->inputs[index_of_constant_input];
318  CHECK(IsConstantParameterArray(*model, old_constant_param_name));
319  if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
320    model->EraseArray(old_constant_param_name);
321  }
322  model->operators.erase(binary_it);
323  return true;
324}
325
326}  // namespace toco
327