1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/contrib/lite/toco/tflite/types.h" 16 17#include <gmock/gmock.h> 18#include <gtest/gtest.h> 19 20namespace toco { 21 22namespace tflite { 23namespace { 24 25using flatbuffers::FlatBufferBuilder; 26using flatbuffers::Offset; 27using flatbuffers::Vector; 28 29// These are types that exist in TF Mini but don't have a correspondence 30// in TF Lite. 31static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone, 32 ArrayDataType::kBool}; 33 34// These are TF Lite types for which there is no correspondence in TF Mini. 35static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = { 36 ::tflite::TensorType_FLOAT16}; 37 38// A little helper to match flatbuffer offsets. 39MATCHER_P(HasOffset, value, "") { return arg.o == value; } 40 41// Helper function that creates an array, writes it into a flatbuffer, and then 42// reads it back in. 43template <ArrayDataType T> 44Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType<T>> items) { 45 // NOTE: This test does not construct the full buffers list. Since 46 // Deserialize normally takes a buffer, we need to synthesize one and provide 47 // an index that is non-zero so the buffer is not assumed to be emtpy. 48 Array src; 49 src.data_type = T; 50 src.GetMutableBuffer<T>().data = items; 51 52 Array result; 53 flatbuffers::FlatBufferBuilder builder; 54 builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T), 55 /*buffer*/ 1)); // Can't use 0 which means empty. 56 flatbuffers::FlatBufferBuilder buffer_builder; 57 Offset<Vector<uint8_t>> data_buffer = 58 DataBuffer::Serialize(src, &buffer_builder); 59 buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer)); 60 61 auto* tensor = 62 flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); 63 auto* buffer = 64 flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer()); 65 DataBuffer::Deserialize(*tensor, *buffer, &result); 66 return result; 67} 68 69TEST(DataType, SupportedTypes) { 70 std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = { 71 {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, 72 {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, 73 {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, 74 {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}}; 75 for (auto x : testdata) { 76 EXPECT_EQ(x.second, DataType::Serialize(x.first)); 77 EXPECT_EQ(x.first, DataType::Deserialize(x.second)); 78 } 79} 80 81TEST(DataType, UnsupportedTypes) { 82 for (::tflite::TensorType t : kUnsupportedTfLiteTypes) { 83 EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type."); 84 } 85 86 // Unsupported types are all serialized as FLOAT32 currently. 87 for (ArrayDataType t : kUnsupportedTocoTypes) { 88 EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t)); 89 } 90} 91 92TEST(DataBuffer, EmptyBuffers) { 93 flatbuffers::FlatBufferBuilder builder; 94 Array array; 95 EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0)); 96 97 builder.Finish(::tflite::CreateTensor(builder)); 98 auto* tensor = 99 flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); 100 flatbuffers::FlatBufferBuilder buffer_builder; 101 Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({}); 102 buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v)); 103 auto* buffer = 104 flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer()); 105 106 DataBuffer::Deserialize(*tensor, *buffer, &array); 107 EXPECT_EQ(nullptr, array.buffer); 108} 109 110TEST(DataBuffer, UnsupportedTypes) { 111 for (ArrayDataType t : kUnsupportedTocoTypes) { 112 flatbuffers::FlatBufferBuilder builder; 113 Array array; 114 array.data_type = t; 115 array.GetMutableBuffer<ArrayDataType::kFloat>(); // This is OK. 116 EXPECT_DEATH(DataBuffer::Serialize(array, &builder), 117 "Unhandled array data type."); 118 } 119 120 for (::tflite::TensorType t : kUnsupportedTfLiteTypes) { 121 flatbuffers::FlatBufferBuilder builder; 122 builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1)); 123 flatbuffers::FlatBufferBuilder buffer_builder; 124 Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({1}); 125 buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v)); 126 auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>( 127 buffer_builder.GetBufferPointer()); 128 auto* tensor = 129 flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); 130 Array array; 131 EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array), 132 "Unhandled tensor type."); 133 } 134} 135 136TEST(DataBuffer, Float) { 137 Array recovered = ToFlatBufferAndBack<ArrayDataType::kFloat>({1.0f, 2.0f}); 138 EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kFloat>().data, 139 ::testing::ElementsAre(1.0f, 2.0f)); 140} 141 142TEST(DataBuffer, Uint8) { 143 Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint8>({127, 244}); 144 EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint8>().data, 145 ::testing::ElementsAre(127, 244)); 146} 147 148TEST(DataBuffer, Int32) { 149 Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt32>({1, 1 << 30}); 150 EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt32>().data, 151 ::testing::ElementsAre(1, 1 << 30)); 152} 153 154TEST(Padding, All) { 155 EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); 156 EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); 157 158 EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid)); 159 EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID)); 160 161 EXPECT_DEATH(Padding::Serialize(static_cast<PaddingType>(10000)), 162 "Unhandled padding type."); 163 EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding."); 164} 165 166TEST(ActivationFunction, All) { 167 std::vector< 168 std::pair<FusedActivationFunctionType, ::tflite::ActivationFunctionType>> 169 testdata = {{FusedActivationFunctionType::kNone, 170 ::tflite::ActivationFunctionType_NONE}, 171 {FusedActivationFunctionType::kRelu, 172 ::tflite::ActivationFunctionType_RELU}, 173 {FusedActivationFunctionType::kRelu6, 174 ::tflite::ActivationFunctionType_RELU6}, 175 {FusedActivationFunctionType::kRelu1, 176 ::tflite::ActivationFunctionType_RELU_N1_TO_1}}; 177 for (auto x : testdata) { 178 EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first)); 179 EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second)); 180 } 181 182 EXPECT_DEATH(ActivationFunction::Serialize( 183 static_cast<FusedActivationFunctionType>(10000)), 184 "Unhandled fused activation function type."); 185 EXPECT_DEATH(ActivationFunction::Deserialize(10000), 186 "Unhandled fused activation function type."); 187} 188 189} // namespace 190} // namespace tflite 191 192} // namespace toco 193