1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <string> 17#include <vector> 18 19#include "absl/strings/numbers.h" 20#include "absl/strings/str_join.h" 21#include "absl/strings/str_split.h" 22#include "absl/strings/strip.h" 23#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" 24#include "tensorflow/contrib/lite/toco/toco_port.h" 25#include "tensorflow/core/platform/logging.h" 26#include "tensorflow/core/util/command_line_flags.h" 27 28namespace toco { 29 30bool ParseTocoFlagsFromCommandLineFlags( 31 int* argc, char* argv[], string* msg, 32 ParsedTocoFlags* parsed_toco_flags_ptr) { 33 using tensorflow::Flag; 34 ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr; 35 std::vector<tensorflow::Flag> flags = { 36 Flag("input_file", parsed_flags.input_file.bind(), 37 parsed_flags.input_file.default_value(), 38 "Input file (model of any supported format). For Protobuf " 39 "formats, both text and binary are supported regardless of file " 40 "extension."), 41 Flag("output_file", parsed_flags.output_file.bind(), 42 parsed_flags.output_file.default_value(), 43 "Output file. " 44 "For Protobuf formats, the binary format will be used."), 45 Flag("input_format", parsed_flags.input_format.bind(), 46 parsed_flags.input_format.default_value(), 47 "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."), 48 Flag("output_format", parsed_flags.output_format.bind(), 49 parsed_flags.output_format.default_value(), 50 "Output file format. " 51 "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), 52 Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), 53 parsed_flags.default_ranges_min.default_value(), 54 "If defined, will be used as the default value for the min bound " 55 "of min/max ranges used for quantization."), 56 Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(), 57 parsed_flags.default_ranges_max.default_value(), 58 "If defined, will be used as the default value for the max bound " 59 "of min/max ranges used for quantization."), 60 Flag("inference_type", parsed_flags.inference_type.bind(), 61 parsed_flags.inference_type.default_value(), 62 "Target data type of arrays in the output file (for input_arrays, " 63 "this may be overridden by inference_input_type). " 64 "One of FLOAT, QUANTIZED_UINT8."), 65 Flag("inference_input_type", parsed_flags.inference_input_type.bind(), 66 parsed_flags.inference_input_type.default_value(), 67 "Target data type of input arrays. " 68 "If not specified, inference_type is used. " 69 "One of FLOAT, QUANTIZED_UINT8."), 70 Flag("input_type", parsed_flags.input_type.bind(), 71 parsed_flags.input_type.default_value(), 72 "Deprecated ambiguous flag that set both --input_data_types and " 73 "--inference_input_type."), 74 Flag("input_types", parsed_flags.input_types.bind(), 75 parsed_flags.input_types.default_value(), 76 "Deprecated ambiguous flag that set both --input_data_types and " 77 "--inference_input_type. Was meant to be a " 78 "comma-separated list, but this was deprecated before " 79 "multiple-input-types was ever properly supported."), 80 81 Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(), 82 parsed_flags.drop_fake_quant.default_value(), 83 "Ignore and discard FakeQuant nodes. For instance, to " 84 "generate plain float code without fake-quantization from a " 85 "quantized graph."), 86 Flag( 87 "reorder_across_fake_quant", 88 parsed_flags.reorder_across_fake_quant.bind(), 89 parsed_flags.reorder_across_fake_quant.default_value(), 90 "Normally, FakeQuant nodes must be strict boundaries for graph " 91 "transformations, in order to ensure that quantized inference has " 92 "the exact same arithmetic behavior as quantized training --- which " 93 "is the whole point of quantized training and of FakeQuant nodes in " 94 "the first place. " 95 "However, that entails subtle requirements on where exactly " 96 "FakeQuant nodes must be placed in the graph. Some quantized graphs " 97 "have FakeQuant nodes at unexpected locations, that prevent graph " 98 "transformations that are necessary in order to generate inference " 99 "code for these graphs. Such graphs should be fixed, but as a " 100 "temporary work-around, setting this reorder_across_fake_quant flag " 101 "allows TOCO to perform necessary graph transformaitons on them, " 102 "at the cost of no longer faithfully matching inference and training " 103 "arithmetic."), 104 Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(), 105 parsed_flags.allow_custom_ops.default_value(), 106 "If true, allow TOCO to create TF Lite Custom operators for all the " 107 "unsupported TensorFlow ops."), 108 Flag( 109 "drop_control_dependency", 110 parsed_flags.drop_control_dependency.bind(), 111 parsed_flags.drop_control_dependency.default_value(), 112 "If true, ignore control dependency requirements in input TensorFlow " 113 "GraphDef. Otherwise an error will be raised upon control dependency " 114 "inputs."), 115 }; 116 bool asked_for_help = 117 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); 118 if (asked_for_help) { 119 *msg += tensorflow::Flags::Usage(argv[0], flags); 120 return false; 121 } else { 122 return tensorflow::Flags::Parse(argc, argv, flags); 123 } 124} 125 126void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, 127 TocoFlags* toco_flags) { 128 namespace port = toco::port; 129 port::CheckInitGoogleIsDone("InitGoogle is not done yet"); 130 131 enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified }; 132 133#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \ 134 do { \ 135 if (requirement == FlagRequirement::kMustBeSpecified) { \ 136 QCHECK(parsed_toco_flags.name.specified()) \ 137 << "Missing required flag: " << #name; \ 138 } \ 139 if (requirement == FlagRequirement::kMustNotBeSpecified) { \ 140 QCHECK(!parsed_toco_flags.name.specified()) \ 141 << "Given other flags, this flag should not have been specified: " \ 142 << #name; \ 143 } \ 144 } while (false) 145#define READ_TOCO_FLAG(name, requirement) \ 146 ENFORCE_FLAG_REQUIREMENT(name, requirement); \ 147 do { \ 148 if (parsed_toco_flags.name.specified()) { \ 149 toco_flags->set_##name(parsed_toco_flags.name.value()); \ 150 } \ 151 } while (false) 152 153#define PARSE_TOCO_FLAG(Type, name, requirement) \ 154 ENFORCE_FLAG_REQUIREMENT(name, requirement); \ 155 do { \ 156 if (parsed_toco_flags.name.specified()) { \ 157 Type x; \ 158 QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \ 159 << "Unrecognized " << #Type << " value " \ 160 << parsed_toco_flags.name.value(); \ 161 toco_flags->set_##name(x); \ 162 } \ 163 } while (false) 164 165 PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified); 166 PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified); 167 PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone); 168 PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone); 169 READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); 170 READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone); 171 READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone); 172 READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); 173 READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); 174 READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone); 175 176 // Deprecated flag handling. 177 if (parsed_toco_flags.input_type.specified()) { 178 LOG(WARNING) 179 << "--input_type is deprecated. It was an ambiguous flag that set both " 180 "--input_data_types and --inference_input_type. If you are trying " 181 "to complement the input file with information about the type of " 182 "input arrays, use --input_data_type. If you are trying to control " 183 "the quantization/dequantization of real-numbers input arrays in " 184 "the output file, use --inference_input_type."; 185 toco::IODataType input_type; 186 QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(), 187 &input_type)); 188 toco_flags->set_inference_input_type(input_type); 189 } 190 if (parsed_toco_flags.input_types.specified()) { 191 LOG(WARNING) 192 << "--input_types is deprecated. It was an ambiguous flag that set " 193 "both --input_data_types and --inference_input_type. If you are " 194 "trying to complement the input file with information about the " 195 "type of input arrays, use --input_data_type. If you are trying to " 196 "control the quantization/dequantization of real-numbers input " 197 "arrays in the output file, use --inference_input_type."; 198 std::vector<string> input_types = 199 absl::StrSplit(parsed_toco_flags.input_types.value(), ','); 200 QCHECK(!input_types.empty()); 201 for (int i = 1; i < input_types.size(); i++) { 202 QCHECK_EQ(input_types[i], input_types[0]); 203 } 204 toco::IODataType input_type; 205 QCHECK(toco::IODataType_Parse(input_types[0], &input_type)); 206 toco_flags->set_inference_input_type(input_type); 207 } 208 209#undef READ_TOCO_FLAG 210#undef PARSE_TOCO_FLAG 211} 212} // namespace toco 213