1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_ 17#define TENSORFLOW_FRAMEWORK_TYPES_H_ 18 19#include <map> 20#include <set> 21#include <string> 22 23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24// Disable clang-format to prevent 'FixedPoint' header from being included 25// before 'Tensor' header on which it depends. 26// clang-format off 27#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" 28// clang-format on 29#include "tensorflow/core/framework/bfloat16.h" 30#include "tensorflow/core/framework/numeric_types.h" 31#include "tensorflow/core/framework/resource_handle.h" 32#include "tensorflow/core/framework/types.pb.h" 33#include "tensorflow/core/framework/variant.h" 34#include "tensorflow/core/lib/core/stringpiece.h" 35#include "tensorflow/core/lib/gtl/array_slice.h" 36#include "tensorflow/core/lib/gtl/inlined_vector.h" 37#include "tensorflow/core/platform/logging.h" 38#include "tensorflow/core/platform/types.h" 39 40namespace tensorflow { 41 42// MemoryType is used to describe whether input or output Tensors of 43// an OpKernel should reside in "Host memory" (e.g., CPU memory) or 44// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU 45// devices). 46enum MemoryType { 47 DEVICE_MEMORY = 0, 48 HOST_MEMORY = 1, 49}; 50 51// A DeviceType is just a string, but we wrap it up in a class to give 52// some type checking as we're passing these around 53class DeviceType { 54 public: 55 DeviceType(const char* type) // NOLINT(runtime/explicit) 56 : type_(type) {} 57 58 explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} 59 60 const char* type() const { return type_.c_str(); } 61 const string& type_string() const { return type_; } 62 63 bool operator<(const DeviceType& other) const; 64 bool operator==(const DeviceType& other) const; 65 bool operator!=(const DeviceType& other) const { return !(*this == other); } 66 67 private: 68 string type_; 69}; 70std::ostream& operator<<(std::ostream& os, const DeviceType& d); 71 72// Convenient constants that can be passed to a DeviceType constructor 73TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" 74TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" 75TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL" 76 77template <typename Device> 78struct DeviceName {}; 79 80template <> 81struct DeviceName<Eigen::ThreadPoolDevice> { 82 static const std::string value; 83}; 84 85#if GOOGLE_CUDA 86template <> 87struct DeviceName<Eigen::GpuDevice> { 88 static const std::string value; 89}; 90#endif // GOOGLE_CUDA 91 92#ifdef TENSORFLOW_USE_SYCL 93template <> 94struct DeviceName<Eigen::SyclDevice> { 95 static const std::string value; 96}; 97#endif // TENSORFLOW_USE_SYCL 98 99typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector; 100typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice; 101 102typedef gtl::InlinedVector<DataType, 4> DataTypeVector; 103typedef gtl::ArraySlice<DataType> DataTypeSlice; 104 105typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; 106 107// Convert the enums to strings for errors: 108string DataTypeString(DataType dtype); 109string DeviceTypeString(const DeviceType& device_type); 110string DataTypeSliceString(const DataTypeSlice dtypes); 111inline string DataTypeVectorString(const DataTypeVector& dtypes) { 112 return DataTypeSliceString(dtypes); 113} 114 115// DataTypeSet represents a set of DataType values as a simple and efficient 116// bit mask. Note that DataTypeSet cannot represent all DataType values; it 117// cannot represent any of the DT_*_REF values. 118class DataTypeSet { 119 private: 120 const uint32 mask_; 121 122 static constexpr uint32 kNumBits = 32; 123 124 public: 125 constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {} 126 explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {} 127 128 constexpr bool Contains(DataType dt) const { 129 return (static_cast<uint32>(dt) < kNumBits) && 130 ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u; 131 } 132 133 class Iterator { 134 const DataTypeSet& set_; 135 uint32 pos_; 136 137 public: 138 Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) { 139 DCHECK_LE(pos, kNumBits); 140 } 141 DataType operator*() const { return static_cast<DataType>(pos_); } 142 Iterator& operator++() { 143 ++pos_; 144 DCHECK_LE(pos_, kNumBits); 145 if (pos_ < kNumBits) { 146 uint32 remaining_mask = set_.mask_ >> pos_; 147 if (remaining_mask != 0u) { 148 pos_ += ctz_uint32(remaining_mask); 149 } 150 } 151 DCHECK_LE(pos_, kNumBits); 152 return *this; 153 } 154 bool operator==(const Iterator& other) const { return pos_ == other.pos_; } 155 bool operator!=(const Iterator& other) const { return !(*this == other); } 156 size_t operator-(const Iterator& other) const { 157 return this->pos_ - other.pos_; 158 } 159 }; 160 161 static uint32 ctz_uint32(uint32 x) { 162 DCHECK_NE(x, 0u); 163#ifdef __GNUC__ 164 return __builtin_ctz(x); 165#else 166 uint32 n = 0u; 167 while ((x & 1u) == 0u) { 168 x >>= 1; 169 ++n; 170 } 171 return n; 172#endif 173 } 174 175 static uint32 clz_uint32(uint32 x) { 176 DCHECK_NE(x, 0u); 177#ifdef __GNUC__ 178 return __builtin_clz(x); 179#else 180 uint32 n = 0u; 181 while ((x >> (kNumBits - 1u)) == 0u) { 182 x <<= 1; 183 ++n; 184 } 185 return n; 186#endif 187 } 188 189 Iterator begin() const { 190 // The begin position is the index of the first bit set to 1 in the entire 191 // bit mask. If there are no bits set to 1, then the index is 0. 192 if (mask_ != 0) { 193 return Iterator(*this, ctz_uint32(mask_)); 194 } 195 // The set is empty. 196 return Iterator(*this, 0); 197 } 198 199 Iterator end() const { 200 // The end position is the index of the highest bit that is set, plus 1. 201 // If there are no bits set to 1, then the index is 0. 202 if (mask_ != 0) { 203 return Iterator(*this, kNumBits - clz_uint32(mask_)); 204 } 205 // The set is empty. 206 return Iterator(*this, 0); 207 } 208 209 size_t size() const { 210#if defined(__GNUC__) 211 return __builtin_popcount(mask_); 212#else 213 size_t n = 0; 214 uint32 x = mask_; 215 while (x > 0) { 216 n += x & 1u; 217 x >>= 1; 218 } 219 return n; 220#endif 221 } 222 223 constexpr DataTypeSet operator|(const DataTypeSet& other) const { 224 return DataTypeSet(mask_ | other.mask_); 225 } 226}; 227 228// If "sp" names a valid type, store it in "*dt" and return true. Otherwise, 229// return false. 230bool DataTypeFromString(StringPiece sp, DataType* dt); 231 232constexpr inline DataTypeSet ToSet(DataType dt) { 233 return DataTypeSet(1u << static_cast<uint32>(dt)); 234} 235 236// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. 237enum { kDataTypeRefOffset = 100 }; 238inline bool IsRefType(DataType dtype) { 239 return dtype > static_cast<DataType>(kDataTypeRefOffset); 240} 241inline DataType MakeRefType(DataType dtype) { 242 DCHECK(!IsRefType(dtype)); 243 return static_cast<DataType>(dtype + kDataTypeRefOffset); 244} 245inline DataType RemoveRefType(DataType dtype) { 246 DCHECK(IsRefType(dtype)); 247 return static_cast<DataType>(dtype - kDataTypeRefOffset); 248} 249inline DataType BaseType(DataType dtype) { 250 return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; 251} 252 253// Returns true if the actual type is the same as or ref of the expected type. 254inline bool TypesCompatible(DataType expected, DataType actual) { 255 return expected == actual || expected == BaseType(actual); 256} 257 258// Does not include _ref types. 259constexpr DataTypeSet kAllTypes = 260 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) | 261 ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) | 262 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | 263 ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | 264 ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) | 265 ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | 266 ToSet(DT_BFLOAT16); 267inline const DataTypeSet& AllTypes() { return kAllTypes; } 268 269#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) 270 271// Types that support '<' and '>'. 272constexpr DataTypeSet kRealNumberTypes = 273 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | 274 ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) | 275 ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); 276inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 277 278// Return the list of all numeric types. 279// Includes complex and quantized types. 280// NOTE: On Android, we only include the float and int32 types for now. 281const DataTypeSet kNumberTypes = 282 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) | 283 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 284 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) | 285 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) | 286 ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); 287inline const DataTypeSet& NumberTypes() { return kNumberTypes; } 288 289constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 290 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 291 ToSet(DT_QINT32); 292inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; } 293 294// Types that support '<' and '>', including quantized types. 295const DataTypeSet kRealAndQuantizedTypes = 296 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | 297 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | 298 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 299 ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16); 300inline const DataTypeSet& RealAndQuantizedTypes() { 301 return kRealAndQuantizedTypes; 302} 303 304#elif defined(__ANDROID_TYPES_FULL__) 305 306constexpr DataTypeSet kRealNumberTypes = 307 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF); 308inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 309 310constexpr DataTypeSet kNumberTypes = 311 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | 312 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF); 313inline DataTypeSet NumberTypes() { return kNumberTypes; } 314 315constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 316 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 317 ToSet(DT_QINT32); 318inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } 319 320constexpr DataTypeSet kRealAndQuantizedTypes = 321 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | 322 ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | 323 ToSet(DT_HALF); 324inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } 325 326#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) 327 328constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32); 329inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 330 331constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) | 332 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 333 ToSet(DT_QINT32); 334inline DataTypeSet NumberTypes() { return kNumberTypes; } 335 336constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 337 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 338 ToSet(DT_QINT32); 339inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } 340 341constexpr DataTypeSet kRealAndQuantizedTypes = 342 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 343 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32); 344inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } 345 346#endif // defined(IS_MOBILE_PLATFORM) 347 348// Validates type T for whether it is a supported DataType. 349template <class T> 350struct IsValidDataType; 351 352// DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType 353// constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT. 354template <class T> 355struct DataTypeToEnum { 356 static_assert(IsValidDataType<T>::value, "Specified Data Type not supported"); 357}; // Specializations below 358 359// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. 360// EnumToDataType<DT_FLOAT>::Type is float. 361template <DataType VALUE> 362struct EnumToDataType {}; // Specializations below 363 364// Template specialization for both DataTypeToEnum and EnumToDataType. 365#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ 366 template <> \ 367 struct DataTypeToEnum<TYPE> { \ 368 static DataType v() { return ENUM; } \ 369 static DataType ref() { return MakeRefType(ENUM); } \ 370 static constexpr DataType value = ENUM; \ 371 }; \ 372 template <> \ 373 struct IsValidDataType<TYPE> { \ 374 static constexpr bool value = true; \ 375 }; \ 376 template <> \ 377 struct EnumToDataType<ENUM> { \ 378 typedef TYPE Type; \ 379 } 380 381MATCH_TYPE_AND_ENUM(float, DT_FLOAT); 382MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); 383MATCH_TYPE_AND_ENUM(int32, DT_INT32); 384MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); 385MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); 386MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); 387MATCH_TYPE_AND_ENUM(int16, DT_INT16); 388MATCH_TYPE_AND_ENUM(int8, DT_INT8); 389MATCH_TYPE_AND_ENUM(string, DT_STRING); 390MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); 391MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); 392MATCH_TYPE_AND_ENUM(int64, DT_INT64); 393MATCH_TYPE_AND_ENUM(uint64, DT_UINT64); 394MATCH_TYPE_AND_ENUM(bool, DT_BOOL); 395MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); 396MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); 397MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); 398MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); 399MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); 400MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); 401MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); 402MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); 403MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); 404 405#undef MATCH_TYPE_AND_ENUM 406 407// All types not specialized are marked invalid. 408template <class T> 409struct IsValidDataType { 410 static constexpr bool value = false; 411}; 412 413// Extra validity checking; not part of public API. 414static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64"); 415static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32"); 416 417// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying 418// is_simple<T> in tensor.cc (and possible choose a more general name?) 419constexpr DataTypeSet kDataTypesCanUseMemcpy = 420 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) | 421 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 422 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | 423 ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 424 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | 425 ToSet(DT_BFLOAT16) | ToSet(DT_HALF); 426inline bool DataTypeCanUseMemcpy(DataType dt) { 427 return kDataTypesCanUseMemcpy.Contains(dt); 428} 429 430// Returns true iff 'dt' is a real, non-quantized floating point type. 431constexpr DataTypeSet kDataTypeIsFloating = 432 ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE); 433inline bool DataTypeIsFloating(DataType dt) { 434 return kDataTypeIsFloating.Contains(dt); 435} 436 437// Returns true iff 'dt' is a complex type. 438constexpr DataTypeSet kDataTypeIsComplex = 439 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128); 440inline bool DataTypeIsComplex(DataType dt) { 441 return kDataTypeIsComplex.Contains(dt); 442} 443 444inline bool DataTypeIsQuantized(DataType dt) { 445 return kQuantizedTypes.Contains(dt); 446} 447 448// Is the dtype nonquantized integral? 449constexpr DataTypeSet kDataTypeIsInteger = 450 ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) | 451 ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64); 452inline bool DataTypeIsInteger(DataType dt) { 453 return kDataTypeIsInteger.Contains(dt); 454} 455 456// Is the dtype a signed integral type? 457constexpr DataTypeSet kDataTypeIsSigned = 458 ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64); 459inline bool DataTypeIsSigned(DataType dt) { 460 return kDataTypeIsSigned.Contains(dt); 461} 462 463// Is the dtype an unsigned integral type? 464constexpr DataTypeSet kDataTypeIsUnsigned = 465 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64); 466inline bool DataTypeIsUnsigned(DataType dt) { 467 return kDataTypeIsUnsigned.Contains(dt); 468} 469 470// Returns a 0 on failure 471int DataTypeSize(DataType dt); 472 473// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. 474// For DT_RESOURCE, the handle always sits on host (even if the underlying 475// object has device-allocated resources). 476bool DataTypeAlwaysOnHost(DataType dt); 477 478} // namespace tensorflow 479 480#endif // TENSORFLOW_FRAMEWORK_TYPES_H_ 481