1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <fcntl.h>
16#include <stdint.h>
17#include <stdio.h>
18#include <stdlib.h>
19#include <sys/mman.h>
20#include <sys/stat.h>
21#include <sys/types.h>
22#include <unistd.h>
23
24#include "tensorflow/contrib/lite/allocation.h"
25#include "tensorflow/contrib/lite/builtin_op_data.h"
26#include "tensorflow/contrib/lite/error_reporter.h"
27#include "tensorflow/contrib/lite/model.h"
28#include "tensorflow/contrib/lite/nnapi_delegate.h"
29#include "tensorflow/contrib/lite/version.h"
30
31namespace tflite {
32
33const char* kEmptyTensorName = "";
34
35std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
36    const char* filename, ErrorReporter* error_reporter) {
37  std::unique_ptr<FlatBufferModel> model;
38  model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
39                                  /*use_nnapi=*/true));
40  if (!model->initialized()) model.reset();
41  return model;
42}
43
44std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
45    const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
46  std::unique_ptr<FlatBufferModel> model;
47  model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
48  if (!model->initialized()) model.reset();
49  return model;
50}
51
52std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
53    const tflite::Model* model_spec, ErrorReporter* error_reporter) {
54  std::unique_ptr<FlatBufferModel> model;
55  model.reset(new FlatBufferModel(model_spec, error_reporter));
56  if (!model->initialized()) model.reset();
57  return model;
58}
59
60FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
61                                 ErrorReporter* error_reporter, bool use_nnapi)
62    : error_reporter_(error_reporter ? error_reporter
63                                     : DefaultErrorReporter()) {
64  if (mmap_file) {
65    if (use_nnapi && NNAPIExists())
66      allocation_ = new NNAPIAllocation(filename, error_reporter);
67    else
68      allocation_ = new MMAPAllocation(filename, error_reporter);
69  } else {
70    allocation_ = new FileCopyAllocation(filename, error_reporter);
71  }
72  if (!allocation_->valid() || !CheckModelIdentifier()) return;
73
74  model_ = ::tflite::GetModel(allocation_->base());
75}
76
77bool FlatBufferModel::CheckModelIdentifier() const {
78  if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
79    const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
80    error_reporter_->Report(
81        "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
82        ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
83    return false;
84  }
85  return true;
86}
87
88FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
89                                 ErrorReporter* error_reporter)
90    : error_reporter_(error_reporter ? error_reporter
91                                     : DefaultErrorReporter()) {
92  allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
93  if (!allocation_->valid()) return;
94
95  model_ = ::tflite::GetModel(allocation_->base());
96}
97
98FlatBufferModel::FlatBufferModel(const Model* model,
99                                 ErrorReporter* error_reporter)
100    : error_reporter_(error_reporter ? error_reporter
101                                     : DefaultErrorReporter()) {
102  model_ = model;
103}
104
105FlatBufferModel::~FlatBufferModel() { delete allocation_; }
106
107InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
108                                       const OpResolver& op_resolver)
109    : model_(model.GetModel()),
110      op_resolver_(op_resolver),
111      error_reporter_(model.error_reporter()),
112      allocation_(model.allocation()) {}
113
114InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
115                                       const OpResolver& op_resolver,
116                                       ErrorReporter* error_reporter)
117    : model_(model),
118      op_resolver_(op_resolver),
119      error_reporter_(error_reporter ? error_reporter
120                                     : DefaultErrorReporter()) {}
121
122TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
123  TfLiteStatus status = kTfLiteOk;
124  auto opcodes = model_->operator_codes();
125  for (const OperatorCode* opcode : *opcodes) {
126    TfLiteRegistration* registration = nullptr;
127
128    if (opcode->builtin_code() != BuiltinOperator_CUSTOM) {
129      auto x = opcode->builtin_code();
130      flatbuffer_op_index_to_registration_types_.push_back(x);
131      registration = op_resolver_.FindOp(x);
132      if (registration == nullptr) {
133        error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
134                                EnumNameBuiltinOperator(x));
135        status = kTfLiteError;
136      }
137    } else if (!opcode->custom_code()) {
138      error_reporter_->Report(
139          "Operator with CUSTOM builtin_code has no custom_code.\n");
140      status = kTfLiteError;
141    } else {
142      const char* name = opcode->custom_code()->c_str();
143      registration = op_resolver_.FindOp(name);
144      flatbuffer_op_index_to_registration_types_.push_back(
145          BuiltinOperator_CUSTOM);
146      if (registration == nullptr) {
147        error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
148        status = kTfLiteError;
149      }
150    }
151    flatbuffer_op_index_to_registration_.push_back(registration);
152  }
153  return status;
154}
155
156namespace {
157template <class T>
158std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
159  std::vector<int> ret(flat_array->Length());
160  for (int i = 0; i < flat_array->Length(); i++) {
161    ret[i] = flat_array->Get(i);
162  }
163  return ret;
164}
165
166// Copies the contents from the flatbuffer int vector `flatbuffer` into the
167// int array `buffer`. `flat_vector` and `buffer` represent the same
168// configuration operation for a given operation.
169void FlatBufferIntVectorToArray(int max_size_of_buffer,
170                                const flatbuffers::Vector<int32_t>* flat_vector,
171                                int* buffer, ErrorReporter* error_reporter) {
172  if (!flat_vector) {
173    error_reporter->Report("Input array not provided for operation.\n");
174  } else {
175    int num_dimensions = flat_vector->Length();
176    if (num_dimensions > max_size_of_buffer / sizeof(int)) {
177      error_reporter->Report(
178          "Found too many dimensions in the operation's input array.\n");
179    } else {
180      for (int i = 0; i < num_dimensions; ++i) {
181        buffer[i] = flat_vector->Get(i);
182      }
183    }
184  }
185}
186
187// Allocate a structure using C malloc, but make sure the structure is a
188// POD structure that doesn't require constructors to run. The reason we do
189// this, is that Interpreter's C extension part will take ownership and wants
190// to use malloc() and free().
191template <class T>
192T* MallocPOD() {
193  static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
194  return static_cast<T*>(malloc(sizeof(T)));
195}
196
197// Parse the appropriate data out of the op.
198//
199// This handles builtin data explicitly as there are flatbuffer schemas.
200//
201// Returns memory that must be feed.
202//
203// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program
204// crashes if error reporter is called.
205void* ParseOpData(const Operator* op, BuiltinOperator op_type,
206                  ErrorReporter* error_reporter) {
207  auto parse_padding = [](Padding padding) {
208    switch (padding) {
209      case Padding_SAME:
210        return kTfLitePaddingSame;
211      case Padding_VALID:
212        return kTfLitePaddingValid;
213    }
214    return kTfLitePaddingUnknown;
215  };
216  auto parse_activation = [](ActivationFunctionType activation) {
217    switch (activation) {
218      case ActivationFunctionType_NONE:
219        return kTfLiteActNone;
220      case ActivationFunctionType_RELU:
221        return kTfLiteActRelu;
222      case ActivationFunctionType_RELU_N1_TO_1:
223        return kTfLiteActRelu1;
224      case ActivationFunctionType_RELU6:
225        return kTfLiteActRelu6;
226      case ActivationFunctionType_TANH:
227        return kTfLiteActTanh;
228      case ActivationFunctionType_SIGN_BIT:
229        return kTfLiteActSignBit;
230    }
231    return kTfLiteActNone;
232  };
233  auto parseLSHProjectionType = [](LSHProjectionType type) {
234    switch (type) {
235      case LSHProjectionType_SPARSE:
236        return kTfLiteLshProjectionSparse;
237      case LSHProjectionType_DENSE:
238        return kTfLiteLshProjectionDense;
239      default:
240        return kTfLiteLshProjectionUnknown;
241    }
242  };
243  auto parseCombinerType = [](CombinerType type) {
244    switch (type) {
245      case CombinerType_MEAN:
246        return kTfLiteCombinerTypeMean;
247      case CombinerType_SQRTN:
248        return kTfLiteCombinerTypeSqrtn;
249      case CombinerType_SUM:
250      default:
251        return kTfLiteCombinerTypeSum;
252    }
253  };
254
255  void* builtin_data = nullptr;
256  switch (op_type) {
257    case BuiltinOperator_CALL:
258      // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
259      // ok for now, since there is no call implementation either.
260      break;
261    case BuiltinOperator_CUSTOM:
262      break;
263    case BuiltinOperator_CONV_2D: {
264      TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
265      if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
266        params->padding = parse_padding(conv_params->padding());
267        params->stride_width = conv_params->stride_w();
268        params->stride_height = conv_params->stride_h();
269        params->activation =
270            parse_activation(conv_params->fused_activation_function());
271      }
272      builtin_data = reinterpret_cast<void*>(params);
273      break;
274    }
275    case BuiltinOperator_TANH:
276    case BuiltinOperator_LOGISTIC:
277    case BuiltinOperator_RELU:
278    case BuiltinOperator_RELU_N1_TO_1:
279    case BuiltinOperator_RELU6:
280    case BuiltinOperator_CONCAT_EMBEDDINGS:
281    case BuiltinOperator_EXP:
282    case BuiltinOperator_TOPK_V2:
283      break;
284    case BuiltinOperator_LSH_PROJECTION: {
285      TfLiteLSHProjectionParams* params =
286          MallocPOD<TfLiteLSHProjectionParams>();
287      if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
288        params->type = parseLSHProjectionType(lshParams->type());
289      }
290      builtin_data = reinterpret_cast<void*>(params);
291      break;
292    }
293    case BuiltinOperator_AVERAGE_POOL_2D:
294    case BuiltinOperator_MAX_POOL_2D:
295    case BuiltinOperator_L2_POOL_2D: {
296      TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
297      if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
298        params->padding = parse_padding(pool_params->padding());
299        params->stride_width = pool_params->stride_w();
300        params->stride_height = pool_params->stride_h();
301        params->filter_width = pool_params->filter_width();
302        params->filter_height = pool_params->filter_height();
303        params->activation =
304            parse_activation(pool_params->fused_activation_function());
305      }
306      builtin_data = reinterpret_cast<void*>(params);
307      break;
308    }
309    case BuiltinOperator_DEPTHWISE_CONV_2D: {
310      TfLiteDepthwiseConvParams* params =
311          MallocPOD<TfLiteDepthwiseConvParams>();
312      if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
313        params->padding = parse_padding(conv_params->padding());
314        params->stride_width = conv_params->stride_w();
315        params->stride_height = conv_params->stride_h();
316        params->depth_multiplier = conv_params->depth_multiplier();
317        params->activation =
318            parse_activation(conv_params->fused_activation_function());
319      }
320      builtin_data = reinterpret_cast<void*>(params);
321      break;
322    }
323    case BuiltinOperator_SVDF: {
324      TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
325      if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
326        params->rank = svdf_params->rank();
327        params->activation =
328            parse_activation(svdf_params->fused_activation_function());
329      }
330      builtin_data = reinterpret_cast<void*>(params);
331      break;
332    }
333    case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
334    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
335      TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
336      if (auto* sequence_rnn_params =
337              op->builtin_options_as_SequenceRNNOptions()) {
338        params->activation =
339            parse_activation(sequence_rnn_params->fused_activation_function());
340        params->time_major = sequence_rnn_params->time_major();
341      }
342      builtin_data = reinterpret_cast<void*>(params);
343      break;
344    }
345    case BuiltinOperator_RNN: {
346      TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
347      if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
348        params->activation =
349            parse_activation(rnn_params->fused_activation_function());
350      }
351      builtin_data = reinterpret_cast<void*>(params);
352      break;
353    }
354    case BuiltinOperator_EMBEDDING_LOOKUP:
355      // no-op.
356      break;
357    case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
358      TfLiteEmbeddingLookupSparseParams* params =
359          MallocPOD<TfLiteEmbeddingLookupSparseParams>();
360      if (auto* embedding_params =
361              op->builtin_options_as_EmbeddingLookupSparseOptions()) {
362        params->combiner = parseCombinerType(embedding_params->combiner());
363      }
364      builtin_data = reinterpret_cast<void*>(params);
365      break;
366    }
367    case BuiltinOperator_FULLY_CONNECTED: {
368      TfLiteFullyConnectedParams* params =
369          MallocPOD<TfLiteFullyConnectedParams>();
370      if (auto* fully_connected_params =
371              op->builtin_options_as_FullyConnectedOptions()) {
372        params->activation = parse_activation(
373            fully_connected_params->fused_activation_function());
374      }
375      builtin_data = reinterpret_cast<void*>(params);
376      break;
377    }
378    case BuiltinOperator_HASHTABLE_LOOKUP:
379      // no-op.
380      break;
381    case BuiltinOperator_SOFTMAX: {
382      TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
383      if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
384        params->beta = softmax_params->beta();
385      }
386      builtin_data = reinterpret_cast<void*>(params);
387      break;
388    }
389    case BuiltinOperator_CONCATENATION: {
390      TfLiteConcatenationParams* params =
391          MallocPOD<TfLiteConcatenationParams>();
392      if (auto* concatenation_params =
393              op->builtin_options_as_ConcatenationOptions()) {
394        params->activation =
395            parse_activation(concatenation_params->fused_activation_function());
396        params->axis = concatenation_params->axis();
397      }
398      builtin_data = reinterpret_cast<void*>(params);
399      break;
400    }
401    case BuiltinOperator_MUL: {
402      auto* params = MallocPOD<TfLiteMulParams>();
403      if (auto* schema_params = op->builtin_options_as_MulOptions()) {
404        params->activation =
405            parse_activation(schema_params->fused_activation_function());
406      }
407      builtin_data = reinterpret_cast<void*>(params);
408      break;
409    }
410    case BuiltinOperator_ADD: {
411      auto* params = MallocPOD<TfLiteAddParams>();
412      if (auto* schema_params = op->builtin_options_as_AddOptions()) {
413        params->activation =
414            parse_activation(schema_params->fused_activation_function());
415      }
416      builtin_data = reinterpret_cast<void*>(params);
417      break;
418    }
419    case BuiltinOperator_DIV: {
420      auto* params = MallocPOD<TfLiteDivParams>();
421      if (auto* schema_params = op->builtin_options_as_DivOptions()) {
422        params->activation =
423            parse_activation(schema_params->fused_activation_function());
424      }
425      builtin_data = reinterpret_cast<void*>(params);
426      break;
427    }
428    case BuiltinOperator_SUB: {
429      auto* params = MallocPOD<TfLiteSubParams>();
430      if (auto* schema_params = op->builtin_options_as_SubOptions()) {
431        params->activation =
432            parse_activation(schema_params->fused_activation_function());
433      }
434      builtin_data = reinterpret_cast<void*>(params);
435      break;
436    }
437    case BuiltinOperator_L2_NORMALIZATION: {
438      auto* params = MallocPOD<TfLiteL2NormParams>();
439      if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
440        params->activation =
441            parse_activation(schema_params->fused_activation_function());
442      }
443      builtin_data = reinterpret_cast<void*>(params);
444      break;
445    }
446    case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
447      auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
448      if (auto* schema_params =
449              op->builtin_options_as_LocalResponseNormalizationOptions()) {
450        params->radius = schema_params->radius();
451        params->bias = schema_params->bias();
452        params->alpha = schema_params->alpha();
453        params->beta = schema_params->beta();
454      }
455      builtin_data = reinterpret_cast<void*>(params);
456      break;
457    }
458    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
459    case BuiltinOperator_LSTM: {
460      TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
461      if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
462        params->activation =
463            parse_activation(lstm_params->fused_activation_function());
464        params->cell_clip = lstm_params->cell_clip();
465        params->proj_clip = lstm_params->proj_clip();
466      }
467      builtin_data = reinterpret_cast<void*>(params);
468      break;
469    }
470    case BuiltinOperator_RESIZE_BILINEAR: {
471      auto* params = MallocPOD<TfLiteResizeBilinearParams>();
472      if (auto* schema_params =
473              op->builtin_options_as_ResizeBilinearOptions()) {
474        params->align_corners = schema_params->align_corners();
475      }
476      builtin_data = reinterpret_cast<void*>(params);
477      break;
478    }
479    case BuiltinOperator_PAD: {
480      break;
481    }
482    case BuiltinOperator_RESHAPE: {
483      auto* params = MallocPOD<TfLiteReshapeParams>();
484      if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
485        auto* new_shape = schema_params->new_shape();
486        FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
487                                   params->shape, error_reporter);
488        params->num_dimensions = new_shape->Length();
489      }
490      builtin_data = reinterpret_cast<void*>(params);
491      break;
492    }
493    case BuiltinOperator_SKIP_GRAM: {
494      TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
495      if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
496        params->ngram_size = skip_gram_params->ngram_size();
497        params->max_skip_size = skip_gram_params->max_skip_size();
498        params->include_all_ngrams = skip_gram_params->include_all_ngrams();
499      }
500      builtin_data = reinterpret_cast<void*>(params);
501      break;
502    }
503    case BuiltinOperator_SPACE_TO_DEPTH: {
504      auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
505      if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
506        params->block_size = schema_params->block_size();
507      }
508      builtin_data = reinterpret_cast<void*>(params);
509      break;
510    }
511    case BuiltinOperator_GATHER: {
512      TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
513      params->axis = 0;
514      if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
515        params->axis = gather_params->axis();
516      }
517
518      builtin_data = reinterpret_cast<void*>(params);
519      break;
520    }
521    case BuiltinOperator_SPACE_TO_BATCH_ND: {
522      break;
523    }
524    case BuiltinOperator_BATCH_TO_SPACE_ND: {
525      break;
526    }
527    case BuiltinOperator_TRANSPOSE: {
528      break;
529    }
530    case BuiltinOperator_MEAN: {
531      auto* params = MallocPOD<TfLiteMeanParams>();
532      if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
533        params->keep_dims = schema_params->keep_dims();
534      }
535      builtin_data = reinterpret_cast<void*>(params);
536      break;
537    }
538    case BuiltinOperator_SPLIT: {
539      auto* params = MallocPOD<TfLiteSplitParams>();
540      if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
541        params->num_splits = schema_params->num_splits();
542      }
543      builtin_data = reinterpret_cast<void*>(params);
544      break;
545    }
546    case BuiltinOperator_SQUEEZE: {
547      auto* params = MallocPOD<TfLiteSqueezeParams>();
548      if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
549        const auto& squeeze_dims = schema_params->squeeze_dims();
550        FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
551                                   params->squeeze_dims, error_reporter);
552        params->num_squeeze_dims = squeeze_dims->Length();
553      }
554      builtin_data = reinterpret_cast<void*>(params);
555      break;
556    }
557    case BuiltinOperator_STRIDED_SLICE: {
558      auto* params = MallocPOD<TfLiteStridedSliceParams>();
559      if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
560        params->begin_mask = schema_params->begin_mask();
561        params->end_mask = schema_params->end_mask();
562        params->ellipsis_mask = schema_params->ellipsis_mask();
563        params->new_axis_mask = schema_params->new_axis_mask();
564        params->shrink_axis_mask = schema_params->shrink_axis_mask();
565      }
566      builtin_data = reinterpret_cast<void*>(params);
567      break;
568    }
569  }
570  return builtin_data;
571}
572
573}  // namespace
574
575TfLiteStatus InterpreterBuilder::ParseNodes(
576    const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
577    Interpreter* interpreter) {
578  TfLiteStatus status = kTfLiteOk;
579  for (int i = 0; i < operators->Length(); ++i) {
580    const auto* op = operators->Get(i);
581    int index = op->opcode_index();
582    if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
583      error_reporter_->Report("Missing registration for opcode_index %d\n",
584                              index);
585      status = kTfLiteError;
586      continue;
587    }
588    const TfLiteRegistration* reg =
589        flatbuffer_op_index_to_registration_[op->opcode_index()];
590    if (reg == nullptr) {
591      error_reporter_->Report("Skipping op for opcode_index %d\n", index);
592      status = kTfLiteError;
593      continue;
594    }
595
596    auto op_type =
597        flatbuffer_op_index_to_registration_types_[op->opcode_index()];
598    if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
599      error_reporter_->Report(
600          "Found builtin operator %s with custom options.\n",
601          EnumNameBuiltinOperator(op_type));
602    }
603    if (op->custom_options()) {
604      interpreter->AddNodeWithParameters(
605          FlatBufferIntArrayToVector(op->inputs()),
606          FlatBufferIntArrayToVector(op->outputs()),
607          reinterpret_cast<const char*>(op->custom_options()->data()),
608          op->custom_options()->size(), nullptr, reg);
609    } else {
610      interpreter->AddNodeWithParameters(
611          FlatBufferIntArrayToVector(op->inputs()),
612          FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
613          ParseOpData(op, op_type, error_reporter_), reg);
614    }
615  }
616
617  return status;
618}
619
620TfLiteStatus InterpreterBuilder::ParseTensors(
621    const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
622    const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
623    Interpreter* interpreter) {
624  TfLiteStatus status = kTfLiteOk;
625
626  // A little helper to get the names of inputs and outputs. Note that they
627  // must outlive the interpreter.
628  auto get_name = [](const tflite::Tensor* t) -> const char* {
629    auto name = t->name();
630    if (name) return name->c_str();
631    return kEmptyTensorName;
632  };
633
634  for (int i = 0; i < tensors->Length(); ++i) {
635    const auto* tensor = tensors->Get(i);
636    std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
637
638    TfLiteQuantizationParams quantization;
639    quantization.scale = 0;
640    quantization.zero_point = 0;
641    auto* q_params = tensor->quantization();
642    if (q_params) {
643      // Note that the schema could hold per-channel quantization parameters
644      // but we really only support one value for the whole tensor.
645      // TODO(aselle): This breaks as well if these are nullptr's.
646      // TODO(aselle): This assumes non per-channel quantization.
647      if (q_params->scale()) quantization.scale = q_params->scale()->Get(0);
648      if (q_params->zero_point())
649        quantization.zero_point = q_params->zero_point()->Get(0);
650    }
651
652    TfLiteType type;
653    switch (tensor->type()) {
654      case TensorType_FLOAT32:
655        type = kTfLiteFloat32;
656        break;
657      case TensorType_INT32:
658        type = kTfLiteInt32;
659        break;
660      case TensorType_UINT8:
661        type = kTfLiteUInt8;
662        break;
663      case TensorType_INT64:
664        type = kTfLiteInt64;
665        break;
666      case TensorType_STRING:
667        type = kTfLiteString;
668        break;
669      default:
670        // tensorType = ArrayType::NONE;
671        error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
672                                EnumNameTensorType(tensor->type()),
673                                tensor->type());
674        status = kTfLiteError;
675        continue;
676    }
677    auto get_readonly_data = [&](const char** buffer_data,
678                                 size_t* buffer_size) {
679      // TODO(aselle): Check what happens if we have an unspecified size
680      // constant.
681      *buffer_data = nullptr;
682      if (tensor->buffer() == 0) return kTfLiteOk;
683      if (tensor->buffer() >= buffers->size()) {
684        error_reporter_->Report(
685            "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
686            i, tensor->buffer(), buffers->size());
687        return kTfLiteError;
688      }
689      if (auto* buffer = (*buffers)[tensor->buffer()]) {
690        if (auto* array = buffer->data()) {
691          if (size_t size = array->size()) {
692            *buffer_size = size;
693            *buffer_data = reinterpret_cast<const char*>(array->data());
694            return kTfLiteOk;
695          }
696        }
697      }
698      return kTfLiteOk;
699    };
700    size_t buffer_size = 0;
701    const char* buffer_ptr;
702    TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
703
704    if (buffer_ptr) {
705      if (interpreter->SetTensorParametersReadOnly(
706              i, type, get_name(tensor), dims, quantization, buffer_ptr,
707              buffer_size, allocation_) != kTfLiteOk) {
708        error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
709                                i);
710        status = kTfLiteError;
711      }
712    } else {
713      if (interpreter->SetTensorParametersReadWrite(
714              i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
715        error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
716                                i);
717        status = kTfLiteError;
718      }
719    }
720  }
721
722  return status;
723}
724
725TfLiteStatus InterpreterBuilder::operator()(
726    std::unique_ptr<Interpreter>* interpreter) {
727  if (!interpreter) {
728    error_reporter_->Report(
729        "Null output pointer passed to InterpreterBuilder.");
730    return kTfLiteError;
731  }
732
733  // Safe exit by deleting partially created interpreter, to reduce verbosity
734  // on error conditions. Use by return cleanup_on_error();
735  auto cleanup_and_error = [&interpreter]() {
736    interpreter->reset();
737    return kTfLiteError;
738  };
739
740  if (!model_) {
741    error_reporter_->Report("Null pointer passed in as model.");
742    return cleanup_and_error();
743  }
744
745  if (model_->version() != TFLITE_SCHEMA_VERSION) {
746    error_reporter_->Report(
747        "Model provided is schema version %d not equal "
748        "to supported version %d.\n",
749        model_->version(), TFLITE_SCHEMA_VERSION);
750    return cleanup_and_error();
751  }
752
753  if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
754    error_reporter_->Report("Registration failed.\n");
755    return cleanup_and_error();
756  }
757
758  // Flatbuffer model schemas define a list of opcodes independent of the graph.
759  // We first map those to registrations. This reduces string lookups for custom
760  // ops since we only do it once per custom op rather than once per custom op
761  // invocation in the model graph.
762  // Construct interpreter with correct number of tensors and operators.
763  auto* subgraphs = model_->subgraphs();
764  auto* buffers = model_->buffers();
765  if (subgraphs->size() != 1) {
766    error_reporter_->Report("Only 1 subgraph is currently supported.\n");
767    return cleanup_and_error();
768  }
769  const tflite::SubGraph* subgraph = (*subgraphs)[0];
770  auto operators = subgraph->operators();
771  auto tensors = subgraph->tensors();
772  if (!operators || !tensors || !buffers) {
773    error_reporter_->Report(
774        "Did not get operators, tensors, or buffers in input flat buffer.\n");
775    return cleanup_and_error();
776  }
777  interpreter->reset(new Interpreter(error_reporter_));
778  if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
779    return cleanup_and_error();
780  }
781
782  // Parse inputs/outputs
783  (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
784  (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
785
786  // Finally setup nodes and tensors
787  if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
788    return cleanup_and_error();
789  if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
790    return cleanup_and_error();
791
792  return kTfLiteOk;
793}
794
795}  // namespace tflite
796