1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/tflite/operator.h"
16
17#include "flatbuffers/flexbuffers.h"
18#include <gmock/gmock.h>
19#include <gtest/gtest.h>
20#include "tensorflow/contrib/lite/toco/tooling_util.h"
21
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24
25namespace toco {
26
27namespace tflite {
28namespace {
29
30class OperatorTest : public ::testing::Test {
31 protected:
32  // Return the operator for the given name and type.
33  const BaseOperator& GetOperator(const string& name, OperatorType type) {
34    using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>;
35    using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
36
37    static auto* by_name = new OpsByName(BuildOperatorByNameMap());
38    static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
39
40    // Make sure the two maps were consitently built.
41    CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
42    BaseOperator* op1 = by_name->at(name).get();
43    CHECK(op1->type() == type) << "while verifying '" << name << "'.";
44
45    CHECK(by_type->count(type))
46        << "No operator for '" << OperatorTypeName(type) << "'.";
47    BaseOperator* op2 = by_type->at(type).get();
48    CHECK(op2->name() == name)
49        << "while verifying '" << OperatorTypeName(type) << "'.";
50
51    return *op1;
52  }
53
54  // Use the given BaseOperator to serialize the tf.mini operator into a set of
55  // TF Lite options. Proceed to deserialize the options back into a new
56  // tf.mini operator, which is then returned. If `options` is given, it will
57  // be populated with the serialized options.
58  template <typename T>
59  std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
60                                             const T& toco_op,
61                                             Options* options = nullptr) {
62    flatbuffers::FlatBufferBuilder builder;
63    Options input_options = op.Serialize(toco_op, &builder);
64
65    if (options) {
66      *options = input_options;
67    }
68
69    builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
70                                  input_options.builtin, input_options.custom,
71                                  ::tflite::CustomOptionsFormat_FLEXBUFFERS));
72    auto* output_options =
73        flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
74    auto new_toco_op = op.Deserialize(output_options->builtin_options(),
75                                      output_options->custom_options());
76
77    CHECK(dynamic_cast<T*>(new_toco_op.get()))
78        << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to "
79        << HelpfulOperatorTypeName(toco_op);
80
81    return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
82  }
83
84  // Verify serialization and deserialization of simple operators (those
85  // that don't have any configuration parameters).
86  template <typename T>
87  void CheckSimpleOperator(const string& name, OperatorType type) {
88    Options options;
89    auto output_toco_op =
90        SerializeAndDeserialize(GetOperator(name, type), T(), &options);
91
92    ASSERT_EQ(0, options.builtin.o);
93    ASSERT_EQ(0, options.custom.o);
94    ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
95
96    ASSERT_NE(nullptr, output_toco_op.get());
97  }
98};
99
100TEST_F(OperatorTest, SimpleOperators) {
101  CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE",
102                                          OperatorType::kDequantize);
103  CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
104  CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
105  CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
106  CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
107  CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
108  CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
109  CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
110}
111
112TEST_F(OperatorTest, BuiltinAdd) {
113  AddOperator op;
114  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
115  auto output_toco_op =
116      SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
117  EXPECT_EQ(op.fused_activation_function,
118            output_toco_op->fused_activation_function);
119}
120
121TEST_F(OperatorTest, BuiltinMean) {
122  MeanOperator op;
123  op.keep_dims = false;
124
125  auto output_toco_op =
126      SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op);
127  EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
128}
129
130TEST_F(OperatorTest, CustomCast) {
131  CastOperator op;
132  op.src_data_type = ArrayDataType::kFloat;
133  op.dst_data_type = ArrayDataType::kUint8;
134  auto output_toco_op =
135      SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
136  EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
137  EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
138}
139
140TEST_F(OperatorTest, CustomConcatenation) {
141  ConcatenationOperator op;
142  op.axis = 123;
143  auto output_toco_op = SerializeAndDeserialize(
144      GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
145  EXPECT_EQ(op.axis, output_toco_op->axis);
146}
147
148TEST_F(OperatorTest, CustomDepthToSpace) {
149  DepthToSpaceOperator op;
150  op.block_size = 123;
151  auto output_toco_op = SerializeAndDeserialize(
152      GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
153  EXPECT_EQ(op.block_size, output_toco_op->block_size);
154}
155
156TEST_F(OperatorTest, CustomFakeQuant) {
157  FakeQuantOperator op;
158  auto* minmax = new MinMax;
159  minmax->min = -10;
160  minmax->max = 200;
161  op.minmax.reset(minmax);
162  auto output_toco_op = SerializeAndDeserialize(
163      GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
164  EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
165  EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
166}
167
168TEST_F(OperatorTest, CustomFullyConnected) {
169  FullyConnectedOperator op;
170  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
171  auto output_toco_op = SerializeAndDeserialize(
172      GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
173  EXPECT_EQ(op.fused_activation_function,
174            output_toco_op->fused_activation_function);
175}
176
177TEST_F(OperatorTest, BuiltinGather) {
178  GatherOperator op;
179  auto output_toco_op =
180      SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
181  ASSERT_NE(nullptr, output_toco_op.get());
182}
183
184TEST_F(OperatorTest, BuiltinL2Pool) {
185  L2PoolOperator op;
186  op.stride_width = 123;
187  op.stride_height = 124;
188  op.padding.type = PaddingType::kValid;
189  op.kwidth = 480;
190  op.kheight = 1080;
191  auto output_toco_op = SerializeAndDeserialize(
192      GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
193  EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
194  EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
195  EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
196  EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
197  EXPECT_EQ(op.kheight, output_toco_op->kheight);
198}
199
200TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
201  LocalResponseNormalizationOperator op;
202  op.range = 123;
203  op.bias = 1.23;
204  op.alpha = 12.3;
205  op.beta = .123;
206  auto output_toco_op = SerializeAndDeserialize(
207      GetOperator("LOCAL_RESPONSE_NORMALIZATION",
208                  OperatorType::kLocalResponseNormalization),
209      op);
210  EXPECT_EQ(op.range, output_toco_op->range);
211  EXPECT_EQ(op.bias, output_toco_op->bias);
212  EXPECT_EQ(op.alpha, output_toco_op->alpha);
213  EXPECT_EQ(op.beta, output_toco_op->beta);
214}
215
216TEST_F(OperatorTest, BuiltinMaxPool) {
217  MaxPoolOperator op;
218  op.stride_width = 123;
219  op.stride_height = 124;
220  op.padding.type = PaddingType::kValid;
221  op.kwidth = 480;
222  op.kheight = 1080;
223  auto output_toco_op = SerializeAndDeserialize(
224      GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
225  EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
226  EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
227  EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
228  EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
229  EXPECT_EQ(op.kheight, output_toco_op->kheight);
230}
231
232TEST_F(OperatorTest, BuiltinReshape) {
233  TensorFlowReshapeOperator op;
234  op.shape = {1, 2, 4, 5, 8};
235  auto output_toco_op = SerializeAndDeserialize(
236      GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op);
237  EXPECT_EQ(op.shape, output_toco_op->shape);
238}
239
240TEST_F(OperatorTest, CustomSoftmax) {
241  SoftmaxOperator op;
242  op.beta = 123.1;
243  auto output_toco_op = SerializeAndDeserialize(
244      GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
245  EXPECT_EQ(op.beta, output_toco_op->beta);
246}
247
248TEST_F(OperatorTest, BuiltinSpaceToDepth) {
249  SpaceToDepthOperator op;
250  op.block_size = 123;
251  auto output_toco_op = SerializeAndDeserialize(
252      GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
253  EXPECT_EQ(op.block_size, output_toco_op->block_size);
254}
255
256TEST_F(OperatorTest, CustomSplit) {
257  TensorFlowSplitOperator op;
258  op.num_split = 123;
259  auto output_toco_op = SerializeAndDeserialize(
260      GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op);
261  EXPECT_EQ(op.num_split, output_toco_op->num_split);
262}
263
264TEST_F(OperatorTest, BuiltinAveragePool) {
265  AveragePoolOperator op;
266  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
267  op.stride_width = 123;
268  op.stride_height = 124;
269  op.padding.type = PaddingType::kValid;
270  op.kwidth = 480;
271  op.kheight = 1080;
272  auto output_toco_op = SerializeAndDeserialize(
273      GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
274  EXPECT_EQ(op.fused_activation_function,
275            output_toco_op->fused_activation_function);
276  EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
277  EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
278  EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
279  EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
280  EXPECT_EQ(op.kheight, output_toco_op->kheight);
281}
282
283TEST_F(OperatorTest, BuiltinConvolution) {
284  ConvOperator op;
285  op.stride_width = 123;
286  op.stride_height = 124;
287  op.padding.type = PaddingType::kValid;
288  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
289  auto output_toco_op =
290      SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
291  EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
292  EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
293  EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
294  EXPECT_EQ(op.fused_activation_function,
295            output_toco_op->fused_activation_function);
296}
297
298TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
299  DepthwiseConvOperator op;
300  op.stride_width = 123;
301  op.stride_height = 124;
302  op.padding.type = PaddingType::kValid;
303  op.depth_multiplier = 6;
304  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
305  auto output_toco_op = SerializeAndDeserialize(
306      GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
307  EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
308  EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
309  EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
310  EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
311  EXPECT_EQ(op.fused_activation_function,
312            output_toco_op->fused_activation_function);
313}
314
315TEST_F(OperatorTest, BuiltinL2Norm) {
316  L2NormalizationOperator op;
317  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
318  auto output_toco_op = SerializeAndDeserialize(
319      GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
320  EXPECT_EQ(op.fused_activation_function,
321            output_toco_op->fused_activation_function);
322}
323
324TEST_F(OperatorTest, BuiltinMul) {
325  MulOperator op;
326  op.fused_activation_function = FusedActivationFunctionType::kRelu6;
327  auto output_toco_op =
328      SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
329  EXPECT_EQ(op.fused_activation_function,
330            output_toco_op->fused_activation_function);
331}
332
333TEST_F(OperatorTest, ResizeBilinear) {
334  ResizeBilinearOperator op;
335  op.align_corners = true;
336  auto output_toco_op = SerializeAndDeserialize(
337      GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
338  EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
339}
340
341TEST_F(OperatorTest, Svdf) {
342  SvdfOperator op;
343  op.fused_activation_function = FusedActivationFunctionType::kRelu;
344  op.rank = 1;
345  auto output_toco_op =
346      SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
347  EXPECT_EQ(op.fused_activation_function,
348            output_toco_op->fused_activation_function);
349  EXPECT_EQ(op.rank, output_toco_op->rank);
350}
351
352TEST_F(OperatorTest, Squeeze) {
353  SqueezeOperator op;
354  op.squeeze_dims = {-2, -3, 4, 1, 4};
355
356  auto output_toco_op = SerializeAndDeserialize(
357      GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
358  EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
359}
360
361TEST_F(OperatorTest, StridedSlice) {
362  StridedSliceOperator op;
363
364  op.begin_mask = 1;
365  op.end_mask = 2;
366  op.ellipsis_mask = 1;
367  op.new_axis_mask = 1;
368  op.shrink_axis_mask = 2;
369
370  auto output_toco_op = SerializeAndDeserialize(
371      GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
372  EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
373  EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
374  EXPECT_EQ(op.strides, output_toco_op->strides);
375  EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
376  EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
377  EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
378  EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
379  EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
380  EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
381}
382
383TEST_F(OperatorTest, BuiltinTopKV2) {
384  TopKV2Operator op;
385  auto output_toco_op = SerializeAndDeserialize(
386      GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
387  ASSERT_NE(nullptr, output_toco_op.get());
388}
389
390TEST_F(OperatorTest, TensorFlowUnsupported) {
391  TensorFlowUnsupportedOperator op;
392  op.tensorflow_op = "MyCustomUnsupportedOp";
393
394  ::tensorflow::NodeDef node_def;
395  auto attr = node_def.mutable_attr();
396  (*attr)["float_attr"].set_f(2.0);
397  (*attr)["str_attr"].set_s("Hello World");
398  (*attr)["int_attr"].set_i(17);
399  (*attr)["bool_attr"].set_b(true);
400  node_def.SerializeToString(&op.tensorflow_node_def);
401
402  auto output_toco_op =
403      SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
404                                          OperatorType::kTensorFlowUnsupported),
405                              op);
406
407  ::tensorflow::NodeDef output_node_def;
408  output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
409  const auto& output_attr = output_node_def.attr();
410  EXPECT_EQ(2.0, output_attr.at("float_attr").f());
411  EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
412  EXPECT_EQ(17, output_attr.at("int_attr").i());
413  EXPECT_EQ(true, output_attr.at("bool_attr").b());
414}
415
416TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
417  TensorFlowUnsupportedOperator op;
418  op.tensorflow_op = "MyCustomUnsupportedOp";
419  auto output_toco_op =
420      SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
421                                          OperatorType::kTensorFlowUnsupported),
422                              op);
423
424  ::tensorflow::NodeDef output_node_def;
425  output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
426  EXPECT_TRUE(output_node_def.attr().empty());
427}
428
429}  // namespace
430}  // namespace tflite
431
432}  // namespace toco
433