1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <memory>
16#include <string>
17#include <unordered_map>
18#include <vector>
19
20#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
21#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
22#include "tensorflow/contrib/lite/toco/model.h"
23#include "tensorflow/contrib/lite/toco/tooling_util.h"
24#include "tensorflow/core/platform/logging.h"
25
26namespace toco {
27
28namespace {
29
30template <ArrayDataType A>
31void DequantizeBuffer(Array* array) {
32  const auto old_data = array->GetBuffer<A>().data;
33  array->buffer = nullptr;
34  array->data_type = ArrayDataType::kFloat;
35  auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
36  new_data.resize(old_data.size());
37  const auto& qparams = array->GetQuantizationParams();
38  for (int i = 0; i < old_data.size(); i++) {
39    new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
40  }
41}
42
43std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
44    Model* model, const string& array_name) {
45  for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
46    for (const auto& input : it->get()->inputs) {
47      if (input == array_name) {
48        return it;
49      }
50    }
51  }
52  return model->operators.end();
53}
54
55void ClearArrayQuantizationParams(const string& array_name, Model* model) {
56  auto* array = &model->GetArray(array_name);
57  CHECK(array->quantization_params);
58  for (auto& input_array : *model->flags.mutable_input_arrays()) {
59    if (input_array.name() == array_name) {
60      auto& qparams = *array->quantization_params;
61      const double new_std_value = 1. / qparams.scale;
62      const double new_mean_value = qparams.zero_point;
63      if (input_array.has_std_value()) {
64        CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
65      } else {
66        input_array.set_std_value(new_std_value);
67      }
68      if (input_array.has_mean_value()) {
69        CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
70      } else {
71        input_array.set_mean_value(new_mean_value);
72      }
73    }
74  }
75  array->quantization_params = nullptr;
76}
77
78bool DequantizeArray(const string& array_name,
79                     GraphTransformation* transformation, Model* model) {
80  auto* array = &model->GetArray(array_name);
81  if (!array->quantization_params) {
82    return false;
83  }
84  transformation->AddMessageF("Dequantizing array: %s", array_name);
85
86  // Dequantize any buffer
87  if (array->buffer) {
88    if (array->data_type == ArrayDataType::kUint8) {
89      DequantizeBuffer<ArrayDataType::kUint8>(array);
90    } else if (array->data_type == ArrayDataType::kInt32) {
91      DequantizeBuffer<ArrayDataType::kInt32>(array);
92    } else {
93      LOG(FATAL) << "Unhandled data type";
94    }
95    CHECK(array->data_type == ArrayDataType::kFloat);
96    CHECK(array->buffer->type == ArrayDataType::kFloat);
97
98    // Clear quantization params, officially makes this a non-quantized array.
99    ClearArrayQuantizationParams(array_name, model);
100    return true;
101  } else {
102    array->data_type = ArrayDataType::kFloat;
103  }
104
105  // Clear quantization params, officially makes this a non-quantized array.
106  ClearArrayQuantizationParams(array_name, model);
107
108  if (array->buffer) {
109    return true;
110  }
111
112  auto* op_outputting_array = GetOpWithOutput(*model, array_name);
113  if (op_outputting_array) {
114    if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
115      return true;
116    }
117  }
118
119  // If there was no minmax info, we can return now. Indeed,
120  // the below only serves to create a FakeQuant node, but some arrays are
121  // quantized without MinMax (see the CHECK above) and that corresponds to
122  // places where a FakeQuant node is actually not wanted, because the
123  // quantization params are meant to be inferred in another way (e.g. bias
124  // vector for a Conv op, see their special-casing in quantize.cc).
125  if (!array->minmax) {
126    return true;
127  }
128
129  // Determine whether to insert a FakeQuant before or after
130  // this array.
131  bool must_insert_fakequant_before = false;
132  bool must_insert_fakequant_after = false;
133  if (IsInputArray(*model, array_name)) {
134    must_insert_fakequant_after = true;
135  }
136  for (const string& output_array : model->flags.output_arrays()) {
137    if (array_name == output_array) {
138      must_insert_fakequant_before = true;
139    }
140  }
141  for (const auto& rnn_state : model->flags.rnn_states()) {
142    if (array_name == rnn_state.state_array()) {
143      must_insert_fakequant_after = true;
144    }
145    if (array_name == rnn_state.back_edge_source_array()) {
146      must_insert_fakequant_before = true;
147    }
148  }
149  CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
150
151  // Create and insert the FakeQuant node
152  auto* fakequant_op = new FakeQuantOperator;
153  model->operators.emplace(FindFirstOpWithInput(model, array_name),
154                           fakequant_op);
155  const string& new_array_name = AvailableArrayName(*model, array_name);
156  auto& new_array = model->GetOrCreateArray(new_array_name);
157  new_array.data_type = ArrayDataType::kFloat;
158  new_array.copy_shape(array->shape());
159  new_array.GetOrCreateMinMax() = array->GetMinMax();
160  fakequant_op->minmax.reset(new MinMax);
161  *fakequant_op->minmax = array->GetMinMax();
162  if (must_insert_fakequant_before) {
163    for (const auto& op : model->operators) {
164      for (string& output : op->outputs) {
165        if (output == array_name) {
166          output = new_array_name;
167        }
168      }
169    }
170    fakequant_op->inputs = {new_array_name};
171    fakequant_op->outputs = {array_name};
172  } else {
173    for (const auto& op : model->operators) {
174      for (string& input : op->inputs) {
175        if (input == array_name) {
176          input = new_array_name;
177        }
178      }
179    }
180    fakequant_op->inputs = {array_name};
181    fakequant_op->outputs = {new_array_name};
182  }
183  return true;
184}
185
186}  // namespace
187
188bool Dequantize::Run(Model* model, std::size_t op_index) {
189  const auto op_it = model->operators.begin() + op_index;
190  auto* op = op_it->get();
191
192  if (op->type == OperatorType::kDequantize) {
193    auto& input_array = model->GetArray(op->inputs[0]);
194    if (input_array.data_type == ArrayDataType::kFloat) {
195      return false;
196    }
197    if (input_array.final_data_type != ArrayDataType::kFloat) {
198      return false;
199    }
200    input_array.data_type = ArrayDataType::kFloat;
201    input_array.quantization_params = nullptr;
202    auto& output_array = model->GetArray(op->outputs[0]);
203    output_array.data_type = ArrayDataType::kFloat;
204    output_array.quantization_params = nullptr;
205    return RemoveTrivialPassthroughOp(this, model, op_index);
206  }
207
208  std::vector<string> arrays;
209  for (const string& input : op->inputs) {
210    arrays.push_back(input);
211  }
212  for (const string& output : op->outputs) {
213    arrays.push_back(output);
214  }
215  bool changed = false;
216  for (const string& array : arrays) {
217    if (!model->IsOptionalArray(array)) {
218      changed |= DequantizeArray(array, this, model);
219    }
220  }
221
222  return changed;
223}
224
225}  // namespace toco
226