1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ 17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ 18 19#include <string> 20 21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22#include "tensorflow/core/framework/types.pb.h" 23#include "tensorflow/core/lib/core/errors.h" 24#include "tensorflow/core/lib/core/status.h" 25#include "tensorflow/core/lib/core/stringpiece.h" 26#include "tensorflow/core/lib/gtl/array_slice.h" 27#include "tensorflow/core/lib/gtl/inlined_vector.h" 28#include "tensorflow/core/lib/strings/strcat.h" 29#include "tensorflow/core/platform/logging.h" 30 31namespace tensorflow { 32 33// START_SKIP_DOXYGEN 34template <class Shape> 35class TensorShapeIter; 36class TensorShape; 37class TensorShapeProto; 38class PartialTensorShape; 39// END_SKIP_DOXYGEN 40 41/// Internal representation for both TensorShape and PartialTensorShape. 42class TensorShapeRep { 43 public: 44 ~TensorShapeRep(); 45 46 /// Copy the specified shape 47 TensorShapeRep(const TensorShapeRep& b); 48 void operator=(const TensorShapeRep& b); 49 50 /// Move the specified shape. After moving, <b> is safe for destruction and 51 // can be reassigned into, but its dimensions and number of elements can be 52 // nonsensical (e.g., negative dimension sizes, or number of elements not 53 // properly recomputed). 54 TensorShapeRep(TensorShapeRep&& b); 55 void operator=(TensorShapeRep&& b); 56 57 /// Clear a tensor shape, producing the scalar shape. 58 void Clear(); 59 60 // Maximum number of dimensions in a tensor. 61 // It's 254 because 255 = kUnknownRank is used to represent unknown rank. 62 static constexpr int MaxDimensions() { return 254; } 63 64 /// \brief Returns the number of elements in the tensor. 65 /// 66 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` 67 /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully 68 /// defined. 69 int64 num_elements() const { return num_elements_; } 70 71 /// For error messages. 72 string DebugString() const; 73 static string DebugString(const TensorShapeProto& proto); 74 75 void DumpRep() const; // XXX 76 77 protected: 78 // Constructable only via TensorShapeBase 79 TensorShapeRep() = default; 80 81 void ClearAllButDataType(); 82 83 // We use 16 bytes to represent a TensorShape. Because we need to 84 // be able to support full 64-bit dimension sizes and an arbitrary 85 // number of dimensions for a Tensor, but most tensor dimensions are 86 // significantly smaller than 64 bits and most tensors are 1, 2, or 3 87 // dimensions, we have several representations. 88 // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1 89 // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1 90 // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using 91 // an out of line vector. 92 // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown. 93 // This value is not allowed in TensorShape either for format compatibility. 94 struct Rep16 { 95 uint16 dims_[6]; 96 }; 97 struct Rep32 { 98 uint32 dims_[3]; 99 }; 100 struct Rep64 { 101 gtl::InlinedVector<int64, 4>* dims_; 102 }; 103 104 // We use the max value of uint16 or uint32 to represent unknown shapes, so 105 // the maximum representable valid shape in these representations is one less. 106 static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1; 107 static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1; 108 static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max(); 109 static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max(); 110 111 Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); } 112 Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); } 113 Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); } 114 115 const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); } 116 const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); } 117 const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); } 118 119 enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 }; 120 121 // Since we have a convenient extra byte available, we allow the 122 // Tensor class to store an 8-bit value in this extra storage. This 123 // allows it to store the Tensor's datatype enum value here and avoid 124 // an extra word of storage. 125 friend class Tensor; 126 friend class TensorShapeTestHelper; 127 DataType data_type() const { return static_cast<DataType>(buf()[13]); } 128 void set_data_type(DataType dt) { 129 // We only have 8 bits available to store DataType, so make sure it fits 130 DCHECK_LT(static_cast<uint32>(dt), 256u); 131 buf()[13] = static_cast<uint8>(dt); 132 } 133 134 // We store the number of dimensions in byte 14, and the RepTag in byte 15. 135 // Bytes [0..13] vary depending on the representation. 136 // A value of 255 indicates unknown rank in the PartialTensorShape case. 137 static const uint8 kUnknownRank = 255; 138 uint8 ndims_byte() const { return buf()[14]; } 139 void set_ndims_byte(uint8 nd) { buf()[14] = nd; } 140 141 RepTag tag() const { return static_cast<RepTag>(buf()[15]); } 142 void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); } 143 144 void set_num_elements(int64 n) { num_elements_ = n; } 145 146 private: 147 void DestructorOutOfLine(); 148 void SlowCopyFrom(const TensorShapeRep& b); 149 150 uint8* buf() { return &u_.buf[0]; } 151 const uint8* buf() const { return &u_.buf[0]; } 152 153 union { 154 uint8 buf[16]; 155 // Force data to be aligned enough for a pointer. 156 Rep64* unused_aligner; 157 } u_; 158 int64 num_elements_; 159}; 160 161/// Base class for TensorShape and PartialTensorShape. 162/// The class is templatized by either TensorShape or PartialTensorShape to 163/// allow skipping known/unknown checks in the TensorShape case, but the 164/// representation is shared exactly for fast conversion. 165template <class Shape> 166class TensorShapeBase : public TensorShapeRep { 167 public: 168 /// \brief Construct a `TensorShapeBase` from the provided sizes. 169 /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape) 170 explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes); 171 TensorShapeBase(std::initializer_list<int64> dim_sizes) 172 : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {} 173 174 /// Construct an empty TensorShape, or an unknown rank PartialTensorShape 175 TensorShapeBase(); 176 177 TensorShapeBase(const TensorShapeProto& proto); 178 179 /// Returns `true` iff `proto` is a valid tensor shape. 180 // For TensorShape, the proto shape must be fully defined. 181 static bool IsValid(const TensorShapeProto& proto); 182 183 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error 184 /// status otherwise. 185 static Status IsValidShape(const TensorShapeProto& proto); 186 187 /// \brief Add a dimension to the end ("inner-most"). 188 /// REQUIRES: `size >= 0` 189 void AddDim(int64 size); 190 191 /// Appends all the dimensions from `shape`. 192 void AppendShape(const TensorShapeBase& shape); 193 194 /// \brief Insert a dimension somewhere in the `TensorShape`. 195 /// REQUIRES: `0 <= d <= dims()` 196 /// REQUIRES: `size >= 0` 197 void InsertDim(int d, int64 size); 198 199 /// \brief Modifies the size of the dimension `d` to be `size` 200 /// REQUIRES: `0 <= d < dims()` 201 /// REQUIRES: `size >= 0` 202 void set_dim(int d, int64 size); 203 204 /// \brief Removes dimension `d` from the `TensorShape`. 205 /// REQUIRES: `0 <= d < dims()` 206 void RemoveDim(int d) { 207 CHECK_GE(d, 0); 208 RemoveDimRange(d, d + 1); 209 } 210 211 /// \brief Removes last `n` dimensions from the `TensorShape`. 212 /// REQUIRES: `0 <= n <= dims()` 213 void RemoveLastDims(int n) { 214 CHECK_LE(n, dims()); 215 RemoveDimRange(dims() - n, dims()); 216 } 217 218 /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`. 219 /// Negative values of `end` are interpreted as `dims() + end + 1` (as in 220 /// Python). The same is true for negative values of `begin`. REQUIRES: 221 /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()` 222 void RemoveDimRange(int begin, int end); 223 224 /// Return whether the rank is unknown 225 bool unknown_rank() const { 226 return kIsPartial && ndims_byte() == kUnknownRank; 227 } 228 229 /// Return the number of dimensions in the tensor. 230 /// Can be -1 meaning unknown rank for PartialTensorShape. 231 int dims() const { 232 uint8 dims = ndims_byte(); 233 return kIsPartial && dims == kUnknownRank ? -1 : dims; 234 } 235 236 /// \brief Returns the number of elements in dimension `d`. 237 /// REQUIRES: `0 <= d < dims()` 238 // TODO(touts): Rename to `dimension()` to match 239 // `Eigen::Tensor::dimension()`? 240 int64 dim_size(int d) const; 241 242 /// Returns sizes of all dimensions. 243 // Returns an empty list for unknown rank PartialTensorShape. 244 gtl::InlinedVector<int64, 4> dim_sizes() const; 245 246 /// Return true iff the rank and all of the dimensions are well defined 247 // TODO(irving): Rename to is_fully_defined now that it's fast. 248 bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; } 249 250 /// Fill `*proto` from `*this`. 251 void AsProto(TensorShapeProto* proto) const; 252 253 /// For iterating through the dimensions. 254 TensorShapeIter<Shape> begin() const; 255 TensorShapeIter<Shape> end() const; 256 257 private: 258 void RecomputeNumElements(); 259 260 // True for PartialTensorShape, false for TensorShape 261 static constexpr bool kIsPartial = 262 std::is_same<Shape, PartialTensorShape>::value; 263 static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value, 264 "Shape is neither TensorShape nor PartialTensorShape"); 265 266 // Used by AddDim and MakeShapeHelper. Does no error checking. 267 void UnsafeAddDim(int64 size, int64 new_num_elements); 268 269 // For use by TensorShapeUtils::MakeShape 270 template <class T, class S> 271 friend Status MakeShapeHelper(const T*, int64, S*); 272}; 273 274/// Represents the shape of a Tensor. 275/// 276/// A tensor's shape is denoted by its number of dimensions and a size for each 277/// dimension. For example, a Tensor represented by a 3 x 4 matrix would have 278/// a shape of 2-D, [3,4]. 279/// 280/// If you know the exact shape of your Tensor when you create the TensorShape 281/// object, you can specify it then, or you can create a TensorShape with 282/// zero dimensions and one element, and call AddDim() to add dimensions later. 283class TensorShape : public TensorShapeBase<TensorShape> { 284 public: 285 using TensorShapeBase<TensorShape>::TensorShapeBase; 286 287 /// Allow a TensorShape to be used as a PartialTensorShape without copying 288 operator const PartialTensorShape&() const; // NOLINT(runtime/explicit) 289 290 /// Returns true if `*this` and `b` have the same sizes. Ignores 291 /// dimension names. 292 bool IsSameSize(const TensorShape& b) const; 293 bool operator==(const TensorShape& b) const { return IsSameSize(b); } 294 bool operator!=(const TensorShape& b) const { return !IsSameSize(b); } 295 296 /// Fill `*dsizes` from `*this`. 297 template <int NDIMS> 298 Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const; 299 300 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in 301 /// which case we pad the rest of the sizes with 1. 302 template <int NDIMS> 303 Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const; 304 305 private: 306 // These CHECK fail to ease debugging. 307 // REQUIRES: dims() == NDIMS 308 void CheckDimsEqual(int NDIMS) const; 309 // REQUIRES: dims() >= NDIMS 310 void CheckDimsAtLeast(int NDIMS) const; 311}; 312 313/// Represents the value of one dimension in a TensorShape. 314struct TensorShapeDim { 315 explicit TensorShapeDim(int64 s) : size(s) {} 316 int64 size; 317}; 318 319// START_SKIP_DOXYGEN 320template <class Shape> 321class TensorShapeIter { 322 public: 323 TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {} 324 bool operator==(const TensorShapeIter& rhs) { 325 DCHECK(shape_ == rhs.shape_); 326 return d_ == rhs.d_; 327 } 328 bool operator!=(const TensorShapeIter& rhs) { 329 DCHECK(shape_ == rhs.shape_); 330 return d_ != rhs.d_; 331 } 332 void operator++() { ++d_; } 333 TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } 334 335 private: 336 const Shape* shape_; 337 int d_; 338}; 339// END_SKIP_DOXYGEN 340 341/// \brief Static helper routines for `TensorShape`. Includes a few common 342/// predicates on a tensor shape. 343class TensorShapeUtils { 344 public: 345 static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } 346 347 static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } 348 349 static bool IsVectorOrHigher(const TensorShape& shape) { 350 return shape.dims() >= 1; 351 } 352 353 static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } 354 355 static bool IsSquareMatrix(const TensorShape& shape) { 356 return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1); 357 } 358 359 static bool IsMatrixOrHigher(const TensorShape& shape) { 360 return shape.dims() >= 2; 361 } 362 363 /// \brief Returns a `TensorShape` whose dimensions are 364 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. 365 static Status MakeShape(const int32* dims, int64 n, TensorShape* out); 366 static Status MakeShape(const int64* dims, int64 n, TensorShape* out); 367 static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out); 368 static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out); 369 static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out); 370 static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out); 371 static Status MakeShape(gtl::ArraySlice<int32> shape, 372 PartialTensorShape* out); 373 static Status MakeShape(gtl::ArraySlice<int64> shape, 374 PartialTensorShape* out); 375 376 static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes); 377 378 /// \brief Returns true iff `shape` starts with `prefix`. 379 static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); 380 381 /// \brief Returns true iff `shape` ends with `suffix`. 382 static bool EndsWith(const TensorShape& shape, const TensorShape& suffix); 383 384 /// \brief Returns the product of values in an int64 array, 385 /// or a failing Status if the array represents a value larger than 386 /// a `TensorShape` can hold. 387 static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements); 388}; 389 390/// Manages the partially known dimensions of a Tensor and their sizes. 391class PartialTensorShape : public TensorShapeBase<PartialTensorShape> { 392 public: 393 PartialTensorShape() {} 394 using TensorShapeBase<PartialTensorShape>::TensorShapeBase; 395 396 /// Add a dimension to the end ("inner-most"), returns a new 397 /// PartialTensorShape. 398 /// REQUIRES: `size >= -1`, where -1 means unknown 399 PartialTensorShape Concatenate(int64 size) const; 400 401 /// Appends all the dimensions from `shape`. Returns a new 402 /// PartialTensorShape. 403 PartialTensorShape Concatenate(const PartialTensorShape& shape) const; 404 405 /// Merges all the dimensions from `shape`. Returns 406 /// `InvalidArgument` error if either `shape` has a different rank 407 /// or if any of the dimensions are incompatible. 408 Status MergeWith(const PartialTensorShape& shape, 409 PartialTensorShape* result) const; 410 411 /// Exact equality test. Returns true iff the ranks match (i.e., both are 412 /// unknown, or both are known and equal), and all dimensions are equal (i.e., 413 /// both dimensions are known, or both are known and equal). This is a 414 /// stronger condition that IsCompatibleWith. 415 bool IsIdenticalTo(const PartialTensorShape& shape) const; 416 417 /// Return true iff the ranks match, and if the 418 /// dimensions all either match or one is unknown. 419 bool IsCompatibleWith(const PartialTensorShape& shape) const; 420 421 // Fill `*shape` from `*this`. 422 // If `*this` is not fully defined, returns false and 423 // `*shape` is left in an intermediate state. Otherwise 424 // returns true. 425 bool AsTensorShape(TensorShape* shape) const; 426 427 /// \brief Returns a `PartialTensorShape` whose dimensions are 428 /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are 429 /// considered "unknown". 430 template <class T> 431 static Status MakePartialShape(const T* dims, int n, 432 PartialTensorShape* out) { 433 return TensorShapeUtils::MakeShape(dims, n, out); 434 } 435}; 436 437/// \brief Static helper routines for `PartialTensorShape`. Includes a few 438/// common predicates on a partially known tensor shape. 439class PartialTensorShapeUtils { 440 public: 441 static string PartialShapeListString( 442 const gtl::ArraySlice<PartialTensorShape>& shapes); 443 444 static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0, 445 const gtl::ArraySlice<PartialTensorShape>& shapes1); 446 447 static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0, 448 const gtl::ArraySlice<PartialTensorShape>& shapes1); 449}; 450 451// ---------------------------------------------------------------------------- 452// Template method implementation details below 453// ---------------------------------------------------------------------------- 454 455template <int NDIMS> 456Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const { 457 CheckDimsEqual(NDIMS); 458 return AsEigenDSizesWithPadding<NDIMS>(); 459} 460 461template <int NDIMS> 462Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() 463 const { 464 CheckDimsAtLeast(NDIMS); 465 static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions"); 466 Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; 467 for (int d = 0; d < dims(); d++) { 468 dsizes[d] = dim_size(d); 469 } 470 for (int d = dims(); d < NDIMS; d++) { 471 dsizes[d] = 1; 472 } 473 return dsizes; 474} 475 476// ---------------------------------------------------------------------------- 477// Inlining of some performance critical routines 478// ---------------------------------------------------------------------------- 479 480inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { 481 num_elements_ = b.num_elements_; 482 if (b.tag() != REP_OUT_OF_LINE) { 483 memcpy(buf(), b.buf(), sizeof(u_.buf)); 484 // memcpy above Implicitly does: 485 // set_ndims_byte(b.ndims_byte()); 486 // set_tag(b.tag()); 487 } else { 488 set_tag(REP16); // So that SlowCopyFrom does not try to deallocate 489 SlowCopyFrom(b); 490 } 491} 492 493inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { 494 num_elements_ = b.num_elements_; 495 memcpy(buf(), b.buf(), sizeof(u_.buf)); 496 // memcpy above Implicitly does: 497 // set_ndims_byte(b.ndims_byte()); 498 // set_tag(b.tag()); 499 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. 500} 501 502inline TensorShapeRep::~TensorShapeRep() { 503 if (tag() == REP_OUT_OF_LINE) { 504 DestructorOutOfLine(); 505 } 506} 507 508inline void TensorShapeRep::operator=(const TensorShapeRep& b) { 509 num_elements_ = b.num_elements_; 510 if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { 511 memcpy(buf(), b.buf(), sizeof(u_.buf)); 512 // memcpy above implicitly also does: 513 // set_tag(b.tag()); 514 // set_ndims_byte(b.ndims_byte()); 515 } else { 516 SlowCopyFrom(b); 517 } 518} 519 520inline void TensorShapeRep::operator=(TensorShapeRep&& b) { 521 if (tag() == REP_OUT_OF_LINE) { 522 DestructorOutOfLine(); 523 } 524 num_elements_ = b.num_elements_; 525 memcpy(buf(), b.buf(), sizeof(u_.buf)); 526 // memcpy above Implicitly does: 527 // set_ndims_byte(b.ndims_byte()); 528 // set_tag(b.tag()); 529 b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. 530} 531 532inline TensorShape::operator const PartialTensorShape&() const { 533 // Downcast to the shared representation and upcast to PartialTensorShape 534 const TensorShapeRep* rep = this; 535 return *static_cast<const PartialTensorShape*>(rep); 536} 537 538// Declare explicit instantiations in .cc file 539extern template class TensorShapeBase<TensorShape>; 540extern template class TensorShapeBase<PartialTensorShape>; 541 542} // namespace tensorflow 543 544#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ 545