1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <memory>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
22#include "tensorflow/contrib/lite/toco/model.h"
23#include "tensorflow/contrib/lite/toco/tooling_util.h"
24#include "tensorflow/core/platform/logging.h"
25
26namespace toco {
27
28namespace {
29
30bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model,
31                        const MinMax& minmax, const string& array_name) {
32  auto& annotated_array = model->GetArray(array_name);
33  if (annotated_array.minmax) {
34    return false;
35  }
36  annotated_array.GetOrCreateMinMax() = minmax;
37  transformation->AddMessageF(
38      "Read min/max annotation for array %s: min=%g, max=%g", array_name,
39      minmax.min, minmax.max);
40  return true;
41}
42
43}  // end namespace
44
45bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
46  const auto fakequant_it = model->operators.begin() + op_index;
47  auto* fakequant_base_op = fakequant_it->get();
48  if (fakequant_base_op->type != OperatorType::kFakeQuant) {
49    return false;
50  }
51  auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
52
53  bool changed = false;
54
55  if (!fakequant_op->minmax) {
56    CHECK_EQ(fakequant_op->inputs.size(), 3);
57    // We need to yield until the min and max parameters have been
58    // resolved to constant arrays.
59    for (int i = 1; i <= 2; i++) {
60      if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) {
61        return false;
62      }
63    }
64
65    // Obtain the final min/max values
66    const auto& min_array = model->GetArray(fakequant_op->inputs[1]);
67    const auto& max_array = model->GetArray(fakequant_op->inputs[2]);
68    CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1);
69    CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1);
70    fakequant_op->minmax.reset(new MinMax);
71    MinMax& minmax = *fakequant_op->minmax;
72    minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
73    minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
74    // We always want [min, max] to contain 0.
75    minmax.min = std::min(minmax.min, 0.);
76    minmax.max = std::max(minmax.max, 0.);
77
78    // We won't use the input arrays that provided these min and max
79    // values, anymore. Delete them unless they are used by something
80    // else.
81    for (int i = 1; i <= 2; i++) {
82      if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
83        model->EraseArray(fakequant_op->inputs[i]);
84      }
85    }
86    fakequant_op->inputs.resize(1);
87    changed = true;
88  }
89
90  // At this point, this FakeQuantOperator should have a MinMax
91  // attached to it, and should only have 1 input (it should not have
92  // 2nd and 3rd input arrays giving min and max anymore).
93  CHECK(fakequant_op->minmax);
94  CHECK_EQ(1, fakequant_op->inputs.size());
95
96  const MinMax& minmax = *fakequant_op->minmax;
97
98  // Record the MinMax info on the input and output arrays
99  changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]);
100  changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]);
101
102  return changed;
103}
104
105}  // namespace toco
106