1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <iterator>
16#include <memory>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
22#include "tensorflow/contrib/lite/toco/model.h"
23#include "tensorflow/contrib/lite/toco/tooling_util.h"
24
25namespace toco {
26
27namespace {
28
29template <typename T>
30bool AreAllBufferElementsZero(const std::vector<T>& buffer_data) {
31  for (auto x : buffer_data) {
32    if (x != 0) {
33      return false;
34    }
35  }
36  return true;
37}
38
39template <ArrayDataType Type>
40void FillArrayWithZeros(Array* array) {
41  CHECK(array->data_type == Type);
42  std::vector<DataType<Type>>& data = array->GetMutableBuffer<Type>().data;
43  data.resize(RequiredBufferSizeForShape(array->shape()));
44  for (size_t i = 0; i < data.size(); i++) {
45    data[i] = 0;
46  }
47}
48
49}  // namespace
50
51// Removes a multiplication by array of constant zeros by making the output
52// array an array of constant zeros and removing the input arrays if they are no
53// longer needed.
54bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
55  const auto mul_it = model->operators.begin() + op_index;
56  auto* mul_op = mul_it->get();
57  if (mul_op->type != OperatorType::kMul) {
58    return false;
59  }
60  const auto& output_array_name = mul_op->outputs[0];
61  auto& output_array = model->GetArray(output_array_name);
62
63  // Yield if the output shape is not known yet.
64  if (!output_array.has_shape()) {
65    return false;
66  }
67
68  // This transformation only handles the case where one operand is all 0's and
69  // the other is non-constant. Other cases are handled by constant propagation
70  // or the trivial binary removal pass.
71  const bool is_input_constant[2] = {
72      IsConstantParameterArray(*model, mul_op->inputs[0]),
73      IsConstantParameterArray(*model, mul_op->inputs[1]),
74  };
75  if (!is_input_constant[0] && !is_input_constant[1]) {
76    // Neither input is constant, so nothing we can resolve here.
77    return false;
78  }
79  if (is_input_constant[0] && is_input_constant[1]) {
80    // Both inputs are constants. That's a job for constants propagation, not
81    // for us to handle here.
82    return false;
83  }
84  const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
85  const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
86  CHECK(is_input_constant[index_of_constant_input]);
87  CHECK(!is_input_constant[index_of_variable_input]);
88
89  const auto& constant_input_array =
90      model->GetArray(mul_op->inputs[index_of_constant_input]);
91
92  CHECK(constant_input_array.data_type == output_array.data_type);
93  switch (output_array.data_type) {
94    case ArrayDataType::kFloat: {
95      const auto& constant_input_data =
96          constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
97      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
98              constant_input_data)) {
99        return false;
100      }
101      FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
102    } break;
103    case ArrayDataType::kUint8: {
104      const auto& constant_input_data =
105          constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
106      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
107              constant_input_data)) {
108        return false;
109      }
110      FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
111    } break;
112    case ArrayDataType::kInt32: {
113      const auto& constant_input_data =
114          constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
115      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
116              constant_input_data)) {
117        return false;
118      }
119      FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
120    } break;
121    case ArrayDataType::kInt64: {
122      const auto& constant_input_data =
123          constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
124      if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
125              constant_input_data)) {
126        return false;
127      }
128      FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
129    } break;
130    default:
131      AddMessageF(
132          "Cannot resolve multiply by 0 because of unsupported data type\n");
133      return false;
134  }
135
136  // Erase input arrays to the multiply if no longer used
137  if (IsDiscardableArray(*model, mul_op->inputs[0]) &&
138      CountOpsWithInput(*model, mul_op->inputs[0]) == 1) {
139    model->EraseArray(mul_op->inputs[0]);
140  }
141  if (IsDiscardableArray(*model, mul_op->inputs[1]) &&
142      CountOpsWithInput(*model, mul_op->inputs[1]) == 1) {
143    model->EraseArray(mul_op->inputs[1]);
144  }
145
146  // Erase the multiply operator.
147  model->operators.erase(mul_it);
148
149  return true;
150}
151
152}  // namespace toco
153