1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <iterator> 16#include <memory> 17#include <string> 18#include <unordered_map> 19#include <vector> 20 21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 22#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" 23#include "tensorflow/contrib/lite/toco/model.h" 24#include "tensorflow/contrib/lite/toco/tooling_util.h" 25#include "tensorflow/core/platform/logging.h" 26 27namespace toco { 28 29namespace { 30 31template <typename Scalar> 32bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data, 33 Scalar value) { 34 for (auto x : buffer_data) { 35 if (x != value) { 36 return false; 37 } 38 } 39 return true; 40} 41} // namespace 42 43// A binary operator is called trivial when exactly one of its operands is 44// a constant and is such that the binary operation is equivalent to 45// the identity operation on its other input. 46// For example, an Add operator is trivial if 47// one of its operands is constant 0, a Mul operator is trivial 48// if one of its operands is constant 1, etc. 49bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { 50 const auto binary_it = model->operators.begin() + op_index; 51 auto* binary_op = binary_it->get(); 52 if (binary_op->type != OperatorType::kAdd && 53 binary_op->type != OperatorType::kMul && 54 binary_op->type != OperatorType::kSub && 55 binary_op->type != OperatorType::kDiv) { 56 return false; 57 } 58 59 CHECK_EQ(binary_op->inputs.size(), 2); 60 61 // This graph transformation is only concerned with the case 62 // when one input is constant and the other is not constant. 63 const bool is_input_constant[2] = { 64 IsConstantParameterArray(*model, binary_op->inputs[0]), 65 IsConstantParameterArray(*model, binary_op->inputs[1]), 66 }; 67 if (!is_input_constant[0] && !is_input_constant[1]) { 68 // Neither input is constant, so nothing we can resolve here. 69 return false; 70 } 71 if (is_input_constant[0] && is_input_constant[1]) { 72 // Both inputs are constants. That's a job for constants 73 // propagation, not for us to handle here. 74 return false; 75 } 76 const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 77 const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 78 CHECK(is_input_constant[index_of_constant_input]); 79 CHECK(!is_input_constant[index_of_variable_input]); 80 81 // Now check if the constant operand makes this binary 82 // operator trivial. 83 const auto& constant_input_array = 84 model->GetArray(binary_op->inputs[index_of_constant_input]); 85 // For now, we only handle floats here. 86 if (constant_input_array.data_type != ArrayDataType::kFloat) { 87 return false; 88 } 89 const auto& constant_input_float_data = 90 constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; 91 bool is_trivial = false; 92 if (binary_op->type == OperatorType::kAdd) { 93 is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); 94 } else if (binary_op->type == OperatorType::kSub) { 95 is_trivial = index_of_constant_input == 1 && 96 AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); 97 } else if (binary_op->type == OperatorType::kMul) { 98 is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); 99 } else if (binary_op->type == OperatorType::kDiv) { 100 is_trivial = index_of_constant_input == 1 && 101 AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); 102 } 103 104 if (!is_trivial) { 105 return false; 106 } 107 108 // Now we know that this node is trivial, so we can remove it. 109 AddMessageF("Removing trivial %s", LogName(*binary_op)); 110 return RemoveTrivialPassthroughOp(this, model, op_index); 111} 112 113} // namespace toco 114