1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <vector>
18
19#include "absl/strings/numbers.h"
20#include "absl/strings/str_join.h"
21#include "absl/strings/str_split.h"
22#include "absl/strings/strip.h"
23#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
24#include "tensorflow/contrib/lite/toco/toco_port.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/util/command_line_flags.h"
27
28namespace toco {
29
30bool ParseTocoFlagsFromCommandLineFlags(
31    int* argc, char* argv[], string* msg,
32    ParsedTocoFlags* parsed_toco_flags_ptr) {
33  using tensorflow::Flag;
34  ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
35  std::vector<tensorflow::Flag> flags = {
36      Flag("input_file", parsed_flags.input_file.bind(),
37           parsed_flags.input_file.default_value(),
38           "Input file (model of any supported format). For Protobuf "
39           "formats, both text and binary are supported regardless of file "
40           "extension."),
41      Flag("output_file", parsed_flags.output_file.bind(),
42           parsed_flags.output_file.default_value(),
43           "Output file. "
44           "For Protobuf formats, the binary format will be used."),
45      Flag("input_format", parsed_flags.input_format.bind(),
46           parsed_flags.input_format.default_value(),
47           "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
48      Flag("output_format", parsed_flags.output_format.bind(),
49           parsed_flags.output_format.default_value(),
50           "Output file format. "
51           "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
52      Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
53           parsed_flags.default_ranges_min.default_value(),
54           "If defined, will be used as the default value for the min bound "
55           "of min/max ranges used for quantization."),
56      Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
57           parsed_flags.default_ranges_max.default_value(),
58           "If defined, will be used as the default value for the max bound "
59           "of min/max ranges used for quantization."),
60      Flag("inference_type", parsed_flags.inference_type.bind(),
61           parsed_flags.inference_type.default_value(),
62           "Target data type of arrays in the output file (for input_arrays, "
63           "this may be overridden by inference_input_type). "
64           "One of FLOAT, QUANTIZED_UINT8."),
65      Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
66           parsed_flags.inference_input_type.default_value(),
67           "Target data type of input arrays. "
68           "If not specified, inference_type is used. "
69           "One of FLOAT, QUANTIZED_UINT8."),
70      Flag("input_type", parsed_flags.input_type.bind(),
71           parsed_flags.input_type.default_value(),
72           "Deprecated ambiguous flag that set both --input_data_types and "
73           "--inference_input_type."),
74      Flag("input_types", parsed_flags.input_types.bind(),
75           parsed_flags.input_types.default_value(),
76           "Deprecated ambiguous flag that set both --input_data_types and "
77           "--inference_input_type. Was meant to be a "
78           "comma-separated list, but this was deprecated before "
79           "multiple-input-types was ever properly supported."),
80
81      Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
82           parsed_flags.drop_fake_quant.default_value(),
83           "Ignore and discard FakeQuant nodes. For instance, to "
84           "generate plain float code without fake-quantization from a "
85           "quantized graph."),
86      Flag(
87          "reorder_across_fake_quant",
88          parsed_flags.reorder_across_fake_quant.bind(),
89          parsed_flags.reorder_across_fake_quant.default_value(),
90          "Normally, FakeQuant nodes must be strict boundaries for graph "
91          "transformations, in order to ensure that quantized inference has "
92          "the exact same arithmetic behavior as quantized training --- which "
93          "is the whole point of quantized training and of FakeQuant nodes in "
94          "the first place. "
95          "However, that entails subtle requirements on where exactly "
96          "FakeQuant nodes must be placed in the graph. Some quantized graphs "
97          "have FakeQuant nodes at unexpected locations, that prevent graph "
98          "transformations that are necessary in order to generate inference "
99          "code for these graphs. Such graphs should be fixed, but as a "
100          "temporary work-around, setting this reorder_across_fake_quant flag "
101          "allows TOCO to perform necessary graph transformaitons on them, "
102          "at the cost of no longer faithfully matching inference and training "
103          "arithmetic."),
104      Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
105           parsed_flags.allow_custom_ops.default_value(),
106           "If true, allow TOCO to create TF Lite Custom operators for all the "
107           "unsupported TensorFlow ops."),
108      Flag(
109          "drop_control_dependency",
110          parsed_flags.drop_control_dependency.bind(),
111          parsed_flags.drop_control_dependency.default_value(),
112          "If true, ignore control dependency requirements in input TensorFlow "
113          "GraphDef. Otherwise an error will be raised upon control dependency "
114          "inputs."),
115  };
116  bool asked_for_help =
117      *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
118  if (asked_for_help) {
119    *msg += tensorflow::Flags::Usage(argv[0], flags);
120    return false;
121  } else {
122    return tensorflow::Flags::Parse(argc, argv, flags);
123  }
124}
125
126void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
127                                       TocoFlags* toco_flags) {
128  namespace port = toco::port;
129  port::CheckInitGoogleIsDone("InitGoogle is not done yet");
130
131  enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified };
132
133#define ENFORCE_FLAG_REQUIREMENT(name, requirement)                          \
134  do {                                                                       \
135    if (requirement == FlagRequirement::kMustBeSpecified) {                  \
136      QCHECK(parsed_toco_flags.name.specified())                             \
137          << "Missing required flag: " << #name;                             \
138    }                                                                        \
139    if (requirement == FlagRequirement::kMustNotBeSpecified) {               \
140      QCHECK(!parsed_toco_flags.name.specified())                            \
141          << "Given other flags, this flag should not have been specified: " \
142          << #name;                                                          \
143    }                                                                        \
144  } while (false)
145#define READ_TOCO_FLAG(name, requirement)                     \
146  ENFORCE_FLAG_REQUIREMENT(name, requirement);                \
147  do {                                                        \
148    if (parsed_toco_flags.name.specified()) {                 \
149      toco_flags->set_##name(parsed_toco_flags.name.value()); \
150    }                                                         \
151  } while (false)
152
153#define PARSE_TOCO_FLAG(Type, name, requirement)               \
154  ENFORCE_FLAG_REQUIREMENT(name, requirement);                 \
155  do {                                                         \
156    if (parsed_toco_flags.name.specified()) {                  \
157      Type x;                                                  \
158      QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \
159          << "Unrecognized " << #Type << " value "             \
160          << parsed_toco_flags.name.value();                   \
161      toco_flags->set_##name(x);                               \
162    }                                                          \
163  } while (false)
164
165  PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified);
166  PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified);
167  PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
168  PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
169  READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
170  READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
171  READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
172  READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
173  READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
174  READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
175
176  // Deprecated flag handling.
177  if (parsed_toco_flags.input_type.specified()) {
178    LOG(WARNING)
179        << "--input_type is deprecated. It was an ambiguous flag that set both "
180           "--input_data_types and --inference_input_type. If you are trying "
181           "to complement the input file with information about the type of "
182           "input arrays, use --input_data_type. If you are trying to control "
183           "the quantization/dequantization of real-numbers input arrays in "
184           "the output file, use --inference_input_type.";
185    toco::IODataType input_type;
186    QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
187                                  &input_type));
188    toco_flags->set_inference_input_type(input_type);
189  }
190  if (parsed_toco_flags.input_types.specified()) {
191    LOG(WARNING)
192        << "--input_types is deprecated. It was an ambiguous flag that set "
193           "both --input_data_types and --inference_input_type. If you are "
194           "trying to complement the input file with information about the "
195           "type of input arrays, use --input_data_type. If you are trying to "
196           "control the quantization/dequantization of real-numbers input "
197           "arrays in the output file, use --inference_input_type.";
198    std::vector<string> input_types =
199        absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
200    QCHECK(!input_types.empty());
201    for (int i = 1; i < input_types.size(); i++) {
202      QCHECK_EQ(input_types[i], input_types[0]);
203    }
204    toco::IODataType input_type;
205    QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
206    toco_flags->set_inference_input_type(input_type);
207  }
208
209#undef READ_TOCO_FLAG
210#undef PARSE_TOCO_FLAG
211}
212}  // namespace toco
213