1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <memory>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
22#include "tensorflow/contrib/lite/toco/model.h"
23#include "tensorflow/contrib/lite/toco/runtime/types.h"
24#include "tensorflow/contrib/lite/toco/tooling_util.h"
25#include "tensorflow/core/platform/logging.h"
26
27namespace toco {
28
29namespace {
30
31void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op,
32                                           const Operator* add_or_sub_op,
33                                           int index_of_constant_input) {
34  CHECK(add_or_sub_op->type == OperatorType::kAdd ||
35        add_or_sub_op->type == OperatorType::kSub);
36  CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
37  // If the op is a subtraction, the constant input should be the right hand
38  // side.
39  // This should have been checked before this point.
40  CHECK(add_or_sub_op->type != OperatorType::kSub ||
41        index_of_constant_input == 1);
42  if (following_op->inputs.size() < 3) {
43    LOG(FATAL) << "Missing bias parameter";
44  }
45  const auto& weights = model->GetArray(following_op->inputs[1]);
46  auto& bias = model->GetArray(following_op->inputs[2]);
47  bias.minmax = nullptr;
48  const auto& operand =
49      model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
50  // We're only supporting the case of a scalar operand. Should have
51  // been checked earlier.
52  CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
53
54  const float scalar_operand =
55      operand.GetBuffer<ArrayDataType::kFloat>().data[0];
56  // At this point we reduce the case of subtraction to that of addition
57  // by negating the operand.
58  float add_scalar_operand = 0.f;
59  if (add_or_sub_op->type == OperatorType::kAdd) {
60    add_scalar_operand = scalar_operand;
61  } else if (add_or_sub_op->type == OperatorType::kSub &&
62             index_of_constant_input == 1) {
63    add_scalar_operand = -scalar_operand;
64  } else {
65    LOG(FATAL) << "Should not get here";
66  }
67  // From here on we are fusing an addition. add_or_sub_op->type does not
68  // matter anymore.
69
70  const Shape& weights_shape = weights.shape();
71  const Shape& bias_shape = bias.shape();
72  const auto& weights_buffer = weights.GetBuffer<ArrayDataType::kFloat>();
73  const float* const weights_data = weights_buffer.data.data();
74  auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
75  float* const bias_data = bias_buffer.data.data();
76
77  if (following_op->type == OperatorType::kConv ||
78      following_op->type == OperatorType::kFullyConnected) {
79    const int output_depth = weights_shape.dims(0);
80    // TODO(b/62904716): Bias array should become 1-D when padding removed.
81    CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1));
82    const int weights_size = RequiredBufferSizeForShape(weights_shape);
83    const int weights_per_depth = weights_size / output_depth;
84    CHECK_EQ(weights_size, weights_per_depth * output_depth);
85
86    for (int d = 0; d < output_depth; d++) {
87      float accumulation = 0;
88      for (int i = 0; i < weights_per_depth; i++) {
89        accumulation +=
90            add_scalar_operand * weights_data[d * weights_per_depth + i];
91      }
92      bias_data[d] += accumulation;
93    }
94  } else if (following_op->type == OperatorType::kDepthwiseConv) {
95    const int output_depth =
96        weights_shape.dims(weights_shape.dimensions_count() - 1);
97    const int weights_size = RequiredBufferSizeForShape(weights_shape);
98    const int weights_per_depth = weights_size / output_depth;
99    CHECK_EQ(weights_size, weights_per_depth * output_depth);
100
101    for (int c = 0; c < output_depth; c++) {
102      float accumulation = 0;
103      for (int k = 0; k < weights_per_depth; k++) {
104        accumulation += add_scalar_operand * weights_data[k * output_depth + c];
105      }
106      bias_data[c] += accumulation;
107    }
108  } else {
109    LOG(FATAL) << "Should not get here.";
110  }
111}
112
113void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
114                                           const Operator* mul_or_div_op,
115                                           int index_of_constant_input) {
116  CHECK(mul_or_div_op->type == OperatorType::kMul ||
117        mul_or_div_op->type == OperatorType::kDiv);
118  CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
119  // If the op is a division, the constant input should be the right hand side.
120  // This should have been checked before this point.
121  CHECK(mul_or_div_op->type != OperatorType::kDiv ||
122        index_of_constant_input == 1);
123  const auto& weights_name = following_op->inputs[1];
124  const auto& bias_name = following_op->inputs[2];
125  auto& weights = model->GetArray(weights_name);
126  DropMinMax(model, weights_name);
127  DropMinMax(model, bias_name);
128  const auto& operand =
129      model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
130  // We're only supporting the case of a scalar operand. Should have
131  // been checked earlier.
132  CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
133
134  const float scalar_operand =
135      operand.GetBuffer<ArrayDataType::kFloat>().data[0];
136
137  float* weights_data =
138      weights.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
139  const int weights_size = RequiredBufferSizeForShape(weights.shape());
140  for (int i = 0; i < weights_size; i++) {
141    if (mul_or_div_op->type == OperatorType::kMul) {
142      weights_data[i] *= scalar_operand;
143    } else if (mul_or_div_op->type == OperatorType::kDiv) {
144      weights_data[i] /= scalar_operand;
145    } else {
146      LOG(FATAL) << "Should not get here";
147    }
148  }
149}
150
151}  // namespace
152
153bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
154  const auto binary_it = model->operators.begin() + op_index;
155  auto* binary_op = binary_it->get();
156  if (binary_op->type != OperatorType::kAdd &&
157      binary_op->type != OperatorType::kMul &&
158      binary_op->type != OperatorType::kSub &&
159      binary_op->type != OperatorType::kDiv) {
160    return false;
161  }
162
163  CHECK_EQ(binary_op->inputs.size(), 2);
164
165  // We only can fuse an binary when the two operands break down as follows:
166  //   1. One operand is the (variable) output of a typical affine (linear plus
167  //   bias)
168  //      op of a finite list of possible types: at the moment Conv,
169  //      DepthwiseConv and
170  //      FullyConnected are supported.
171  //   2. The other operand is a constant param array.
172  const bool is_input_constant[2] = {
173      IsConstantParameterArray(*model, binary_op->inputs[0]),
174      IsConstantParameterArray(*model, binary_op->inputs[1]),
175  };
176  if (!is_input_constant[0] && !is_input_constant[1]) {
177    // Neither input is constant, so nothing we can fuse into a constant.
178    return false;
179  }
180  if (is_input_constant[0] && is_input_constant[1]) {
181    // Both inputs are constants. That's a job for constants
182    // propagation, not for us to handle here.
183    return false;
184  }
185  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
186  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
187  CHECK(is_input_constant[index_of_constant_input]);
188  CHECK(!is_input_constant[index_of_variable_input]);
189
190  // For division, we can only fuse if the denominator is constant.
191  if (binary_op->type == OperatorType::kDiv) {
192    if (index_of_constant_input != 1) {
193      AddMessageF("Not fusing %s because the denominator is not constant",
194                  LogName(*binary_op));
195      return false;
196    }
197  }
198
199  const auto& operand_shape =
200      model->GetArray(binary_op->inputs[index_of_constant_input]).shape();
201  for (const auto& dim : operand_shape.dims()) {
202    if (dim > 1) {
203      AddMessageF(
204          "Not fusing %s into the following affine op, because we only know "
205          "how to do so when the constant operand is a scalar",
206          LogName(*binary_op));
207      return false;
208    }
209  }
210
211  if (binary_op->fused_activation_function !=
212      FusedActivationFunctionType::kNone) {
213    AddMessageF("Not fusing %s because it has a fused activation function",
214                LogName(*binary_op));
215    return false;
216  }
217
218  Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
219
220  if (!following_op) {
221    AddMessageF(
222        "Not fusing %s because it is not consumed by exactly one other op",
223        LogName(*binary_op));
224    return false;
225  }
226
227  if (following_op->type != OperatorType::kConv &&
228      following_op->type != OperatorType::kFullyConnected &&
229      following_op->type != OperatorType::kDepthwiseConv) {
230    AddMessageF(
231        "Not fusing %s because the following %s is not of one of the supported "
232        "types",
233        LogName(*binary_op), LogName(*following_op));
234    return false;
235  }
236
237  if (following_op->inputs.size() < 3) {
238    AddMessageF(
239        "Not fusing %s because the following %s does not have a bias vector",
240        LogName(*following_op), LogName(*binary_op));
241    return false;
242  }
243
244  const auto& weights = model->GetArray(following_op->inputs[1]);
245  const auto& bias = model->GetArray(following_op->inputs[2]);
246  if (!weights.buffer || !bias.buffer) {
247    AddMessageF(
248        "Not fusing %s because the following %s has non-constant weights or "
249        "bias arrays",
250        LogName(*binary_op), LogName(*following_op));
251    return false;
252  }
253
254  // Try to fuse the binary params into the following op's params
255  if (binary_op->type == OperatorType::kAdd ||
256      binary_op->type == OperatorType::kSub) {
257    if (following_op->type == OperatorType::kConv) {
258      if (static_cast<ConvOperator*>(following_op)->padding.type !=
259          PaddingType::kValid) {
260        AddMessageF(
261            "Not fusing %s because the following %s does not use VALID padding",
262            LogName(*binary_op), LogName(*following_op));
263        return false;
264      }
265    }
266    if (following_op->type == OperatorType::kDepthwiseConv) {
267      if (static_cast<DepthwiseConvOperator*>(following_op)->padding.type !=
268          PaddingType::kValid) {
269        AddMessageF(
270            "Not fusing %s because the following %s does not use VALID padding",
271            LogName(*binary_op), LogName(*following_op));
272        return false;
273      }
274    }
275    FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
276                                          index_of_constant_input);
277  } else if (binary_op->type == OperatorType::kMul ||
278             binary_op->type == OperatorType::kDiv) {
279    FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op,
280                                          index_of_constant_input);
281  } else {
282    LOG(FATAL) << "should not get here";
283  }
284
285  AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
286              LogName(*following_op));
287
288  model->EraseArray(binary_op->outputs[0]);
289  following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
290  const auto& old_constant_param_name =
291      binary_op->inputs[index_of_constant_input];
292  CHECK(IsConstantParameterArray(*model, old_constant_param_name));
293  if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
294    model->EraseArray(old_constant_param_name);
295  }
296  model->operators.erase(binary_it);
297  return true;
298}
299
300}  // namespace toco
301