10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tflite/types.h"
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <gmock/gmock.h>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <gtest/gtest.h>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace tflite {
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleusing flatbuffers::FlatBufferBuilder;
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleusing flatbuffers::Offset;
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleusing flatbuffers::Vector;
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// These are types that exist in TF Mini but don't have a correspondence
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// in TF Lite.
3136635f4a389c812bb328821e5e533feeef7d26edYu-Cheng Lingstatic const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone,
3236635f4a389c812bb328821e5e533feeef7d26edYu-Cheng Ling                                                      ArrayDataType::kBool};
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// These are TF Lite types for which there is no correspondence in TF Mini.
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestatic const ::tflite::TensorType kUnsupportedTfLiteTypes[] = {
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    ::tflite::TensorType_FLOAT16};
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// A little helper to match flatbuffer offsets.
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleMATCHER_P(HasOffset, value, "") { return arg.o == value; }
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Helper function that creates an array, writes it into a flatbuffer, and then
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// reads it back in.
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selletemplate <ArrayDataType T>
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleArray ToFlatBufferAndBack(std::initializer_list<::toco::DataType<T>> items) {
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // NOTE: This test does not construct the full buffers list. Since
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Deserialize normally takes a buffer, we need to synthesize one and provide
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // an index that is non-zero so the buffer is not assumed to be emtpy.
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Array src;
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  src.data_type = T;
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  src.GetMutableBuffer<T>().data = items;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Array result;
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  flatbuffers::FlatBufferBuilder builder;
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T),
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                              /*buffer*/ 1));  // Can't use 0 which means empty.
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  flatbuffers::FlatBufferBuilder buffer_builder;
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Offset<Vector<uint8_t>> data_buffer =
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      DataBuffer::Serialize(src, &buffer_builder);
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer));
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* tensor =
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* buffer =
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DataBuffer::Deserialize(*tensor, *buffer, &result);
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return result;
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataType, SupportedTypes) {
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = {
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
7336635f4a389c812bb328821e5e533feeef7d26edYu-Cheng Ling      {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}};
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (auto x : testdata) {
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_EQ(x.second, DataType::Serialize(x.first));
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_EQ(x.first, DataType::Deserialize(x.second));
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataType, UnsupportedTypes) {
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type.");
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Unsupported types are all serialized as FLOAT32 currently.
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (ArrayDataType t : kUnsupportedTocoTypes) {
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t));
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataBuffer, EmptyBuffers) {
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  flatbuffers::FlatBufferBuilder builder;
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Array array;
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0));
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  builder.Finish(::tflite::CreateTensor(builder));
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* tensor =
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  flatbuffers::FlatBufferBuilder buffer_builder;
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({});
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* buffer =
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer());
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DataBuffer::Deserialize(*tensor, *buffer, &array);
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_EQ(nullptr, array.buffer);
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataBuffer, UnsupportedTypes) {
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (ArrayDataType t : kUnsupportedTocoTypes) {
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    flatbuffers::FlatBufferBuilder builder;
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Array array;
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    array.data_type = t;
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    array.GetMutableBuffer<ArrayDataType::kFloat>();  // This is OK.
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_DEATH(DataBuffer::Serialize(array, &builder),
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                 "Unhandled array data type.");
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (::tflite::TensorType t : kUnsupportedTfLiteTypes) {
1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    flatbuffers::FlatBufferBuilder builder;
1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1));
1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    flatbuffers::FlatBufferBuilder buffer_builder;
1240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Offset<Vector<uint8_t>> v = buffer_builder.CreateVector<uint8_t>({1});
1250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v));
1260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>(
1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        buffer_builder.GetBufferPointer());
1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    auto* tensor =
1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer());
1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Array array;
1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array),
1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                 "Unhandled tensor type.");
1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataBuffer, Float) {
1370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Array recovered = ToFlatBufferAndBack<ArrayDataType::kFloat>({1.0f, 2.0f});
1380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kFloat>().data,
1390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ::testing::ElementsAre(1.0f, 2.0f));
1400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataBuffer, Uint8) {
1430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint8>({127, 244});
1440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint8>().data,
1450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ::testing::ElementsAre(127, 244));
1460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(DataBuffer, Int32) {
1490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt32>({1, 1 << 30});
1500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt32>().data,
1510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              ::testing::ElementsAre(1, 1 << 30));
1520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(Padding, All) {
1550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
1560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
1570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid));
1590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID));
1600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_DEATH(Padding::Serialize(static_cast<PaddingType>(10000)),
1620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle               "Unhandled padding type.");
1630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding.");
1640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleTEST(ActivationFunction, All) {
1670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<
1680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      std::pair<FusedActivationFunctionType, ::tflite::ActivationFunctionType>>
1690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      testdata = {{FusedActivationFunctionType::kNone,
1700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   ::tflite::ActivationFunctionType_NONE},
1710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {FusedActivationFunctionType::kRelu,
1720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   ::tflite::ActivationFunctionType_RELU},
1730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {FusedActivationFunctionType::kRelu6,
1740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   ::tflite::ActivationFunctionType_RELU6},
1750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                  {FusedActivationFunctionType::kRelu1,
176a4973345351a14a786987cd7f648a99c029fdc1dA. Unique TensorFlower                   ::tflite::ActivationFunctionType_RELU_N1_TO_1}};
1770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (auto x : testdata) {
1780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first));
1790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second));
1800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_DEATH(ActivationFunction::Serialize(
1830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                   static_cast<FusedActivationFunctionType>(10000)),
1840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle               "Unhandled fused activation function type.");
1850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  EXPECT_DEATH(ActivationFunction::Deserialize(10000),
1860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle               "Unhandled fused activation function type.");
1870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
1900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace tflite
1910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
193