10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <iterator>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selletemplate <typename Scalar>
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                 Scalar value) {
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (auto x : buffer_data) {
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (x != value) {
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      return false;
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// A binary operator is called trivial when exactly one of its operands is
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// a constant and is such that the binary operation is equivalent to
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// the identity operation on its other input.
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// For example, an Add operator is trivial if
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// one of its operands is constant 0, a Mul operator is trivial
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// if one of its operands is constant 1, etc.
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto binary_it = model->operators.begin() + op_index;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* binary_op = binary_it->get();
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (binary_op->type != OperatorType::kAdd &&
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      binary_op->type != OperatorType::kMul &&
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      binary_op->type != OperatorType::kSub &&
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      binary_op->type != OperatorType::kDiv) {
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK_EQ(binary_op->inputs.size(), 2);
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This graph transformation is only concerned with the case
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // when one input is constant and the other is not constant.
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const bool is_input_constant[2] = {
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      IsConstantParameterArray(*model, binary_op->inputs[0]),
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      IsConstantParameterArray(*model, binary_op->inputs[1]),
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  };
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!is_input_constant[0] && !is_input_constant[1]) {
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // Neither input is constant, so nothing we can resolve here.
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (is_input_constant[0] && is_input_constant[1]) {
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // Both inputs are constants. That's a job for constants
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // propagation, not for us to handle here.
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(is_input_constant[index_of_constant_input]);
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(!is_input_constant[index_of_variable_input]);
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Now check if the constant operand makes this binary
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // operator trivial.
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& constant_input_array =
84ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower      model->GetArray(binary_op->inputs[index_of_constant_input]);
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // For now, we only handle floats here.
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (constant_input_array.data_type != ArrayDataType::kFloat) {
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& constant_input_float_data =
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  bool is_trivial = false;
9265aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower  if (binary_op->type == OperatorType::kAdd) {
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
9465aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower  } else if (binary_op->type == OperatorType::kSub) {
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    is_trivial = index_of_constant_input == 1 &&
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                 AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
9765aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower  } else if (binary_op->type == OperatorType::kMul) {
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
9965aa9ee2500b0108d89f5fd1368ec3b73b273082A. Unique TensorFlower  } else if (binary_op->type == OperatorType::kDiv) {
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    is_trivial = index_of_constant_input == 1 &&
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                 AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!is_trivial) {
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Now we know that this node is trivial, so we can remove it.
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Removing trivial %s", LogName(*binary_op));
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return RemoveTrivialPassthroughOp(this, model, op_index);
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
114