1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
16
17#include <string>
18#include <vector>
19
20#include "absl/strings/numbers.h"
21#include "absl/strings/str_join.h"
22#include "absl/strings/str_split.h"
23#include "absl/strings/string_view.h"
24#include "absl/strings/strip.h"
25#include "tensorflow/contrib/lite/toco/args.h"
26#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
27#include "tensorflow/contrib/lite/toco/toco_port.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/util/command_line_flags.h"
30
31// "batch" flag only exists internally
32#ifdef PLATFORM_GOOGLE
33#include "base/commandlineflags.h"
34#endif
35
36namespace toco {
37
38bool ParseModelFlagsFromCommandLineFlags(
39    int* argc, char* argv[], string* msg,
40    ParsedModelFlags* parsed_model_flags_ptr) {
41  ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42  using tensorflow::Flag;
43  std::vector<tensorflow::Flag> flags = {
44      Flag("input_array", parsed_flags.input_array.bind(),
45           parsed_flags.input_array.default_value(),
46           "Deprecated: use --input_arrays instead. Name of the input array. "
47           "If not specified, will try to read "
48           "that information from the input file."),
49      Flag("input_arrays", parsed_flags.input_arrays.bind(),
50           parsed_flags.input_arrays.default_value(),
51           "Names of the output arrays, comma-separated. If not specified, "
52           "will try to read that information from the input file."),
53      Flag("output_array", parsed_flags.output_array.bind(),
54           parsed_flags.output_array.default_value(),
55           "Deprecated: use --output_arrays instead. Name of the output array, "
56           "when specifying a unique output array. "
57           "If not specified, will try to read that information from the "
58           "input file."),
59      Flag("output_arrays", parsed_flags.output_arrays.bind(),
60           parsed_flags.output_arrays.default_value(),
61           "Names of the output arrays, comma-separated. "
62           "If not specified, will try to read "
63           "that information from the input file."),
64      Flag("input_shape", parsed_flags.input_shape.bind(),
65           parsed_flags.input_shape.default_value(),
66           "Deprecated: use --input_shapes instead. Input array shape. For "
67           "many models the shape takes the form "
68           "batch size, input array height, input array width, input array "
69           "depth."),
70      Flag("input_shapes", parsed_flags.input_shapes.bind(),
71           parsed_flags.input_shapes.default_value(),
72           "Shapes corresponding to --input_arrays, colon-separated. For "
73           "many models each shape takes the form batch size, input array "
74           "height, input array width, input array depth."),
75      Flag("input_data_type", parsed_flags.input_data_type.bind(),
76           parsed_flags.input_data_type.default_value(),
77           "Deprecated: use --input_data_types instead. Input array type, if "
78           "not already provided in the graph. "
79           "Typically needs to be specified when passing arbitrary arrays "
80           "to --input_array."),
81      Flag("input_data_types", parsed_flags.input_data_types.bind(),
82           parsed_flags.input_data_types.default_value(),
83           "Input arrays types, comma-separated, if not already provided in "
84           "the graph. "
85           "Typically needs to be specified when passing arbitrary arrays "
86           "to --input_arrays."),
87      Flag("mean_value", parsed_flags.mean_value.bind(),
88           parsed_flags.mean_value.default_value(),
89           "Deprecated: use --mean_values instead. mean_value parameter for "
90           "image models, used to compute input "
91           "activations from input pixel data."),
92      Flag("mean_values", parsed_flags.mean_values.bind(),
93           parsed_flags.mean_values.default_value(),
94           "mean_values parameter for image models, comma-separated list of "
95           "doubles, used to compute input activations from input pixel "
96           "data. Each entry in the list should match an entry in "
97           "--input_arrays."),
98      Flag("std_value", parsed_flags.std_value.bind(),
99           parsed_flags.std_value.default_value(),
100           "Deprecated: use --std_values instead. std_value parameter for "
101           "image models, used to compute input "
102           "activations from input pixel data."),
103      Flag("std_values", parsed_flags.std_values.bind(),
104           parsed_flags.std_values.default_value(),
105           "std_value parameter for image models, comma-separated list of "
106           "doubles, used to compute input activations from input pixel "
107           "data. Each entry in the list should match an entry in "
108           "--input_arrays."),
109      Flag("variable_batch", parsed_flags.variable_batch.bind(),
110           parsed_flags.variable_batch.default_value(),
111           "If true, the model accepts an arbitrary batch size. Mutually "
112           "exclusive "
113           "with the 'batch' field: at most one of these two fields can be "
114           "set."),
115      Flag("rnn_states", parsed_flags.rnn_states.bind(),
116           parsed_flags.rnn_states.default_value(), ""),
117      Flag("model_checks", parsed_flags.model_checks.bind(),
118           parsed_flags.model_checks.default_value(),
119           "A list of model checks to be applied to verify the form of the "
120           "model.  Applied after the graph transformations after import."),
121      Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(),
122           parsed_flags.graphviz_first_array.default_value(),
123           "If set, defines the start of the sub-graph to be dumped to "
124           "GraphViz."),
125      Flag(
126          "graphviz_last_array", parsed_flags.graphviz_last_array.bind(),
127          parsed_flags.graphviz_last_array.default_value(),
128          "If set, defines the end of the sub-graph to be dumped to GraphViz."),
129      Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
130           parsed_flags.dump_graphviz.default_value(),
131           "Dump graphviz during LogDump call. If string is non-empty then "
132           "it defines path to dump, otherwise will skip dumping."),
133      Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
134           parsed_flags.dump_graphviz_video.default_value(),
135           "If true, will dump graphviz at each "
136           "graph transformation, which may be used to generate a video."),
137      Flag("allow_nonexistent_arrays",
138           parsed_flags.allow_nonexistent_arrays.bind(),
139           parsed_flags.allow_nonexistent_arrays.default_value(),
140           "If true, will allow passing inexistent arrays in --input_arrays "
141           "and --output_arrays. This makes little sense, is only useful to "
142           "more easily get graph visualizations."),
143      Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
144           parsed_flags.allow_nonascii_arrays.default_value(),
145           "If true, will allow passing non-ascii-printable characters in "
146           "--input_arrays and --output_arrays. By default (if false), only "
147           "ascii printable characters are allowed, i.e. character codes "
148           "ranging from 32 to 127. This is disallowed by default so as to "
149           "catch common copy-and-paste issues where invisible unicode "
150           "characters are unwittingly added to these strings."),
151      Flag(
152          "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
153          parsed_flags.arrays_extra_info_file.default_value(),
154          "Path to an optional file containing a serialized ArraysExtraInfo "
155          "proto allowing to pass extra information about arrays not specified "
156          "in the input model file, such as extra MinMax information."),
157  };
158  bool asked_for_help =
159      *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
160  if (asked_for_help) {
161    *msg += tensorflow::Flags::Usage(argv[0], flags);
162    return false;
163  } else {
164    if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
165  }
166  auto& dump_options = *GraphVizDumpOptions::singleton();
167  dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value();
168  dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value();
169  dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
170  dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
171
172  return true;
173}
174
175void ReadModelFlagsFromCommandLineFlags(
176    const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
177  toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
178
179// "batch" flag only exists internally
180#ifdef PLATFORM_GOOGLE
181  CHECK(!((base::SpecifiedOnCommandLine("batch") &&
182           parsed_model_flags.variable_batch.specified())))
183      << "The --batch and --variable_batch flags are mutually exclusive.";
184#endif
185  CHECK(!(parsed_model_flags.output_array.specified() &&
186          parsed_model_flags.output_arrays.specified()))
187      << "The --output_array and --vs flags are mutually exclusive.";
188
189  if (parsed_model_flags.output_array.specified()) {
190    model_flags->add_output_arrays(parsed_model_flags.output_array.value());
191  }
192
193  if (parsed_model_flags.output_arrays.specified()) {
194    std::vector<string> output_arrays =
195        absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
196    for (const string& output_array : output_arrays) {
197      model_flags->add_output_arrays(output_array);
198    }
199  }
200
201  const bool uses_single_input_flags =
202      parsed_model_flags.input_array.specified() ||
203      parsed_model_flags.mean_value.specified() ||
204      parsed_model_flags.std_value.specified() ||
205      parsed_model_flags.input_shape.specified();
206
207  const bool uses_multi_input_flags =
208      parsed_model_flags.input_arrays.specified() ||
209      parsed_model_flags.mean_values.specified() ||
210      parsed_model_flags.std_values.specified() ||
211      parsed_model_flags.input_shapes.specified();
212
213  QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
214      << "Use either the singular-form input flags (--input_array, "
215         "--input_shape, --mean_value, --std_value) or the plural form input "
216         "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
217         "but not both forms within the same command line.";
218
219  if (parsed_model_flags.input_array.specified()) {
220    QCHECK(uses_single_input_flags);
221    model_flags->add_input_arrays()->set_name(
222        parsed_model_flags.input_array.value());
223  }
224  if (parsed_model_flags.input_arrays.specified()) {
225    QCHECK(uses_multi_input_flags);
226    for (const auto& input_array :
227         absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
228      model_flags->add_input_arrays()->set_name(string(input_array));
229    }
230  }
231  if (parsed_model_flags.mean_value.specified()) {
232    QCHECK(uses_single_input_flags);
233    model_flags->mutable_input_arrays(0)->set_mean_value(
234        parsed_model_flags.mean_value.value());
235  }
236  if (parsed_model_flags.mean_values.specified()) {
237    QCHECK(uses_multi_input_flags);
238    std::vector<string> mean_values =
239        absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
240    QCHECK(mean_values.size() == model_flags->input_arrays_size());
241    for (int i = 0; i < mean_values.size(); ++i) {
242      char* last = nullptr;
243      model_flags->mutable_input_arrays(i)->set_mean_value(
244          strtod(mean_values[i].data(), &last));
245      CHECK(last != mean_values[i].data());
246    }
247  }
248  if (parsed_model_flags.std_value.specified()) {
249    QCHECK(uses_single_input_flags);
250    model_flags->mutable_input_arrays(0)->set_std_value(
251        parsed_model_flags.std_value.value());
252  }
253  if (parsed_model_flags.std_values.specified()) {
254    QCHECK(uses_multi_input_flags);
255    std::vector<string> std_values =
256        absl::StrSplit(parsed_model_flags.std_values.value(), ',');
257    QCHECK(std_values.size() == model_flags->input_arrays_size());
258    for (int i = 0; i < std_values.size(); ++i) {
259      char* last = nullptr;
260      model_flags->mutable_input_arrays(i)->set_std_value(
261          strtod(std_values[i].data(), &last));
262      CHECK(last != std_values[i].data());
263    }
264  }
265  if (parsed_model_flags.input_data_type.specified()) {
266    QCHECK(uses_single_input_flags);
267    IODataType type;
268    QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
269    model_flags->mutable_input_arrays(0)->set_data_type(type);
270  }
271  if (parsed_model_flags.input_data_types.specified()) {
272    QCHECK(uses_multi_input_flags);
273    std::vector<string> input_data_types =
274        absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
275    QCHECK(input_data_types.size() == model_flags->input_arrays_size());
276    for (int i = 0; i < input_data_types.size(); ++i) {
277      IODataType type;
278      QCHECK(IODataType_Parse(input_data_types[i], &type));
279      model_flags->mutable_input_arrays(i)->set_data_type(type);
280    }
281  }
282  if (parsed_model_flags.input_shape.specified()) {
283    QCHECK(uses_single_input_flags);
284    if (model_flags->input_arrays().empty()) {
285      model_flags->add_input_arrays();
286    }
287    auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
288    shape->clear_dims();
289    const IntList& list = parsed_model_flags.input_shape.value();
290    for (auto& dim : list.elements) {
291      shape->add_dims(dim);
292    }
293  }
294  if (parsed_model_flags.input_shapes.specified()) {
295    QCHECK(uses_multi_input_flags);
296    std::vector<string> input_shapes =
297        absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
298    QCHECK(input_shapes.size() == model_flags->input_arrays_size());
299    for (int i = 0; i < input_shapes.size(); ++i) {
300      auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
301      shape->clear_dims();
302      for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
303        int size;
304        CHECK(absl::SimpleAtoi(dim_str, &size))
305            << "Failed to parse input_shape: " << input_shapes[i];
306        shape->add_dims(size);
307      }
308    }
309  }
310
311#define READ_MODEL_FLAG(name)                                   \
312  do {                                                          \
313    if (parsed_model_flags.name.specified()) {                  \
314      model_flags->set_##name(parsed_model_flags.name.value()); \
315    }                                                           \
316  } while (false)
317
318  READ_MODEL_FLAG(variable_batch);
319
320#undef READ_MODEL_FLAG
321
322  for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
323    auto* rnn_state_proto = model_flags->add_rnn_states();
324    for (const auto& kv_pair : element) {
325      const string& key = kv_pair.first;
326      const string& value = kv_pair.second;
327      if (key == "state_array") {
328        rnn_state_proto->set_state_array(value);
329      } else if (key == "back_edge_source_array") {
330        rnn_state_proto->set_back_edge_source_array(value);
331      } else if (key == "size") {
332        int32 size = 0;
333        CHECK(absl::SimpleAtoi(value, &size));
334        CHECK_GT(size, 0);
335        rnn_state_proto->set_size(size);
336      } else {
337        LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
338      }
339    }
340    CHECK(rnn_state_proto->has_state_array() &&
341          rnn_state_proto->has_back_edge_source_array() &&
342          rnn_state_proto->has_size())
343        << "--rnn_states must include state_array, back_edge_source_array and "
344           "size.";
345  }
346
347  for (const auto& element : parsed_model_flags.model_checks.value().elements) {
348    auto* model_check_proto = model_flags->add_model_checks();
349    for (const auto& kv_pair : element) {
350      const string& key = kv_pair.first;
351      const string& value = kv_pair.second;
352      if (key == "count_type") {
353        model_check_proto->set_count_type(value);
354      } else if (key == "count_min") {
355        int32 count = 0;
356        CHECK(absl::SimpleAtoi(value, &count));
357        CHECK_GE(count, -1);
358        model_check_proto->set_count_min(count);
359      } else if (key == "count_max") {
360        int32 count = 0;
361        CHECK(absl::SimpleAtoi(value, &count));
362        CHECK_GE(count, -1);
363        model_check_proto->set_count_max(count);
364      } else {
365        LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
366      }
367    }
368  }
369
370  model_flags->set_allow_nonascii_arrays(
371      parsed_model_flags.allow_nonascii_arrays.value());
372  model_flags->set_allow_nonexistent_arrays(
373      parsed_model_flags.allow_nonexistent_arrays.value());
374
375  if (parsed_model_flags.arrays_extra_info_file.specified()) {
376    string arrays_extra_info_file_contents;
377    port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
378                            &arrays_extra_info_file_contents,
379                            port::file::Defaults());
380    ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
381                                      model_flags->mutable_arrays_extra_info());
382  }
383}
384
385ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
386  static auto* flags = [must_already_exist]() {
387    if (must_already_exist) {
388      fprintf(stderr, __FILE__
389              ":"
390              "GlobalParsedModelFlags() used without initialization\n");
391      fflush(stderr);
392      abort();
393    }
394    return new toco::ParsedModelFlags;
395  }();
396  return flags;
397}
398
399ParsedModelFlags* GlobalParsedModelFlags() {
400  return UncheckedGlobalParsedModelFlags(true);
401}
402
403void ParseModelFlagsOrDie(int* argc, char* argv[]) {
404  // TODO(aselle): in the future allow Google version to use
405  // flags, and only use this mechanism for open source
406  auto* flags = UncheckedGlobalParsedModelFlags(false);
407  string msg;
408  bool model_success =
409      toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
410  if (!model_success || !msg.empty()) {
411    // Log in non-standard way since this happens pre InitGoogle.
412    fprintf(stderr, "%s", msg.c_str());
413    fflush(stderr);
414    abort();
415  }
416}
417
418}  // namespace toco
419