1cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 3cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFloweryou may not use this file except in compliance with the License. 5cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerYou may obtain a copy of the License at 6cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 7cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 9cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerSee the License for the specific language governing permissions and 13cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerlimitations under the License. 14cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower==============================================================================*/ 15cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <iterator> 16cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <memory> 17cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <string> 18cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <unordered_map> 19cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include <vector> 20cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 21cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 22cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/model.h" 23cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/tooling_util.h" 24cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 25cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowernamespace toco { 26cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 27cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowernamespace { 28cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 29cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowertemplate <typename T> 30cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerbool AreAllBufferElementsZero(const std::vector<T>& buffer_data) { 31cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower for (auto x : buffer_data) { 32cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (x != 0) { 33cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 34cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 35cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 36cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return true; 37cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower} 38cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 39cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowertemplate <ArrayDataType Type> 40cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowervoid FillArrayWithZeros(Array* array) { 41cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower CHECK(array->data_type == Type); 42cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower std::vector<DataType<Type>>& data = array->GetMutableBuffer<Type>().data; 43cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower data.resize(RequiredBufferSizeForShape(array->shape())); 44cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower for (size_t i = 0; i < data.size(); i++) { 45cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower data[i] = 0; 46cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 47cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower} 48cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 49cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower} // namespace 50cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 51cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower// Removes a multiplication by array of constant zeros by making the output 52cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower// array an array of constant zeros and removing the input arrays if they are no 53cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower// longer needed. 54cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlowerbool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { 55cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto mul_it = model->operators.begin() + op_index; 56cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower auto* mul_op = mul_it->get(); 57cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (mul_op->type != OperatorType::kMul) { 58cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 59cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 60cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto& output_array_name = mul_op->outputs[0]; 61cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower auto& output_array = model->GetArray(output_array_name); 62cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 63cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // Yield if the output shape is not known yet. 64cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (!output_array.has_shape()) { 65cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 66cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 67cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 68cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // This transformation only handles the case where one operand is all 0's and 69cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // the other is non-constant. Other cases are handled by constant propagation 70cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // or the trivial binary removal pass. 71cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const bool is_input_constant[2] = { 72cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower IsConstantParameterArray(*model, mul_op->inputs[0]), 73cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower IsConstantParameterArray(*model, mul_op->inputs[1]), 74cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower }; 75cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (!is_input_constant[0] && !is_input_constant[1]) { 76cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // Neither input is constant, so nothing we can resolve here. 77cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 78cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 79cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (is_input_constant[0] && is_input_constant[1]) { 80cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // Both inputs are constants. That's a job for constants propagation, not 81cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // for us to handle here. 82cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 83cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 84cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const int index_of_constant_input = is_input_constant[0] ? 0 : 1; 85cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const int index_of_variable_input = is_input_constant[0] ? 1 : 0; 86cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower CHECK(is_input_constant[index_of_constant_input]); 87cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower CHECK(!is_input_constant[index_of_variable_input]); 88cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 89cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto& constant_input_array = 90cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower model->GetArray(mul_op->inputs[index_of_constant_input]); 91cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 92cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower CHECK(constant_input_array.data_type == output_array.data_type); 93cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower switch (output_array.data_type) { 94cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower case ArrayDataType::kFloat: { 95cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto& constant_input_data = 96cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; 97cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>( 98cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_data)) { 99cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 100cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 101cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower FillArrayWithZeros<ArrayDataType::kFloat>(&output_array); 102cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } break; 103cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower case ArrayDataType::kUint8: { 104cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto& constant_input_data = 105cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_array.GetBuffer<ArrayDataType::kUint8>().data; 106cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>( 107cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_data)) { 108cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 109cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 110cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower FillArrayWithZeros<ArrayDataType::kUint8>(&output_array); 111cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } break; 112cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower case ArrayDataType::kInt32: { 113cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto& constant_input_data = 114cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_array.GetBuffer<ArrayDataType::kInt32>().data; 115cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>( 116cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_data)) { 117cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 118cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 119cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower FillArrayWithZeros<ArrayDataType::kInt32>(&output_array); 120cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } break; 121cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower case ArrayDataType::kInt64: { 122cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower const auto& constant_input_data = 123cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_array.GetBuffer<ArrayDataType::kInt64>().data; 124cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>( 125cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower constant_input_data)) { 126cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 127cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 128cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower FillArrayWithZeros<ArrayDataType::kInt64>(&output_array); 129cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } break; 130cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower default: 131cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower AddMessageF( 132cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower "Cannot resolve multiply by 0 because of unsupported data type\n"); 133cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return false; 134cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 135cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 136cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // Erase input arrays to the multiply if no longer used 137cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (IsDiscardableArray(*model, mul_op->inputs[0]) && 138cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower CountOpsWithInput(*model, mul_op->inputs[0]) == 1) { 139cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower model->EraseArray(mul_op->inputs[0]); 140cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 141cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower if (IsDiscardableArray(*model, mul_op->inputs[1]) && 142cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower CountOpsWithInput(*model, mul_op->inputs[1]) == 1) { 143cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower model->EraseArray(mul_op->inputs[1]); 144cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower } 145cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 146cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower // Erase the multiply operator. 147cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower model->operators.erase(mul_it); 148cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 149cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower return true; 150cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower} 151cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower 152cf74c749aa9f7fb8eabb4a254c4f53cc2dbadae3A. Unique TensorFlower} // namespace toco 153