1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18
19#include <string>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/types.pb.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/lib/core/stringpiece.h"
26#include "tensorflow/core/lib/gtl/array_slice.h"
27#include "tensorflow/core/lib/gtl/inlined_vector.h"
28#include "tensorflow/core/lib/strings/strcat.h"
29#include "tensorflow/core/platform/logging.h"
30
31namespace tensorflow {
32
33// START_SKIP_DOXYGEN
34template <class Shape>
35class TensorShapeIter;
36class TensorShape;
37class TensorShapeProto;
38class PartialTensorShape;
39// END_SKIP_DOXYGEN
40
41/// Internal representation for both TensorShape and PartialTensorShape.
42class TensorShapeRep {
43 public:
44  ~TensorShapeRep();
45
46  /// Copy the specified shape
47  TensorShapeRep(const TensorShapeRep& b);
48  void operator=(const TensorShapeRep& b);
49
50  /// Move the specified shape.  After moving, <b> is safe for destruction and
51  // can be reassigned into, but its dimensions and number of elements can be
52  // nonsensical (e.g., negative dimension sizes, or number of elements not
53  // properly recomputed).
54  TensorShapeRep(TensorShapeRep&& b);
55  void operator=(TensorShapeRep&& b);
56
57  /// Clear a tensor shape, producing the scalar shape.
58  void Clear();
59
60  // Maximum number of dimensions in a tensor.
61  // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
62  static constexpr int MaxDimensions() { return 254; }
63
64  /// \brief Returns the number of elements in the tensor.
65  ///
66  /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
67  /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
68  /// defined.
69  int64 num_elements() const { return num_elements_; }
70
71  /// For error messages.
72  string DebugString() const;
73  static string DebugString(const TensorShapeProto& proto);
74
75  void DumpRep() const;  // XXX
76
77 protected:
78  // Constructable only via TensorShapeBase
79  TensorShapeRep() = default;
80
81  void ClearAllButDataType();
82
83  // We use 16 bytes to represent a TensorShape.  Because we need to
84  // be able to support full 64-bit dimension sizes and an arbitrary
85  // number of dimensions for a Tensor, but most tensor dimensions are
86  // significantly smaller than 64 bits and most tensors are 1, 2, or 3
87  // dimensions, we have several representations.
88  // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
89  // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
90  // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
91  //        an out of line vector.
92  // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
93  // This value is not allowed in TensorShape either for format compatibility.
94  struct Rep16 {
95    uint16 dims_[6];
96  };
97  struct Rep32 {
98    uint32 dims_[3];
99  };
100  struct Rep64 {
101    gtl::InlinedVector<int64, 4>* dims_;
102  };
103
104  // We use the max value of uint16 or uint32 to represent unknown shapes, so
105  // the maximum representable valid shape in these representations is one less.
106  static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
107  static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
108  static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
109  static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
110
111  Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
112  Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
113  Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
114
115  const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
116  const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
117  const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
118
119  enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
120
121  // Since we have a convenient extra byte available, we allow the
122  // Tensor class to store an 8-bit value in this extra storage.  This
123  // allows it to store the Tensor's datatype enum value here and avoid
124  // an extra word of storage.
125  friend class Tensor;
126  friend class TensorShapeTestHelper;
127  DataType data_type() const { return static_cast<DataType>(buf()[13]); }
128  void set_data_type(DataType dt) {
129    // We only have 8 bits available to store DataType, so make sure it fits
130    DCHECK_LT(static_cast<uint32>(dt), 256u);
131    buf()[13] = static_cast<uint8>(dt);
132  }
133
134  // We store the number of dimensions in byte 14, and the RepTag in byte 15.
135  // Bytes [0..13] vary depending on the representation.
136  // A value of 255 indicates unknown rank in the PartialTensorShape case.
137  static const uint8 kUnknownRank = 255;
138  uint8 ndims_byte() const { return buf()[14]; }
139  void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
140
141  RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
142  void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
143
144  void set_num_elements(int64 n) { num_elements_ = n; }
145
146 private:
147  void DestructorOutOfLine();
148  void SlowCopyFrom(const TensorShapeRep& b);
149
150  uint8* buf() { return &u_.buf[0]; }
151  const uint8* buf() const { return &u_.buf[0]; }
152
153  union {
154    uint8 buf[16];
155    // Force data to be aligned enough for a pointer.
156    Rep64* unused_aligner;
157  } u_;
158  int64 num_elements_;
159};
160
161/// Base class for TensorShape and PartialTensorShape.
162/// The class is templatized by either TensorShape or PartialTensorShape to
163/// allow skipping known/unknown checks in the TensorShape case, but the
164/// representation is shared exactly for fast conversion.
165template <class Shape>
166class TensorShapeBase : public TensorShapeRep {
167 public:
168  /// \brief Construct a `TensorShapeBase` from the provided sizes.
169  /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
170  explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
171  TensorShapeBase(std::initializer_list<int64> dim_sizes)
172      : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
173
174  /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
175  TensorShapeBase();
176
177  TensorShapeBase(const TensorShapeProto& proto);
178
179  /// Returns `true` iff `proto` is a valid tensor shape.
180  // For TensorShape, the proto shape must be fully defined.
181  static bool IsValid(const TensorShapeProto& proto);
182
183  /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
184  /// status otherwise.
185  static Status IsValidShape(const TensorShapeProto& proto);
186
187  /// \brief Add a dimension to the end ("inner-most").
188  /// REQUIRES: `size >= 0`
189  void AddDim(int64 size);
190
191  /// Appends all the dimensions from `shape`.
192  void AppendShape(const TensorShapeBase& shape);
193
194  /// \brief Insert a dimension somewhere in the `TensorShape`.
195  /// REQUIRES: `0 <= d <= dims()`
196  /// REQUIRES: `size >= 0`
197  void InsertDim(int d, int64 size);
198
199  /// \brief Modifies the size of the dimension `d` to be `size`
200  /// REQUIRES: `0 <= d < dims()`
201  /// REQUIRES: `size >= 0`
202  void set_dim(int d, int64 size);
203
204  /// \brief Removes dimension `d` from the `TensorShape`.
205  /// REQUIRES: `0 <= d < dims()`
206  void RemoveDim(int d) {
207    CHECK_GE(d, 0);
208    RemoveDimRange(d, d + 1);
209  }
210
211  /// \brief Removes last `n` dimensions from the `TensorShape`.
212  /// REQUIRES: `0 <= n <= dims()`
213  void RemoveLastDims(int n) {
214    CHECK_LE(n, dims());
215    RemoveDimRange(dims() - n, dims());
216  }
217
218  /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
219  /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
220  /// Python). The same is true for negative values of `begin`. REQUIRES:
221  /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
222  void RemoveDimRange(int begin, int end);
223
224  /// Return whether the rank is unknown
225  bool unknown_rank() const {
226    return kIsPartial && ndims_byte() == kUnknownRank;
227  }
228
229  /// Return the number of dimensions in the tensor.
230  /// Can be -1 meaning unknown rank for PartialTensorShape.
231  int dims() const {
232    uint8 dims = ndims_byte();
233    return kIsPartial && dims == kUnknownRank ? -1 : dims;
234  }
235
236  /// \brief Returns the number of elements in dimension `d`.
237  /// REQUIRES: `0 <= d < dims()`
238  // TODO(touts): Rename to `dimension()` to match
239  // `Eigen::Tensor::dimension()`?
240  int64 dim_size(int d) const;
241
242  /// Returns sizes of all dimensions.
243  // Returns an empty list for unknown rank PartialTensorShape.
244  gtl::InlinedVector<int64, 4> dim_sizes() const;
245
246  /// Return true iff the rank and all of the dimensions are well defined
247  // TODO(irving): Rename to is_fully_defined now that it's fast.
248  bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
249
250  /// Fill `*proto` from `*this`.
251  void AsProto(TensorShapeProto* proto) const;
252
253  /// For iterating through the dimensions.
254  TensorShapeIter<Shape> begin() const;
255  TensorShapeIter<Shape> end() const;
256
257 private:
258  void RecomputeNumElements();
259
260  // True for PartialTensorShape, false for TensorShape
261  static constexpr bool kIsPartial =
262      std::is_same<Shape, PartialTensorShape>::value;
263  static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
264                "Shape is neither TensorShape nor PartialTensorShape");
265
266  // Used by AddDim and MakeShapeHelper.  Does no error checking.
267  void UnsafeAddDim(int64 size, int64 new_num_elements);
268
269  // For use by TensorShapeUtils::MakeShape
270  template <class T, class S>
271  friend Status MakeShapeHelper(const T*, int64, S*);
272};
273
274/// Represents the shape of a Tensor.
275///
276/// A tensor's shape is denoted by its number of dimensions and a size for each
277/// dimension.  For example, a Tensor represented by a 3 x 4 matrix would have
278/// a shape of 2-D, [3,4].
279///
280/// If you know the exact shape of your Tensor when you create the TensorShape
281/// object, you can specify it then, or you can create a TensorShape with
282/// zero dimensions and one element, and call AddDim() to add dimensions later.
283class TensorShape : public TensorShapeBase<TensorShape> {
284 public:
285  using TensorShapeBase<TensorShape>::TensorShapeBase;
286
287  /// Allow a TensorShape to be used as a PartialTensorShape without copying
288  operator const PartialTensorShape&() const;  // NOLINT(runtime/explicit)
289
290  /// Returns true if `*this` and `b` have the same sizes. Ignores
291  /// dimension names.
292  bool IsSameSize(const TensorShape& b) const;
293  bool operator==(const TensorShape& b) const { return IsSameSize(b); }
294  bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
295
296  /// Fill `*dsizes` from `*this`.
297  template <int NDIMS>
298  Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
299
300  /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
301  /// which case we pad the rest of the sizes with 1.
302  template <int NDIMS>
303  Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
304
305 private:
306  // These CHECK fail to ease debugging.
307  // REQUIRES: dims() == NDIMS
308  void CheckDimsEqual(int NDIMS) const;
309  // REQUIRES: dims() >= NDIMS
310  void CheckDimsAtLeast(int NDIMS) const;
311};
312
313/// Represents the value of one dimension in a TensorShape.
314struct TensorShapeDim {
315  explicit TensorShapeDim(int64 s) : size(s) {}
316  int64 size;
317};
318
319// START_SKIP_DOXYGEN
320template <class Shape>
321class TensorShapeIter {
322 public:
323  TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
324  bool operator==(const TensorShapeIter& rhs) {
325    DCHECK(shape_ == rhs.shape_);
326    return d_ == rhs.d_;
327  }
328  bool operator!=(const TensorShapeIter& rhs) {
329    DCHECK(shape_ == rhs.shape_);
330    return d_ != rhs.d_;
331  }
332  void operator++() { ++d_; }
333  TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
334
335 private:
336  const Shape* shape_;
337  int d_;
338};
339// END_SKIP_DOXYGEN
340
341/// \brief Static helper routines for `TensorShape`. Includes a few common
342/// predicates on a tensor shape.
343class TensorShapeUtils {
344 public:
345  static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
346
347  static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
348
349  static bool IsVectorOrHigher(const TensorShape& shape) {
350    return shape.dims() >= 1;
351  }
352
353  static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
354
355  static bool IsSquareMatrix(const TensorShape& shape) {
356    return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
357  }
358
359  static bool IsMatrixOrHigher(const TensorShape& shape) {
360    return shape.dims() >= 2;
361  }
362
363  /// \brief Returns a `TensorShape` whose dimensions are
364  /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
365  static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
366  static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
367  static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
368  static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
369  static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
370  static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
371  static Status MakeShape(gtl::ArraySlice<int32> shape,
372                          PartialTensorShape* out);
373  static Status MakeShape(gtl::ArraySlice<int64> shape,
374                          PartialTensorShape* out);
375
376  static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
377
378  /// \brief Returns true iff `shape` starts with `prefix`.
379  static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
380
381  /// \brief Returns true iff `shape` ends with `suffix`.
382  static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
383
384  /// \brief Returns the product of values in an int64 array,
385  /// or a failing Status if the array represents a value larger than
386  /// a `TensorShape` can hold.
387  static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
388};
389
390/// Manages the partially known dimensions of a Tensor and their sizes.
391class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
392 public:
393  PartialTensorShape() {}
394  using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
395
396  /// Add a dimension to the end ("inner-most"), returns a new
397  /// PartialTensorShape.
398  /// REQUIRES: `size >= -1`, where -1 means unknown
399  PartialTensorShape Concatenate(int64 size) const;
400
401  /// Appends all the dimensions from `shape`.  Returns a new
402  /// PartialTensorShape.
403  PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
404
405  /// Merges all the dimensions from `shape`.  Returns
406  /// `InvalidArgument` error if either `shape` has a different rank
407  /// or if any of the dimensions are incompatible.
408  Status MergeWith(const PartialTensorShape& shape,
409                   PartialTensorShape* result) const;
410
411  /// Exact equality test. Returns true iff the ranks match (i.e., both are
412  /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
413  /// both dimensions are known, or both are known and equal). This is a
414  /// stronger condition that IsCompatibleWith.
415  bool IsIdenticalTo(const PartialTensorShape& shape) const;
416
417  /// Return true iff the ranks match, and if the
418  /// dimensions all either match or one is unknown.
419  bool IsCompatibleWith(const PartialTensorShape& shape) const;
420
421  // Fill `*shape` from `*this`.
422  // If `*this` is not fully defined, returns false and
423  // `*shape` is left in an intermediate state.  Otherwise
424  // returns true.
425  bool AsTensorShape(TensorShape* shape) const;
426
427  /// \brief Returns a `PartialTensorShape` whose dimensions are
428  /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.  Values of -1 are
429  /// considered "unknown".
430  template <class T>
431  static Status MakePartialShape(const T* dims, int n,
432                                 PartialTensorShape* out) {
433    return TensorShapeUtils::MakeShape(dims, n, out);
434  }
435};
436
437/// \brief Static helper routines for `PartialTensorShape`. Includes a few
438/// common predicates on a partially known tensor shape.
439class PartialTensorShapeUtils {
440 public:
441  static string PartialShapeListString(
442      const gtl::ArraySlice<PartialTensorShape>& shapes);
443
444  static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
445                           const gtl::ArraySlice<PartialTensorShape>& shapes1);
446
447  static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
448                            const gtl::ArraySlice<PartialTensorShape>& shapes1);
449};
450
451// ----------------------------------------------------------------------------
452// Template method implementation details below
453// ----------------------------------------------------------------------------
454
455template <int NDIMS>
456Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
457  CheckDimsEqual(NDIMS);
458  return AsEigenDSizesWithPadding<NDIMS>();
459}
460
461template <int NDIMS>
462Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
463    const {
464  CheckDimsAtLeast(NDIMS);
465  static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
466  Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
467  for (int d = 0; d < dims(); d++) {
468    dsizes[d] = dim_size(d);
469  }
470  for (int d = dims(); d < NDIMS; d++) {
471    dsizes[d] = 1;
472  }
473  return dsizes;
474}
475
476// ----------------------------------------------------------------------------
477// Inlining of some performance critical routines
478// ----------------------------------------------------------------------------
479
480inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
481  num_elements_ = b.num_elements_;
482  if (b.tag() != REP_OUT_OF_LINE) {
483    memcpy(buf(), b.buf(), sizeof(u_.buf));
484    // memcpy above Implicitly does:
485    //   set_ndims_byte(b.ndims_byte());
486    //   set_tag(b.tag());
487  } else {
488    set_tag(REP16);  // So that SlowCopyFrom does not try to deallocate
489    SlowCopyFrom(b);
490  }
491}
492
493inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
494  num_elements_ = b.num_elements_;
495  memcpy(buf(), b.buf(), sizeof(u_.buf));
496  // memcpy above Implicitly does:
497  //   set_ndims_byte(b.ndims_byte());
498  //   set_tag(b.tag());
499  b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
500}
501
502inline TensorShapeRep::~TensorShapeRep() {
503  if (tag() == REP_OUT_OF_LINE) {
504    DestructorOutOfLine();
505  }
506}
507
508inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
509  num_elements_ = b.num_elements_;
510  if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
511    memcpy(buf(), b.buf(), sizeof(u_.buf));
512    // memcpy above implicitly also does:
513    //   set_tag(b.tag());
514    //   set_ndims_byte(b.ndims_byte());
515  } else {
516    SlowCopyFrom(b);
517  }
518}
519
520inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
521  if (tag() == REP_OUT_OF_LINE) {
522    DestructorOutOfLine();
523  }
524  num_elements_ = b.num_elements_;
525  memcpy(buf(), b.buf(), sizeof(u_.buf));
526  // memcpy above Implicitly does:
527  //   set_ndims_byte(b.ndims_byte());
528  //   set_tag(b.tag());
529  b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
530}
531
532inline TensorShape::operator const PartialTensorShape&() const {
533  // Downcast to the shared representation and upcast to PartialTensorShape
534  const TensorShapeRep* rep = this;
535  return *static_cast<const PartialTensorShape*>(rep);
536}
537
538// Declare explicit instantiations in .cc file
539extern template class TensorShapeBase<TensorShape>;
540extern template class TensorShapeBase<PartialTensorShape>;
541
542}  // namespace tensorflow
543
544#endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
545