1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <memory>
16#include <string>
17#include <unordered_map>
18#include <vector>
19
20#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
21#include "tensorflow/contrib/lite/toco/model.h"
22#include "tensorflow/contrib/lite/toco/runtime/types.h"
23#include "tensorflow/contrib/lite/toco/tooling_util.h"
24#include "tensorflow/core/platform/logging.h"
25
26namespace toco {
27
28namespace {
29
30void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
31                                           const Operator* add_or_sub_op,
32                                           int index_of_constant_input) {
33  CHECK(add_or_sub_op->type == OperatorType::kAdd ||
34        add_or_sub_op->type == OperatorType::kSub);
35  CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
36  if (preceding_op->inputs.size() < 3) {
37    LOG(FATAL) << "Missing bias parameter";
38  }
39  auto& bias = model->GetArray(preceding_op->inputs[2]);
40  bias.minmax = nullptr;
41  const auto& operand =
42      model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
43
44  const Shape& bias_shape = bias.shape();
45  const Shape& operand_shape = operand.shape();
46  auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
47  float* const bias_data = bias_buffer.data.data();
48  const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
49  const float* const operand_data = operand_buffer.data.data();
50
51  // TODO(b/62904716): Bias array should become 1-D when padding removed.
52  const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
53  CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
54
55  enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
56
57  const OpType optype = (add_or_sub_op->type == OperatorType::kAdd)
58                            ? OpType::BiasPlusOperand
59                            : (index_of_constant_input == 1)
60                                  ? OpType::BiasMinusOperand
61                                  : OpType::OperandMinusBias;
62
63  for (int i = 0; i < depth; i++) {
64    float& bias_val = bias_data[i];
65    const float operand_val = operand_data[i];
66    if (optype == OpType::BiasPlusOperand) {
67      bias_val += operand_val;
68    } else if (optype == OpType::BiasMinusOperand) {
69      bias_val -= operand_val;
70    } else if (optype == OpType::OperandMinusBias) {
71      bias_val = operand_val - bias_val;
72    } else {
73      LOG(FATAL) << "Should not get here.";
74    }
75  }
76}
77
78void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
79                                           const Operator* mul_or_div_op,
80                                           int index_of_constant_input) {
81  CHECK(mul_or_div_op->type == OperatorType::kMul ||
82        mul_or_div_op->type == OperatorType::kDiv);
83  CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
84  // If the op is a division, the constant input should be the right hand side.
85  // This should have been checked before this point.
86  CHECK(mul_or_div_op->type != OperatorType::kDiv ||
87        index_of_constant_input == 1);
88  if (preceding_op->inputs.size() < 3) {
89    LOG(FATAL) << "Missing bias parameter";
90  }
91  const auto& weights_name = preceding_op->inputs[1];
92  const auto& bias_name = preceding_op->inputs[2];
93  auto& weights = model->GetArray(weights_name);
94  DropMinMax(model, weights_name);
95  auto& bias = model->GetArray(bias_name);
96  DropMinMax(model, bias_name);
97  const auto& operand =
98      model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
99
100  const Shape& weights_shape = weights.shape();
101  const Shape& bias_shape = bias.shape();
102  const Shape& operand_shape = operand.shape();
103  auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>();
104  float* const weights_data = weights_buffer.data.data();
105  auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
106  float* const bias_data = bias_buffer.data.data();
107  const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
108  const float* const operand_data = operand_buffer.data.data();
109
110  // We support broadcasting the operand along the depth dimension,
111  // when the operand's depth is 1.
112  int operand_channel_increment = 0;
113  if (operand_shape.dimensions_count() >= 1 &&
114      operand_shape.dims(operand_shape.dimensions_count() - 1) ==
115          bias_shape.dims(bias_shape.dimensions_count() - 1)) {
116    operand_channel_increment = 1;
117  } else if (operand_shape.dimensions_count() == 0 ||
118             operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
119    operand_channel_increment = 0;
120  } else {
121    LOG(FATAL) << "Operand shape mismatch.";
122  }
123
124  int output_depth;
125
126  if (preceding_op->type == OperatorType::kConv ||
127      preceding_op->type == OperatorType::kFullyConnected) {
128    output_depth = weights_shape.dims(0);
129  } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
130    output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1);
131  } else {
132    LOG(FATAL) << "Should not get here";
133  }
134
135  const int weights_size = RequiredBufferSizeForShape(weights_shape);
136  const int weights_per_depth = weights_size / output_depth;
137  CHECK_EQ(weights_size, weights_per_depth * output_depth);
138
139  int operand_channel = 0;
140  for (int c = 0; c < output_depth; c++) {
141    if (mul_or_div_op->type == OperatorType::kMul) {
142      bias_data[c] *= operand_data[operand_channel];
143    } else if (mul_or_div_op->type == OperatorType::kDiv) {
144      bias_data[c] /= operand_data[operand_channel];
145    } else {
146      LOG(FATAL) << "Should not get here";
147    }
148    if (preceding_op->type == OperatorType::kConv ||
149        preceding_op->type == OperatorType::kFullyConnected) {
150      for (int i = 0; i < weights_per_depth; i++) {
151        if (mul_or_div_op->type == OperatorType::kMul) {
152          weights_data[c * weights_per_depth + i] *=
153              operand_data[operand_channel];
154        } else if (mul_or_div_op->type == OperatorType::kDiv) {
155          weights_data[c * weights_per_depth + i] /=
156              operand_data[operand_channel];
157        } else {
158          LOG(FATAL) << "Should not get here";
159        }
160      }
161    } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
162      for (int k = 0; k < weights_per_depth; k++) {
163        if (mul_or_div_op->type == OperatorType::kMul) {
164          weights_data[k * output_depth + c] *= operand_data[operand_channel];
165        } else if (mul_or_div_op->type == OperatorType::kDiv) {
166          weights_data[k * output_depth + c] /= operand_data[operand_channel];
167        } else {
168          LOG(FATAL) << "Should not get here";
169        }
170      }
171    } else {
172      LOG(FATAL) << "Should not get here";
173    }
174    operand_channel += operand_channel_increment;
175  }
176}
177}  // namespace
178
179bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
180  const auto binary_it = model->operators.begin() + op_index;
181  const auto* binary_op = binary_it->get();
182  if (binary_op->type != OperatorType::kAdd &&
183      binary_op->type != OperatorType::kMul &&
184      binary_op->type != OperatorType::kSub &&
185      binary_op->type != OperatorType::kDiv) {
186    return false;
187  }
188
189  CHECK_EQ(binary_op->inputs.size(), 2);
190
191  // We only can fuse an binary when the two operands break down as follows:
192  //   1. One operand is the (variable) output of a typical affine (linear plus
193  //   bias)
194  //      op of a finite list of possible types: at the moment Conv,
195  //      DepthwiseConv and
196  //      FullyConnected are supported.
197  //   2. The other operand is a constant param array.
198  const bool is_input_constant[2] = {
199      IsConstantParameterArray(*model, binary_op->inputs[0]),
200      IsConstantParameterArray(*model, binary_op->inputs[1]),
201  };
202  if (!is_input_constant[0] && !is_input_constant[1]) {
203    // Neither input is constant, so nothing we can fuse into a constant.
204    return false;
205  }
206  if (is_input_constant[0] && is_input_constant[1]) {
207    // Both inputs are constants. That's a job for constants
208    // propagation, not for us to handle here.
209    return false;
210  }
211  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
212  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
213  CHECK(is_input_constant[index_of_constant_input]);
214  CHECK(!is_input_constant[index_of_variable_input]);
215
216  // For division, we can only fuse if the denominator is constant.
217  if (binary_op->type == OperatorType::kDiv) {
218    if (index_of_constant_input != 1) {
219      AddMessageF("Not fusing %s because the denominator is not constant",
220                  LogName(*binary_op));
221      return false;
222    }
223  }
224
225  Operator* preceding_op =
226      GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]);
227  if (!preceding_op) {
228    AddMessageF("Not fusing %s because it is not the output of another op",
229                LogName(*binary_op));
230    return false;
231  }
232
233  for (const string& output_array : model->flags.output_arrays()) {
234    if (preceding_op->outputs[0] == output_array) {
235      return false;
236    }
237  }
238
239  if (preceding_op->type != OperatorType::kConv &&
240      preceding_op->type != OperatorType::kFullyConnected &&
241      preceding_op->type != OperatorType::kDepthwiseConv) {
242    AddMessageF(
243        "Not fusing %s because the preceding %s is not of one of the supported "
244        "types",
245        LogName(*binary_op), LogName(*preceding_op));
246    return false;
247  }
248
249  if (preceding_op->fused_activation_function !=
250      FusedActivationFunctionType::kNone) {
251    AddMessageF(
252        "Not fusing %s because the preceding %s has a fused activation "
253        "function",
254        LogName(*binary_op), LogName(*preceding_op));
255    return false;
256  }
257
258  if (preceding_op->inputs.size() < 3) {
259    AddMessageF(
260        "Not fusing %s because the preceding %s does not have a bias vector",
261        LogName(*binary_op), LogName(*preceding_op));
262    return false;
263  }
264
265  const auto& weights = model->GetArray(preceding_op->inputs[1]);
266  const auto& bias = model->GetArray(preceding_op->inputs[2]);
267  if (binary_op->type == OperatorType::kAdd ||
268      binary_op->type == OperatorType::kSub) {
269    if (!bias.buffer) {
270      AddMessageF(
271          "Not fusing %s because the preceding %s has a non-constant bias "
272          "array",
273          LogName(*binary_op), LogName(*preceding_op));
274      return false;
275    }
276  } else {
277    if (!weights.buffer || !bias.buffer) {
278      AddMessageF(
279          "Not fusing %s because the preceding %s has non-constant weights or "
280          "bias arrays",
281          LogName(*binary_op), LogName(*preceding_op));
282      return false;
283    }
284  }
285
286  int count_ops_consuming_output =
287      CountOpsWithInput(*model, preceding_op->outputs[0]);
288  DCHECK_GE(count_ops_consuming_output, 1);
289  if (count_ops_consuming_output > 1) {
290    AddMessageF(
291        "Not fusing %s because the output of the preceding %s is consumed by "
292        "another op",
293        LogName(*binary_op), LogName(*preceding_op));
294    return false;
295  }
296
297  AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
298              LogName(*preceding_op));
299
300  if (binary_op->type == OperatorType::kAdd ||
301      binary_op->type == OperatorType::kSub) {
302    FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op,
303                                          index_of_constant_input);
304  } else if (binary_op->type == OperatorType::kMul ||
305             binary_op->type == OperatorType::kDiv) {
306    FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op,
307                                          index_of_constant_input);
308  } else {
309    LOG(FATAL) << "should not get here";
310  }
311
312  model->EraseArray(preceding_op->outputs[0]);
313  preceding_op->outputs[0] = binary_op->outputs[0];
314  preceding_op->fused_activation_function =
315      binary_op->fused_activation_function;
316  const auto& old_constant_param_name =
317      binary_op->inputs[index_of_constant_input];
318  CHECK(IsConstantParameterArray(*model, old_constant_param_name));
319  if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
320    model->EraseArray(old_constant_param_name);
321  }
322  model->operators.erase(binary_it);
323  return true;
324}
325
326}  // namespace toco
327