1cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
3cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFloweryou may not use this file except in compliance with the License.
5cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerYou may obtain a copy of the License at
6cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
7cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
9cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerSee the License for the specific language governing permissions and
13cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerlimitations under the License.
14cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower==============================================================================*/
15cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <iterator>
16cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <memory>
17cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <string>
18cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <unordered_map>
19cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <vector>
20cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
21cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
22cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/model.h"
23cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/tooling_util.h"
24cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
25cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowernamespace toco {
26cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
27cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowernamespace {
28cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
29cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowertemplate <typename T>
30cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerbool AreAllBufferElementsZero(const std::vector<T>& buffer_data) {
31cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  for (auto x : buffer_data) {
32cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    if (x != 0) {
33cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      return false;
34cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    }
35cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
36cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  return true;
37cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower}
38cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
39cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowertemplate <ArrayDataType Type>
40cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowervoid FillArrayWithZeros(Array* array) {
41cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  CHECK(array->data_type == Type);
42cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  std::vector<DataType<Type>>& data = array->GetMutableBuffer<Type>().data;
43cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  data.resize(RequiredBufferSizeForShape(array->shape()));
44cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  for (size_t i = 0; i < data.size(); i++) {
45cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    data[i] = 0;
46cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
47cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower}
48cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
49cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower}  // namespace
50cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
51cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower// Removes a multiplication by array of constant zeros by making the output
52cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower// array an array of constant zeros and removing the input arrays if they are no
53cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower// longer needed.
54cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerbool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
55cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  const auto mul_it = model->operators.begin() + op_index;
56cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  auto* mul_op = mul_it->get();
57cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  if (mul_op->type != OperatorType::kMul) {
58cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    return false;
59cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
60cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  const auto& output_array_name = mul_op->outputs[0];
61cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  auto& output_array = model->GetArray(output_array_name);
62cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
63cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  // Yield if the output shape is not known yet.
64cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  if (!output_array.has_shape()) {
65cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    return false;
66cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
67cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
68cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  // This transformation only handles the case where one operand is all 0's and
69cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  // the other is non-constant. Other cases are handled by constant propagation
70cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  // or the trivial binary removal pass.
71cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  const bool is_input_constant[2] = {
72cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      IsConstantParameterArray(*model, mul_op->inputs[0]),
73cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      IsConstantParameterArray(*model, mul_op->inputs[1]),
74cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  };
75cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  if (!is_input_constant[0] && !is_input_constant[1]) {
76cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    // Neither input is constant, so nothing we can resolve here.
77cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    return false;
78cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
79cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  if (is_input_constant[0] && is_input_constant[1]) {
80cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    // Both inputs are constants. That's a job for constants propagation, not
81cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    // for us to handle here.
82cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    return false;
83cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
84cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
85cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
86cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  CHECK(is_input_constant[index_of_constant_input]);
87cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  CHECK(!is_input_constant[index_of_variable_input]);
88cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
89cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  const auto& constant_input_array =
90cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      model->GetArray(mul_op->inputs[index_of_constant_input]);
91cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
92cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  CHECK(constant_input_array.data_type == output_array.data_type);
93cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  switch (output_array.data_type) {
94cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    case ArrayDataType::kFloat: {
95cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      const auto& constant_input_data =
96cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower          constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
97cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
98cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower              constant_input_data)) {
99cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower        return false;
100cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      }
101cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
102cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    } break;
103cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    case ArrayDataType::kUint8: {
104cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      const auto& constant_input_data =
105cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower          constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
106cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
107cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower              constant_input_data)) {
108cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower        return false;
109cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      }
110cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
111cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    } break;
112cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    case ArrayDataType::kInt32: {
113cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      const auto& constant_input_data =
114cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower          constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
115cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
116cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower              constant_input_data)) {
117cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower        return false;
118cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      }
119cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
120cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    } break;
121cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    case ArrayDataType::kInt64: {
122cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      const auto& constant_input_data =
123cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower          constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
124cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
125cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower              constant_input_data)) {
126cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower        return false;
127cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      }
128cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
129cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    } break;
130cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    default:
131cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      AddMessageF(
132cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower          "Cannot resolve multiply by 0 because of unsupported data type\n");
133cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      return false;
134cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
135cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
136cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  // Erase input arrays to the multiply if no longer used
137cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  if (IsDiscardableArray(*model, mul_op->inputs[0]) &&
138cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      CountOpsWithInput(*model, mul_op->inputs[0]) == 1) {
139cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    model->EraseArray(mul_op->inputs[0]);
140cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
141cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  if (IsDiscardableArray(*model, mul_op->inputs[1]) &&
142cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower      CountOpsWithInput(*model, mul_op->inputs[1]) == 1) {
143cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower    model->EraseArray(mul_op->inputs[1]);
144cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  }
145cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
146cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  // Erase the multiply operator.
147cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  model->operators.erase(mul_it);
148cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
149cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower  return true;
150cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower}
151cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower
152cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower}  // namespace toco
153