1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <iterator>
16#include <memory>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
22#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
23#include "tensorflow/contrib/lite/toco/model.h"
24#include "tensorflow/contrib/lite/toco/tooling_util.h"
25#include "tensorflow/core/platform/logging.h"
26
27namespace toco {
28
29namespace {
30
31template <typename Scalar>
32bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
33                                 Scalar value) {
34  for (auto x : buffer_data) {
35    if (x != value) {
36      return false;
37    }
38  }
39  return true;
40}
41}  // namespace
42
43// A binary operator is called trivial when exactly one of its operands is
44// a constant and is such that the binary operation is equivalent to
45// the identity operation on its other input.
46// For example, an Add operator is trivial if
47// one of its operands is constant 0, a Mul operator is trivial
48// if one of its operands is constant 1, etc.
49bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
50  const auto binary_it = model->operators.begin() + op_index;
51  auto* binary_op = binary_it->get();
52  if (binary_op->type != OperatorType::kAdd &&
53      binary_op->type != OperatorType::kMul &&
54      binary_op->type != OperatorType::kSub &&
55      binary_op->type != OperatorType::kDiv) {
56    return false;
57  }
58
59  CHECK_EQ(binary_op->inputs.size(), 2);
60
61  // This graph transformation is only concerned with the case
62  // when one input is constant and the other is not constant.
63  const bool is_input_constant[2] = {
64      IsConstantParameterArray(*model, binary_op->inputs[0]),
65      IsConstantParameterArray(*model, binary_op->inputs[1]),
66  };
67  if (!is_input_constant[0] && !is_input_constant[1]) {
68    // Neither input is constant, so nothing we can resolve here.
69    return false;
70  }
71  if (is_input_constant[0] && is_input_constant[1]) {
72    // Both inputs are constants. That's a job for constants
73    // propagation, not for us to handle here.
74    return false;
75  }
76  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
77  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
78  CHECK(is_input_constant[index_of_constant_input]);
79  CHECK(!is_input_constant[index_of_variable_input]);
80
81  // Now check if the constant operand makes this binary
82  // operator trivial.
83  const auto& constant_input_array =
84      model->GetArray(binary_op->inputs[index_of_constant_input]);
85  // For now, we only handle floats here.
86  if (constant_input_array.data_type != ArrayDataType::kFloat) {
87    return false;
88  }
89  const auto& constant_input_float_data =
90      constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
91  bool is_trivial = false;
92  if (binary_op->type == OperatorType::kAdd) {
93    is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
94  } else if (binary_op->type == OperatorType::kSub) {
95    is_trivial = index_of_constant_input == 1 &&
96                 AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
97  } else if (binary_op->type == OperatorType::kMul) {
98    is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
99  } else if (binary_op->type == OperatorType::kDiv) {
100    is_trivial = index_of_constant_input == 1 &&
101                 AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
102  }
103
104  if (!is_trivial) {
105    return false;
106  }
107
108  // Now we know that this node is trivial, so we can remove it.
109  AddMessageF("Removing trivial %s", LogName(*binary_op));
110  return RemoveTrivialPassthroughOp(this, model, op_index);
111}
112
113}  // namespace toco
114