10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <iterator> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory> 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string> 180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map> 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector> 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h" 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h" 250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h" 260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco { 280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace { 300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selletemplate <typename Scalar> 320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data, 330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Scalar value) { 340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (auto x : buffer_data) { 350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (x != value) { 360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace 420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// A binary operator is called trivial when exactly one of its operands is 440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// a constant and is such that the binary operation is equivalent to 450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// the identity operation on its other input. 460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// For example, an Add operator is trivial if 470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// one of its operands is constant 0, a Mul operator is trivial 480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// if one of its operands is constant 1, etc. 490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { 500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto binary_it = model->operators.begin() + op_index; 510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* binary_op = binary_it->get(); 520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (binary_op->type != OperatorType::kAdd && 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle binary_op->type != OperatorType::kMul && 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle binary_op->type != OperatorType::kSub && 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle binary_op->type != OperatorType::kDiv) { 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK_EQ(binary_op->inputs.size(), 2); 600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // This graph transformation is only concerned with the case 620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // when one input is constant and the other is not constant. 630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const bool is_input_constant[2] = { 640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle IsConstantParameterArray(*model, binary_op->inputs[0]), 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle IsConstantParameterArray(*model, binary_op->inputs[1]), 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle }; 670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!is_input_constant[0] && !is_input_constant[1]) { 680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Neither input is constant, so nothing we can resolve here. 690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (is_input_constant[0] && is_input_constant[1]) { 720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Both inputs are constants. That's a job for constants 730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // propagation, not for us to handle here. 740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(is_input_constant[index_of_constant_input]); 790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(!is_input_constant[index_of_variable_input]); 800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Now check if the constant operand makes this binary 820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // operator trivial. 830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& constant_input_array = 84ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->GetArray(binary_op->inputs[index_of_constant_input]); 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // For now, we only handle floats here. 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (constant_input_array.data_type != ArrayDataType::kFloat) { 870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& constant_input_float_data = 900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; 910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bool is_trivial = false; 9265aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower if (binary_op->type == OperatorType::kAdd) { 930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); 9465aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower } else if (binary_op->type == OperatorType::kSub) { 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle is_trivial = index_of_constant_input == 1 && 960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); 9765aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower } else if (binary_op->type == OperatorType::kMul) { 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); 9965aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower } else if (binary_op->type == OperatorType::kDiv) { 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle is_trivial = index_of_constant_input == 1 && 1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); 1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!is_trivial) { 1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Now we know that this node is trivial, so we can remove it. 1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Removing trivial %s", LogName(*binary_op)); 1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return RemoveTrivialPassthroughOp(this, model, op_index); 1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace toco 114