types_test.cc revision c8b59c046895fa5b6d79f73e0b5817330fcfbfc1
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/types.h" 17 18#include "tensorflow/core/framework/type_traits.h" 19#include "tensorflow/core/platform/protobuf.h" 20#include "tensorflow/core/platform/test.h" 21 22namespace tensorflow { 23namespace { 24 25TEST(TypesTest, DeviceTypeName) { 26 EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU))); 27 EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU))); 28} 29 30TEST(TypesTest, kDataTypeRefOffset) { 31 // Basic sanity check 32 EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF); 33 34 // Use the meta-data provided by proto2 to iterate through the basic 35 // types and validate that adding kDataTypeRefOffset gives the 36 // corresponding reference type. 37 const auto* enum_descriptor = DataType_descriptor(); 38 int e = DataType_MIN; 39 if (e == DT_INVALID) ++e; 40 int e_ref = e + kDataTypeRefOffset; 41 EXPECT_FALSE(DataType_IsValid(e_ref - 1)) 42 << "Reference enum " 43 << enum_descriptor->FindValueByNumber(e_ref - 1)->name() 44 << " without corresponding base enum with value " << e - 1; 45 for (; 46 DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX; 47 ++e, ++e_ref) { 48 string enum_name = enum_descriptor->FindValueByNumber(e)->name(); 49 string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name(); 50 EXPECT_EQ(enum_name + "_REF", enum_ref_name) 51 << enum_name << "_REF should have value " << e_ref << " not " 52 << enum_ref_name; 53 // Validate DataTypeString() as well. 54 DataType dt_e = static_cast<DataType>(e); 55 DataType dt_e_ref = static_cast<DataType>(e_ref); 56 EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref)); 57 58 // Test DataTypeFromString reverse conversion 59 DataType dt_e2, dt_e2_ref; 60 EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2)); 61 EXPECT_EQ(dt_e, dt_e2); 62 EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref)); 63 EXPECT_EQ(dt_e_ref, dt_e2_ref); 64 } 65 ASSERT_FALSE(DataType_IsValid(e)) 66 << "Should define " << enum_descriptor->FindValueByNumber(e)->name() 67 << "_REF to be " << e_ref; 68 ASSERT_FALSE(DataType_IsValid(e_ref)) 69 << "Extra reference enum " 70 << enum_descriptor->FindValueByNumber(e_ref)->name() 71 << " without corresponding base enum with value " << e; 72 ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for " 73 << e_ref; 74 75 // Make sure there are no enums defined after the last regular type before 76 // the first reference type. 77 for (; e < DataType_MIN + kDataTypeRefOffset; ++e) { 78 EXPECT_FALSE(DataType_IsValid(e)) 79 << "Discontinuous enum value " 80 << enum_descriptor->FindValueByNumber(e)->name() << " = " << e; 81 } 82} 83 84TEST(TypesTest, DataTypeFromString) { 85 DataType dt; 86 ASSERT_TRUE(DataTypeFromString("int32", &dt)); 87 EXPECT_EQ(DT_INT32, dt); 88 ASSERT_TRUE(DataTypeFromString("int32_ref", &dt)); 89 EXPECT_EQ(DT_INT32_REF, dt); 90 EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt)); 91 EXPECT_FALSE(DataTypeFromString("foo", &dt)); 92 EXPECT_FALSE(DataTypeFromString("foo_ref", &dt)); 93 ASSERT_TRUE(DataTypeFromString("int64", &dt)); 94 EXPECT_EQ(DT_INT64, dt); 95 ASSERT_TRUE(DataTypeFromString("int64_ref", &dt)); 96 EXPECT_EQ(DT_INT64_REF, dt); 97 ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt)); 98 EXPECT_EQ(DT_QUINT8_REF, dt); 99 ASSERT_TRUE(DataTypeFromString("bfloat16", &dt)); 100 EXPECT_EQ(DT_BFLOAT16, dt); 101} 102 103template <typename T> 104static bool GetQuantized() { 105 return is_quantized<T>::value; 106} 107 108TEST(TypesTest, QuantizedTypes) { 109 // NOTE: GUnit cannot parse is::quantized<TYPE>::value() within the 110 // EXPECT_TRUE() clause, so we delegate through a template function. 111 EXPECT_TRUE(GetQuantized<qint8>()); 112 EXPECT_TRUE(GetQuantized<quint8>()); 113 EXPECT_TRUE(GetQuantized<qint32>()); 114 115 EXPECT_FALSE(GetQuantized<int8>()); 116 EXPECT_FALSE(GetQuantized<uint8>()); 117 EXPECT_FALSE(GetQuantized<int16>()); 118 EXPECT_FALSE(GetQuantized<int32>()); 119 120 EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8)); 121 EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8)); 122 EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32)); 123 124 EXPECT_FALSE(DataTypeIsQuantized(DT_INT8)); 125 EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8)); 126 EXPECT_FALSE(DataTypeIsQuantized(DT_UINT16)); 127 EXPECT_FALSE(DataTypeIsQuantized(DT_INT16)); 128 EXPECT_FALSE(DataTypeIsQuantized(DT_INT32)); 129 EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16)); 130} 131 132TEST(TypesTest, IntegerTypes) { 133 for (auto dt : AllTypes()) { 134 const string name = DataTypeString(dt); 135 const StringPiece n = name; 136 EXPECT_EQ(DataTypeIsInteger(dt), 137 n.starts_with("int") || n.starts_with("uint")) 138 << "DataTypeInteger failed for " << name; 139 } 140} 141 142} // namespace 143} // namespace tensorflow 144