1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/toco_tooling.h"
16
17#include <cstdlib>
18#include <memory>
19#include <set>
20
21#include "absl/strings/str_join.h"
22#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h"
23#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
24#include "tensorflow/contrib/lite/toco/export_tensorflow.h"
25#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
26#include "tensorflow/contrib/lite/toco/import_tensorflow.h"
27#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
28#include "tensorflow/contrib/lite/toco/tflite/export.h"
29#include "tensorflow/contrib/lite/toco/tflite/import.h"
30#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
31#include "tensorflow/contrib/lite/toco/tooling_util.h"
32#include "tensorflow/core/platform/logging.h"
33
34namespace toco {
35namespace {
36// CHECK-fails if the model contains a kTensorFlowUnsupported operation.
37void CheckUnsupportedOperations(const Model& model) {
38  std::set<string> unsupported_ops;
39  for (auto& op : model.operators) {
40    if (op->type == OperatorType::kTensorFlowUnsupported) {
41      unsupported_ops.insert(
42          static_cast<const TensorFlowUnsupportedOperator*>(op.get())
43              ->tensorflow_op);
44    }
45  }
46  QCHECK(unsupported_ops.empty())
47      << "These unsupported ops were not removed by graph transformations: "
48      << absl::StrJoin(unsupported_ops, ", ");
49}
50
51void MakeGeneralGraphTransformationsSet(
52    GraphTransformationsSet* transformations) {
53  CHECK(transformations->empty());
54  transformations->Add(new ConvertExpandDimsToReshape);
55  transformations->Add(new ConvertTrivialAddNToAdd);
56  transformations->Add(new ConvertTrivialStackToReshape);
57  transformations->Add(new ConvertTrivialTransposeToReshape);
58  transformations->Add(new ConvertReorderAxes);
59  transformations->Add(new ResolveReshapeAttributes);
60  transformations->Add(new ResolveTransposeAttributes);
61  transformations->Add(new PropagateArrayDataTypes);
62  transformations->Add(new PropagateFixedSizes);
63  transformations->Add(new RemoveTensorFlowAssert);
64  transformations->Add(new RemoveTensorFlowIdentity);
65  transformations->Add(new RemoveTrivialConcatenation);
66  transformations->Add(new RemoveTrivialConcatenationInput);
67  transformations->Add(new RemoveTrivialSlice);
68  transformations->Add(new RemoveUnusedOp);
69  transformations->Add(new EnsureBiasVectors);
70  transformations->Add(new ResolveReorderAxes);
71  transformations->Add(new UnrollBatchMatMul);
72  transformations->Add(new ResolveTensorFlowMatMul);
73  transformations->Add(new FuseBinaryIntoPrecedingAffine);
74  transformations->Add(new FuseBinaryIntoFollowingAffine);
75  transformations->Add(new ReorderActivationFunctions);
76  transformations->Add(new ResolveBatchNormalization);
77  transformations->Add(new ResolveConstantBinaryOperator);
78  transformations->Add(new ResolveConstantFill);
79  transformations->Add(new ResolveConstantRange);
80  transformations->Add(new ResolveConstantStack);
81  transformations->Add(new ResolveConstantStridedSlice);
82  transformations->Add(new ResolveConstantTranspose);
83  transformations->Add(new ResolveConstantUnaryOperator);
84  transformations->Add(new ResolveTensorFlowMerge);
85  transformations->Add(new ResolveSqueezeAttributes);
86  transformations->Add(new ResolveTensorFlowSwitch);
87  transformations->Add(new ResolveTensorFlowTile);
88  transformations->Add(new ResolveTensorFlowConcat);
89  transformations->Add(new ResolveMultiplyByZero);
90  transformations->Add(new IdentifyL2Normalization);
91  transformations->Add(new IdentifyL2Pool);
92  transformations->Add(new IdentifyRelu1);
93  transformations->Add(new RemoveTrivialBinaryOperator);
94  transformations->Add(new ReadFakeQuantMinMax);
95  transformations->Add(new ResolveSpaceToBatchNDAttributes);
96  transformations->Add(new ResolveBatchToSpaceNDAttributes);
97  transformations->Add(new ResolvePadAttributes);
98  transformations->Add(new ResolveStridedSliceAttributes);
99  transformations->Add(new ResolveSliceAttributes);
100  transformations->Add(new ResolveMeanAttributes);
101  transformations->Add(new ResolveConstantShapeOrRank);
102  transformations->Add(new MakeInitialDequantizeOperator);
103  transformations->Add(new ResolveConstantFakeQuant);
104}
105
106bool SupportsQuantization(FileFormat format) {
107  return (format == GRAPHVIZ_DOT || format == TFLITE);
108}
109
110bool SupportsFusedActivationFunction(FileFormat format) {
111  return (format == GRAPHVIZ_DOT || format == TFLITE);
112}
113
114bool SupportsLstmCell(FileFormat format) {
115  return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
116          format == TFLITE);
117}
118
119bool SupportsPreallocatedWorkspace(FileFormat format) {
120  return (format == TFLITE);
121}
122
123bool IsRealValued(toco::ArrayDataType type) {
124  return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
125                           type == toco::ArrayDataType::kUint8);
126}
127
128void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
129  const FileFormat output_format = toco_flags.output_format();
130  ArrayDataType type;
131  if (toco_flags.has_inference_input_type()) {
132    type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
133  } else if (toco_flags.has_inference_type()) {
134    type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
135  } else if (!SupportsQuantization(output_format)) {
136    // Data type is implicitly float for non-quantized formats
137    type = ArrayDataType::kFloat;
138  } else {
139    // Nothing to do. Data types stay as-is.
140    return;
141  }
142
143  for (int i = 0; i < model->flags.input_arrays_size(); i++) {
144    string const& array_name = model->flags.input_arrays(i).name();
145    auto* array = &model->GetArray(array_name);
146    // Note that the notion of changing data types only applies to real-numbers
147    // arrays (see the documentation for inference_input_type).
148    // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
149    // i.e. represent real numbers by means of quantization parameters,
150    // and not plain integer uint8 input arrays.
151    if (!IsRealValued(array->data_type)) {
152      // Ignore non-real data types.
153      continue;
154    }
155
156    array->final_data_type = type;
157  }
158}
159
160}  // namespace
161
162std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
163                              const ModelFlags& model_flags,
164                              const string& input_file_contents) {
165  std::unique_ptr<Model> model;
166  switch (toco_flags.input_format()) {
167    case TENSORFLOW_GRAPHDEF: {
168      TensorFlowImportFlags tf_import_flags;
169      tf_import_flags.drop_control_dependency =
170          toco_flags.has_drop_control_dependency()
171              ? toco_flags.drop_control_dependency()
172              : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
173      model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
174                                       input_file_contents);
175      break;
176    }
177    case TFLITE:
178      model = toco::tflite::Import(model_flags, input_file_contents);
179      ResolveModelFlags(model_flags, model.get());
180      CheckInvariants(*model);
181      break;
182    default:
183      LOG(FATAL) << "Unhandled input_format";
184  }
185
186  LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
187
188  return model;
189}
190
191void Transform(const TocoFlags& toco_flags, Model* model) {
192  // Clean up after import.
193  SetFinalDataTypeOnInputs(toco_flags, model);
194  UseArraysExtraInfo(model);
195  FinishBuildingRNNStates(model);
196
197  const FileFormat output_format = toco_flags.output_format();
198  const IODataType inference_type = toco_flags.inference_type();
199
200  const bool quantize_output =
201      SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8;
202
203  if (quantize_output) {
204    QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
205        << "Quantized inference is not allowed with float inputs.";
206  }
207
208  // Remove unused ops before performing any other optimizations. This is to
209  // stop optimizations from crossing the input/output boundaries. For example
210  // this will stop BatchNorm fusing if the output node is in between a conv
211  // and BatchNorm layers.
212  RunGraphTransformations(model, "Removing unused ops",
213                          {new toco::RemoveUnusedOp});
214
215  GraphTransformationsSet transformations;
216  MakeGeneralGraphTransformationsSet(&transformations);
217  auto* remove_trivial_reshape = new RemoveTrivialReshape;
218  transformations.Add(remove_trivial_reshape);
219  if (SupportsFusedActivationFunction(output_format)) {
220    transformations.Add(new FuseActivationFunctions);
221  } else {
222    transformations.Add(new UnfuseActivationFunctions);
223  }
224  if (toco_flags.drop_fake_quant()) {
225    transformations.Add(new DropFakeQuant);
226  } else {
227    // See the doc for --reorder_across_fake_quant: that flag is needed to
228    // support some existing models, e.g. WordLens, that have FakeQuant
229    // nodes in the wrong places.
230    // TODO(benoitjacob): drop special casing when we can.
231    if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
232      transformations.Add(new DropFakeQuant);
233    }
234  }
235  transformations.Add(new ConvertPureConvToDepthwise);
236  if (SupportsLstmCell(output_format)) {
237    transformations.Add(new IdentifyLstmCell);
238    if (output_format == TFLITE) {
239      transformations.Add(new toco::SplitLstmCellInputs);
240    } else {
241      transformations.Add(new toco::MergeLstmCellInputs);
242    }
243  }
244  transformations.Add(new ResolveConstantConcatenation);
245  RunGraphTransformations(model, "general graph transformations",
246                          transformations);
247
248  if (quantize_output) {
249    RunGraphTransformations(model, "pre-quantization graph transformations",
250                            {new HardcodeMinMax, new DropFakeQuant});
251  }
252
253  if (quantize_output) {
254    if (toco_flags.has_default_ranges_min() &&
255        toco_flags.has_default_ranges_max()) {
256      UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
257                                  toco_flags.default_ranges_max());
258      // The new MinMax info may need to be propagated a bit.
259      RunGraphTransformations(
260          model, "default min-max range propagation graph transformations",
261          {new HardcodeMinMax});
262    }
263    CheckIsReadyForQuantization(*model);
264    RunGraphTransformations(
265        model, "quantization graph transformations",
266        {new Quantize, new RemoveTrivialQuantizedActivationFunc,
267         new RemoveFinalDequantizeOp});
268  } else {
269    GraphTransformationsSet dequantization_transformations{new Dequantize};
270    // Dequantize creates FakeQuant nodes. We may want to discard
271    // those immediately.
272    if (toco_flags.drop_fake_quant()) {
273      dequantization_transformations.Add(new DropFakeQuant);
274    }
275
276    RunGraphTransformations(model, "dequantization graph transformations",
277                            dequantization_transformations);
278  }
279
280  if (output_format == TENSORFLOW_GRAPHDEF) {
281    EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
282  }
283
284  LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
285
286  if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
287    // By now there shouldn't be any unsupported ops when exporting to
288    // TensorFlow GraphDef.
289    CheckUnsupportedOperations(*model);
290  }
291
292  if (SupportsPreallocatedWorkspace(output_format)) {
293    AllocateTransientArrays(model, kDefaultTransientDataAlignment);
294    LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
295  }
296
297  CheckModelCounts(*model);
298  CheckFinalDataTypesSatisfied(*model);
299
300  int64 ops_count;
301  if (EstimateArithmeticOpsCount(*model, &ops_count)) {
302    LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
303              << " billion (note that a multiply-add is counted as 2 ops).";
304  }
305}
306
307void Export(const TocoFlags& toco_flags, const Model& model,
308            bool allow_custom_ops, string* output_file_contents) {
309  switch (toco_flags.output_format()) {
310    case TENSORFLOW_GRAPHDEF:
311      ExportTensorFlowGraphDef(model, output_file_contents);
312      break;
313    case TFLITE:
314      toco::tflite::Export(model, allow_custom_ops, output_file_contents);
315      break;
316    case GRAPHVIZ_DOT:
317      DumpGraphviz(model, output_file_contents);
318      break;
319    default:
320      LOG(FATAL) << "Unhandled output_format";
321  }
322}
323
324}  // namespace toco
325