1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <algorithm> 16#include <memory> 17#include <string> 18#include <unordered_map> 19#include <vector> 20 21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 22#include "tensorflow/contrib/lite/toco/model.h" 23#include "tensorflow/contrib/lite/toco/tooling_util.h" 24#include "tensorflow/core/platform/logging.h" 25 26namespace toco { 27 28namespace { 29 30bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model, 31 const MinMax& minmax, const string& array_name) { 32 auto& annotated_array = model->GetArray(array_name); 33 if (annotated_array.minmax) { 34 return false; 35 } 36 annotated_array.GetOrCreateMinMax() = minmax; 37 transformation->AddMessageF( 38 "Read min/max annotation for array %s: min=%g, max=%g", array_name, 39 minmax.min, minmax.max); 40 return true; 41} 42 43} // end namespace 44 45bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { 46 const auto fakequant_it = model->operators.begin() + op_index; 47 auto* fakequant_base_op = fakequant_it->get(); 48 if (fakequant_base_op->type != OperatorType::kFakeQuant) { 49 return false; 50 } 51 auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op); 52 53 bool changed = false; 54 55 if (!fakequant_op->minmax) { 56 CHECK_EQ(fakequant_op->inputs.size(), 3); 57 // We need to yield until the min and max parameters have been 58 // resolved to constant arrays. 59 for (int i = 1; i <= 2; i++) { 60 if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) { 61 return false; 62 } 63 } 64 65 // Obtain the final min/max values 66 const auto& min_array = model->GetArray(fakequant_op->inputs[1]); 67 const auto& max_array = model->GetArray(fakequant_op->inputs[2]); 68 CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); 69 CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); 70 fakequant_op->minmax.reset(new MinMax); 71 MinMax& minmax = *fakequant_op->minmax; 72 minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0]; 73 minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0]; 74 // We always want [min, max] to contain 0. 75 minmax.min = std::min(minmax.min, 0.); 76 minmax.max = std::max(minmax.max, 0.); 77 78 // We won't use the input arrays that provided these min and max 79 // values, anymore. Delete them unless they are used by something 80 // else. 81 for (int i = 1; i <= 2; i++) { 82 if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { 83 model->EraseArray(fakequant_op->inputs[i]); 84 } 85 } 86 fakequant_op->inputs.resize(1); 87 changed = true; 88 } 89 90 // At this point, this FakeQuantOperator should have a MinMax 91 // attached to it, and should only have 1 input (it should not have 92 // 2nd and 3rd input arrays giving min and max anymore). 93 CHECK(fakequant_op->minmax); 94 CHECK_EQ(1, fakequant_op->inputs.size()); 95 96 const MinMax& minmax = *fakequant_op->minmax; 97 98 // Record the MinMax info on the input and output arrays 99 changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]); 100 changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]); 101 102 return changed; 103} 104 105} // namespace toco 106