1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" 16 17#include <string> 18#include <vector> 19 20#include "absl/strings/numbers.h" 21#include "absl/strings/str_join.h" 22#include "absl/strings/str_split.h" 23#include "absl/strings/string_view.h" 24#include "absl/strings/strip.h" 25#include "tensorflow/contrib/lite/toco/args.h" 26#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" 27#include "tensorflow/contrib/lite/toco/toco_port.h" 28#include "tensorflow/core/platform/logging.h" 29#include "tensorflow/core/util/command_line_flags.h" 30 31// "batch" flag only exists internally 32#ifdef PLATFORM_GOOGLE 33#include "base/commandlineflags.h" 34#endif 35 36namespace toco { 37 38bool ParseModelFlagsFromCommandLineFlags( 39 int* argc, char* argv[], string* msg, 40 ParsedModelFlags* parsed_model_flags_ptr) { 41 ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr; 42 using tensorflow::Flag; 43 std::vector<tensorflow::Flag> flags = { 44 Flag("input_array", parsed_flags.input_array.bind(), 45 parsed_flags.input_array.default_value(), 46 "Deprecated: use --input_arrays instead. Name of the input array. " 47 "If not specified, will try to read " 48 "that information from the input file."), 49 Flag("input_arrays", parsed_flags.input_arrays.bind(), 50 parsed_flags.input_arrays.default_value(), 51 "Names of the output arrays, comma-separated. If not specified, " 52 "will try to read that information from the input file."), 53 Flag("output_array", parsed_flags.output_array.bind(), 54 parsed_flags.output_array.default_value(), 55 "Deprecated: use --output_arrays instead. Name of the output array, " 56 "when specifying a unique output array. " 57 "If not specified, will try to read that information from the " 58 "input file."), 59 Flag("output_arrays", parsed_flags.output_arrays.bind(), 60 parsed_flags.output_arrays.default_value(), 61 "Names of the output arrays, comma-separated. " 62 "If not specified, will try to read " 63 "that information from the input file."), 64 Flag("input_shape", parsed_flags.input_shape.bind(), 65 parsed_flags.input_shape.default_value(), 66 "Deprecated: use --input_shapes instead. Input array shape. For " 67 "many models the shape takes the form " 68 "batch size, input array height, input array width, input array " 69 "depth."), 70 Flag("input_shapes", parsed_flags.input_shapes.bind(), 71 parsed_flags.input_shapes.default_value(), 72 "Shapes corresponding to --input_arrays, colon-separated. For " 73 "many models each shape takes the form batch size, input array " 74 "height, input array width, input array depth."), 75 Flag("input_data_type", parsed_flags.input_data_type.bind(), 76 parsed_flags.input_data_type.default_value(), 77 "Deprecated: use --input_data_types instead. Input array type, if " 78 "not already provided in the graph. " 79 "Typically needs to be specified when passing arbitrary arrays " 80 "to --input_array."), 81 Flag("input_data_types", parsed_flags.input_data_types.bind(), 82 parsed_flags.input_data_types.default_value(), 83 "Input arrays types, comma-separated, if not already provided in " 84 "the graph. " 85 "Typically needs to be specified when passing arbitrary arrays " 86 "to --input_arrays."), 87 Flag("mean_value", parsed_flags.mean_value.bind(), 88 parsed_flags.mean_value.default_value(), 89 "Deprecated: use --mean_values instead. mean_value parameter for " 90 "image models, used to compute input " 91 "activations from input pixel data."), 92 Flag("mean_values", parsed_flags.mean_values.bind(), 93 parsed_flags.mean_values.default_value(), 94 "mean_values parameter for image models, comma-separated list of " 95 "doubles, used to compute input activations from input pixel " 96 "data. Each entry in the list should match an entry in " 97 "--input_arrays."), 98 Flag("std_value", parsed_flags.std_value.bind(), 99 parsed_flags.std_value.default_value(), 100 "Deprecated: use --std_values instead. std_value parameter for " 101 "image models, used to compute input " 102 "activations from input pixel data."), 103 Flag("std_values", parsed_flags.std_values.bind(), 104 parsed_flags.std_values.default_value(), 105 "std_value parameter for image models, comma-separated list of " 106 "doubles, used to compute input activations from input pixel " 107 "data. Each entry in the list should match an entry in " 108 "--input_arrays."), 109 Flag("variable_batch", parsed_flags.variable_batch.bind(), 110 parsed_flags.variable_batch.default_value(), 111 "If true, the model accepts an arbitrary batch size. Mutually " 112 "exclusive " 113 "with the 'batch' field: at most one of these two fields can be " 114 "set."), 115 Flag("rnn_states", parsed_flags.rnn_states.bind(), 116 parsed_flags.rnn_states.default_value(), ""), 117 Flag("model_checks", parsed_flags.model_checks.bind(), 118 parsed_flags.model_checks.default_value(), 119 "A list of model checks to be applied to verify the form of the " 120 "model. Applied after the graph transformations after import."), 121 Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(), 122 parsed_flags.graphviz_first_array.default_value(), 123 "If set, defines the start of the sub-graph to be dumped to " 124 "GraphViz."), 125 Flag( 126 "graphviz_last_array", parsed_flags.graphviz_last_array.bind(), 127 parsed_flags.graphviz_last_array.default_value(), 128 "If set, defines the end of the sub-graph to be dumped to GraphViz."), 129 Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(), 130 parsed_flags.dump_graphviz.default_value(), 131 "Dump graphviz during LogDump call. If string is non-empty then " 132 "it defines path to dump, otherwise will skip dumping."), 133 Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(), 134 parsed_flags.dump_graphviz_video.default_value(), 135 "If true, will dump graphviz at each " 136 "graph transformation, which may be used to generate a video."), 137 Flag("allow_nonexistent_arrays", 138 parsed_flags.allow_nonexistent_arrays.bind(), 139 parsed_flags.allow_nonexistent_arrays.default_value(), 140 "If true, will allow passing inexistent arrays in --input_arrays " 141 "and --output_arrays. This makes little sense, is only useful to " 142 "more easily get graph visualizations."), 143 Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(), 144 parsed_flags.allow_nonascii_arrays.default_value(), 145 "If true, will allow passing non-ascii-printable characters in " 146 "--input_arrays and --output_arrays. By default (if false), only " 147 "ascii printable characters are allowed, i.e. character codes " 148 "ranging from 32 to 127. This is disallowed by default so as to " 149 "catch common copy-and-paste issues where invisible unicode " 150 "characters are unwittingly added to these strings."), 151 Flag( 152 "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(), 153 parsed_flags.arrays_extra_info_file.default_value(), 154 "Path to an optional file containing a serialized ArraysExtraInfo " 155 "proto allowing to pass extra information about arrays not specified " 156 "in the input model file, such as extra MinMax information."), 157 }; 158 bool asked_for_help = 159 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); 160 if (asked_for_help) { 161 *msg += tensorflow::Flags::Usage(argv[0], flags); 162 return false; 163 } else { 164 if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; 165 } 166 auto& dump_options = *GraphVizDumpOptions::singleton(); 167 dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value(); 168 dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value(); 169 dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); 170 dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); 171 172 return true; 173} 174 175void ReadModelFlagsFromCommandLineFlags( 176 const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { 177 toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet"); 178 179// "batch" flag only exists internally 180#ifdef PLATFORM_GOOGLE 181 CHECK(!((base::SpecifiedOnCommandLine("batch") && 182 parsed_model_flags.variable_batch.specified()))) 183 << "The --batch and --variable_batch flags are mutually exclusive."; 184#endif 185 CHECK(!(parsed_model_flags.output_array.specified() && 186 parsed_model_flags.output_arrays.specified())) 187 << "The --output_array and --vs flags are mutually exclusive."; 188 189 if (parsed_model_flags.output_array.specified()) { 190 model_flags->add_output_arrays(parsed_model_flags.output_array.value()); 191 } 192 193 if (parsed_model_flags.output_arrays.specified()) { 194 std::vector<string> output_arrays = 195 absl::StrSplit(parsed_model_flags.output_arrays.value(), ','); 196 for (const string& output_array : output_arrays) { 197 model_flags->add_output_arrays(output_array); 198 } 199 } 200 201 const bool uses_single_input_flags = 202 parsed_model_flags.input_array.specified() || 203 parsed_model_flags.mean_value.specified() || 204 parsed_model_flags.std_value.specified() || 205 parsed_model_flags.input_shape.specified(); 206 207 const bool uses_multi_input_flags = 208 parsed_model_flags.input_arrays.specified() || 209 parsed_model_flags.mean_values.specified() || 210 parsed_model_flags.std_values.specified() || 211 parsed_model_flags.input_shapes.specified(); 212 213 QCHECK(!(uses_single_input_flags && uses_multi_input_flags)) 214 << "Use either the singular-form input flags (--input_array, " 215 "--input_shape, --mean_value, --std_value) or the plural form input " 216 "flags (--input_arrays, --input_shapes, --mean_values, --std_values), " 217 "but not both forms within the same command line."; 218 219 if (parsed_model_flags.input_array.specified()) { 220 QCHECK(uses_single_input_flags); 221 model_flags->add_input_arrays()->set_name( 222 parsed_model_flags.input_array.value()); 223 } 224 if (parsed_model_flags.input_arrays.specified()) { 225 QCHECK(uses_multi_input_flags); 226 for (const auto& input_array : 227 absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) { 228 model_flags->add_input_arrays()->set_name(string(input_array)); 229 } 230 } 231 if (parsed_model_flags.mean_value.specified()) { 232 QCHECK(uses_single_input_flags); 233 model_flags->mutable_input_arrays(0)->set_mean_value( 234 parsed_model_flags.mean_value.value()); 235 } 236 if (parsed_model_flags.mean_values.specified()) { 237 QCHECK(uses_multi_input_flags); 238 std::vector<string> mean_values = 239 absl::StrSplit(parsed_model_flags.mean_values.value(), ','); 240 QCHECK(mean_values.size() == model_flags->input_arrays_size()); 241 for (int i = 0; i < mean_values.size(); ++i) { 242 char* last = nullptr; 243 model_flags->mutable_input_arrays(i)->set_mean_value( 244 strtod(mean_values[i].data(), &last)); 245 CHECK(last != mean_values[i].data()); 246 } 247 } 248 if (parsed_model_flags.std_value.specified()) { 249 QCHECK(uses_single_input_flags); 250 model_flags->mutable_input_arrays(0)->set_std_value( 251 parsed_model_flags.std_value.value()); 252 } 253 if (parsed_model_flags.std_values.specified()) { 254 QCHECK(uses_multi_input_flags); 255 std::vector<string> std_values = 256 absl::StrSplit(parsed_model_flags.std_values.value(), ','); 257 QCHECK(std_values.size() == model_flags->input_arrays_size()); 258 for (int i = 0; i < std_values.size(); ++i) { 259 char* last = nullptr; 260 model_flags->mutable_input_arrays(i)->set_std_value( 261 strtod(std_values[i].data(), &last)); 262 CHECK(last != std_values[i].data()); 263 } 264 } 265 if (parsed_model_flags.input_data_type.specified()) { 266 QCHECK(uses_single_input_flags); 267 IODataType type; 268 QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type)); 269 model_flags->mutable_input_arrays(0)->set_data_type(type); 270 } 271 if (parsed_model_flags.input_data_types.specified()) { 272 QCHECK(uses_multi_input_flags); 273 std::vector<string> input_data_types = 274 absl::StrSplit(parsed_model_flags.input_data_types.value(), ','); 275 QCHECK(input_data_types.size() == model_flags->input_arrays_size()); 276 for (int i = 0; i < input_data_types.size(); ++i) { 277 IODataType type; 278 QCHECK(IODataType_Parse(input_data_types[i], &type)); 279 model_flags->mutable_input_arrays(i)->set_data_type(type); 280 } 281 } 282 if (parsed_model_flags.input_shape.specified()) { 283 QCHECK(uses_single_input_flags); 284 if (model_flags->input_arrays().empty()) { 285 model_flags->add_input_arrays(); 286 } 287 auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape(); 288 shape->clear_dims(); 289 const IntList& list = parsed_model_flags.input_shape.value(); 290 for (auto& dim : list.elements) { 291 shape->add_dims(dim); 292 } 293 } 294 if (parsed_model_flags.input_shapes.specified()) { 295 QCHECK(uses_multi_input_flags); 296 std::vector<string> input_shapes = 297 absl::StrSplit(parsed_model_flags.input_shapes.value(), ':'); 298 QCHECK(input_shapes.size() == model_flags->input_arrays_size()); 299 for (int i = 0; i < input_shapes.size(); ++i) { 300 auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); 301 shape->clear_dims(); 302 for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { 303 int size; 304 CHECK(absl::SimpleAtoi(dim_str, &size)) 305 << "Failed to parse input_shape: " << input_shapes[i]; 306 shape->add_dims(size); 307 } 308 } 309 } 310 311#define READ_MODEL_FLAG(name) \ 312 do { \ 313 if (parsed_model_flags.name.specified()) { \ 314 model_flags->set_##name(parsed_model_flags.name.value()); \ 315 } \ 316 } while (false) 317 318 READ_MODEL_FLAG(variable_batch); 319 320#undef READ_MODEL_FLAG 321 322 for (const auto& element : parsed_model_flags.rnn_states.value().elements) { 323 auto* rnn_state_proto = model_flags->add_rnn_states(); 324 for (const auto& kv_pair : element) { 325 const string& key = kv_pair.first; 326 const string& value = kv_pair.second; 327 if (key == "state_array") { 328 rnn_state_proto->set_state_array(value); 329 } else if (key == "back_edge_source_array") { 330 rnn_state_proto->set_back_edge_source_array(value); 331 } else if (key == "size") { 332 int32 size = 0; 333 CHECK(absl::SimpleAtoi(value, &size)); 334 CHECK_GT(size, 0); 335 rnn_state_proto->set_size(size); 336 } else { 337 LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states"; 338 } 339 } 340 CHECK(rnn_state_proto->has_state_array() && 341 rnn_state_proto->has_back_edge_source_array() && 342 rnn_state_proto->has_size()) 343 << "--rnn_states must include state_array, back_edge_source_array and " 344 "size."; 345 } 346 347 for (const auto& element : parsed_model_flags.model_checks.value().elements) { 348 auto* model_check_proto = model_flags->add_model_checks(); 349 for (const auto& kv_pair : element) { 350 const string& key = kv_pair.first; 351 const string& value = kv_pair.second; 352 if (key == "count_type") { 353 model_check_proto->set_count_type(value); 354 } else if (key == "count_min") { 355 int32 count = 0; 356 CHECK(absl::SimpleAtoi(value, &count)); 357 CHECK_GE(count, -1); 358 model_check_proto->set_count_min(count); 359 } else if (key == "count_max") { 360 int32 count = 0; 361 CHECK(absl::SimpleAtoi(value, &count)); 362 CHECK_GE(count, -1); 363 model_check_proto->set_count_max(count); 364 } else { 365 LOG(FATAL) << "Unknown key '" << key << "' in --model_checks"; 366 } 367 } 368 } 369 370 model_flags->set_allow_nonascii_arrays( 371 parsed_model_flags.allow_nonascii_arrays.value()); 372 model_flags->set_allow_nonexistent_arrays( 373 parsed_model_flags.allow_nonexistent_arrays.value()); 374 375 if (parsed_model_flags.arrays_extra_info_file.specified()) { 376 string arrays_extra_info_file_contents; 377 port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(), 378 &arrays_extra_info_file_contents, 379 port::file::Defaults()); 380 ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents, 381 model_flags->mutable_arrays_extra_info()); 382 } 383} 384 385ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) { 386 static auto* flags = [must_already_exist]() { 387 if (must_already_exist) { 388 fprintf(stderr, __FILE__ 389 ":" 390 "GlobalParsedModelFlags() used without initialization\n"); 391 fflush(stderr); 392 abort(); 393 } 394 return new toco::ParsedModelFlags; 395 }(); 396 return flags; 397} 398 399ParsedModelFlags* GlobalParsedModelFlags() { 400 return UncheckedGlobalParsedModelFlags(true); 401} 402 403void ParseModelFlagsOrDie(int* argc, char* argv[]) { 404 // TODO(aselle): in the future allow Google version to use 405 // flags, and only use this mechanism for open source 406 auto* flags = UncheckedGlobalParsedModelFlags(false); 407 string msg; 408 bool model_success = 409 toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags); 410 if (!model_success || !msg.empty()) { 411 // Log in non-standard way since this happens pre InitGoogle. 412 fprintf(stderr, "%s", msg.c_str()); 413 fflush(stderr); 414 abort(); 415 } 416} 417 418} // namespace toco 419