1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <vector> 16 17#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 18#include "tensorflow/contrib/lite/toco/model.h" 19#include "tensorflow/contrib/lite/toco/tooling_util.h" 20#include "tensorflow/core/platform/logging.h" 21 22namespace toco { 23 24template <ArrayDataType Type> 25bool ComputeFillArray(Model* model, FillOperator* op) { 26 const auto& val_array = model->GetArray(op->inputs[1]); 27 auto& output_array = model->GetArray(op->outputs[0]); 28 29 CHECK(val_array.data_type == Type); 30 CHECK(output_array.data_type == Type); 31 32 // Compute the array data 33 std::vector<DataType<Type>>& data = 34 output_array.GetMutableBuffer<Type>().data; 35 data.resize(RequiredBufferSizeForShape(output_array.shape())); 36 DataType<Type> fill_val = val_array.GetBuffer<Type>().data[0]; 37 for (size_t i = 0; i < data.size(); i++) { 38 data[i] = fill_val; 39 } 40 41 return true; 42} 43 44bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { 45 const auto fill_it = model->operators.begin() + op_index; 46 auto* base_op = fill_it->get(); 47 if (base_op->type != OperatorType::kFill) { 48 return false; 49 } 50 auto* op = static_cast<FillOperator*>(base_op); 51 52 CHECK_EQ(op->inputs.size(), 2); 53 CHECK_EQ(op->outputs.size(), 1); 54 55 auto& output_array = model->GetArray(op->outputs[0]); 56 if (output_array.data_type == ArrayDataType::kNone) { 57 // Yield until the output type has been set by PropagateArrayDataTypes 58 return false; 59 } 60 61 if (!output_array.has_shape()) { 62 // Yield until the output shape has been set by PropagateFixedShapes 63 return false; 64 } 65 66 const auto& val_array = model->GetArray(op->inputs[1]); 67 if (!val_array.has_shape()) { 68 // Yield until the value shape has been resolved. 69 return false; 70 } 71 if (!IsConstantParameterArray(*model, op->inputs[1])) { 72 // Yield until the value is constant. 73 return false; 74 } 75 CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1); 76 77 switch (output_array.data_type) { 78 case ArrayDataType::kFloat: 79 if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) { 80 return false; 81 } 82 break; 83 case ArrayDataType::kUint8: 84 if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) { 85 return false; 86 } 87 break; 88 case ArrayDataType::kInt32: 89 if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) { 90 return false; 91 } 92 break; 93 case ArrayDataType::kInt64: 94 if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) { 95 return false; 96 } 97 break; 98 default: 99 LOG(FATAL) << "Unsupported data type given to Fill op with output \"" 100 << op->outputs[0] << "\""; 101 break; 102 } 103 104 // Erase input arrays if no longer used 105 if (IsDiscardableArray(*model, op->inputs[0]) && 106 CountOpsWithInput(*model, op->inputs[0]) == 1) { 107 model->EraseArray(op->inputs[0]); 108 } 109 if (IsDiscardableArray(*model, op->inputs[1]) && 110 CountOpsWithInput(*model, op->inputs[1]) == 1) { 111 model->EraseArray(op->inputs[1]); 112 } 113 114 // Erase the operator 115 model->operators.erase(fill_it); 116 117 return true; 118} 119 120} // namespace toco 121