1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_
17#define TENSORFLOW_FRAMEWORK_TYPES_H_
18
19#include <map>
20#include <set>
21#include <string>
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24// Disable clang-format to prevent 'FixedPoint' header from being included
25// before 'Tensor' header on which it depends.
26// clang-format off
27#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
28// clang-format on
29#include "tensorflow/core/framework/bfloat16.h"
30#include "tensorflow/core/framework/numeric_types.h"
31#include "tensorflow/core/framework/resource_handle.h"
32#include "tensorflow/core/framework/types.pb.h"
33#include "tensorflow/core/framework/variant.h"
34#include "tensorflow/core/lib/core/stringpiece.h"
35#include "tensorflow/core/lib/gtl/array_slice.h"
36#include "tensorflow/core/lib/gtl/inlined_vector.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/types.h"
39
40namespace tensorflow {
41
42// MemoryType is used to describe whether input or output Tensors of
43// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
44// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
45// devices).
46enum MemoryType {
47  DEVICE_MEMORY = 0,
48  HOST_MEMORY = 1,
49};
50
51// A DeviceType is just a string, but we wrap it up in a class to give
52// some type checking as we're passing these around
53class DeviceType {
54 public:
55  DeviceType(const char* type)  // NOLINT(runtime/explicit)
56      : type_(type) {}
57
58  explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {}
59
60  const char* type() const { return type_.c_str(); }
61  const string& type_string() const { return type_; }
62
63  bool operator<(const DeviceType& other) const;
64  bool operator==(const DeviceType& other) const;
65  bool operator!=(const DeviceType& other) const { return !(*this == other); }
66
67 private:
68  string type_;
69};
70std::ostream& operator<<(std::ostream& os, const DeviceType& d);
71
72// Convenient constants that can be passed to a DeviceType constructor
73TF_EXPORT extern const char* const DEVICE_CPU;   // "CPU"
74TF_EXPORT extern const char* const DEVICE_GPU;   // "GPU"
75TF_EXPORT extern const char* const DEVICE_SYCL;  // "SYCL"
76
77template <typename Device>
78struct DeviceName {};
79
80template <>
81struct DeviceName<Eigen::ThreadPoolDevice> {
82  static const std::string value;
83};
84
85#if GOOGLE_CUDA
86template <>
87struct DeviceName<Eigen::GpuDevice> {
88  static const std::string value;
89};
90#endif  // GOOGLE_CUDA
91
92#ifdef TENSORFLOW_USE_SYCL
93template <>
94struct DeviceName<Eigen::SyclDevice> {
95  static const std::string value;
96};
97#endif  // TENSORFLOW_USE_SYCL
98
99typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
100typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
101
102typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
103typedef gtl::ArraySlice<DataType> DataTypeSlice;
104
105typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
106
107// Convert the enums to strings for errors:
108string DataTypeString(DataType dtype);
109string DeviceTypeString(const DeviceType& device_type);
110string DataTypeSliceString(const DataTypeSlice dtypes);
111inline string DataTypeVectorString(const DataTypeVector& dtypes) {
112  return DataTypeSliceString(dtypes);
113}
114
115// DataTypeSet represents a set of DataType values as a simple and efficient
116// bit mask.  Note that DataTypeSet cannot represent all DataType values; it
117// cannot represent any of the DT_*_REF values.
118class DataTypeSet {
119 private:
120  const uint32 mask_;
121
122  static constexpr uint32 kNumBits = 32;
123
124 public:
125  constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {}
126  explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {}
127
128  constexpr bool Contains(DataType dt) const {
129    return (static_cast<uint32>(dt) < kNumBits) &&
130           ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u;
131  }
132
133  class Iterator {
134    const DataTypeSet& set_;
135    uint32 pos_;
136
137   public:
138    Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) {
139      DCHECK_LE(pos, kNumBits);
140    }
141    DataType operator*() const { return static_cast<DataType>(pos_); }
142    Iterator& operator++() {
143      ++pos_;
144      DCHECK_LE(pos_, kNumBits);
145      if (pos_ < kNumBits) {
146        uint32 remaining_mask = set_.mask_ >> pos_;
147        if (remaining_mask != 0u) {
148          pos_ += ctz_uint32(remaining_mask);
149        }
150      }
151      DCHECK_LE(pos_, kNumBits);
152      return *this;
153    }
154    bool operator==(const Iterator& other) const { return pos_ == other.pos_; }
155    bool operator!=(const Iterator& other) const { return !(*this == other); }
156    size_t operator-(const Iterator& other) const {
157      return this->pos_ - other.pos_;
158    }
159  };
160
161  static uint32 ctz_uint32(uint32 x) {
162    DCHECK_NE(x, 0u);
163#ifdef __GNUC__
164    return __builtin_ctz(x);
165#else
166    uint32 n = 0u;
167    while ((x & 1u) == 0u) {
168      x >>= 1;
169      ++n;
170    }
171    return n;
172#endif
173  }
174
175  static uint32 clz_uint32(uint32 x) {
176    DCHECK_NE(x, 0u);
177#ifdef __GNUC__
178    return __builtin_clz(x);
179#else
180    uint32 n = 0u;
181    while ((x >> (kNumBits - 1u)) == 0u) {
182      x <<= 1;
183      ++n;
184    }
185    return n;
186#endif
187  }
188
189  Iterator begin() const {
190    // The begin position is the index of the first bit set to 1 in the entire
191    // bit mask. If there are no bits set to 1, then the index is 0.
192    if (mask_ != 0) {
193      return Iterator(*this, ctz_uint32(mask_));
194    }
195    // The set is empty.
196    return Iterator(*this, 0);
197  }
198
199  Iterator end() const {
200    // The end position is the index of the highest bit that is set, plus 1.
201    // If there are no bits set to 1, then the index is 0.
202    if (mask_ != 0) {
203      return Iterator(*this, kNumBits - clz_uint32(mask_));
204    }
205    // The set is empty.
206    return Iterator(*this, 0);
207  }
208
209  size_t size() const {
210#if defined(__GNUC__)
211    return __builtin_popcount(mask_);
212#else
213    size_t n = 0;
214    uint32 x = mask_;
215    while (x > 0) {
216      n += x & 1u;
217      x >>= 1;
218    }
219    return n;
220#endif
221  }
222
223  constexpr DataTypeSet operator|(const DataTypeSet& other) const {
224    return DataTypeSet(mask_ | other.mask_);
225  }
226};
227
228// If "sp" names a valid type, store it in "*dt" and return true.  Otherwise,
229// return false.
230bool DataTypeFromString(StringPiece sp, DataType* dt);
231
232constexpr inline DataTypeSet ToSet(DataType dt) {
233  return DataTypeSet(1u << static_cast<uint32>(dt));
234}
235
236// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc.
237enum { kDataTypeRefOffset = 100 };
238inline bool IsRefType(DataType dtype) {
239  return dtype > static_cast<DataType>(kDataTypeRefOffset);
240}
241inline DataType MakeRefType(DataType dtype) {
242  DCHECK(!IsRefType(dtype));
243  return static_cast<DataType>(dtype + kDataTypeRefOffset);
244}
245inline DataType RemoveRefType(DataType dtype) {
246  DCHECK(IsRefType(dtype));
247  return static_cast<DataType>(dtype - kDataTypeRefOffset);
248}
249inline DataType BaseType(DataType dtype) {
250  return IsRefType(dtype) ? RemoveRefType(dtype) : dtype;
251}
252
253// Returns true if the actual type is the same as or ref of the expected type.
254inline bool TypesCompatible(DataType expected, DataType actual) {
255  return expected == actual || expected == BaseType(actual);
256}
257
258// Does not include _ref types.
259constexpr DataTypeSet kAllTypes =
260    ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) |
261    ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) |
262    ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
263    ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) |
264    ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) |
265    ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) |
266    ToSet(DT_BFLOAT16);
267inline const DataTypeSet& AllTypes() { return kAllTypes; }
268
269#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
270
271// Types that support '<' and '>'.
272constexpr DataTypeSet kRealNumberTypes =
273    ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
274    ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) |
275    ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
276inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
277
278// Return the list of all numeric types.
279// Includes complex and quantized types.
280// NOTE: On Android, we only include the float and int32 types for now.
281const DataTypeSet kNumberTypes =
282    ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) |
283    ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
284    ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) |
285    ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) |
286    ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
287inline const DataTypeSet& NumberTypes() { return kNumberTypes; }
288
289constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
290                                        ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
291                                        ToSet(DT_QINT32);
292inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; }
293
294// Types that support '<' and '>', including quantized types.
295const DataTypeSet kRealAndQuantizedTypes =
296    ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
297    ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) |
298    ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
299    ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16);
300inline const DataTypeSet& RealAndQuantizedTypes() {
301  return kRealAndQuantizedTypes;
302}
303
304#elif defined(__ANDROID_TYPES_FULL__)
305
306constexpr DataTypeSet kRealNumberTypes =
307    ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF);
308inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
309
310constexpr DataTypeSet kNumberTypes =
311    ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
312    ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF);
313inline DataTypeSet NumberTypes() { return kNumberTypes; }
314
315constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
316                                        ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
317                                        ToSet(DT_QINT32);
318inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
319
320constexpr DataTypeSet kRealAndQuantizedTypes =
321    ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
322    ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
323    ToSet(DT_HALF);
324inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
325
326#else  // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
327
328constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32);
329inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
330
331constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) |
332                                     ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
333                                     ToSet(DT_QINT32);
334inline DataTypeSet NumberTypes() { return kNumberTypes; }
335
336constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
337                                        ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
338                                        ToSet(DT_QINT32);
339inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
340
341constexpr DataTypeSet kRealAndQuantizedTypes =
342    ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
343    ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32);
344inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
345
346#endif  // defined(IS_MOBILE_PLATFORM)
347
348// Validates type T for whether it is a supported DataType.
349template <class T>
350struct IsValidDataType;
351
352// DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType
353// constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT.
354template <class T>
355struct DataTypeToEnum {
356  static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
357};  // Specializations below
358
359// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
360// EnumToDataType<DT_FLOAT>::Type is float.
361template <DataType VALUE>
362struct EnumToDataType {};  // Specializations below
363
364// Template specialization for both DataTypeToEnum and EnumToDataType.
365#define MATCH_TYPE_AND_ENUM(TYPE, ENUM)                 \
366  template <>                                           \
367  struct DataTypeToEnum<TYPE> {                         \
368    static DataType v() { return ENUM; }                \
369    static DataType ref() { return MakeRefType(ENUM); } \
370    static constexpr DataType value = ENUM;             \
371  };                                                    \
372  template <>                                           \
373  struct IsValidDataType<TYPE> {                        \
374    static constexpr bool value = true;                 \
375  };                                                    \
376  template <>                                           \
377  struct EnumToDataType<ENUM> {                         \
378    typedef TYPE Type;                                  \
379  }
380
381MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
382MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
383MATCH_TYPE_AND_ENUM(int32, DT_INT32);
384MATCH_TYPE_AND_ENUM(uint32, DT_UINT32);
385MATCH_TYPE_AND_ENUM(uint16, DT_UINT16);
386MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
387MATCH_TYPE_AND_ENUM(int16, DT_INT16);
388MATCH_TYPE_AND_ENUM(int8, DT_INT8);
389MATCH_TYPE_AND_ENUM(string, DT_STRING);
390MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
391MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128);
392MATCH_TYPE_AND_ENUM(int64, DT_INT64);
393MATCH_TYPE_AND_ENUM(uint64, DT_UINT64);
394MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
395MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
396MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
397MATCH_TYPE_AND_ENUM(qint16, DT_QINT16);
398MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16);
399MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
400MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
401MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF);
402MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE);
403MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT);
404
405#undef MATCH_TYPE_AND_ENUM
406
407// All types not specialized are marked invalid.
408template <class T>
409struct IsValidDataType {
410  static constexpr bool value = false;
411};
412
413// Extra validity checking; not part of public API.
414static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64");
415static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
416
417// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
418// is_simple<T> in tensor.cc (and possible choose a more general name?)
419constexpr DataTypeSet kDataTypesCanUseMemcpy =
420    ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) |
421    ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
422    ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
423    ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
424    ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
425    ToSet(DT_BFLOAT16) | ToSet(DT_HALF);
426inline bool DataTypeCanUseMemcpy(DataType dt) {
427  return kDataTypesCanUseMemcpy.Contains(dt);
428}
429
430// Returns true iff 'dt' is a real, non-quantized floating point type.
431constexpr DataTypeSet kDataTypeIsFloating =
432    ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE);
433inline bool DataTypeIsFloating(DataType dt) {
434  return kDataTypeIsFloating.Contains(dt);
435}
436
437// Returns true iff 'dt' is a complex type.
438constexpr DataTypeSet kDataTypeIsComplex =
439    ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128);
440inline bool DataTypeIsComplex(DataType dt) {
441  return kDataTypeIsComplex.Contains(dt);
442}
443
444inline bool DataTypeIsQuantized(DataType dt) {
445  return kQuantizedTypes.Contains(dt);
446}
447
448// Is the dtype nonquantized integral?
449constexpr DataTypeSet kDataTypeIsInteger =
450    ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) |
451    ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64);
452inline bool DataTypeIsInteger(DataType dt) {
453  return kDataTypeIsInteger.Contains(dt);
454}
455
456// Is the dtype a signed integral type?
457constexpr DataTypeSet kDataTypeIsSigned =
458    ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64);
459inline bool DataTypeIsSigned(DataType dt) {
460  return kDataTypeIsSigned.Contains(dt);
461}
462
463// Is the dtype an unsigned integral type?
464constexpr DataTypeSet kDataTypeIsUnsigned =
465    ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64);
466inline bool DataTypeIsUnsigned(DataType dt) {
467  return kDataTypeIsUnsigned.Contains(dt);
468}
469
470// Returns a 0 on failure
471int DataTypeSize(DataType dt);
472
473// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
474// For DT_RESOURCE, the handle always sits on host (even if the underlying
475// object has device-allocated resources).
476bool DataTypeAlwaysOnHost(DataType dt);
477
478}  // namespace tensorflow
479
480#endif  // TENSORFLOW_FRAMEWORK_TYPES_H_
481