1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/tflite/operator.h"
16
17#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
18#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
19#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
20#include "tensorflow/contrib/lite/toco/tflite/types.h"
21
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24
25namespace toco {
26
27namespace tflite {
28
29class AveragePool
30    : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
31                             ::tflite::BuiltinOptions_Pool2DOptions> {
32 public:
33  using BuiltinOperator::BuiltinOperator;
34
35  flatbuffers::Offset<TfLiteOptions> WriteOptions(
36      const TocoOperator& op,
37      flatbuffers::FlatBufferBuilder* builder) const override {
38    auto padding = Padding::Serialize(op.padding.type);
39    auto activation_function =
40        ActivationFunction::Serialize(op.fused_activation_function);
41    return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
42                                         op.stride_height, op.kwidth,
43                                         op.kheight, activation_function);
44  }
45
46  void ReadOptions(const TfLiteOptions& options,
47                   TocoOperator* op) const override {
48    op->padding.type = Padding::Deserialize(options.padding());
49    op->stride_width = options.stride_w();
50    op->stride_height = options.stride_h();
51    op->kwidth = options.filter_width();
52    op->kheight = options.filter_height();
53    op->fused_activation_function =
54        ActivationFunction::Deserialize(options.fused_activation_function());
55  }
56};
57
58class Convolution
59    : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
60                             ::tflite::BuiltinOptions_Conv2DOptions> {
61 public:
62  using BuiltinOperator::BuiltinOperator;
63
64  flatbuffers::Offset<TfLiteOptions> WriteOptions(
65      const TocoOperator& op,
66      flatbuffers::FlatBufferBuilder* builder) const override {
67    auto padding = Padding::Serialize(op.padding.type);
68    auto activation_function =
69        ActivationFunction::Serialize(op.fused_activation_function);
70    return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
71                                         op.stride_height, activation_function);
72  }
73
74  void ReadOptions(const TfLiteOptions& options,
75                   TocoOperator* op) const override {
76    op->padding.type = Padding::Deserialize(options.padding());
77    op->stride_width = options.stride_w();
78    op->stride_height = options.stride_h();
79    op->fused_activation_function =
80        ActivationFunction::Deserialize(options.fused_activation_function());
81  }
82};
83
84class DepthwiseConvolution
85    : public BuiltinOperator<DepthwiseConvOperator,
86                             ::tflite::DepthwiseConv2DOptions,
87                             ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
88 public:
89  using BuiltinOperator::BuiltinOperator;
90
91  flatbuffers::Offset<TfLiteOptions> WriteOptions(
92      const TocoOperator& op,
93      flatbuffers::FlatBufferBuilder* builder) const override {
94    auto padding = Padding::Serialize(op.padding.type);
95    auto activation_function =
96        ActivationFunction::Serialize(op.fused_activation_function);
97    return ::tflite::CreateDepthwiseConv2DOptions(
98        *builder, padding, op.stride_width, op.stride_height,
99        op.depth_multiplier, activation_function);
100  }
101
102  void ReadOptions(const TfLiteOptions& options,
103                   TocoOperator* op) const override {
104    op->padding.type = Padding::Deserialize(options.padding());
105    op->stride_width = options.stride_w();
106    op->stride_height = options.stride_h();
107    op->depth_multiplier = options.depth_multiplier();
108    op->fused_activation_function =
109        ActivationFunction::Deserialize(options.fused_activation_function());
110  }
111};
112
113class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
114                                   ::tflite::BuiltinOptions_AddOptions> {
115 public:
116  using BuiltinOperator::BuiltinOperator;
117
118  flatbuffers::Offset<TfLiteOptions> WriteOptions(
119      const TocoOperator& op,
120      flatbuffers::FlatBufferBuilder* builder) const override {
121    auto activation_function =
122        ActivationFunction::Serialize(op.fused_activation_function);
123    return ::tflite::CreateAddOptions(*builder, activation_function);
124  }
125
126  void ReadOptions(const TfLiteOptions& options,
127                   TocoOperator* op) const override {
128    op->fused_activation_function =
129        ActivationFunction::Deserialize(options.fused_activation_function());
130  }
131};
132
133class SpaceToBatchND
134    : public BuiltinOperator<SpaceToBatchNDOperator,
135                             ::tflite::SpaceToBatchNDOptions,
136                             ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
137 public:
138  using BuiltinOperator::BuiltinOperator;
139
140  flatbuffers::Offset<TfLiteOptions> WriteOptions(
141      const TocoOperator& op,
142      flatbuffers::FlatBufferBuilder* builder) const override {
143    return ::tflite::CreateSpaceToBatchNDOptions(*builder);
144  }
145
146  void ReadOptions(const TfLiteOptions& options,
147                   TocoOperator* op) const override {}
148};
149
150class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
151                                   ::tflite::BuiltinOptions_SubOptions> {
152 public:
153  using BuiltinOperator::BuiltinOperator;
154
155  flatbuffers::Offset<TfLiteOptions> WriteOptions(
156      const TocoOperator& op,
157      flatbuffers::FlatBufferBuilder* builder) const override {
158    auto activation_function =
159        ActivationFunction::Serialize(op.fused_activation_function);
160    return ::tflite::CreateSubOptions(*builder, activation_function);
161  }
162
163  void ReadOptions(const TfLiteOptions& options,
164                   TocoOperator* op) const override {
165    op->fused_activation_function =
166        ActivationFunction::Deserialize(options.fused_activation_function());
167  }
168};
169
170class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
171                                   ::tflite::BuiltinOptions_DivOptions> {
172 public:
173  using BuiltinOperator::BuiltinOperator;
174
175  flatbuffers::Offset<TfLiteOptions> WriteOptions(
176      const TocoOperator& op,
177      flatbuffers::FlatBufferBuilder* builder) const override {
178    auto activation_function =
179        ActivationFunction::Serialize(op.fused_activation_function);
180    return ::tflite::CreateDivOptions(*builder, activation_function);
181  }
182
183  void ReadOptions(const TfLiteOptions& options,
184                   TocoOperator* op) const override {
185    op->fused_activation_function =
186        ActivationFunction::Deserialize(options.fused_activation_function());
187  }
188};
189
190class BatchToSpaceND
191    : public BuiltinOperator<BatchToSpaceNDOperator,
192                             ::tflite::BatchToSpaceNDOptions,
193                             ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
194 public:
195  using BuiltinOperator::BuiltinOperator;
196
197  flatbuffers::Offset<TfLiteOptions> WriteOptions(
198      const TocoOperator& op,
199      flatbuffers::FlatBufferBuilder* builder) const override {
200    return ::tflite::CreateBatchToSpaceNDOptions(*builder);
201  }
202
203  void ReadOptions(const TfLiteOptions& options,
204                   TocoOperator* op) const override {}
205};
206
207class Cast : public CustomOperator<CastOperator> {
208 public:
209  using CustomOperator::CustomOperator;
210  void WriteOptions(const TocoOperator& op,
211                    flexbuffers::Builder* fbb) const override {
212    fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
213    fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
214  }
215  void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
216    op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
217    op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
218  }
219};
220
221class Concatenation
222    : public BuiltinOperator<ConcatenationOperator,
223                             ::tflite::ConcatenationOptions,
224                             ::tflite::BuiltinOptions_ConcatenationOptions> {
225 public:
226  using BuiltinOperator::BuiltinOperator;
227  flatbuffers::Offset<TfLiteOptions> WriteOptions(
228      const TocoOperator& op,
229      flatbuffers::FlatBufferBuilder* builder) const override {
230    return ::tflite::CreateConcatenationOptions(*builder, op.axis);
231  }
232
233  void ReadOptions(const TfLiteOptions& options,
234                   TocoOperator* op) const override {
235    op->axis = options.axis();
236  }
237};
238
239class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
240 public:
241  using CustomOperator::CustomOperator;
242  void WriteOptions(const TocoOperator& op,
243                    flexbuffers::Builder* fbb) const override {
244    fbb->Int("block_size", op.block_size);
245  }
246  void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
247    op->block_size = m["block_size"].AsInt64();
248  }
249};
250
251class FakeQuant : public CustomOperator<FakeQuantOperator> {
252 public:
253  using CustomOperator::CustomOperator;
254  void WriteOptions(const TocoOperator& op,
255                    flexbuffers::Builder* fbb) const override {
256    fbb->Float("min", op.minmax->min);
257    fbb->Float("max", op.minmax->max);
258  }
259  void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
260    auto* minmax = new MinMax;
261    minmax->min = m["min"].AsFloat();
262    minmax->max = m["max"].AsFloat();
263    op->minmax.reset(minmax);
264  }
265};
266
267class FullyConnected
268    : public BuiltinOperator<FullyConnectedOperator,
269                             ::tflite::FullyConnectedOptions,
270                             ::tflite::BuiltinOptions_FullyConnectedOptions> {
271 public:
272  using BuiltinOperator::BuiltinOperator;
273  flatbuffers::Offset<TfLiteOptions> WriteOptions(
274      const TocoOperator& op,
275      flatbuffers::FlatBufferBuilder* builder) const override {
276    auto activation_function =
277        ActivationFunction::Serialize(op.fused_activation_function);
278    return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
279  }
280
281  void ReadOptions(const TfLiteOptions& options,
282                   TocoOperator* op) const override {
283    op->fused_activation_function =
284        ActivationFunction::Deserialize(options.fused_activation_function());
285  }
286};
287
288class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
289                                      ::tflite::BuiltinOptions_GatherOptions> {
290 public:
291  using BuiltinOperator::BuiltinOperator;
292  flatbuffers::Offset<TfLiteOptions> WriteOptions(
293      const TocoOperator& op,
294      flatbuffers::FlatBufferBuilder* builder) const override {
295    return ::tflite::CreateGatherOptions(*builder, op.axis);
296  }
297
298  void ReadOptions(const TfLiteOptions& options,
299                   TocoOperator* op) const override {
300    op->axis = options.axis();
301  }
302};
303
304class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
305                                    ::tflite::BuiltinOptions_SVDFOptions> {
306 public:
307  using BuiltinOperator::BuiltinOperator;
308  flatbuffers::Offset<TfLiteOptions> WriteOptions(
309      const TocoOperator& op,
310      flatbuffers::FlatBufferBuilder* builder) const override {
311    auto activation_function =
312        ActivationFunction::Serialize(op.fused_activation_function);
313    return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
314  }
315
316  void ReadOptions(const TfLiteOptions& options,
317                   TocoOperator* op) const override {
318    op->fused_activation_function =
319        ActivationFunction::Deserialize(options.fused_activation_function());
320    op->rank = options.rank();
321  }
322};
323
324class L2Normalization
325    : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
326                             ::tflite::BuiltinOptions_L2NormOptions> {
327 public:
328  using BuiltinOperator::BuiltinOperator;
329  flatbuffers::Offset<TfLiteOptions> WriteOptions(
330      const TocoOperator& op,
331      flatbuffers::FlatBufferBuilder* builder) const override {
332    auto activation_function =
333        ActivationFunction::Serialize(op.fused_activation_function);
334    return ::tflite::CreateL2NormOptions(*builder, activation_function);
335  }
336
337  void ReadOptions(const TfLiteOptions& options,
338                   TocoOperator* op) const override {
339    op->fused_activation_function =
340        ActivationFunction::Deserialize(options.fused_activation_function());
341  }
342};
343
344class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
345                                      ::tflite::BuiltinOptions_Pool2DOptions> {
346 public:
347  using BuiltinOperator::BuiltinOperator;
348  flatbuffers::Offset<TfLiteOptions> WriteOptions(
349      const TocoOperator& op,
350      flatbuffers::FlatBufferBuilder* builder) const override {
351    auto padding = Padding::Serialize(op.padding.type);
352    auto activation_function =
353        ActivationFunction::Serialize(op.fused_activation_function);
354    return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
355                                         op.stride_height, op.kwidth,
356                                         op.kheight, activation_function);
357  }
358
359  void ReadOptions(const TfLiteOptions& options,
360                   TocoOperator* op) const override {
361    op->padding.type = Padding::Deserialize(options.padding());
362    op->stride_width = options.stride_w();
363    op->stride_height = options.stride_h();
364    op->kwidth = options.filter_width();
365    op->kheight = options.filter_height();
366    op->fused_activation_function =
367        ActivationFunction::Deserialize(options.fused_activation_function());
368  }
369};
370
371class LocalResponseNormalization
372    : public BuiltinOperator<
373          LocalResponseNormalizationOperator,
374          ::tflite::LocalResponseNormalizationOptions,
375          ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
376 public:
377  using BuiltinOperator::BuiltinOperator;
378  flatbuffers::Offset<TfLiteOptions> WriteOptions(
379      const TocoOperator& op,
380      flatbuffers::FlatBufferBuilder* builder) const override {
381    return ::tflite::CreateLocalResponseNormalizationOptions(
382        *builder, op.range, op.bias, op.alpha, op.beta);
383  }
384
385  void ReadOptions(const TfLiteOptions& options,
386                   TocoOperator* op) const override {
387    op->range = options.radius();
388    op->bias = options.bias();
389    op->alpha = options.alpha();
390    op->beta = options.beta();
391  }
392};
393
394class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
395                                       ::tflite::BuiltinOptions_Pool2DOptions> {
396 public:
397  using BuiltinOperator::BuiltinOperator;
398  flatbuffers::Offset<TfLiteOptions> WriteOptions(
399      const TocoOperator& op,
400      flatbuffers::FlatBufferBuilder* builder) const override {
401    auto padding = Padding::Serialize(op.padding.type);
402    auto activation_function =
403        ActivationFunction::Serialize(op.fused_activation_function);
404    return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
405                                         op.stride_height, op.kwidth,
406                                         op.kheight, activation_function);
407  }
408
409  void ReadOptions(const TfLiteOptions& options,
410                   TocoOperator* op) const override {
411    op->padding.type = Padding::Deserialize(options.padding());
412    op->stride_width = options.stride_w();
413    op->stride_height = options.stride_h();
414    op->kwidth = options.filter_width();
415    op->kheight = options.filter_height();
416    op->fused_activation_function =
417        ActivationFunction::Deserialize(options.fused_activation_function());
418  }
419};
420
421class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
422                                   ::tflite::BuiltinOptions_MulOptions> {
423 public:
424  using BuiltinOperator::BuiltinOperator;
425
426  flatbuffers::Offset<TfLiteOptions> WriteOptions(
427      const TocoOperator& op,
428      flatbuffers::FlatBufferBuilder* builder) const override {
429    auto activation_function =
430        ActivationFunction::Serialize(op.fused_activation_function);
431    return ::tflite::CreateMulOptions(*builder, activation_function);
432  }
433
434  void ReadOptions(const TfLiteOptions& options,
435                   TocoOperator* op) const override {
436    op->fused_activation_function =
437        ActivationFunction::Deserialize(options.fused_activation_function());
438  }
439};
440
441class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
442                                   ::tflite::BuiltinOptions_PadOptions> {
443 public:
444  using BuiltinOperator::BuiltinOperator;
445
446  flatbuffers::Offset<TfLiteOptions> WriteOptions(
447      const TocoOperator& op,
448      flatbuffers::FlatBufferBuilder* builder) const override {
449    return ::tflite::CreatePadOptions(*builder);
450  }
451
452  void ReadOptions(const TfLiteOptions& options,
453                   TocoOperator* op) const override {}
454};
455
456class Reshape
457    : public BuiltinOperator<TensorFlowReshapeOperator,
458                             ::tflite::ReshapeOptions,
459                             ::tflite::BuiltinOptions_ReshapeOptions> {
460 public:
461  using BuiltinOperator::BuiltinOperator;
462
463  flatbuffers::Offset<TfLiteOptions> WriteOptions(
464      const TocoOperator& op,
465      flatbuffers::FlatBufferBuilder* builder) const override {
466    return ::tflite::CreateReshapeOptions(*builder,
467                                          builder->CreateVector(op.shape));
468  }
469
470  void ReadOptions(const TfLiteOptions& options,
471                   TocoOperator* op) const override {
472    op->shape.insert(op->shape.end(), options.new_shape()->begin(),
473                     options.new_shape()->end());
474  }
475};
476
477class Softmax
478    : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
479                             ::tflite::BuiltinOptions_SoftmaxOptions> {
480 public:
481  using BuiltinOperator::BuiltinOperator;
482  flatbuffers::Offset<TfLiteOptions> WriteOptions(
483      const TocoOperator& op,
484      flatbuffers::FlatBufferBuilder* builder) const override {
485    return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
486  }
487
488  void ReadOptions(const TfLiteOptions& options,
489                   TocoOperator* op) const override {
490    op->beta = options.beta();
491  }
492};
493
494class SpaceToDepth
495    : public BuiltinOperator<SpaceToDepthOperator,
496                             ::tflite::SpaceToDepthOptions,
497                             ::tflite::BuiltinOptions_SpaceToDepthOptions> {
498 public:
499  using BuiltinOperator::BuiltinOperator;
500  flatbuffers::Offset<TfLiteOptions> WriteOptions(
501      const TocoOperator& op,
502      flatbuffers::FlatBufferBuilder* builder) const override {
503    return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
504  }
505
506  void ReadOptions(const TfLiteOptions& options,
507                   TocoOperator* op) const override {
508    op->block_size = options.block_size();
509  }
510};
511
512class Transpose
513    : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
514                             ::tflite::BuiltinOptions_TransposeOptions> {
515 public:
516  using BuiltinOperator::BuiltinOperator;
517  flatbuffers::Offset<TfLiteOptions> WriteOptions(
518      const TocoOperator& op,
519      flatbuffers::FlatBufferBuilder* builder) const override {
520    return ::tflite::CreateTransposeOptions(*builder);
521  }
522
523  void ReadOptions(const TfLiteOptions& options,
524                   TocoOperator* op) const override {}
525};
526
527class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
528                                    ::tflite::BuiltinOptions_LSTMOptions> {
529 public:
530  using BuiltinOperator::BuiltinOperator;
531  flatbuffers::Offset<TfLiteOptions> WriteOptions(
532      const TocoOperator& op,
533      flatbuffers::FlatBufferBuilder* builder) const override {
534    // Current toco converter only supports tanh, no clip.
535    return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
536                                       ::tflite::ActivationFunctionType_TANH,
537                                       /*cell_clip=*/0.0,
538                                       /*proj_clip=*/0.0);
539  }
540
541  void ReadOptions(const TfLiteOptions& options,
542                   TocoOperator* op) const override {
543    // Only support tanh activation, so check that tflite type is tanh.
544    CHECK(options.fused_activation_function() ==
545          ::tflite::ActivationFunctionType_TANH);
546  }
547};
548
549class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
550                                    ::tflite::BuiltinOptions_MeanOptions> {
551 public:
552  using BuiltinOperator::BuiltinOperator;
553  flatbuffers::Offset<TfLiteOptions> WriteOptions(
554      const TocoOperator& op,
555      flatbuffers::FlatBufferBuilder* builder) const override {
556    return ::tflite::CreateMeanOptions(*builder, op.keep_dims);
557  }
558
559  void ReadOptions(const TfLiteOptions& options,
560                   TocoOperator* op) const override {
561    op->keep_dims = options.keep_dims();
562  }
563};
564
565class ResizeBilinear
566    : public BuiltinOperator<ResizeBilinearOperator,
567                             ::tflite::ResizeBilinearOptions,
568                             ::tflite::BuiltinOptions_ResizeBilinearOptions> {
569 public:
570  using BuiltinOperator::BuiltinOperator;
571  flatbuffers::Offset<TfLiteOptions> WriteOptions(
572      const TocoOperator& op,
573      flatbuffers::FlatBufferBuilder* builder) const override {
574    return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
575  }
576
577  void ReadOptions(const TfLiteOptions& options,
578                   TocoOperator* op) const override {
579    op->align_corners = options.align_corners();
580  }
581};
582
583class Squeeze
584    : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
585                             ::tflite::BuiltinOptions_SqueezeOptions> {
586 public:
587  using BuiltinOperator::BuiltinOperator;
588
589  flatbuffers::Offset<TfLiteOptions> WriteOptions(
590      const TocoOperator& op,
591      flatbuffers::FlatBufferBuilder* builder) const override {
592    auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
593    return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
594  }
595
596  void ReadOptions(const TfLiteOptions& options,
597                   TocoOperator* op) const override {
598    op->squeeze_dims.insert(op->squeeze_dims.end(),
599                            options.squeeze_dims()->begin(),
600                            options.squeeze_dims()->end());
601  }
602};
603
604class Split
605    : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
606                             ::tflite::BuiltinOptions_SplitOptions> {
607 public:
608  using BuiltinOperator::BuiltinOperator;
609
610  flatbuffers::Offset<TfLiteOptions> WriteOptions(
611      const TocoOperator& op,
612      flatbuffers::FlatBufferBuilder* builder) const override {
613    return ::tflite::CreateSplitOptions(*builder, op.num_split);
614  }
615
616  void ReadOptions(const TfLiteOptions& options,
617                   TocoOperator* op) const override {
618    op->num_split = options.num_splits();
619  }
620};
621
622class StridedSlice
623    : public BuiltinOperator<StridedSliceOperator,
624                             ::tflite::StridedSliceOptions,
625                             ::tflite::BuiltinOptions_StridedSliceOptions> {
626 public:
627  using BuiltinOperator::BuiltinOperator;
628  flatbuffers::Offset<TfLiteOptions> WriteOptions(
629      const TocoOperator& op,
630      flatbuffers::FlatBufferBuilder* builder) const override {
631    return ::tflite::CreateStridedSliceOptions(
632        *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
633        op.new_axis_mask, op.shrink_axis_mask);
634  }
635
636  void ReadOptions(const TfLiteOptions& options,
637                   TocoOperator* op) const override {
638    op->begin_mask = options.begin_mask();
639    op->end_mask = options.end_mask();
640    op->ellipsis_mask = options.ellipsis_mask();
641    op->new_axis_mask = options.new_axis_mask();
642    op->shrink_axis_mask = options.shrink_axis_mask();
643  }
644};
645
646class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
647                                       ::tflite::BuiltinOptions_TopKV2Options> {
648 public:
649  using BuiltinOperator::BuiltinOperator;
650  flatbuffers::Offset<TfLiteOptions> WriteOptions(
651      const TocoOperator& op,
652      flatbuffers::FlatBufferBuilder* builder) const override {
653    return ::tflite::CreateTopKV2Options(*builder);
654  }
655
656  void ReadOptions(const TfLiteOptions& options,
657                   TocoOperator* op) const override {}
658};
659
660class TensorFlowUnsupported : public BaseOperator {
661 public:
662  using BaseOperator::BaseOperator;
663
664  Options Serialize(const Operator& op,
665                    flatbuffers::FlatBufferBuilder* builder) const override {
666    auto fbb =
667        WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
668    if (fbb) {
669      return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
670    } else {
671      return Options::Custom(0);
672    }
673  }
674
675  std::unique_ptr<Operator> Deserialize(
676      const BuiltinOptions* builtin_options,
677      const CustomOptions* custom_options) const override {
678    auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
679    if (custom_options) {
680      auto flexbuffer_map =
681          flexbuffers::GetRoot(custom_options->data(), custom_options->size())
682              .AsMap();
683      ReadOptions(flexbuffer_map, op.get());
684    }
685    return std::unique_ptr<Operator>(op.release());
686  }
687
688  std::unique_ptr<flexbuffers::Builder> WriteOptions(
689      const TensorFlowUnsupportedOperator& op) const {
690    auto fbb = absl::make_unique<flexbuffers::Builder>();
691
692    ::tensorflow::NodeDef node_def;
693    if (!node_def.ParseFromString(op.tensorflow_node_def)) {
694      LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
695      return std::unique_ptr<flexbuffers::Builder>();
696    }
697
698    bool has_valid_attr = false;
699    size_t map_start = fbb->StartMap();
700    for (const auto& pair : node_def.attr()) {
701      const char* key = pair.first.c_str();
702      const auto& attr = pair.second;
703      switch (attr.value_case()) {
704        case ::tensorflow::AttrValue::kS:
705          fbb->String(key, attr.s());
706          has_valid_attr = true;
707          break;
708        case ::tensorflow::AttrValue::kI:
709          fbb->Int(key, attr.i());
710          has_valid_attr = true;
711          break;
712        case ::tensorflow::AttrValue::kF:
713          fbb->Float(key, attr.f());
714          has_valid_attr = true;
715          break;
716        case ::tensorflow::AttrValue::kB:
717          fbb->Bool(key, attr.b());
718          has_valid_attr = true;
719          break;
720        default:
721          LOG(WARNING) << "Ignoring unsupported attribute type with key '"
722                       << key << "'";
723          break;
724      }
725    }
726    if (!has_valid_attr) {
727      return std::unique_ptr<flexbuffers::Builder>();
728    }
729    fbb->EndMap(map_start);
730    fbb->Finish();
731    return std::unique_ptr<flexbuffers::Builder>(fbb.release());
732  }
733
734  void ReadOptions(const flexbuffers::Map& m,
735                   TensorFlowUnsupportedOperator* op) const {
736    ::tensorflow::NodeDef node_def;
737    auto attr = node_def.mutable_attr();
738
739    const auto& keys = m.Keys();
740    for (size_t i = 0; i < keys.size(); ++i) {
741      const auto key = keys[i].AsKey();
742      const auto& value = m[key];
743      switch (value.GetType()) {
744        case flexbuffers::TYPE_STRING:
745          (*attr)[key].set_s(value.AsString().c_str());
746          break;
747        case flexbuffers::TYPE_INT:
748          (*attr)[key].set_i(value.AsInt64());
749          break;
750        case flexbuffers::TYPE_FLOAT:
751          (*attr)[key].set_f(value.AsFloat());
752          break;
753        case flexbuffers::TYPE_BOOL:
754          (*attr)[key].set_b(value.AsBool());
755          break;
756        default:
757          LOG(WARNING) << "Ignoring unsupported attribute type with key '"
758                       << key << "'";
759          break;
760      }
761    }
762    node_def.SerializeToString(&op->tensorflow_node_def);
763  }
764};
765
766namespace {
767// Build a vector containing all the known operators.
768std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
769  std::vector<std::unique_ptr<BaseOperator>> ops;
770
771  // Builtin Operators.
772  ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
773  ops.emplace_back(new Div(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
774  ops.emplace_back(new Sub(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
775  ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
776                                   OperatorType::kAveragePool));
777  ops.emplace_back(
778      new SpaceToBatchND(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
779                         OperatorType::kSpaceToBatchND));
780  ops.emplace_back(
781      new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
782                         OperatorType::kBatchToSpaceND));
783  ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
784                                     OperatorType::kConcatenation));
785  ops.emplace_back(
786      new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv));
787  ops.emplace_back(
788      new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
789                               OperatorType::kDepthwiseConv));
790  ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED,
791                                      OperatorType::kFullyConnected));
792  ops.emplace_back(
793      new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather));
794  ops.emplace_back(
795      new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION,
796                          OperatorType::kL2Normalization));
797  ops.emplace_back(
798      new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool));
799  ops.emplace_back(new LocalResponseNormalization(
800      ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
801      OperatorType::kLocalResponseNormalization));
802  ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D,
803                               OperatorType::kMaxPool));
804  ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
805  ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
806  ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
807                               OperatorType::kTensorFlowReshape));
808  ops.emplace_back(
809      new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
810  ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
811                                    OperatorType::kSpaceToDepth));
812  ops.emplace_back(
813      new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
814  ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
815                                 OperatorType::kTranspose));
816  ops.emplace_back(
817      new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
818  ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
819                                      OperatorType::kResizeBilinear));
820  ops.emplace_back(
821      new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
822  ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT,
823                             OperatorType::kTensorFlowSplit));
824  ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
825                                    OperatorType::kStridedSlice));
826  ops.emplace_back(
827      new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2));
828  ops.emplace_back(
829      new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
830
831  // Custom Operators.
832  ops.emplace_back(new Cast("CAST", OperatorType::kCast));
833  ops.emplace_back(
834      new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
835  ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
836  ops.emplace_back(new TensorFlowUnsupported(
837      "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
838
839  // There operators are supported by Toco, but not by TF Lite, and has no
840  // attributes.
841  ops.emplace_back(
842      new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
843  ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
844  ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
845      "RSQRT", OperatorType::kTensorFlowRsqrt));
846  // Simple Operators.
847  ops.emplace_back(new SimpleOperator<DequantizeOperator>(
848      "DEQUANTIZE", OperatorType::kDequantize));
849  ops.emplace_back(
850      new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
851  ops.emplace_back(
852      new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu));
853  ops.emplace_back(
854      new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
855  ops.emplace_back(
856      new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
857  ops.emplace_back(new SimpleOperator<LogisticOperator>(
858      "LOGISTIC", OperatorType::kLogistic));
859  ops.emplace_back(
860      new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
861  ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp));
862
863  return ops;
864}
865}  // namespace
866
867std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
868  std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
869
870  std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
871  for (auto& op : ops) {
872    result[op->type()] = std::move(op);
873  }
874
875  return result;
876}
877
878std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
879  std::map<string, std::unique_ptr<BaseOperator>> result;
880
881  std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
882  for (auto& op : ops) {
883    result[op->name()] = std::move(op);
884  }
885
886  return result;
887}
888
889}  // namespace tflite
890
891}  // namespace toco
892