10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// This inserts an operator whose output is a float array (name:
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// flags.input_array()).  It has to wait for any existing operators that
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// generate this output to be removed by graph transformations.  Note that there
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// may be more than one operator that takes the input_array as their input, and
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// that some of these may be removed by graph transformations.
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                  GraphTransformation* transformation,
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                  Model* model) {
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // An operator with the required output may be a dequantize operator already
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // created.  Alternatively it may be an operator that needs to be removed
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // because it is unused, in which case we wait for RemoveUnusedOp to do its
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // work.
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (GetOpWithOutput(*model, input_name)) {
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // We only apply for the first operator if there is more than one.  This is
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // not strictly necessary for ordering correctness, since we insert the
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // dequant operator at the beginning of the op sequence, but it makes the
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // insertion more predictable (eg forward vs backwards operator sweep).
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (CountOpsWithInput(*model, input_name) > 1) {
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (op != GetFirstOpWithInput(*model, input_name)) {
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      return false;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& input_array = model->GetArray(input_name);
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (input_array.data_type != ArrayDataType::kFloat) {
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (input_array.final_data_type == input_array.data_type ||
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      input_array.final_data_type == ArrayDataType::kNone) {
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& dequantized_input_name =
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      AvailableArrayName(*model, input_name + "_dequantized");
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (auto& other_op : model->operators) {
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (string& other_op_input : other_op->inputs) {
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (other_op_input == input_name) {
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        other_op_input = dequantized_input_name;
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& dequantized_input_array =
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      model->GetOrCreateArray(dequantized_input_name);
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* image_input_op = new DequantizeOperator;
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  image_input_op->inputs = {input_name};
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  image_input_op->outputs = {dequantized_input_name};
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.emplace(model->operators.begin(), image_input_op);
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(input_array.final_data_type == ArrayDataType::kUint8);
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  input_array.data_type = ArrayDataType::kUint8;
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  dequantized_input_array.data_type = ArrayDataType::kFloat;
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& input_minmax = input_array.GetMinMax();
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  dequantized_input_minmax = input_minmax;
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& input_qparams = input_array.GetOrCreateQuantizationParams();
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      model->flags, input_minmax, &input_qparams);
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  transformation->AddMessageF(
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      "Created %s"
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      " to handle quantized input image data, taking over existing"
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      " mean_value and std_value flags. Cleared those flags.",
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      LogName(*image_input_op));
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // This is effectively a transformation applied to edges.  We iterate over the
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // specified node (op) and proceed for input edges.
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto it = model->operators.begin() + op_index;
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto* op = it->get();
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  bool change_made = false;
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (auto& input : op->inputs) {
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (auto& input_array : *model->flags.mutable_input_arrays()) {
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (input_array.name() == input) {
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) {
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          change_made = true;
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          input_array.clear_mean_value();
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          input_array.clear_std_value();
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        }
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return change_made;
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
121