1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/tooling_util.h"
16
17#include <functional>
18#include <iterator>
19#include <set>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23
24#include "absl/strings/ascii.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/str_join.h"
27#include "absl/strings/str_replace.h"
28#include "absl/strings/str_split.h"
29#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
30#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
31#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
32#include "tensorflow/contrib/lite/toco/toco_port.h"
33#include "tensorflow/core/platform/logging.h"
34
35namespace toco {
36
37// Find the longest common prefix of two strings.
38absl::string_view FindLongestCommonPrefix(absl::string_view a,
39                                          absl::string_view b) {
40  if (a.empty() || b.empty()) return absl::string_view();
41
42  const char* pa = a.data();
43  const char* pb = b.data();
44  size_t count = 0;
45  const size_t limit = std::min(a.size(), b.size());
46  while (count < limit && *pa == *pb) {
47    ++pa;
48    ++pb;
49    ++count;
50  }
51
52  return absl::string_view(a.data(), count);
53}
54
55string LogName(const Operator& op) {
56  const string& opname = HelpfulOperatorTypeName(op);
57  if (op.outputs.empty()) {
58    return toco::port::StringF("{%s operator}", opname);
59  } else {
60    return toco::port::StringF("{%s operator with output %s}", opname,
61                               op.outputs[0]);
62  }
63}
64
65bool IsInputArray(const Model& model, const string& name) {
66  for (const auto& input_array : model.flags.input_arrays()) {
67    if (input_array.name() == name) {
68      return true;
69    }
70  }
71  return false;
72}
73
74bool IsArrayConsumed(const Model& model, const string& name) {
75  if (GetOpWithInput(model, name)) {
76    return true;
77  }
78  for (const string& model_output : model.flags.output_arrays()) {
79    if (model_output == name) {
80      return true;
81    }
82  }
83  for (const auto& rnn_state : model.flags.rnn_states()) {
84    if (rnn_state.back_edge_source_array() == name) {
85      return true;
86    }
87  }
88  return false;
89}
90
91int CountTrueOutputs(const Model& model, const Operator& op) {
92  int count = 0;
93  for (const string& output : op.outputs) {
94    if (IsArrayConsumed(model, output)) {
95      ++count;
96    }
97  }
98  return count;
99}
100
101int CountOpsWithInput(const Model& model, const string& array_name) {
102  int count = 0;
103  for (const auto& op : model.operators) {
104    for (auto& input : op->inputs) {
105      if (input == array_name) {
106        count++;
107      }
108    }
109  }
110  return count;
111}
112
113bool DeleteArrayIfUnused(const string& array_name, Model* model) {
114  if (IsDiscardableArray(*model, array_name) &&
115      CountOpsWithInput(*model, array_name) == 0) {
116    model->EraseArray(array_name);
117    return true;
118  }
119  return false;
120}
121
122bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) {
123  if (IsDiscardableArray(*model, array_name) &&
124      CountOpsWithInput(*model, array_name) == 1) {
125    model->EraseArray(array_name);
126    return true;
127  }
128  return false;
129}
130
131std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
132    const Model& model, const string& array_name) {
133  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
134    for (auto& output : it->get()->outputs) {
135      if (output == array_name) {
136        return it;
137      }
138    }
139  }
140  return model.operators.end();
141}
142
143std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
144    Model& model, const string& array_name) {
145  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
146    for (auto& output : it->get()->outputs) {
147      if (output == array_name) {
148        return it;
149      }
150    }
151  }
152  return model.operators.end();
153}
154
155Operator* GetOpWithOutput(const Model& model, const string& array_name) {
156  auto it = FindOpWithOutput(model, array_name);
157  return it == model.operators.end() ? nullptr : it->get();
158}
159
160// GetFirstOpWithInput assumes that this finds the first op.
161std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
162    const Model& model, const string& array_name) {
163  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
164    for (auto& input : it->get()->inputs) {
165      if (input == array_name) {
166        return it;
167      }
168    }
169  }
170  return model.operators.end();
171}
172
173std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
174    Model& model, const string& array_name) {
175  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
176    for (auto& input : it->get()->inputs) {
177      if (input == array_name) {
178        return it;
179      }
180    }
181  }
182  return model.operators.end();
183}
184
185std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
186    const Model& model, const Operator* op) {
187  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
188    if (it->get() == op) {
189      return it;
190    }
191  }
192  return model.operators.end();
193}
194
195std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
196                                                        const Operator* op) {
197  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
198    if (it->get() == op) {
199      return it;
200    }
201  }
202  return model.operators.end();
203}
204
205Operator* GetOpWithInput(const Model& model, const string& array_name) {
206  auto it = FindOpWithInput(model, array_name);
207  return it == model.operators.end() ? nullptr : it->get();
208}
209
210Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
211  auto it = FindOpWithInput(model, array_name);
212  return it == model.operators.end() ? nullptr : it->get();
213}
214
215string FormatArraysList(const Model& model, const std::vector<string>& list) {
216  if (list.empty()) {
217    return "[]";
218  }
219  string result = "";
220  if (list.size() > 1) {
221    result += "[ ";
222  }
223  for (std::size_t i = 0; i < list.size(); i++) {
224    if (i > 0) {
225      result += ", ";
226    }
227    result += list[i];
228  }
229  if (list.size() > 1) {
230    result += " ]";
231  }
232  return result;
233}
234
235const char* OperatorTypeName(OperatorType type) {
236  switch (type) {
237#define HANDLE_OPERATORTYPENAME_CASE(c) \
238  case OperatorType::k##c:              \
239    return #c;
240    HANDLE_OPERATORTYPENAME_CASE(Add)
241    HANDLE_OPERATORTYPENAME_CASE(AddN)
242    HANDLE_OPERATORTYPENAME_CASE(AveragePool)
243    HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
244    HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
245    HANDLE_OPERATORTYPENAME_CASE(Conv)
246    HANDLE_OPERATORTYPENAME_CASE(Concatenation)
247    HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
248    HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
249    HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
250    HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
251    HANDLE_OPERATORTYPENAME_CASE(Dequantize)
252    HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
253    HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
254    HANDLE_OPERATORTYPENAME_CASE(Logistic)
255    HANDLE_OPERATORTYPENAME_CASE(LstmCell)
256    HANDLE_OPERATORTYPENAME_CASE(MaxPool)
257    HANDLE_OPERATORTYPENAME_CASE(L2Pool)
258    HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
259    HANDLE_OPERATORTYPENAME_CASE(Mul)
260    HANDLE_OPERATORTYPENAME_CASE(Relu)
261    HANDLE_OPERATORTYPENAME_CASE(Relu1)
262    HANDLE_OPERATORTYPENAME_CASE(Relu6)
263    HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
264    HANDLE_OPERATORTYPENAME_CASE(Softmax)
265    HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
266    HANDLE_OPERATORTYPENAME_CASE(Div)
267    HANDLE_OPERATORTYPENAME_CASE(Tanh)
268    HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
269    HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
270    HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
271    HANDLE_OPERATORTYPENAME_CASE(Fill)
272    HANDLE_OPERATORTYPENAME_CASE(FloorMod)
273    HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
274    HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater)
275    HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual)
276    HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity)
277    HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess)
278    HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual)
279    HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul)
280    HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax)
281    HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum)
282    HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge)
283    HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin)
284    HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
285    HANDLE_OPERATORTYPENAME_CASE(Neg)
286    HANDLE_OPERATORTYPENAME_CASE(Pad)
287    HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
288    HANDLE_OPERATORTYPENAME_CASE(Stack)
289    HANDLE_OPERATORTYPENAME_CASE(Range)
290    HANDLE_OPERATORTYPENAME_CASE(Rank)
291    HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape)
292    HANDLE_OPERATORTYPENAME_CASE(Squeeze)
293    HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt)
294    HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape)
295    HANDLE_OPERATORTYPENAME_CASE(Slice)
296    HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit)
297    HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt)
298    HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare)
299    HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch)
300    HANDLE_OPERATORTYPENAME_CASE(Sub)
301    HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum)
302    HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile)
303    HANDLE_OPERATORTYPENAME_CASE(Transpose)
304    HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
305    HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat)
306    HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2)
307    HANDLE_OPERATORTYPENAME_CASE(Cast)
308    HANDLE_OPERATORTYPENAME_CASE(Floor)
309    HANDLE_OPERATORTYPENAME_CASE(Gather)
310    HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
311    HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
312    HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
313    HANDLE_OPERATORTYPENAME_CASE(Mean)
314    HANDLE_OPERATORTYPENAME_CASE(Svdf)
315    HANDLE_OPERATORTYPENAME_CASE(ArgMax)
316    HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
317    HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
318    HANDLE_OPERATORTYPENAME_CASE(Exp)
319    default:
320      LOG(FATAL) << "Unhandled op type";
321#undef HANDLE_OPERATORTYPENAME_CASE
322  }
323}
324
325string HelpfulOperatorTypeName(const Operator& op) {
326  if (op.type == OperatorType::kTensorFlowUnsupported) {
327    return toco::port::StringF(
328        "(Unsupported TensorFlow op: %s)",
329        static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
330  }
331  return OperatorTypeName(op.type);
332}
333
334bool OperatorSupportsFusedActivation(OperatorType type) {
335  switch (type) {
336    case OperatorType::kConcatenation:
337    case OperatorType::kGather:
338    case OperatorType::kSlice:
339    case OperatorType::kSqueeze:
340    case OperatorType::kTensorFlowReshape:
341    case OperatorType::kTensorFlowSplit:
342      return false;
343    default:
344      return true;
345  }
346}
347
348void LogSummary(int log_level, const Model& model) {
349  VLOG(log_level) << "Operators summary (" << model.operators.size()
350                  << " operators):";
351  std::unordered_multiset<OperatorType> ops_by_type;
352  for (const auto& op : model.operators) {
353    ops_by_type.insert(op->type);
354  }
355  auto it = ops_by_type.begin();
356  while (it != ops_by_type.end()) {
357    int count = ops_by_type.count(*it);
358    VLOG(log_level) << "    " << OperatorTypeName(*it) << ": " << count;
359    std::advance(it, count);
360  }
361}
362
363void LogArray(int log_level, const Model& model, const string& name) {
364  const auto& array = model.GetArray(name);
365  VLOG(log_level) << "Array: " << name;
366  switch (array.data_type) {
367    case ArrayDataType::kNone:
368      VLOG(log_level) << "  Data type:";
369      break;
370    case ArrayDataType::kFloat:
371      VLOG(log_level) << "  Data type: kFloat";
372      break;
373    case ArrayDataType::kInt32:
374      VLOG(log_level) << "  Data type: kInt32";
375      break;
376    case ArrayDataType::kUint8:
377      VLOG(log_level) << "  Data type: kUint8";
378      break;
379    case ArrayDataType::kString:
380      VLOG(log_level) << "  Data type: kString";
381      break;
382    default:
383      VLOG(log_level) << "  Data type: other (numerical value: "
384                      << static_cast<int>(array.data_type) << ")";
385      break;
386  }
387  switch (array.final_data_type) {
388    case ArrayDataType::kNone:
389      VLOG(log_level) << "  Final type:";
390      break;
391    case ArrayDataType::kFloat:
392      VLOG(log_level) << "  Final type: kFloat";
393      break;
394    case ArrayDataType::kInt32:
395      VLOG(log_level) << "  Final type: kInt32";
396      break;
397    case ArrayDataType::kUint8:
398      VLOG(log_level) << "  Final type: kUint8";
399      break;
400    case ArrayDataType::kString:
401      VLOG(log_level) << "  Final type: kString";
402      break;
403    default:
404      VLOG(log_level) << "  Final type: other (numerical value: "
405                      << static_cast<int>(array.data_type) << ")";
406      break;
407  }
408  if (array.buffer) {
409    VLOG(log_level) << "  Constant Buffer";
410  }
411  if (array.alloc) {
412    VLOG(log_level) << "  Transient Alloc";
413  }
414  if (array.has_shape()) {
415    const Shape& array_shape = array.shape();
416    if (array_shape.dimensions_count() == 0) {
417      VLOG(log_level) << "  (Zero dimensions)";
418    } else {
419      string message = "  Dims: ";
420      bool first = true;
421      for (const int dim : array_shape.dims()) {
422        if (!first) {
423          message += ", ";
424        }
425        first = false;
426        toco::port::AppendF(&message, "%d", dim);
427      }
428      VLOG(log_level) << message;
429    }
430  }
431  if (array.minmax) {
432    VLOG(log_level) << "  MinMax: " << array.minmax->min << " .. "
433                    << array.minmax->max;
434  }
435  if (array.quantization_params) {
436    VLOG(log_level) << "  QuantizationParams: zero_point="
437                    << static_cast<int>(array.quantization_params->zero_point)
438                    << ", scale=" << array.quantization_params->scale;
439  }
440}
441
442void DumpGraphvizVideoFrame(const Model& model) {
443  namespace port = toco::port;
444
445  const auto& dump_options = *GraphVizDumpOptions::singleton();
446  if (!dump_options.dump_graphviz_video) {
447    return;
448  }
449  CHECK(!dump_options.dump_graphviz.empty());
450  // TODO(benoitjacob): the static data here means that this function
451  // is stateful, not reentrant, and effectively leaks memory till exit
452  // (since dump_hashes can only grow in size). It also means that it
453  // really only is intended to be called for a single model during the
454  // process' lifetime. So it's not great design at all. The overriding
455  // design aspect here is to make the video-dumping code as unintrusive
456  // and self-contained as possible. Eventually, we'll want to have that
457  // cleaned-up, but that will require some form of general statefulness
458  // in toco (some kind of 'tooling state' data structure) that does
459  // not exist at present, and would be premature to design here just for
460  // this new video-dumping feature.
461  static int dump_id = 0;
462  static std::unordered_set<std::size_t> dump_hashes;
463  string graphviz_dump;
464  DumpGraphviz(model, &graphviz_dump);
465  std::size_t hash = std::hash<string>{}(graphviz_dump);
466  if (!dump_hashes.count(hash)) {
467    LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
468    dump_hashes.insert(hash);
469    CHECK(port::file::SetContents(
470              port::file::JoinPath(
471                  dump_options.dump_graphviz,
472                  toco::port::StringF("toco_video_%05d.dot", dump_id)),
473              graphviz_dump, port::file::Defaults())
474              .ok());
475    dump_id++;
476  }
477}
478
479void LogDump(int log_level, const string& message, const Model& model) {
480  namespace port = toco::port;
481  const auto& dump_options = *GraphVizDumpOptions::singleton();
482
483  DumpGraphvizVideoFrame(model);
484  if (!dump_options.dump_graphviz.empty()) {
485    string graphviz_dump;
486
487    DumpGraphviz(model, &graphviz_dump);
488    CHECK(port::file::SetContents(
489              port::file::JoinPath(
490                  dump_options.dump_graphviz,
491                  absl::StrCat("toco_",
492                               absl::StrReplaceAll(message, {{" ", "_"}}),
493                               ".dot")),
494              graphviz_dump, port::file::Defaults())
495              .ok());
496  }
497
498  if (!VLOG_IS_ON(log_level)) {
499    return;
500  }
501  VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
502  LogSummary(log_level, model);
503  std::unordered_set<string> already_printed_arrays;
504  for (const auto& op : model.operators) {
505    for (const auto& input : op->inputs) {
506      if (!already_printed_arrays.count(input)) {
507        already_printed_arrays.insert(input);
508        LogArray(log_level, model, input);
509      }
510    }
511    VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
512    VLOG(log_level) << "  " << FormatArraysList(model, op->inputs) << " -> "
513                    << FormatArraysList(model, op->outputs);
514    if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
515      VLOG(log_level) << "    (with fused activation function)";
516    }
517    for (const auto& output : op->outputs) {
518      if (!already_printed_arrays.count(output)) {
519        already_printed_arrays.insert(output);
520        LogArray(log_level, model, output);
521      }
522    }
523  }
524  VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
525}
526
527// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
528void ExtendShape(Shape* shape, int new_shape_size) {
529  CHECK_GE(new_shape_size, shape->dimensions_count());
530  const int size_increase = new_shape_size - shape->dimensions_count();
531  auto* shape_dims = shape->mutable_dims();
532  shape_dims->insert(shape_dims->begin(), size_increase, 1);
533}
534
535// TODO(b/62904716) Remove along with remaining uses.
536void UnextendShape(Shape* shape, int new_shape_size) {
537  CHECK_LE(new_shape_size, shape->dimensions_count());
538  const int size_reduction = shape->dimensions_count() - new_shape_size;
539  for (int i = 0; i < size_reduction; i++) {
540    CHECK_EQ(shape->dims(i), 1);
541  }
542  std::vector<int>& shape_dims = *shape->mutable_dims();
543  shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
544}
545
546void CheckShapeDimensions(const Shape& shape) {
547  for (int i = 0; i < shape.dimensions_count(); ++i) {
548    CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
549                                 << ". shape = " << ShapeToString(shape);
550  }
551}
552
553bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
554  CheckShapeDimensions(shape0);
555  CheckShapeDimensions(shape1);
556
557  const Shape* longer = &shape0;
558  const Shape* shorter = &shape1;
559  if (shape1.dimensions_count() > shape0.dimensions_count()) {
560    longer = &shape1;
561    shorter = &shape0;
562  }
563
564  // Walk dimensions back to front until we run out of dimensions in the shorter
565  // shape.
566  int longer_index = longer->dimensions_count() - 1;
567  int shorter_index = shorter->dimensions_count() - 1;
568  while (shorter_index >= 0) {
569    const int d_long = longer->dims(longer_index);
570    const int d_short = shorter->dims(shorter_index);
571    // Broadcasting fails if the dimensions are different *and* neither is 1.
572    if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
573      return false;
574    }
575    longer_index--;
576    shorter_index--;
577  }
578  return true;
579}
580
581bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
582  CheckShapeDimensions(shape0);
583  CheckShapeDimensions(shape1);
584
585  const Shape* longer = &shape0;
586  const Shape* shorter = &shape1;
587  if (shape1.dimensions_count() > shape0.dimensions_count()) {
588    longer = &shape1;
589    shorter = &shape0;
590  }
591
592  // Walk dimensions back to front until we run out of dimensions in the shorter
593  // shape.
594  int longer_index = longer->dimensions_count() - 1;
595  int shorter_index = shorter->dimensions_count() - 1;
596  while (shorter_index >= 0) {
597    const int d_long = longer->dims(longer_index);
598    const int d_short = shorter->dims(shorter_index);
599    // Extending fails if the dimensions are different.
600    if (d_long != d_short) {
601      return false;
602    }
603    longer_index--;
604    shorter_index--;
605  }
606
607  // The remaining dimensions in the longer shape must be 1.
608  while (longer_index >= 0) {
609    const int d_long = longer->dims(longer_index);
610    if (d_long != 1) {
611      return false;
612    }
613    longer_index--;
614  }
615
616  return true;
617}
618
619int RequiredBufferSizeForShape(const Shape& shape) {
620  int max_offset = 1;
621  for (const auto& dim : shape.dims()) {
622    CHECK_GE(dim, 1);
623    max_offset *= dim;
624  }
625  return max_offset;
626}
627
628bool IsConstantParameterArray(const Model& model, const string& name) {
629  if (!model.HasArray(name)) {
630    return false;
631  }
632
633  return !!model.GetArray(name).buffer;
634}
635
636namespace {
637// Take an array name, which may be something like "name:3_5" and make it
638// acceptable as a TF node name, say "name_3_5";
639string SanitizeNameForTFNode(const string& array_name) {
640  auto node_name = array_name;
641  std::replace(node_name.begin(), node_name.end(), ':', '_');
642  return node_name;
643}
644
645void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
646  for (const auto& input_array : model_flags.input_arrays()) {
647    for (const string& output_array : model_flags.output_arrays()) {
648      QCHECK_NE(input_array.name(), output_array)
649          << "The array " << output_array
650          << " is listed in both --input_arrays and --output_arrays.";
651    }
652  }
653}
654
655bool IsAsciiPrintable(const string& name) {
656  for (char c : name) {
657    if (!absl::ascii_isprint(c)) {
658      return false;
659    }
660  }
661  return true;
662}
663
664string DumpAscii(const string& name) {
665  string result;
666  port::AppendF(&result, "ASCII | Hex\n");
667  port::AppendF(&result, "------+----\n");
668  for (char c : name) {
669    if (absl::ascii_isprint(c)) {
670      port::AppendF(&result, "%c     | %x\n", c, c);
671    } else {
672      port::AppendF(&result, "      | %x   Not ASCII printable!\n", c);
673    }
674  }
675  return result;
676}
677
678void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
679  if (model_flags.allow_nonascii_arrays()) {
680    return;
681  }
682  for (const auto& input_array : model_flags.input_arrays()) {
683    QCHECK(IsAsciiPrintable(input_array.name()))
684        << "Non-ASCII-printable character found in --input_arrays: "
685        << input_array.name()
686        << ". Pass --allow_nonascii_arrays to allow that. "
687        << "Here is a dump of the string:\n\n"
688        << DumpAscii(input_array.name());
689  }
690  for (const string& output_array : model_flags.output_arrays()) {
691    QCHECK(IsAsciiPrintable(output_array))
692        << "Non-ASCII-printable character found in --output_arrays: "
693        << output_array << ". Pass --allow_nonascii_arrays to allow that. "
694        << "Here is a dump of the string:\n\n"
695        << DumpAscii(output_array);
696  }
697}
698
699void CheckNonExistentIOArrays(const Model& model) {
700  if (model.flags.allow_nonexistent_arrays()) {
701    return;
702  }
703  for (const auto& input_array : model.flags.input_arrays()) {
704    CHECK(model.HasArray(input_array.name()))
705        << "Input array not found: " << input_array.name();
706  }
707  for (const string& output_array : model.flags.output_arrays()) {
708    CHECK(model.HasArray(output_array))
709        << "Output array not found: " << output_array;
710  }
711  for (const auto& rnn_state : model.flags.rnn_states()) {
712    if (!rnn_state.discardable()) {
713      CHECK(model.HasArray(rnn_state.state_array()));
714      CHECK(model.HasArray(rnn_state.back_edge_source_array()));
715    }
716  }
717}
718}  // namespace
719
720void CheckNoMissingArray(const Model& model) {
721  for (const auto& op : model.operators) {
722    for (const auto& input : op->inputs) {
723      CHECK(model.HasArray(input) || model.optional_arrays.count(input))
724          << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
725    }
726    for (const auto& output : op->outputs) {
727      CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
728    }
729  }
730  CheckNonExistentIOArrays(model);
731}
732
733void FixNoMissingArray(Model* model) {
734  for (const auto& op : model->operators) {
735    for (const auto& input : op->inputs) {
736      if (!model->HasArray(input)) {
737        model->GetOrCreateArray(input);
738      }
739    }
740    for (const auto& output : op->outputs) {
741      if (!model->HasArray(output)) {
742        model->GetOrCreateArray(output);
743      }
744    }
745  }
746  if (model->flags.allow_nonexistent_arrays()) {
747    for (const string& output_array : model->flags.output_arrays()) {
748      model->GetOrCreateArray(output_array);
749    }
750    for (const auto& rnn_state : model->flags.rnn_states()) {
751      model->GetOrCreateArray(rnn_state.state_array());
752      model->GetOrCreateArray(rnn_state.back_edge_source_array());
753    }
754  }
755}
756
757void CheckNoOrphanedArray(const Model& model) {
758  std::unordered_set<string> arrays_without_known_use;
759  for (const auto& array : model.GetArrayMap()) {
760    if (IsDiscardableArray(model, array.first)) {
761      arrays_without_known_use.insert(array.first);
762    }
763  }
764  for (const auto& op : model.operators) {
765    for (const auto& input : op->inputs) {
766      arrays_without_known_use.erase(input);
767    }
768    for (const auto& output : op->outputs) {
769      arrays_without_known_use.erase(output);
770    }
771  }
772  for (const auto& rnn_state : model.flags.rnn_states()) {
773    arrays_without_known_use.erase(rnn_state.state_array());
774    arrays_without_known_use.erase(rnn_state.back_edge_source_array());
775  }
776  if (!arrays_without_known_use.empty()) {
777    for (const auto& array : arrays_without_known_use) {
778      LOG(INFO) << "Error: Orphaned array: " << array;
779    }
780  }
781  CHECK(arrays_without_known_use.empty());
782}
783
784void FixNoOrphanedArray(Model* model) {
785  std::unordered_set<string> arrays_without_known_use;
786  for (const auto& array : model->GetArrayMap()) {
787    arrays_without_known_use.insert(array.first);
788  }
789  for (const auto& op : model->operators) {
790    for (const auto& input : op->inputs) {
791      arrays_without_known_use.erase(input);
792    }
793    for (const auto& output : op->outputs) {
794      arrays_without_known_use.erase(output);
795    }
796  }
797  for (const auto& rnn_state : model->flags.rnn_states()) {
798    arrays_without_known_use.erase(rnn_state.state_array());
799    arrays_without_known_use.erase(rnn_state.back_edge_source_array());
800  }
801  for (const auto& array : arrays_without_known_use) {
802    if (IsDiscardableArray(*model, array)) {
803      model->EraseArray(array);
804    }
805  }
806}
807
808// Apply checks to arrays individually (for-each fashion).
809//
810// Check consistency of array fields, check name.
811void CheckEachArray(const Model& model) {
812  for (const auto& array_entry : model.GetArrayMap()) {
813    const auto& array = array_entry.second;
814    if (array->has_shape()) {
815      for (int d : array->shape().dims()) {
816        CHECK_GE(d, 1);
817      }
818    }
819    // It's OK to have a buffer or an alloc, but not both.
820    // (Since allocs are for transient arrays without a buffer).
821    CHECK(!array->buffer || !array->alloc);
822    // If there is a buffer, its type should be consistent with data_type.
823    if (array->buffer) {
824      CHECK(array->buffer->type == array->data_type);
825    }
826
827    // Check name.  Either "name_with_suffix_8", "name_with_port:3", but not
828    // "name_with_both:3_8".
829    const string& name = array_entry.first;
830    auto colon_pos = name.find_first_of(":");
831    if (colon_pos != string::npos) {
832      CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
833               string::npos)
834          << "Array name must only have digits after colon";
835    }
836    CHECK_GT(colon_pos, 0)
837        << "First character of array name must not be a colon.";
838  }
839}
840
841void CheckOperatorOrdering(const Model& model) {
842  std::unordered_set<string> arrays_behind_us;
843  for (const auto& array_entry : model.GetArrayMap()) {
844    if (!GetOpWithOutput(model, array_entry.first)) {
845      arrays_behind_us.insert(array_entry.first);
846    }
847  }
848  arrays_behind_us.insert(model.optional_arrays.begin(),
849                          model.optional_arrays.end());
850  for (const auto& op : model.operators) {
851    for (const auto& input : op->inputs) {
852      if (!IsConstantParameterArray(model, input)) {
853        CHECK(arrays_behind_us.count(input));
854      }
855    }
856    for (const auto& output : op->outputs) {
857      CHECK(!arrays_behind_us.count(output));
858      arrays_behind_us.insert(output);
859    }
860  }
861  for (const string& output_array : model.flags.output_arrays()) {
862    CHECK(arrays_behind_us.count(output_array));
863  }
864}
865
866void FixOperatorOrdering(Model* model) {
867  std::unordered_set<string> arrays_behind_us;
868  for (const auto& array_entry : model->GetArrayMap()) {
869    if (!GetOpWithOutput(*model, array_entry.first)) {
870      arrays_behind_us.insert(array_entry.first);
871    }
872  }
873  arrays_behind_us.insert(model->optional_arrays.begin(),
874                          model->optional_arrays.end());
875  std::vector<std::unique_ptr<Operator>> old_operators;
876  std::swap(old_operators, model->operators);
877  std::set<std::size_t> remaining;
878  for (std::size_t i = 0; i < old_operators.size(); i++) {
879    remaining.insert(i);
880  }
881  std::unordered_map<string, string> reason_why_leftover;
882  while (true) {
883    bool inserted_something = false;
884    for (auto i : remaining) {
885      bool can_insert = true;
886      auto& op = old_operators[i];
887      CHECK(op.get());
888      for (const auto& input : op->inputs) {
889        if (!IsConstantParameterArray(*model, input) &&
890            !arrays_behind_us.count(input)) {
891          for (const string& output : op->outputs) {
892            reason_why_leftover[output] = input;
893          }
894          can_insert = false;
895          break;
896        }
897      }
898      if (can_insert) {
899        model->operators.emplace_back(nullptr);
900        for (const auto& output : op->outputs) {
901          arrays_behind_us.insert(output);
902        }
903        std::swap(op, model->operators.back());
904        remaining.erase(i);
905        inserted_something = true;
906        break;
907      }
908    }
909    if (!inserted_something) {
910      break;
911    }
912  }
913  if (!remaining.empty()) {
914    LOG(ERROR)
915        << "No viable ordering of operators was found. "
916        << "Here is a 'backtrace' of at least one part of the graph that is "
917        << "problematic. It starts with the first operator that has as "
918        << "problematic input array, and then walks back the graph to "
919        << "the operator that produced that input array, etc., until we find "
920        << "the root cause:";
921    LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
922    LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
923    const Operator* bad_op = old_operators[*remaining.begin()].get();
924    std::unordered_set<string> bad_inputs_already_traced;
925    // The following while(true) loop should always end with a LOG(FATAL).
926    while (true) {
927      LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
928                 << FormatArraysList(*model, bad_op->inputs) << " -> "
929                 << FormatArraysList(*model, bad_op->outputs);
930      bool found_bad_output = false;
931      string bad_output;
932      for (const string& output : bad_op->outputs) {
933        if (reason_why_leftover.count(output)) {
934          found_bad_output = true;
935          bad_output = output;
936          break;
937        }
938      }
939      CHECK(found_bad_output);
940      const string& bad_input = reason_why_leftover[bad_output];
941      LOG(ERROR) << "The bad input here is: " << bad_input;
942      if (bad_inputs_already_traced.count(bad_input)) {
943        LOG(FATAL)
944            << "Cycle found! We already encountered that "
945            << "input array, " << bad_input << ", earlier in the "
946            << "above trace! We expect graphs to be acyclic, even "
947            << "RNNs. Let us know if some graph actually needs to have "
948            << "cycles, but first, please check if it really is "
949            << "an *inference* graph. *Training* graphs are out-of-scope "
950            << "for toco.";
951      }
952      bad_inputs_already_traced.insert(bad_input);
953      bad_op = nullptr;
954      for (auto i : remaining) {
955        const Operator* op = old_operators[i].get();
956        for (const string& output : op->outputs) {
957          if (bad_input == output) {
958            bad_op = op;
959            break;
960          }
961        }
962        if (bad_op) {
963          break;
964        }
965      }
966      if (!bad_op) {
967        LOG(ERROR) << "And that's the root cause: "
968                   << "that array, " << bad_input << ", isn't produced by any "
969                   << "operator, or provided in any other way.";
970        LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
971        LOG(FATAL) << "(The above was a multi-line fatal error)";
972      }
973      LOG(ERROR) << "And that array is the output of the following operator:";
974    }
975  }
976  CHECK(remaining.empty())
977      << "Should never get here! In case of bad graph, "
978      << "the above code should have generated a FATAL error already!";
979}
980
981void CheckInvariants(const Model& model) {
982  CheckInputArraysAreNotOutputArrays(model.flags);
983  CheckNonAsciiIOArrays(model.flags);
984  CheckNoMissingArray(model);
985  CheckNoOrphanedArray(model);
986  CheckEachArray(model);
987  CheckOperatorOrdering(model);
988}
989
990void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
991                       const int count, const string& count_description) {
992  if (model_check.count_min() >= 0) {
993    CHECK_GE(count, model_check.count_min())
994        << "Mismatch in " << count_description << ": count  was " << count
995        << ", but the specified "
996        << (model_check.count_max() > model_check.count_min() ? "minimum"
997                                                              : "value")
998        << " was " << model_check.count_min() << ".";
999  }
1000  if (model_check.count_max() > model_check.count_min()) {
1001    CHECK_LE(count, model_check.count_max())
1002        << "Mismatch in " << count_description << ": count  was " << count
1003        << ", but the specified maximum was " << model_check.count_max() << ".";
1004  }
1005}
1006
1007void CheckModelCounts(const Model& model) {
1008  std::unordered_multiset<OperatorType> ops_by_type;
1009  std::unordered_map<string, OperatorType> op_type_by_name;
1010  if (model.flags.model_checks_size() == 0) {
1011    return;
1012  }
1013
1014  for (const auto& op : model.operators) {
1015    ops_by_type.insert(op->type);
1016    op_type_by_name[OperatorTypeName(op->type)] = op->type;
1017  }
1018  for (const auto& model_check : model.flags.model_checks()) {
1019    string count_type = model_check.count_type();
1020    if (count_type == "None") {
1021      continue;
1022    } else if (count_type == "Arrays") {
1023      CheckCountInRange(model_check, model.GetArrayMap().size(),
1024                        "count of arrays");
1025    } else if (count_type == "Total") {
1026      CheckCountInRange(model_check, model.operators.size(),
1027                        "count of all operator instances");
1028    } else {
1029      // The check type is not itself checked against the set of valid
1030      // operators, mainly because the enum set cannot be iterated in C++.
1031      const int found_count =
1032          op_type_by_name.count(count_type) > 0
1033              ? ops_by_type.count(op_type_by_name[count_type])
1034              : 0;
1035      CheckCountInRange(model_check, found_count,
1036                        "count of instances of " + count_type + " operator");
1037    }
1038  }
1039}
1040
1041void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1042                   std::vector<int>* out_dims) {
1043  CHECK(out_dims->empty());
1044  if (num_dims == 0) {
1045    return;
1046  } else if (num_dims == 1) {
1047    CHECK_EQ(batch, 1);
1048    *out_dims = {depth};
1049  } else if (num_dims == 2) {
1050    *out_dims = {batch, depth};
1051  } else if (num_dims == 3) {
1052    CHECK_EQ(batch, 1);
1053    *out_dims = {height, width, depth};
1054  } else if (num_dims == 4) {
1055    *out_dims = {batch, height, width, depth};
1056  } else {
1057    LOG(FATAL) << "Should not get here: " << num_dims;
1058  }
1059}
1060
1061void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
1062  int batch = 1;
1063  int num_dims = -1;
1064  for (const auto& input_array : model->flags.input_arrays()) {
1065    // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1066    // a better match by name.
1067    if (input_array.name() == name || num_dims == -1) {
1068      num_dims = input_array.shape().dims_size();
1069      if (num_dims > 0) {
1070        batch = input_array.shape().dims(0);
1071      }
1072    }
1073  }
1074  Array& array = model->GetOrCreateArray(name);
1075  if (array.has_shape()) {
1076    num_dims = array.shape().dimensions_count();
1077  }
1078  if (!array.has_shape() && num_dims >= 0) {
1079    Shape* shape = array.mutable_shape();
1080    std::vector<int> dims;
1081    MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1082    *shape->mutable_dims() = dims;
1083  }
1084}
1085
1086void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1087  // Merge info about input_arrays from model_flags into model->flags
1088  for (const auto& specified_input_array : model_flags.input_arrays()) {
1089    toco::InputArray* dst_input_array = nullptr;
1090    for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1091      toco::InputArray* candidate_dst_input_array =
1092          model->flags.mutable_input_arrays(i);
1093      if (candidate_dst_input_array->name() == specified_input_array.name()) {
1094        // specified_input_array from model_flags maps to dst_input_array
1095        // in model->flags
1096        dst_input_array = candidate_dst_input_array;
1097        break;
1098      }
1099    }
1100    if (!dst_input_array) {
1101      // Specified_input_array from model_flags is not found in model->flags.
1102      // Match a name-less specified input array when there can be no ambiguity
1103      // as there is only 1 input array.
1104      if (model->flags.input_arrays_size() == 1 &&
1105          model_flags.input_arrays_size() == 1 &&
1106          !specified_input_array.has_name()) {
1107        dst_input_array = model->flags.mutable_input_arrays(0);
1108      }
1109    }
1110    if (!dst_input_array) {
1111      // Still no match, so create a new input array to copy
1112      // specified_input_array into.
1113      dst_input_array = model->flags.add_input_arrays();
1114      dst_input_array->set_name(specified_input_array.name());
1115    }
1116
1117#define RESOLVE_MODEL_FLAG(field_name)                                       \
1118  if (specified_input_array.has_##field_name()) {                            \
1119    if (dst_input_array->has_##field_name()) {                               \
1120      QCHECK_EQ(dst_input_array->field_name(),                               \
1121                specified_input_array.field_name())                          \
1122          << "For input array '" << dst_input_array->name() << "', "         \
1123          << "specified " #field_name " flag with value: "                   \
1124          << specified_input_array.field_name()                              \
1125          << " does not agree with already defined " #field_name             \
1126             " of this model, with value: "                                  \
1127          << specified_input_array.field_name();                             \
1128    } else {                                                                 \
1129      dst_input_array->set_##field_name(specified_input_array.field_name()); \
1130    }                                                                        \
1131  }
1132    RESOLVE_MODEL_FLAG(std_value);
1133    RESOLVE_MODEL_FLAG(mean_value);
1134#undef RESOLVE_MODEL_FLAG
1135
1136    if (specified_input_array.has_shape()) {
1137      if (dst_input_array->has_shape()) {
1138        QCHECK_EQ(specified_input_array.shape().dims_size(),
1139                  dst_input_array->shape().dims_size())
1140            << "For input array '" << specified_input_array.name() << "', "
1141            << "size of specified input shape flag with size: "
1142            << specified_input_array.shape().dims_size()
1143            << " does not agree with already defined input shape"
1144               " of this model, with size: "
1145            << dst_input_array->shape().dims_size();
1146        // We treat the first dimension as a special case, since it is often
1147        // a batch size and the input_shape flag is effectively overriding
1148        // the model.
1149        for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1150          QCHECK_EQ(specified_input_array.shape().dims(i),
1151                    dst_input_array->shape().dims(i))
1152              << "At dimension number " << i << " of input array "
1153              << specified_input_array.name() << ", the specified shape's "
1154              << "dimension flag with dimension: "
1155              << specified_input_array.shape().dims(i)
1156              << " does not agree with already defined shape"
1157              << " of this model, with dimension: "
1158              << dst_input_array->shape().dims(i);
1159        }
1160      } else {
1161        *dst_input_array->mutable_shape() = specified_input_array.shape();
1162      }
1163    }
1164
1165    if (specified_input_array.has_data_type()) {
1166      QCHECK(!dst_input_array->has_data_type());
1167      dst_input_array->set_data_type(specified_input_array.data_type());
1168    }
1169  }
1170
1171  if (model_flags.output_arrays_size() > 0) {
1172    model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1173  }
1174
1175#define RESOLVE_MODEL_FLAG(name)                                           \
1176  if (model_flags.has_##name()) {                                          \
1177    if (model->flags.has_##name()) {                                       \
1178      QCHECK_EQ(model_flags.name(), model->flags.name())                   \
1179          << "Specified " #name " flag with value: " << model_flags.name() \
1180          << " does not agree with already defined " #name                 \
1181             " of this model, with value: "                                \
1182          << model->flags.name();                                          \
1183    } else {                                                               \
1184      model->flags.set_##name(model_flags.name());                         \
1185    }                                                                      \
1186  }
1187
1188  RESOLVE_MODEL_FLAG(variable_batch)
1189
1190#undef RESOLVE_MODEL_FLAG
1191
1192  if (!model_flags.rnn_states().empty()) {
1193    model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1194  }
1195
1196  if (model->flags.model_checks_size() == 0) {
1197    model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1198  }
1199
1200  QCHECK_GT(model->flags.output_arrays_size(), 0)
1201      << "This model does not define output arrays, so a "
1202         "--output_arrays flag must be given on the command-line.";
1203
1204  for (const auto& input_array_proto : model->flags.input_arrays()) {
1205    auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1206    if (input_array_proto.has_data_type()) {
1207      const ArrayDataType specified_type =
1208          ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1209      QCHECK(specified_type != ArrayDataType::kNone);
1210      if (input_array.data_type != ArrayDataType::kNone) {
1211        QCHECK(specified_type == input_array.data_type)
1212            << "For input array " << input_array_proto.name()
1213            << " the specified input data type "
1214            << IODataType_Name(input_array_proto.data_type())
1215            << " conflicts with the existing type.";
1216      }
1217      input_array.data_type = specified_type;
1218    }
1219
1220    if (input_array.data_type == ArrayDataType::kNone) {
1221      // We start out with a float input array;
1222      // that may get replaced by a uint8 array later, by
1223      // MakeInitialDequantizeOp.
1224      input_array.data_type = ArrayDataType::kFloat;
1225    }
1226
1227    // Compare/merge the model->flags describing the input_shape with
1228    // the actual input array's shape.
1229    if (!input_array.has_shape()) {
1230      if (input_array_proto.has_shape()) {
1231        auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1232        for (auto dim : input_array_proto.shape().dims()) {
1233          CHECK_GE(dim, 1);
1234          input_array_dims.push_back(dim);
1235        }
1236      }
1237    } else {
1238      if (input_array_proto.has_shape()) {
1239        // If an input shape was specified on the flags ensure that it matches
1240        // the actual shape in the model.
1241        const auto& input_array_dims =
1242            *input_array.mutable_shape()->mutable_dims();
1243        CHECK_EQ(input_array_dims.size(),
1244                 input_array_proto.shape().dims_size());
1245        for (int i = 0; i < input_array_dims.size(); i++) {
1246          CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1247        }
1248      }
1249    }
1250
1251    const float mean_value = input_array_proto.mean_value();
1252    const float std_value = input_array_proto.std_value();
1253    MinMax input_minmax;
1254    input_minmax.min = (0.f - mean_value) / std_value;
1255    input_minmax.max = (255.f - mean_value) / std_value;
1256    if (input_array.minmax) {
1257      if (input_array_proto.has_mean_value() ||
1258          input_array_proto.has_std_value()) {
1259        CHECK(input_minmax == *input_array.minmax)
1260            << input_minmax.min << ", " << input_minmax.max
1261            << " != " << input_array.minmax->min << ", "
1262            << input_array.minmax->max;
1263      }
1264    } else {
1265      input_array.GetOrCreateMinMax() = input_minmax;
1266    }
1267  }
1268  // Creation of the RNN state arrays
1269  for (const auto& rnn_state : model->flags.rnn_states()) {
1270    CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1271                               model);
1272  }
1273
1274  for (const auto& input_array : model->flags.input_arrays()) {
1275    if (input_array.has_shape()) {
1276      CHECK(input_array.shape().dims_size());
1277    }
1278  }
1279
1280  model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1281  model->flags.set_allow_nonexistent_arrays(
1282      model_flags.allow_nonexistent_arrays());
1283
1284  CHECK(!model->flags.has_arrays_extra_info());
1285  *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1286}
1287
1288void CheckIsReadyForQuantization(const Model& model) {
1289  for (const auto& op : model.operators) {
1290    for (const auto& input : op->inputs) {
1291      const auto& input_array = model.GetArray(input);
1292      if (input_array.data_type != ArrayDataType::kFloat) {
1293        // The array is not floats, no quantization needed.
1294        continue;
1295      }
1296      if (input_array.minmax) {
1297        // The array has minmax, we're good.
1298        continue;
1299      }
1300      if (input_array.buffer) {
1301        // The array has a constant buffer, so we can
1302        // fall back to computing the minmax from actual array entries
1303        // (with a WARNING about possible accuracy implications).
1304        continue;
1305      }
1306      LOG(FATAL)
1307          << "Array " << input << ", which is an input to the "
1308          << HelpfulOperatorTypeName(*op) << " operator producing the output "
1309          << "array " << op->outputs[0] << ", is lacking min/max data, "
1310          << "which is necessary for quantization. Either target a "
1311          << "non-quantized output format, or change the input graph to "
1312          << "contain min/max information, or pass --default_ranges_min= and "
1313          << "--default_ranges_max= if you do not care about the accuracy of "
1314          << "results.";
1315    }
1316  }
1317}
1318
1319void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
1320                                 double default_ranges_max) {
1321  for (const auto& op : model->operators) {
1322    for (const auto& input : op->inputs) {
1323      auto& input_array = model->GetArray(input);
1324      if (!input_array.minmax && !input_array.buffer) {
1325        auto& minmax = input_array.GetOrCreateMinMax();
1326        minmax.min = default_ranges_min;
1327        minmax.max = default_ranges_max;
1328      }
1329    }
1330    for (const auto& output : op->outputs) {
1331      auto& output_array = model->GetArray(output);
1332      if (!output_array.minmax && !output_array.buffer) {
1333        auto& minmax = output_array.GetOrCreateMinMax();
1334        minmax.min = default_ranges_min;
1335        minmax.max = default_ranges_max;
1336      }
1337    }
1338  }
1339}
1340
1341int ElementSize(ArrayDataType data_type) {
1342  switch (data_type) {
1343    case ArrayDataType::kFloat:
1344      return 4;
1345    case ArrayDataType::kInt8:
1346      return 1;
1347    case ArrayDataType::kUint8:
1348      return 1;
1349    case ArrayDataType::kInt16:
1350      return 2;
1351    case ArrayDataType::kUint16:
1352      return 2;
1353    case ArrayDataType::kInt32:
1354      return 4;
1355    case ArrayDataType::kUint32:
1356      return 4;
1357    case ArrayDataType::kInt64:
1358      return 8;
1359    case ArrayDataType::kUint64:
1360      return 8;
1361
1362    // Usually not critical limitation because strings are only input and/or
1363    // output.
1364    case ArrayDataType::kString:
1365      LOG(FATAL) << "Transient arrays with strings are not supported yet";
1366      return 0;
1367    default:
1368      LOG(FATAL) << "Should not get here.";
1369      return 0;
1370  }
1371}
1372
1373void DropMinMax(Model* model, const string& array_name) {
1374  auto& array = model->GetArray(array_name);
1375  if (!!array.minmax) {
1376    LOG(WARNING) << "Dropping MinMax information in array " << array_name
1377                 << ". Expect inaccuracy in quantized inference.";
1378    array.minmax = nullptr;
1379  }
1380}
1381
1382bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
1383  // Optional array is not transient
1384  if (model.IsOptionalArray(array_name)) return false;
1385  // The model's input and output arrays are externally allocated.
1386  // They are not transient arrays.
1387  if (IsInputArray(model, array_name)) {
1388    return false;
1389  }
1390  for (const string& output_array : model.flags.output_arrays()) {
1391    if (array_name == output_array) {
1392      return false;
1393    }
1394  }
1395  const auto& array = &model.GetArray(array_name);
1396  // An array with a constant buffer isn't a transient array.
1397  if (!!array->buffer) {
1398    return false;
1399  }
1400  // An array without shape isn't allocatable.
1401  if (!array->has_shape()) {
1402    return false;
1403  }
1404  return true;
1405}
1406
1407string AvailableArrayName(const Model& model, const string& name) {
1408  string sanitized_name = SanitizeNameForTFNode(name);
1409  if (!model.HasArray(sanitized_name) &&
1410      !model.IsOptionalArray(sanitized_name)) {
1411    return sanitized_name;
1412  }
1413  const int kNumSuffixesToTry = 1000;
1414  for (int i = 0; i < kNumSuffixesToTry; i++) {
1415    const string& name_with_suffix =
1416        toco::port::StringF("%s_%d", sanitized_name, i);
1417    if (!model.HasArray(name_with_suffix) &&
1418        !model.IsOptionalArray(name_with_suffix)) {
1419      return name_with_suffix;
1420    }
1421  }
1422  LOG(FATAL) << "Could not find an available array name starting with "
1423             << sanitized_name << ". Tried " << kNumSuffixesToTry
1424             << " suffixes, all were taken!";
1425  return "";
1426}
1427
1428string ShapeToString(const Shape& shape) {
1429  if (shape.dimensions_count() == 0) {
1430    return "[]";
1431  }
1432
1433  return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1434}
1435
1436void PrintArrayShape(Model* model, const string& name) {
1437  if (!model->GetArray(name).has_shape()) {
1438    LOG(INFO) << name << " has no shape";
1439    return;
1440  }
1441  LOG(INFO) << name
1442            << " has shape: " << ShapeToString(model->GetArray(name).shape());
1443}
1444
1445bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
1446  bool is_fc_weights = false;
1447  bool is_something_else = false;
1448  for (const auto& op : model.operators) {
1449    for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1450      if (op->inputs[input_index] == name) {
1451        if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1452          is_fc_weights = true;
1453        } else {
1454          is_something_else = true;
1455        }
1456      }
1457    }
1458  }
1459  CHECK(!(is_fc_weights && is_something_else));
1460  return is_fc_weights;
1461}
1462
1463string CreateInt32Array(Model* model, const string& param_name,
1464                        const std::vector<int>& value) {
1465  auto param_array_name = AvailableArrayName(*model, param_name);
1466  auto& param_array = model->GetOrCreateArray(param_array_name);
1467  param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1468  param_array.data_type = ArrayDataType::kInt32;
1469  auto& param_array_data =
1470      param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1471  param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1472  for (int i = 0; i < value.size(); ++i) {
1473    param_array_data[i] = value[i];
1474  }
1475  return param_array_name;
1476}
1477
1478bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
1479  int64 total = 0;
1480  for (const auto& op : model.operators) {
1481    switch (op->type) {
1482      case OperatorType::kFullyConnected:
1483      case OperatorType::kConv:
1484      case OperatorType::kDepthwiseConv: {
1485        const auto& output_array = model.GetArray(op->outputs[0]);
1486        const auto& weights_array = model.GetArray(op->inputs[1]);
1487        if (!output_array.has_shape() || !weights_array.has_shape()) {
1488          return false;
1489        }
1490        int cols = 1;
1491        for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1492          cols *= output_array.shape().dims(i);
1493        }
1494        const int64 cost_per_col =
1495            2 * RequiredBufferSizeForShape(weights_array.shape());
1496        total += cost_per_col * cols;
1497        if (op->inputs.size() > 2) {
1498          // There is a bias vector. One more op per output value.
1499          total += RequiredBufferSizeForShape(output_array.shape());
1500        }
1501        break;
1502      }
1503      case OperatorType::kAdd:
1504      case OperatorType::kSub:
1505      case OperatorType::kMul: {
1506        const auto& output_array = model.GetArray(op->outputs[0]);
1507        if (!output_array.has_shape()) {
1508          return false;
1509        }
1510        total += RequiredBufferSizeForShape(output_array.shape());
1511        break;
1512      }
1513      case OperatorType::kAddN: {
1514        const auto& output_array = model.GetArray(op->outputs[0]);
1515        if (!output_array.has_shape()) {
1516          return false;
1517        }
1518        // AddN cost is roughly the same cost as N-1 Adds.
1519        const int num_adds = op->inputs.size() - 1;
1520        total += num_adds * RequiredBufferSizeForShape(output_array.shape());
1521        break;
1522      }
1523      case OperatorType::kLogistic:
1524      case OperatorType::kSoftmax:
1525      case OperatorType::kLogSoftmax:
1526      case OperatorType::kTanh: {
1527        const auto& output_array = model.GetArray(op->outputs[0]);
1528        if (!output_array.has_shape()) {
1529          return false;
1530        }
1531        // As a very rough ballpark, the cost of evaluating a math function
1532        // such as tanh or logistic is about 32 multiplications, and about as
1533        // many additions/subtractions. (Just a power-of-two order-of-magnitude
1534        // from looking at actual implementations that we use in runtime/ code).
1535        total += 64 * RequiredBufferSizeForShape(output_array.shape());
1536        break;
1537      }
1538      case OperatorType::kMaxPool: {
1539        const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get());
1540        const auto& output_array = model.GetArray(op->outputs[0]);
1541        if (!output_array.has_shape()) {
1542          return false;
1543        }
1544        total += RequiredBufferSizeForShape(output_array.shape()) *
1545                 maxpool.kheight * maxpool.kwidth;
1546        break;
1547      }
1548      case OperatorType::kAveragePool: {
1549        const auto& avgpool =
1550            *static_cast<const AveragePoolOperator*>(op.get());
1551        const auto& output_array = model.GetArray(op->outputs[0]);
1552        if (!output_array.has_shape()) {
1553          return false;
1554        }
1555        total += RequiredBufferSizeForShape(output_array.shape()) *
1556                 avgpool.kheight * avgpool.kwidth;
1557        break;
1558      }
1559      case OperatorType::kL2Pool: {
1560        const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get());
1561        const auto& output_array = model.GetArray(op->outputs[0]);
1562        if (!output_array.has_shape()) {
1563          return false;
1564        }
1565        // The sum of squares requires (kheight*kwidth) multiply-adds,
1566        // and then there is the sqrt which we ballpark at 32 ops.
1567        const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
1568        total +=
1569            RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
1570        break;
1571      }
1572      case OperatorType::kL2Normalization: {
1573        const auto& output_array = model.GetArray(op->outputs[0]);
1574        if (!output_array.has_shape()) {
1575          return false;
1576        }
1577        // Computing the squared L2 norm is N multiply-adds so 2N ops,
1578        // then the single inverse-sqrt is negligible, then we multiply each
1579        // value by the resulting multiplier, so an extra N ops. Total 3N ops.
1580        total += 3 * RequiredBufferSizeForShape(output_array.shape());
1581        break;
1582      }
1583      default:
1584        break;
1585    }
1586  }
1587  *result = total;
1588  return true;
1589}
1590
1591void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
1592                     std::vector<int>* shuffle) {
1593  CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
1594  shuffle->resize(4);
1595  for (int i = 0; i < 4; i++) {
1596    (*shuffle)[i] = i;
1597  }
1598  if (input_axes_order == output_axes_order) {
1599    // nothing to do
1600  } else if (AxesCount(input_axes_order) == 2) {
1601    shuffle->resize(2);
1602    (*shuffle)[0] = 1;
1603    (*shuffle)[1] = 0;
1604  } else if (input_axes_order == AxesOrder::kOHWI &&
1605             output_axes_order == AxesOrder::kHWIO) {
1606    // 3210 <- 3210
1607    // HWIO <- OHWI
1608    (*shuffle)[0] = 1;
1609    (*shuffle)[1] = 2;
1610    (*shuffle)[2] = 3;
1611    (*shuffle)[3] = 0;
1612  } else if (input_axes_order == AxesOrder::kHWIO &&
1613             output_axes_order == AxesOrder::kOHWI) {
1614    // 3210 <- 3210
1615    // OHWI <- HWIO
1616    (*shuffle)[0] = 3;
1617    (*shuffle)[1] = 0;
1618    (*shuffle)[2] = 1;
1619    (*shuffle)[3] = 2;
1620  } else {
1621    LOG(FATAL) << "Bad shuffle";
1622  }
1623}
1624
1625void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
1626                   std::vector<int>* extended_shuffle) {
1627  *extended_shuffle = input_shuffle;
1628  CHECK(newdim >= input_shuffle.size());
1629  const int pad_size = newdim - input_shuffle.size();
1630  extended_shuffle->resize(newdim);
1631  for (int i = 0; i < pad_size; i++) {
1632    (*extended_shuffle)[i] = i;
1633  }
1634  for (int i = pad_size; i < newdim; i++) {
1635    (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
1636  }
1637}
1638
1639void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
1640                 AxesOrder output_axes_order, Shape* output_shape) {
1641  if (input_axes_order == AxesOrder::kHWIM &&
1642      output_axes_order == AxesOrder::k1HWO) {
1643    // This special case isn't just a permutation, the IM pair of dims get
1644    // merged into the 3 dim, so we have to special-case it.
1645    *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
1646                           input_shape.dims(3) * input_shape.dims(2)});
1647  } else {
1648    std::vector<int> shuffle;
1649    GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
1650    std::vector<int>* output_dims = output_shape->mutable_dims();
1651    output_dims->resize(input_shape.dimensions_count());
1652    for (int i = 0; i < input_shape.dimensions_count(); i++) {
1653      (*output_dims)[i] = input_shape.dims(shuffle[i]);
1654    }
1655  }
1656}
1657
1658template <typename T>
1659void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
1660                          AxesOrder output_axes_order,
1661                          const Shape& output_shape, const T* input_data,
1662                          T* output_data) {
1663  if (input_axes_order == AxesOrder::kHWIM &&
1664      output_axes_order == AxesOrder::k1HWO) {
1665    // This special case isn't just a permutation, the IM pair of dims get
1666    // merged into the O dim, so we have to special-case it. Fortunately,
1667    // as far as array shuffling is concerned, it's just the identity
1668    // transformation.
1669    memcpy(output_data, input_data,
1670           RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
1671    return;
1672  }
1673  CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
1674  const int dim = input_shape.dimensions_count();
1675  CHECK_LE(dim, 4);
1676  std::vector<int> shuffle;
1677  GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
1678  CHECK(shuffle.size() >= dim);
1679  for (int i = 0; i < dim; i++) {
1680    CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
1681    CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
1682  }
1683  Shape extended_input_shape = input_shape;
1684  ExtendShape(&extended_input_shape, 4);
1685  Shape extended_output_shape = output_shape;
1686  ExtendShape(&extended_output_shape, 4);
1687  std::vector<int> extended_shuffle;
1688  ExtendShuffle(shuffle, 4, &extended_shuffle);
1689
1690  const std::vector<int>& extended_input_dims = extended_input_shape.dims();
1691  const std::vector<int>& extended_output_dims = extended_output_shape.dims();
1692
1693  // TODO(starka): Rework to handle different numbers of dimensions.
1694  int input_strides[4];
1695  input_strides[3] = 1;
1696  input_strides[2] = extended_input_dims[3];
1697  input_strides[1] = input_strides[2] * extended_input_dims[2];
1698  input_strides[0] = input_strides[1] * extended_input_dims[1];
1699  const int input_stride_0 = input_strides[extended_shuffle[3]];
1700  const int input_stride_1 = input_strides[extended_shuffle[2]];
1701  const int input_stride_2 = input_strides[extended_shuffle[1]];
1702  const int input_stride_3 = input_strides[extended_shuffle[0]];
1703
1704  const int output_size_0 = extended_output_dims[3];
1705  const int output_size_1 = extended_output_dims[2];
1706  const int output_size_2 = extended_output_dims[1];
1707  const int output_size_3 = extended_output_dims[0];
1708  const int output_stride_0 = 1;
1709  const int output_stride_1 = output_size_0;
1710  const int output_stride_2 = output_stride_1 * output_size_1;
1711  const int output_stride_3 = output_stride_2 * output_size_2;
1712
1713  for (int i3 = 0; i3 < output_size_3; i3++) {
1714    const T* const input_ptr_3 = input_data + i3 * input_stride_3;
1715    T* const output_ptr_3 = output_data + i3 * output_stride_3;
1716    for (int i2 = 0; i2 < output_size_2; i2++) {
1717      const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
1718      T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
1719      for (int i1 = 0; i1 < output_size_1; i1++) {
1720        const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
1721        T* output_ptr = output_ptr_2 + i1 * output_stride_1;
1722        T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
1723        while (output_ptr != output_ptr_end) {
1724          *output_ptr = *input_ptr;
1725          input_ptr += input_stride_0;
1726          output_ptr += output_stride_0;
1727        }
1728      }
1729    }
1730  }
1731}
1732
1733void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
1734                  AxesOrder output_axes_order, const Shape& output_shape,
1735                  const uint8* input_data, uint8* output_data) {
1736  ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
1737                              output_shape, input_data, output_data);
1738}
1739
1740void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
1741                  AxesOrder output_axes_order, const Shape& output_shape,
1742                  const float* input_data, float* output_data) {
1743  ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
1744                              output_shape, input_data, output_data);
1745}
1746
1747int AxesCount(AxesOrder axes_order) {
1748  switch (axes_order) {
1749    case AxesOrder::kOneAxis:
1750      return 1;
1751    case AxesOrder::kRC:
1752      return 2;
1753    case AxesOrder::kCR:
1754      return 2;
1755    case AxesOrder::kHWIO:
1756      return 4;
1757    case AxesOrder::kOHWI:
1758      return 4;
1759    case AxesOrder::kHWIM:
1760      return 4;
1761    case AxesOrder::k1HWO:
1762      return 4;
1763    case AxesOrder::kNHWC:
1764      return 4;
1765    default:
1766      LOG(FATAL) << "Bad AxesOrder";
1767      return 0;
1768  }
1769}
1770
1771bool IsDiscardableArray(const Model& model, const string& array_name) {
1772  for (const auto& input_array : model.flags.input_arrays()) {
1773    if (array_name == input_array.name()) {
1774      return false;
1775    }
1776  }
1777  for (const string& output_array : model.flags.output_arrays()) {
1778    if (array_name == output_array) {
1779      return false;
1780    }
1781  }
1782  for (const auto& rnn_state : model.flags.rnn_states()) {
1783    if (!rnn_state.discardable()) {
1784      if (array_name == rnn_state.state_array()) {
1785        return false;
1786      }
1787      if (array_name == rnn_state.back_edge_source_array()) {
1788        return false;
1789      }
1790    }
1791  }
1792  return true;
1793}
1794
1795void CheckFinalDataTypesSatisfied(const Model& model) {
1796  for (const auto& array_entry : model.GetArrayMap()) {
1797    const auto& array = *array_entry.second;
1798    if (array.final_data_type != ArrayDataType::kNone) {
1799      CHECK(array.final_data_type == array.data_type)
1800          << "Array \"" << array_entry.first
1801          << "\" has mis-matching actual and final data types ("
1802          << static_cast<int>(array.data_type) << ","
1803          << static_cast<int>(array.final_data_type) << ").";
1804    }
1805  }
1806}
1807
1808ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
1809  switch (type) {
1810    case FLOAT:
1811      return ArrayDataType::kFloat;
1812    case QUANTIZED_UINT8:
1813      return ArrayDataType::kUint8;
1814    case INT32:
1815      return ArrayDataType::kInt32;
1816    case INT64:
1817      return ArrayDataType::kInt64;
1818    default:
1819      return ArrayDataType::kNone;
1820  }
1821}
1822
1823void FinishBuildingRNNStates(Model* model) {
1824  for (const auto& rnn_state : model->flags.rnn_states()) {
1825    if (!model->HasArray(rnn_state.back_edge_source_array()) ||
1826        !model->HasArray(rnn_state.state_array())) {
1827      CHECK(model->HasArray(rnn_state.back_edge_source_array()));
1828      CHECK(model->HasArray(rnn_state.state_array()));
1829      continue;
1830    }
1831    const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
1832    auto& dst_array = model->GetArray(rnn_state.state_array());
1833    if (src_array.data_type == ArrayDataType::kNone &&
1834        dst_array.data_type == ArrayDataType::kNone) {
1835      dst_array.data_type = ArrayDataType::kFloat;
1836    }
1837  }
1838}
1839
1840void UseArraysExtraInfo(Model* model) {
1841  for (const auto& entry : model->flags.arrays_extra_info().entries()) {
1842    QCHECK(model->HasArray(entry.name()))
1843        << "ArraysExtraInfo refers to non-existent array name: "
1844        << entry.name();
1845    auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
1846    minmax.min = entry.min();
1847    minmax.max = entry.max();
1848  }
1849}
1850
1851}  // namespace toco
1852