1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/tflite/types.h"
16
17#include <gmock/gmock.h>
18#include <gtest/gtest.h>
19
20namespace toco {
21
22namespace tflite {
23namespace {
24
25using flatbuffers::FlatBufferBuilder;
26using flatbuffers::Offset;
27using flatbuffers::Vector;
28
29// These are types that exist in TF Mini but don't have a correspondence
30// in TF Lite.
31static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone,
32                                                      ArrayDataType::kBool};
33
34// These are TF Lite types for which there is no correspondence in TF Mini.
35static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = {
36    ::tflite::TensorType_FLOAT16};
37
38// A little helper to match flatbuffer offsets.
39MATCHER_P(HasOffset, value, "") { return arg.o == value; }
40
41// Helper function that creates an array, writes it into a flatbuffer, and then
42// reads it back in.
43template <ArrayDataType T>
44Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType<T>> items) {
45  // NOTE: This test does not construct the full buffers list. Since
46  // Deserialize normally takes a buffer, we need to synthesize one and provide
47  // an index that is non-zero so the buffer is not assumed to be emtpy.
48  Array src;
49  src.data_type = T;
50  src.GetMutableBuffer<T>().data = items;
51
52  Array result;
53  flatbuffers::FlatBufferBuilder builder;
54  builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T),
55                              /*buffer*/ 1));  // Can't use 0 which means empty.
56  flatbuffers::FlatBufferBuilder buffer_builder;
57  Offset<Vector<uint8_t>> data_buffer =
58      DataBuffer::Serialize(src, &buffer_builder);
59  buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer));
60
61  auto* tensor =
62      flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
63  auto* buffer =
64      flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
65  DataBuffer::Deserialize(*tensor, *buffer, &result);
66  return result;
67}
68
69TEST(DataType, SupportedTypes) {
70  std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = {
71      {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
72      {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
73      {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
74      {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}};
75  for (auto x : testdata) {
76    EXPECT_EQ(x.second, DataType::Serialize(x.first));
77    EXPECT_EQ(x.first, DataType::Deserialize(x.second));
78  }
79}
80
81TEST(DataType, UnsupportedTypes) {
82  for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
83    EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type.");
84  }
85
86  // Unsupported types are all serialized as FLOAT32 currently.
87  for (ArrayDataType t : kUnsupportedTocoTypes) {
88    EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t));
89  }
90}
91
92TEST(DataBuffer, EmptyBuffers) {
93  flatbuffers::FlatBufferBuilder builder;
94  Array array;
95  EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0));
96
97  builder.Finish(::tflite::CreateTensor(builder));
98  auto* tensor =
99      flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
100  flatbuffers::FlatBufferBuilder buffer_builder;
101  Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({});
102  buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
103  auto* buffer =
104      flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
105
106  DataBuffer::Deserialize(*tensor, *buffer, &array);
107  EXPECT_EQ(nullptr, array.buffer);
108}
109
110TEST(DataBuffer, UnsupportedTypes) {
111  for (ArrayDataType t : kUnsupportedTocoTypes) {
112    flatbuffers::FlatBufferBuilder builder;
113    Array array;
114    array.data_type = t;
115    array.GetMutableBuffer<ArrayDataType::kFloat>();  // This is OK.
116    EXPECT_DEATH(DataBuffer::Serialize(array, &builder),
117                 "Unhandled array data type.");
118  }
119
120  for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
121    flatbuffers::FlatBufferBuilder builder;
122    builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1));
123    flatbuffers::FlatBufferBuilder buffer_builder;
124    Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({1});
125    buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
126    auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>(
127        buffer_builder.GetBufferPointer());
128    auto* tensor =
129        flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
130    Array array;
131    EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array),
132                 "Unhandled tensor type.");
133  }
134}
135
136TEST(DataBuffer, Float) {
137  Array recovered = ToFlatBufferAndBack<ArrayDataType::kFloat>({1.0f, 2.0f});
138  EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kFloat>().data,
139              ::testing::ElementsAre(1.0f, 2.0f));
140}
141
142TEST(DataBuffer, Uint8) {
143  Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint8>({127, 244});
144  EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint8>().data,
145              ::testing::ElementsAre(127, 244));
146}
147
148TEST(DataBuffer, Int32) {
149  Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt32>({1, 1 << 30});
150  EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt32>().data,
151              ::testing::ElementsAre(1, 1 << 30));
152}
153
154TEST(Padding, All) {
155  EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
156  EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
157
158  EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid));
159  EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID));
160
161  EXPECT_DEATH(Padding::Serialize(static_cast<PaddingType>(10000)),
162               "Unhandled padding type.");
163  EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding.");
164}
165
166TEST(ActivationFunction, All) {
167  std::vector<
168      std::pair<FusedActivationFunctionType, ::tflite::ActivationFunctionType>>
169      testdata = {{FusedActivationFunctionType::kNone,
170                   ::tflite::ActivationFunctionType_NONE},
171                  {FusedActivationFunctionType::kRelu,
172                   ::tflite::ActivationFunctionType_RELU},
173                  {FusedActivationFunctionType::kRelu6,
174                   ::tflite::ActivationFunctionType_RELU6},
175                  {FusedActivationFunctionType::kRelu1,
176                   ::tflite::ActivationFunctionType_RELU_N1_TO_1}};
177  for (auto x : testdata) {
178    EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first));
179    EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second));
180  }
181
182  EXPECT_DEATH(ActivationFunction::Serialize(
183                   static_cast<FusedActivationFunctionType>(10000)),
184               "Unhandled fused activation function type.");
185  EXPECT_DEATH(ActivationFunction::Deserialize(10000),
186               "Unhandled fused activation function type.");
187}
188
189}  // namespace
190}  // namespace tflite
191
192}  // namespace toco
193