10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <cmath>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<std::unique_ptr<Operator>>::iterator FindOperator(
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Model* model, const Operator* op) {
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto it = model->operators.begin();
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (; it != model->operators.end(); ++it) {
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (it->get() == op) {
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      break;
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return it;
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto div_it = model->operators.begin() + op_index;
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto* div_or_mul_op = div_it->get();
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  OperatorType expected_op_type_producing_div_or_mul_input;
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (div_or_mul_op->type == OperatorType::kDiv) {
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  } else if (div_or_mul_op->type == OperatorType::kMul) {
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    expected_op_type_producing_div_or_mul_input =
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        OperatorType::kTensorFlowRsqrt;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  } else {
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK_EQ(div_or_mul_op->inputs.size(), 2);
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* op_producing_div_or_mul_input[2] = {
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      GetOpWithOutput(*model, div_or_mul_op->inputs[0]),
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      GetOpWithOutput(*model, div_or_mul_op->inputs[1]),
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  };
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!op_producing_div_or_mul_input[1] ||
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      op_producing_div_or_mul_input[1]->type !=
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          expected_op_type_producing_div_or_mul_input) {
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* op_producing_sqrt_or_rsqrt_input =
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!op_producing_sqrt_or_rsqrt_input) {
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // There may be an Add or a Maximum here, adding or clamping to a "small"
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // constant scalar.
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Reported bug: b/29395854
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* add_op = nullptr;
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* op_producing_add_input = nullptr;
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      op_producing_sqrt_or_rsqrt_input->type ==
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          OperatorType::kTensorFlowMaximum) {
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    add_op = op_producing_sqrt_or_rsqrt_input;
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    bool add_can_be_removed = false;
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (int i = 0; i < 2; i++) {
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      const auto& input_array =
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]);
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (!input_array.buffer) {
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        continue;
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (input_array.buffer->type != ArrayDataType::kFloat) {
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        continue;
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (RequiredBufferSizeForShape(input_array.shape()) != 1) {
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        continue;
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      const auto& input_float_data =
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          input_array.GetBuffer<ArrayDataType::kFloat>().data;
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (std::abs(input_float_data[0]) > 1e-3f) {
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        continue;
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      add_can_be_removed = true;
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]);
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      break;
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (!add_can_be_removed) {
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      AddMessageF(
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          "Giving up trying to identify L2Normalization subgraph "
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          " because the operator producing the input to the square root, %s,"
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          ", does not match the expected pattern",
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          LogName(*op_producing_sqrt_or_rsqrt_input));
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      return false;
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* sum_op =
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (sum_op->type != OperatorType::kTensorFlowSum) {
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    AddMessageF(
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        "Giving up trying to identify L2Normalization subgraph: "
1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        "expected Sum op, got %s",
1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        LogName(*sum_op));
1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (square_op->type != OperatorType::kTensorFlowSquare) {
1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    AddMessageF(
1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        "Giving up trying to identify L2Normalization subgraph: "
1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        "expected Square op, got %s",
1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        LogName(*square_op));
1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK_EQ(square_op->inputs.size(), 1);
1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (square_op->inputs[0] != div_or_mul_op->inputs[0]) {
1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    AddMessageF(
1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        "Giving up trying to identify L2Normalization subgraph: %s does not "
1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        "take the same input as the Mul/Div node",
1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        LogName(*square_op));
1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Create and emplace the new L2Normalization
1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* l2norm_op = new L2NormalizationOperator;
1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  l2norm_op->inputs = {div_or_mul_op->inputs[0]};
1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  l2norm_op->outputs = div_or_mul_op->outputs;
1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.emplace(div_it, l2norm_op);
1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
1500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Erase the subgraph that is now replaced by L2Normalization
1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, square_op));
153ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(sum_op->inputs[0]);
1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (sum_op->inputs.size() > 1) {
155ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower    model->EraseArray(sum_op->inputs[1]);
1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, sum_op));
1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (add_op) {
159ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower    model->EraseArray(add_op->inputs[0]);
160ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower    model->EraseArray(add_op->inputs[1]);
1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    model->operators.erase(FindOperator(model, add_op));
1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
163ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(sqrt_or_rsqrt_op->inputs[0]);
1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
165ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(div_or_mul_op->inputs[1]);
1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(FindOperator(model, div_or_mul_op));
1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
171