1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cmath>
16#include <memory>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
22#include "tensorflow/contrib/lite/toco/model.h"
23#include "tensorflow/contrib/lite/toco/tooling_util.h"
24#include "tensorflow/core/platform/logging.h"
25
26namespace toco {
27
28namespace {
29
30std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
31    Model* model, const Operator* op) {
32  auto it = model->operators.begin();
33  for (; it != model->operators.end(); ++it) {
34    if (it->get() == op) {
35      break;
36    }
37  }
38  return it;
39}
40}  // namespace
41
42bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
43  const auto div_it = model->operators.begin() + op_index;
44  const auto* div_or_mul_op = div_it->get();
45  OperatorType expected_op_type_producing_div_or_mul_input;
46  if (div_or_mul_op->type == OperatorType::kDiv) {
47    expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
48  } else if (div_or_mul_op->type == OperatorType::kMul) {
49    expected_op_type_producing_div_or_mul_input =
50        OperatorType::kTensorFlowRsqrt;
51  } else {
52    return false;
53  }
54  CHECK_EQ(div_or_mul_op->inputs.size(), 2);
55  Operator* op_producing_div_or_mul_input[2] = {
56      GetOpWithOutput(*model, div_or_mul_op->inputs[0]),
57      GetOpWithOutput(*model, div_or_mul_op->inputs[1]),
58  };
59  if (!op_producing_div_or_mul_input[1] ||
60      op_producing_div_or_mul_input[1]->type !=
61          expected_op_type_producing_div_or_mul_input) {
62    return false;
63  }
64  Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
65  CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
66  Operator* op_producing_sqrt_or_rsqrt_input =
67      GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
68  if (!op_producing_sqrt_or_rsqrt_input) {
69    return false;
70  }
71
72  // There may be an Add or a Maximum here, adding or clamping to a "small"
73  // constant scalar.
74  // Reported bug: b/29395854
75  Operator* add_op = nullptr;
76  Operator* op_producing_add_input = nullptr;
77  if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
78      op_producing_sqrt_or_rsqrt_input->type ==
79          OperatorType::kTensorFlowMaximum) {
80    add_op = op_producing_sqrt_or_rsqrt_input;
81    bool add_can_be_removed = false;
82    CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
83    for (int i = 0; i < 2; i++) {
84      const auto& input_array =
85          model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]);
86      if (!input_array.buffer) {
87        continue;
88      }
89      if (input_array.buffer->type != ArrayDataType::kFloat) {
90        continue;
91      }
92      if (RequiredBufferSizeForShape(input_array.shape()) != 1) {
93        continue;
94      }
95      const auto& input_float_data =
96          input_array.GetBuffer<ArrayDataType::kFloat>().data;
97      if (std::abs(input_float_data[0]) > 1e-3f) {
98        continue;
99      }
100      add_can_be_removed = true;
101      op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]);
102      break;
103    }
104    if (!add_can_be_removed) {
105      AddMessageF(
106          "Giving up trying to identify L2Normalization subgraph "
107          " because the operator producing the input to the square root, %s,"
108          ", does not match the expected pattern",
109          LogName(*op_producing_sqrt_or_rsqrt_input));
110      return false;
111    }
112  }
113
114  Operator* sum_op =
115      add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
116  if (sum_op->type != OperatorType::kTensorFlowSum) {
117    AddMessageF(
118        "Giving up trying to identify L2Normalization subgraph: "
119        "expected Sum op, got %s",
120        LogName(*sum_op));
121    return false;
122  }
123
124  Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
125  if (square_op->type != OperatorType::kTensorFlowSquare) {
126    AddMessageF(
127        "Giving up trying to identify L2Normalization subgraph: "
128        "expected Square op, got %s",
129        LogName(*square_op));
130    return false;
131  }
132
133  CHECK_EQ(square_op->inputs.size(), 1);
134
135  if (square_op->inputs[0] != div_or_mul_op->inputs[0]) {
136    AddMessageF(
137        "Giving up trying to identify L2Normalization subgraph: %s does not "
138        "take the same input as the Mul/Div node",
139        LogName(*square_op));
140    return false;
141  }
142
143  // Create and emplace the new L2Normalization
144  auto* l2norm_op = new L2NormalizationOperator;
145  l2norm_op->inputs = {div_or_mul_op->inputs[0]};
146  l2norm_op->outputs = div_or_mul_op->outputs;
147  model->operators.emplace(div_it, l2norm_op);
148
149  AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
150
151  // Erase the subgraph that is now replaced by L2Normalization
152  model->operators.erase(FindOperator(model, square_op));
153  model->EraseArray(sum_op->inputs[0]);
154  if (sum_op->inputs.size() > 1) {
155    model->EraseArray(sum_op->inputs[1]);
156  }
157  model->operators.erase(FindOperator(model, sum_op));
158  if (add_op) {
159    model->EraseArray(add_op->inputs[0]);
160    model->EraseArray(add_op->inputs[1]);
161    model->operators.erase(FindOperator(model, add_op));
162  }
163  model->EraseArray(sqrt_or_rsqrt_op->inputs[0]);
164  model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
165  model->EraseArray(div_or_mul_op->inputs[1]);
166  model->operators.erase(FindOperator(model, div_or_mul_op));
167  return true;
168}
169
170}  // namespace toco
171