10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto fakequant_it = model->operators.begin() + op_index;
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto* fakequant_base_op = fakequant_it->get();
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (fakequant_base_op->type != OperatorType::kFakeQuant) {
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto* fakequant_op =
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      static_cast<const FakeQuantOperator*>(fakequant_base_op);
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Yield until the fakequant MinMax has been resolved.
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!fakequant_op->minmax) {
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This transformation only applies when the input array is constant.
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& output_array = model->GetArray(fakequant_op->outputs[0]);
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(input_array.data_type == ArrayDataType::kFloat);
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  output_array.data_type = ArrayDataType::kFloat;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(!output_array.buffer);
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
5360d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower  output_array.GetOrCreateMinMax() = *fakequant_op->minmax;
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kFloat>();
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int size = input_buffer.data.size();
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  output_buffer.data.resize(size);
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  QuantizationParams qparams;
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      model->flags, *fakequant_op->minmax, &qparams);
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (int i = 0; i < size; i++) {
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const double src_val = input_buffer.data[i];
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const double unclamped_quantized_val =
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        std::round(qparams.zero_point + src_val / qparams.scale);
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const double quantized_val =
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        std::min(255., std::max(0., unclamped_quantized_val));
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    output_buffer.data[i] = dst_val;
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
70ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower    model->EraseArray(fakequant_op->inputs[0]);
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(fakequant_it);
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
78