1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/contrib/lite/toco/toco_tooling.h" 16 17#include <cstdlib> 18#include <memory> 19#include <set> 20 21#include "absl/strings/str_join.h" 22#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" 23#include "tensorflow/contrib/lite/toco/dump_graphviz.h" 24#include "tensorflow/contrib/lite/toco/export_tensorflow.h" 25#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 26#include "tensorflow/contrib/lite/toco/import_tensorflow.h" 27#include "tensorflow/contrib/lite/toco/model_flags.pb.h" 28#include "tensorflow/contrib/lite/toco/tflite/export.h" 29#include "tensorflow/contrib/lite/toco/tflite/import.h" 30#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" 31#include "tensorflow/contrib/lite/toco/tooling_util.h" 32#include "tensorflow/core/platform/logging.h" 33 34namespace toco { 35namespace { 36// CHECK-fails if the model contains a kTensorFlowUnsupported operation. 37void CheckUnsupportedOperations(const Model& model) { 38 std::set<string> unsupported_ops; 39 for (auto& op : model.operators) { 40 if (op->type == OperatorType::kTensorFlowUnsupported) { 41 unsupported_ops.insert( 42 static_cast<const TensorFlowUnsupportedOperator*>(op.get()) 43 ->tensorflow_op); 44 } 45 } 46 QCHECK(unsupported_ops.empty()) 47 << "These unsupported ops were not removed by graph transformations: " 48 << absl::StrJoin(unsupported_ops, ", "); 49} 50 51void MakeGeneralGraphTransformationsSet( 52 GraphTransformationsSet* transformations) { 53 CHECK(transformations->empty()); 54 transformations->Add(new ConvertExpandDimsToReshape); 55 transformations->Add(new ConvertTrivialAddNToAdd); 56 transformations->Add(new ConvertTrivialStackToReshape); 57 transformations->Add(new ConvertTrivialTransposeToReshape); 58 transformations->Add(new ConvertReorderAxes); 59 transformations->Add(new ResolveReshapeAttributes); 60 transformations->Add(new ResolveTransposeAttributes); 61 transformations->Add(new PropagateArrayDataTypes); 62 transformations->Add(new PropagateFixedSizes); 63 transformations->Add(new RemoveTensorFlowAssert); 64 transformations->Add(new RemoveTensorFlowIdentity); 65 transformations->Add(new RemoveTrivialConcatenation); 66 transformations->Add(new RemoveTrivialConcatenationInput); 67 transformations->Add(new RemoveTrivialSlice); 68 transformations->Add(new RemoveUnusedOp); 69 transformations->Add(new EnsureBiasVectors); 70 transformations->Add(new ResolveReorderAxes); 71 transformations->Add(new UnrollBatchMatMul); 72 transformations->Add(new ResolveTensorFlowMatMul); 73 transformations->Add(new FuseBinaryIntoPrecedingAffine); 74 transformations->Add(new FuseBinaryIntoFollowingAffine); 75 transformations->Add(new ReorderActivationFunctions); 76 transformations->Add(new ResolveBatchNormalization); 77 transformations->Add(new ResolveConstantBinaryOperator); 78 transformations->Add(new ResolveConstantFill); 79 transformations->Add(new ResolveConstantRange); 80 transformations->Add(new ResolveConstantStack); 81 transformations->Add(new ResolveConstantStridedSlice); 82 transformations->Add(new ResolveConstantTranspose); 83 transformations->Add(new ResolveConstantUnaryOperator); 84 transformations->Add(new ResolveTensorFlowMerge); 85 transformations->Add(new ResolveSqueezeAttributes); 86 transformations->Add(new ResolveTensorFlowSwitch); 87 transformations->Add(new ResolveTensorFlowTile); 88 transformations->Add(new ResolveTensorFlowConcat); 89 transformations->Add(new ResolveMultiplyByZero); 90 transformations->Add(new IdentifyL2Normalization); 91 transformations->Add(new IdentifyL2Pool); 92 transformations->Add(new IdentifyRelu1); 93 transformations->Add(new RemoveTrivialBinaryOperator); 94 transformations->Add(new ReadFakeQuantMinMax); 95 transformations->Add(new ResolveSpaceToBatchNDAttributes); 96 transformations->Add(new ResolveBatchToSpaceNDAttributes); 97 transformations->Add(new ResolvePadAttributes); 98 transformations->Add(new ResolveStridedSliceAttributes); 99 transformations->Add(new ResolveSliceAttributes); 100 transformations->Add(new ResolveMeanAttributes); 101 transformations->Add(new ResolveConstantShapeOrRank); 102 transformations->Add(new MakeInitialDequantizeOperator); 103 transformations->Add(new ResolveConstantFakeQuant); 104} 105 106bool SupportsQuantization(FileFormat format) { 107 return (format == GRAPHVIZ_DOT || format == TFLITE); 108} 109 110bool SupportsFusedActivationFunction(FileFormat format) { 111 return (format == GRAPHVIZ_DOT || format == TFLITE); 112} 113 114bool SupportsLstmCell(FileFormat format) { 115 return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT || 116 format == TFLITE); 117} 118 119bool SupportsPreallocatedWorkspace(FileFormat format) { 120 return (format == TFLITE); 121} 122 123bool IsRealValued(toco::ArrayDataType type) { 124 return static_cast<bool>(type == toco::ArrayDataType::kFloat || 125 type == toco::ArrayDataType::kUint8); 126} 127 128void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) { 129 const FileFormat output_format = toco_flags.output_format(); 130 ArrayDataType type; 131 if (toco_flags.has_inference_input_type()) { 132 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type()); 133 } else if (toco_flags.has_inference_type()) { 134 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type()); 135 } else if (!SupportsQuantization(output_format)) { 136 // Data type is implicitly float for non-quantized formats 137 type = ArrayDataType::kFloat; 138 } else { 139 // Nothing to do. Data types stay as-is. 140 return; 141 } 142 143 for (int i = 0; i < model->flags.input_arrays_size(); i++) { 144 string const& array_name = model->flags.input_arrays(i).name(); 145 auto* array = &model->GetArray(array_name); 146 // Note that the notion of changing data types only applies to real-numbers 147 // arrays (see the documentation for inference_input_type). 148 // TODO(benoitjacob) this is assuming that uint8 arrays are quantized, 149 // i.e. represent real numbers by means of quantization parameters, 150 // and not plain integer uint8 input arrays. 151 if (!IsRealValued(array->data_type)) { 152 // Ignore non-real data types. 153 continue; 154 } 155 156 array->final_data_type = type; 157 } 158} 159 160} // namespace 161 162std::unique_ptr<Model> Import(const TocoFlags& toco_flags, 163 const ModelFlags& model_flags, 164 const string& input_file_contents) { 165 std::unique_ptr<Model> model; 166 switch (toco_flags.input_format()) { 167 case TENSORFLOW_GRAPHDEF: { 168 TensorFlowImportFlags tf_import_flags; 169 tf_import_flags.drop_control_dependency = 170 toco_flags.has_drop_control_dependency() 171 ? toco_flags.drop_control_dependency() 172 : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF); 173 model = ImportTensorFlowGraphDef(model_flags, tf_import_flags, 174 input_file_contents); 175 break; 176 } 177 case TFLITE: 178 model = toco::tflite::Import(model_flags, input_file_contents); 179 ResolveModelFlags(model_flags, model.get()); 180 CheckInvariants(*model); 181 break; 182 default: 183 LOG(FATAL) << "Unhandled input_format"; 184 } 185 186 LogDump(kLogLevelModelChanged, "AT IMPORT", *model); 187 188 return model; 189} 190 191void Transform(const TocoFlags& toco_flags, Model* model) { 192 // Clean up after import. 193 SetFinalDataTypeOnInputs(toco_flags, model); 194 UseArraysExtraInfo(model); 195 FinishBuildingRNNStates(model); 196 197 const FileFormat output_format = toco_flags.output_format(); 198 const IODataType inference_type = toco_flags.inference_type(); 199 200 const bool quantize_output = 201 SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8; 202 203 if (quantize_output) { 204 QCHECK_NE(toco_flags.inference_input_type(), FLOAT) 205 << "Quantized inference is not allowed with float inputs."; 206 } 207 208 // Remove unused ops before performing any other optimizations. This is to 209 // stop optimizations from crossing the input/output boundaries. For example 210 // this will stop BatchNorm fusing if the output node is in between a conv 211 // and BatchNorm layers. 212 RunGraphTransformations(model, "Removing unused ops", 213 {new toco::RemoveUnusedOp}); 214 215 GraphTransformationsSet transformations; 216 MakeGeneralGraphTransformationsSet(&transformations); 217 auto* remove_trivial_reshape = new RemoveTrivialReshape; 218 transformations.Add(remove_trivial_reshape); 219 if (SupportsFusedActivationFunction(output_format)) { 220 transformations.Add(new FuseActivationFunctions); 221 } else { 222 transformations.Add(new UnfuseActivationFunctions); 223 } 224 if (toco_flags.drop_fake_quant()) { 225 transformations.Add(new DropFakeQuant); 226 } else { 227 // See the doc for --reorder_across_fake_quant: that flag is needed to 228 // support some existing models, e.g. WordLens, that have FakeQuant 229 // nodes in the wrong places. 230 // TODO(benoitjacob): drop special casing when we can. 231 if ((quantize_output && toco_flags.reorder_across_fake_quant())) { 232 transformations.Add(new DropFakeQuant); 233 } 234 } 235 transformations.Add(new ConvertPureConvToDepthwise); 236 if (SupportsLstmCell(output_format)) { 237 transformations.Add(new IdentifyLstmCell); 238 if (output_format == TFLITE) { 239 transformations.Add(new toco::SplitLstmCellInputs); 240 } else { 241 transformations.Add(new toco::MergeLstmCellInputs); 242 } 243 } 244 transformations.Add(new ResolveConstantConcatenation); 245 RunGraphTransformations(model, "general graph transformations", 246 transformations); 247 248 if (quantize_output) { 249 RunGraphTransformations(model, "pre-quantization graph transformations", 250 {new HardcodeMinMax, new DropFakeQuant}); 251 } 252 253 if (quantize_output) { 254 if (toco_flags.has_default_ranges_min() && 255 toco_flags.has_default_ranges_max()) { 256 UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(), 257 toco_flags.default_ranges_max()); 258 // The new MinMax info may need to be propagated a bit. 259 RunGraphTransformations( 260 model, "default min-max range propagation graph transformations", 261 {new HardcodeMinMax}); 262 } 263 CheckIsReadyForQuantization(*model); 264 RunGraphTransformations( 265 model, "quantization graph transformations", 266 {new Quantize, new RemoveTrivialQuantizedActivationFunc, 267 new RemoveFinalDequantizeOp}); 268 } else { 269 GraphTransformationsSet dequantization_transformations{new Dequantize}; 270 // Dequantize creates FakeQuant nodes. We may want to discard 271 // those immediately. 272 if (toco_flags.drop_fake_quant()) { 273 dequantization_transformations.Add(new DropFakeQuant); 274 } 275 276 RunGraphTransformations(model, "dequantization graph transformations", 277 dequantization_transformations); 278 } 279 280 if (output_format == TENSORFLOW_GRAPHDEF) { 281 EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); 282 } 283 284 LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); 285 286 if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { 287 // By now there shouldn't be any unsupported ops when exporting to 288 // TensorFlow GraphDef. 289 CheckUnsupportedOperations(*model); 290 } 291 292 if (SupportsPreallocatedWorkspace(output_format)) { 293 AllocateTransientArrays(model, kDefaultTransientDataAlignment); 294 LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model); 295 } 296 297 CheckModelCounts(*model); 298 CheckFinalDataTypesSatisfied(*model); 299 300 int64 ops_count; 301 if (EstimateArithmeticOpsCount(*model, &ops_count)) { 302 LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count 303 << " billion (note that a multiply-add is counted as 2 ops)."; 304 } 305} 306 307void Export(const TocoFlags& toco_flags, const Model& model, 308 bool allow_custom_ops, string* output_file_contents) { 309 switch (toco_flags.output_format()) { 310 case TENSORFLOW_GRAPHDEF: 311 ExportTensorFlowGraphDef(model, output_file_contents); 312 break; 313 case TFLITE: 314 toco::tflite::Export(model, allow_custom_ops, output_file_contents); 315 break; 316 case GRAPHVIZ_DOT: 317 DumpGraphviz(model, output_file_contents); 318 break; 319 default: 320 LOG(FATAL) << "Unhandled output_format"; 321 } 322} 323 324} // namespace toco 325