tensor_shape.h revision b49aef3c17e39a9c7127477333c7cc76aef57a80
1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 3db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergLicensed under the Apache License, Version 2.0 (the "License"); 4db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergyou may not use this file except in compliance with the License. 5db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergYou may obtain a copy of the License at 6db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 7db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg http://www.apache.org/licenses/LICENSE-2.0 8db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 9db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergUnless required by applicable law or agreed to in writing, software 10db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergdistributed under the License is distributed on an "AS IS" BASIS, 11db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergSee the License for the specific language governing permissions and 13db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberglimitations under the License. 14db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg==============================================================================*/ 15db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 16db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ 17db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ 18db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 19db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include <string> 20db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 21db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 229934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower#include "tensorflow/core/framework/types.pb.h" 23db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/lib/core/errors.h" 24db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/lib/core/status.h" 25db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/lib/core/stringpiece.h" 26db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/lib/gtl/array_slice.h" 27db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/lib/gtl/inlined_vector.h" 28db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/lib/strings/strcat.h" 29db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#include "tensorflow/core/platform/logging.h" 30db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 31db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergnamespace tensorflow { 32db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 33120c3f11f857ce6da7ae90b3eab943896499dca2A. Unique TensorFlower// START_SKIP_DOXYGEN 34f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingtemplate <class Shape> 35f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass TensorShapeIter; 36f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass TensorShape; 37e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irvingclass TensorShapeProto; 38f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass PartialTensorShape; 39120c3f11f857ce6da7ae90b3eab943896499dca2A. Unique TensorFlower// END_SKIP_DOXYGEN 40120c3f11f857ce6da7ae90b3eab943896499dca2A. Unique TensorFlower 41f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// Internal representation for both TensorShape and PartialTensorShape. 42f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass TensorShapeRep { 43db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg public: 44f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving ~TensorShapeRep(); 45df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 46df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower /// Copy the specified shape 47f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeRep(const TensorShapeRep& b); 48f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void operator=(const TensorShapeRep& b); 49df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 50f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower /// Move the specified shape. After moving, <b> is safe for destruction and 51f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // can be reassigned into, but its dimensions and number of elements can be 52f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // nonsensical (e.g., negative dimension sizes, or number of elements not 53f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // properly recomputed). 54f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeRep(TensorShapeRep&& b); 55f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void operator=(TensorShapeRep&& b); 56db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 57f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Clear a tensor shape, producing the scalar shape. 58db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg void Clear(); 59db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 60ffd88925f4a29e68779f5b95eb99d60b40daf184David G. Andersen // Maximum number of dimensions in a tensor. 61f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // It's 254 because 255 = kUnknownRank is used to represent unknown rank. 62f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static constexpr int MaxDimensions() { return 254; } 63db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 64db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg /// \brief Returns the number of elements in the tensor. 65db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg /// 66db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` 67f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully 68f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// defined. 69db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg int64 num_elements() const { return num_elements_; } 70db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 71db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg /// For error messages. 72db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg string DebugString() const; 73d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving static string DebugString(const TensorShapeProto& proto); 74d552be23658b3bdd1b7dedd34f25631773e81dffGeoffrey Irving 75df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower void DumpRep() const; // XXX 76e8ee5286a686c6fc3057ba7cf9ba9ef7003789a6Geoffrey Irving 77f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving protected: 78f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Constructable only via TensorShapeBase 79f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeRep() = default; 80df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 81f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void ClearAllButDataType(); 82e8ee5286a686c6fc3057ba7cf9ba9ef7003789a6Geoffrey Irving 83df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // We use 16 bytes to represent a TensorShape. Because we need to 84df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // be able to support full 64-bit dimension sizes and an arbitrary 85df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // number of dimensions for a Tensor, but most tensor dimensions are 86df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // significantly smaller than 64 bits and most tensors are 1, 2, or 3 87df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // dimensions, we have several representations. 88f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1 89f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1 901bcebcc665e7e9c280be65899002fd3d5a7456a5Geoffrey Irving // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using 91df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // an out of line vector. 92f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown. 93f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // This value is not allowed in TensorShape either for format compatibility. 94df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower struct Rep16 { 95f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving uint16 dims_[6]; 96df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower }; 97df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower struct Rep32 { 98f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving uint32 dims_[3]; 99df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower }; 100df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower struct Rep64 { 101df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower gtl::InlinedVector<int64, 4>* dims_; 102df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower }; 103df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 104f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // We use the max value of uint16 or uint32 to represent unknown shapes, so 105f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // the maximum representable valid shape in these representations is one less. 106f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1; 107f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1; 108f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max(); 109f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max(); 110df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 111df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); } 112df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); } 113df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); } 114df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 115df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); } 116df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); } 117df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); } 118df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 119df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 }; 120df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 1219934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower // Since we have a convenient extra byte available, we allow the 1229934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower // Tensor class to store an 8-bit value in this extra storage. This 1239934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower // allows it to store the Tensor's datatype enum value here and avoid 1249934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower // an extra word of storage. 1259934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower friend class Tensor; 1269934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower friend class TensorShapeTestHelper; 1279934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower DataType data_type() const { return static_cast<DataType>(buf()[13]); } 1289934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower void set_data_type(DataType dt) { 1299934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower // We only have 8 bits available to store DataType, so make sure it fits 130f4002dd3cb62d64053bc6097a70f83c718c24af2A. Unique TensorFlower DCHECK_LT(static_cast<uint32>(dt), 256u); 1319934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower buf()[13] = static_cast<uint8>(dt); 1329934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower } 1339934b9c6841ee56368c7cb1d10053f68a02680baA. Unique TensorFlower 134df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // We store the number of dimensions in byte 14, and the RepTag in byte 15. 135f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Bytes [0..13] vary depending on the representation. 136f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // A value of 255 indicates unknown rank in the PartialTensorShape case. 137f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static const uint8 kUnknownRank = 255; 138df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower uint8 ndims_byte() const { return buf()[14]; } 139df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower void set_ndims_byte(uint8 nd) { buf()[14] = nd; } 140df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 141df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower RepTag tag() const { return static_cast<RepTag>(buf()[15]); } 142df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); } 143df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 144f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void set_num_elements(int64 n) { num_elements_ = n; } 145f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 146f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving private: 147f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void DestructorOutOfLine(); 148f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void SlowCopyFrom(const TensorShapeRep& b); 149f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 150f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving uint8* buf() { return &u_.buf[0]; } 151f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving const uint8* buf() const { return &u_.buf[0]; } 152f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 153df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower union { 154df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower uint8 buf[16]; 155df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // Force data to be aligned enough for a pointer. 156df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower Rep64* unused_aligner; 157df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower } u_; 158db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg int64 num_elements_; 159db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg}; 160db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 161f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// Base class for TensorShape and PartialTensorShape. 162f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// The class is templatized by either TensorShape or PartialTensorShape to 163f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// allow skipping known/unknown checks in the TensorShape case, but the 164f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// representation is shared exactly for fast conversion. 165f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingtemplate <class Shape> 166f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass TensorShapeBase : public TensorShapeRep { 167f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving public: 168f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Construct a `TensorShapeBase` from the provided sizes. 169f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape) 170f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes); 171f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeBase(std::initializer_list<int64> dim_sizes) 172f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {} 173f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 174f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Construct an empty TensorShape, or an unknown rank PartialTensorShape 175f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeBase(); 176f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 177f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeBase(const TensorShapeProto& proto); 178f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 179f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Returns `true` iff `proto` is a valid tensor shape. 180f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // For TensorShape, the proto shape must be fully defined. 181f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static bool IsValid(const TensorShapeProto& proto); 182f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 183f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error 184f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// status otherwise. 185f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static Status IsValidShape(const TensorShapeProto& proto); 186f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 187f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Add a dimension to the end ("inner-most"). 188f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `size >= 0` 189f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void AddDim(int64 size); 190f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 191f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Appends all the dimensions from `shape`. 192f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void AppendShape(const TensorShapeBase& shape); 193f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 194f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Maximum number of dimensions in a tensor. 195f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static constexpr int MaxDimensions() { return 254; } 196f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 197f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Insert a dimension somewhere in the `TensorShape`. 198f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `0 <= d <= dims()` 199f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `size >= 0` 200f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void InsertDim(int d, int64 size); 201f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 202f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Modifies the size of the dimension `d` to be `size` 203f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `0 <= d < dims()` 204f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `size >= 0` 205f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void set_dim(int d, int64 size); 206f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 207f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Removes dimension `d` from the `TensorShape`. 208f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `0 <= d < dims()` 209b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower void RemoveDim(int d) { 210b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower CHECK_GE(d, 0); 211b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower RemoveDimRange(d, d + 1); 212b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower } 213b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower 214b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower /// \brief Removes last `n` dimensions from the `TensorShape`. 215b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower /// REQUIRES: `0 <= n <= dims()` 216b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower void RemoveLastDims(int n) { 217b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower CHECK_LE(n, dims()); 218b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower RemoveDimRange(dims() - n, dims()); 219b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower } 220b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower 221b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`. 222b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower /// Negative values of `end` are interpreted as `dims() + end + 1` (as in 223b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower /// Python). The same is true for negative values of `begin`. REQUIRES: 224b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()` 225b49aef3c17e39a9c7127477333c7cc76aef57a80A. Unique TensorFlower void RemoveDimRange(int begin, int end); 226f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 227f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Return whether the rank is unknown 228f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool unknown_rank() const { 229f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving return kIsPartial && ndims_byte() == kUnknownRank; 230f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving } 231f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 232f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Return the number of dimensions in the tensor. 233f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Can be -1 meaning unknown rank for PartialTensorShape. 234f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving int dims() const { 235f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving uint8 dims = ndims_byte(); 236f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving return kIsPartial && dims == kUnknownRank ? -1 : dims; 237f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving } 238f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 239f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Returns the number of elements in dimension `d`. 240f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `0 <= d < dims()` 241f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // TODO(touts): Rename to `dimension()` to match 242f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // `Eigen::Tensor::dimension()`? 243f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving int64 dim_size(int d) const; 244f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 245f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Returns sizes of all dimensions. 246f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Returns an empty list for unknown rank PartialTensorShape. 247f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving gtl::InlinedVector<int64, 4> dim_sizes() const; 248f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 249f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Return true iff the rank and all of the dimensions are well defined 250f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // TODO(irving): Rename to is_fully_defined now that it's fast. 251f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; } 252f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 253f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Fill `*proto` from `*this`. 254f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void AsProto(TensorShapeProto* proto) const; 255f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 256f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// For iterating through the dimensions. 257f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeIter<Shape> begin() const; 258f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeIter<Shape> end() const; 259f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 260f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving private: 261f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void RecomputeNumElements(); 262f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 263f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // True for PartialTensorShape, false for TensorShape 264f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static constexpr bool kIsPartial = 265f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving std::is_same<Shape, PartialTensorShape>::value; 266f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value, 267f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving "Shape is neither TensorShape nor PartialTensorShape"); 268f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 269f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Used by AddDim and MakeShapeHelper. Does no error checking. 270f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void UnsafeAddDim(int64 size, int64 new_num_elements); 271f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 272f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // For use by TensorShapeUtils::MakeShape 273f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving template <class T, class S> 274f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving friend Status MakeShapeHelper(const T*, int64, S*); 275f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving}; 276f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 277f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// Represents the shape of a Tensor. 278f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// 279f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// A tensor's shape is denoted by its number of dimensions and a size for each 280f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// dimension. For example, a Tensor represented by a 3 x 4 matrix would have 281f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// a shape of 2-D, [3,4]. 282f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// 283f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// If you know the exact shape of your Tensor when you create the TensorShape 284f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// object, you can specify it then, or you can create a TensorShape with 285f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// zero dimensions and one element, and call AddDim() to add dimensions later. 286f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass TensorShape : public TensorShapeBase<TensorShape> { 287f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving public: 288f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving using TensorShapeBase<TensorShape>::TensorShapeBase; 289f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 290f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Allow a TensorShape to be used as a PartialTensorShape without copying 291f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving operator const PartialTensorShape&() const; // NOLINT(runtime/explicit) 292f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 293f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Returns true if `*this` and `b` have the same sizes. Ignores 294f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// dimension names. 295f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool IsSameSize(const TensorShape& b) const; 296f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool operator==(const TensorShape& b) const { return IsSameSize(b); } 297f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool operator!=(const TensorShape& b) const { return !IsSameSize(b); } 298f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 299f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Fill `*dsizes` from `*this`. 300f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving template <int NDIMS> 301f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const; 302f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 303f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in 304f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// which case we pad the rest of the sizes with 1. 305f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving template <int NDIMS> 306f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const; 307f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 308f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving private: 309f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // These CHECK fail to ease debugging. 310f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // REQUIRES: dims() == NDIMS 311f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void CheckDimsEqual(int NDIMS) const; 312f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // REQUIRES: dims() >= NDIMS 313f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving void CheckDimsAtLeast(int NDIMS) const; 314f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving}; 315f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 316120c3f11f857ce6da7ae90b3eab943896499dca2A. Unique TensorFlower/// Represents the value of one dimension in a TensorShape. 317db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergstruct TensorShapeDim { 318db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg explicit TensorShapeDim(int64 s) : size(s) {} 319d33346a801181cb601a4cfc95f1087e165b19703Vijay Vasudevan int64 size; 320db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg}; 321db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 322120c3f11f857ce6da7ae90b3eab943896499dca2A. Unique TensorFlower// START_SKIP_DOXYGEN 323f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingtemplate <class Shape> 324db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergclass TensorShapeIter { 325db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg public: 326f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {} 327db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg bool operator==(const TensorShapeIter& rhs) { 328db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg DCHECK(shape_ == rhs.shape_); 329db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg return d_ == rhs.d_; 330db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg } 331db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg bool operator!=(const TensorShapeIter& rhs) { 332db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg DCHECK(shape_ == rhs.shape_); 333db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg return d_ != rhs.d_; 334db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg } 335db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg void operator++() { ++d_; } 336db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } 337db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 338db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg private: 339f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving const Shape* shape_; 340db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg int d_; 341db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg}; 342120c3f11f857ce6da7ae90b3eab943896499dca2A. Unique TensorFlower// END_SKIP_DOXYGEN 343db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 344db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg/// \brief Static helper routines for `TensorShape`. Includes a few common 345db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg/// predicates on a tensor shape. 346db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergclass TensorShapeUtils { 347db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg public: 348db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } 349db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 350db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } 351db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 352db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg static bool IsVectorOrHigher(const TensorShape& shape) { 353db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg return shape.dims() >= 1; 354db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg } 355db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 356db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } 357db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 358053fc50b2cbb1979cbed1662c3d964c36fdafecdA. Unique TensorFlower static bool IsSquareMatrix(const TensorShape& shape) { 359053fc50b2cbb1979cbed1662c3d964c36fdafecdA. Unique TensorFlower return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1); 360053fc50b2cbb1979cbed1662c3d964c36fdafecdA. Unique TensorFlower } 361053fc50b2cbb1979cbed1662c3d964c36fdafecdA. Unique TensorFlower 362db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg static bool IsMatrixOrHigher(const TensorShape& shape) { 363db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg return shape.dims() >= 2; 364db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg } 365db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 366db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg /// \brief Returns a `TensorShape` whose dimensions are 367db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. 36896a1afa21c59999c68887550b0c8c27cdf171339David G. Andersen static Status MakeShape(const int32* dims, int64 n, TensorShape* out); 36996a1afa21c59999c68887550b0c8c27cdf171339David G. Andersen static Status MakeShape(const int64* dims, int64 n, TensorShape* out); 3709d6b7680cfcc5b784401d17fef997f0c089038b1A. Unique TensorFlower static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out); 3719d6b7680cfcc5b784401d17fef997f0c089038b1A. Unique TensorFlower static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out); 372f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out); 373f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out); 374f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static Status MakeShape(gtl::ArraySlice<int32> shape, 375f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape* out); 376f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static Status MakeShape(gtl::ArraySlice<int64> shape, 377f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape* out); 378db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 379725e968a934419eec45b9acaf4bf8dc5f5f0574eGeoffrey Irving static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes); 380db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 381b5388091998dbeb6641b5d29dbceb825ee359acdA. Unique TensorFlower /// \brief Returns true iff `shape` starts with `prefix`. 382b5388091998dbeb6641b5d29dbceb825ee359acdA. Unique TensorFlower static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); 383b5388091998dbeb6641b5d29dbceb825ee359acdA. Unique TensorFlower 384b5388091998dbeb6641b5d29dbceb825ee359acdA. Unique TensorFlower /// \brief Returns true iff `shape` ends with `suffix`. 385b5388091998dbeb6641b5d29dbceb825ee359acdA. Unique TensorFlower static bool EndsWith(const TensorShape& shape, const TensorShape& suffix); 386cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo 387cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo /// \brief Returns the product of values in an int64 array, 388cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo /// or a failing Status if the array represents a value larger than 389cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo /// a `TensorShape` can hold. 390cade141580c76b41ba71bdc4b019722e674ab954Eugene Brevdo static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements); 391db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg}; 392db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 393f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// Manages the partially known dimensions of a Tensor and their sizes. 394f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass PartialTensorShape : public TensorShapeBase<PartialTensorShape> { 395f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving public: 396f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape() {} 397f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving using TensorShapeBase<PartialTensorShape>::TensorShapeBase; 398f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 399f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Add a dimension to the end ("inner-most"), returns a new 400f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// PartialTensorShape. 401f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// REQUIRES: `size >= -1`, where -1 means unknown 402f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape Concatenate(int64 size) const; 403f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 404f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Appends all the dimensions from `shape`. Returns a new 405f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// PartialTensorShape. 406f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape Concatenate(const PartialTensorShape& shape) const; 407f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 408f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Merges all the dimensions from `shape`. Returns 409f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// `InvalidArgument` error if either `shape` has a different rank 410f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// or if any of the dimensions are incompatible. 411f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving Status MergeWith(const PartialTensorShape& shape, 412f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape* result) const; 413f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 414f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Exact equality test. Returns true iff the ranks match (i.e., both are 415f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// unknown, or both are known and equal), and all dimensions are equal (i.e., 416f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// both dimensions are known, or both are known and equal). This is a 417f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// stronger condition that IsCompatibleWith. 418f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool IsIdenticalTo(const PartialTensorShape& shape) const; 419f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 420f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// Return true iff the ranks match, and if the 421f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// dimensions all either match or one is unknown. 422f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool IsCompatibleWith(const PartialTensorShape& shape) const; 423f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 424f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Fill `*shape` from `*this`. 425f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // If `*this` is not fully defined, returns false and 426f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // `*shape` is left in an intermediate state. Otherwise 427f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // returns true. 428f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving bool AsTensorShape(TensorShape* shape) const; 429f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 430f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// \brief Returns a `PartialTensorShape` whose dimensions are 431f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are 432f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving /// considered "unknown". 433f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving template <class T> 434f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static Status MakePartialShape(const T* dims, int n, 435f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving PartialTensorShape* out) { 436f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving return TensorShapeUtils::MakeShape(dims, n, out); 437f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving } 438f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving}; 439f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 440f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// \brief Static helper routines for `PartialTensorShape`. Includes a few 441f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving/// common predicates on a partially known tensor shape. 442f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingclass PartialTensorShapeUtils { 443f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving public: 444f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static string PartialShapeListString( 445f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving const gtl::ArraySlice<PartialTensorShape>& shapes); 446f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 447f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0, 448f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving const gtl::ArraySlice<PartialTensorShape>& shapes1); 449f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 450f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0, 451f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving const gtl::ArraySlice<PartialTensorShape>& shapes1); 452f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving}; 453f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 454db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg// ---------------------------------------------------------------------------- 455db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg// Template method implementation details below 456db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg// ---------------------------------------------------------------------------- 457db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 458db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergtemplate <int NDIMS> 459db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergEigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const { 460821920c1f25968f5dfcd2f8999b293ebedf85957A. Unique TensorFlower CheckDimsEqual(NDIMS); 461db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg return AsEigenDSizesWithPadding<NDIMS>(); 462db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg} 463db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 464db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenbergtemplate <int NDIMS> 465db7478e8998f7703c57a75a950c905ec0cb59d7bJosh LevenbergEigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() 466db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg const { 467821920c1f25968f5dfcd2f8999b293ebedf85957A. Unique TensorFlower CheckDimsAtLeast(NDIMS); 4684d9ec5ece5771a1982352574ce2cad587644fadaDavid G. Andersen static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions"); 469db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; 470db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg for (int d = 0; d < dims(); d++) { 471db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg dsizes[d] = dim_size(d); 472db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg } 473db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg for (int d = dims(); d < NDIMS; d++) { 474db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg dsizes[d] = 1; 475db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg } 476db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg return dsizes; 477db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg} 478db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 479df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower// ---------------------------------------------------------------------------- 480df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower// Inlining of some performance critical routines 481df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower// ---------------------------------------------------------------------------- 482df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 483f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvinginline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { 484df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower num_elements_ = b.num_elements_; 485df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower if (b.tag() != REP_OUT_OF_LINE) { 486df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower memcpy(buf(), b.buf(), sizeof(u_.buf)); 487df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // memcpy above Implicitly does: 488df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // set_ndims_byte(b.ndims_byte()); 489df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // set_tag(b.tag()); 490df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower } else { 491df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower set_tag(REP16); // So that SlowCopyFrom does not try to deallocate 492df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower SlowCopyFrom(b); 493df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower } 494df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower} 495df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 496f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvinginline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { 497f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower num_elements_ = b.num_elements_; 498f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower memcpy(buf(), b.buf(), sizeof(u_.buf)); 499f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // memcpy above Implicitly does: 500f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // set_ndims_byte(b.ndims_byte()); 501f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // set_tag(b.tag()); 502f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. 503f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower} 504f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower 505f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvinginline TensorShapeRep::~TensorShapeRep() { 506df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower if (tag() == REP_OUT_OF_LINE) { 507b76812c8a84c0efd45342121bfc1cff2b6bb1051A. Unique TensorFlower DestructorOutOfLine(); 508df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower } 509df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower} 510df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 511f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvinginline void TensorShapeRep::operator=(const TensorShapeRep& b) { 512df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower num_elements_ = b.num_elements_; 513df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { 514df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower memcpy(buf(), b.buf(), sizeof(u_.buf)); 515df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // memcpy above implicitly also does: 516df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // set_tag(b.tag()); 517df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower // set_ndims_byte(b.ndims_byte()); 518df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower } else { 519df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower SlowCopyFrom(b); 520df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower } 521df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower} 522df66b9fe049cb17e58454f309b662bdaf0d14fdbA. Unique TensorFlower 523f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvinginline void TensorShapeRep::operator=(TensorShapeRep&& b) { 524f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower if (tag() == REP_OUT_OF_LINE) { 525f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower DestructorOutOfLine(); 526f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower } 527f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower num_elements_ = b.num_elements_; 528f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower memcpy(buf(), b.buf(), sizeof(u_.buf)); 529f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // memcpy above Implicitly does: 530f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // set_ndims_byte(b.ndims_byte()); 531f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower // set_tag(b.tag()); 532f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. 533f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower} 534f7a662e595f1631c12d15173344ee0f50d2cd9f8A. Unique TensorFlower 535f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvinginline TensorShape::operator const PartialTensorShape&() const { 536f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving // Downcast to the shared representation and upcast to PartialTensorShape 537f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving const TensorShapeRep* rep = this; 538f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving return *static_cast<const PartialTensorShape*>(rep); 539f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving} 540f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 541f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving// Declare explicit instantiations in .cc file 542f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingextern template class TensorShapeBase<TensorShape>; 543f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irvingextern template class TensorShapeBase<PartialTensorShape>; 544f5bbcdf2d05ddc3b24a17ac7fc2cfdf36ef2ee2bGeoffrey Irving 545db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg} // namespace tensorflow 546db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg 547db7478e8998f7703c57a75a950c905ec0cb59d7bJosh Levenberg#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ 548