1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_UTIL_TENSOR_FORMAT_H_ 17#define TENSORFLOW_UTIL_TENSOR_FORMAT_H_ 18 19#include <array> 20#include <vector> 21 22#include "tensorflow/core/framework/tensor.h" 23#include "tensorflow/core/lib/gtl/inlined_vector.h" 24#include "tensorflow/core/platform/types.h" 25 26namespace tensorflow { 27 28// Tensor format for input/output activations used in convolution operations. 29// The mnemonics specify the meaning of each tensor dimension sorted from 30// largest to smallest memory stride. 31// N = Batch, H = Image Height, W = Image Width, C = Number of Channels. 32enum TensorFormat { 33 // FORMAT_NHWC is the default format in TensorFlow. 34 FORMAT_NHWC = 0, 35 36 // FORMAT_NCHW often improves performance on GPUs. 37 FORMAT_NCHW = 1, 38 39 // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized 40 // int8 convolution and fused convolution. It is laid out in the same order 41 // as NCHW, except that the size of the Channels dimension is divided by 4, 42 // and a new dimension of size 4 is appended, which packs 4 adjacent channel 43 // activations for the same pixel into an int32. Thus an NCHW format tensor 44 // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in 45 // NCHW_VECT_C format. 46 // A pre-condition of this format is that C must be a multiple of 4. 47 FORMAT_NCHW_VECT_C = 2, 48}; 49 50// Tensor format for convolutional filters. 51// The mnemonics specify the meaning of each tensor dimension sorted 52// from largest to smallest memory stride. 53// H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels. 54// Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'. 55enum FilterTensorFormat { 56 // FORMAT_HWIO is the default filter format in TensorFlow. 57 // Ops that do not have a 'filter_format' attribute will assume this format. 58 FORMAT_HWIO = 0, 59 60 // FORMAT_OIHW often improves performance on GPUs. 61 FORMAT_OIHW = 1, 62 63 // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized 64 // int8 convolution and fused convolution. It is analagous to the NCHW_VECT_C 65 // data format. It is laid out in the same order as OIHW, except that the size 66 // of the Input Channels dimension is divided by 4, and a new dimension of 67 // size 4 is appended, which packs 4 adjacent input channel weights into an 68 // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have 69 // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format. 70 // A pre-condition of this format is that I must be a multiple of 4. 71 FORMAT_OIHW_VECT_I = 2, 72}; 73 74// Parse tensor format from the given string. 75// Return true if the parsing succeeds, and false if it fails. 76bool FormatFromString(const string& format_str, TensorFormat* format); 77 78// Parse tensor format from the given string. 79// Return true if the parsing succeeds, and false if it fails. 80bool FilterFormatFromString(const string& format_str, 81 FilterTensorFormat* format); 82 83// Convert a tensor format into string. 84string ToString(TensorFormat format); 85 86// Convert a filter tensor format into string. 87string ToString(FilterTensorFormat format); 88 89// Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor 90// format 'format'. 91inline int GetTensorSpatialDims(int num_dims, TensorFormat format) { 92 if (format == FORMAT_NCHW_VECT_C) { 93 return num_dims - 3; // Exclude N,C,InnerC. 94 } else { 95 return num_dims - 2; // Exclude N,C. 96 } 97} 98 99inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) { 100 if (format == FORMAT_OIHW_VECT_I) { 101 return num_dims - 3; // Exclude O,I,InnerI. 102 } else { 103 return num_dims - 2; // Exclude O,I. 104 } 105} 106 107// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and 108// tensor format 'format'. This is the inverse of GetTensorSpatialDims. 109inline int GetTensorDimsFromSpatialDims(int num_spatial_dims, 110 TensorFormat format) { 111 if (format == FORMAT_NCHW_VECT_C) { 112 return num_spatial_dims + 3; // Include N,C,InnerC. 113 } else { 114 return num_spatial_dims + 2; // Include N,C. 115 } 116} 117 118// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and 119// filter tensor format 'format'. 120inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims, 121 FilterTensorFormat format) { 122 if (format == FORMAT_OIHW_VECT_I) { 123 return num_spatial_dims + 3; // Include O,I,InnerI. 124 } else { 125 return num_spatial_dims + 2; // Include O,I. 126 } 127} 128 129// Returns the index of the batch dimension. 130inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) { 131 switch (format) { 132 case FORMAT_NHWC: 133 case FORMAT_NCHW: 134 case FORMAT_NCHW_VECT_C: 135 return 0; 136 default: 137 LOG(FATAL) << "Unknown format " << format; 138 return -1; // Avoid compiler warning about missing return value 139 } 140} 141 142// Returns the index of the feature dimension. If format is NCHW_VECT_C, returns 143// the index of the outer feature dimension (i.e. dimension 1, whose size would 144// be num_features / 4 in this case). 145inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) { 146 switch (format) { 147 case FORMAT_NHWC: 148 return num_dims - 1; 149 case FORMAT_NCHW: 150 case FORMAT_NCHW_VECT_C: 151 return 1; 152 default: 153 LOG(FATAL) << "Unknown format " << format; 154 return -1; // Avoid compiler warning about missing return value 155 } 156} 157 158// Returns the index of the inner feature dimension. 159inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) { 160 DCHECK_EQ(format, FORMAT_NCHW_VECT_C); 161 return num_dims - 1; 162} 163 164// Returns the index of the `dim`-th spatial dimension. 165inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format, 166 int dim) { 167 CHECK(dim >= 0 && dim < GetTensorSpatialDims(num_dims, format)) 168 << dim << " " << num_dims << " " << ToString(format); 169 switch (format) { 170 case FORMAT_NHWC: 171 return dim + 1; 172 case FORMAT_NCHW: 173 case FORMAT_NCHW_VECT_C: 174 return dim + 2; 175 default: 176 LOG(FATAL) << "Unknown format " << format; 177 return -1; // Avoid compiler warning about missing return value 178 } 179} 180 181// Returns the index of the `dim`-th spatial dimension. 182inline int GetFilterTensorSpatialDimIndex(int num_dims, 183 FilterTensorFormat format, int dim) { 184 CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format)) 185 << dim << " " << num_dims << " " << ToString(format); 186 switch (format) { 187 case FORMAT_HWIO: 188 return dim; 189 case FORMAT_OIHW: 190 case FORMAT_OIHW_VECT_I: 191 return dim + 2; 192 default: 193 LOG(FATAL) << "Unknown format " << format; 194 return -1; // Avoid compiler warning about missing return value 195 } 196} 197 198// Returns the index of the inner input channels dimension. 199inline int GetFilterTensorInnerInputChannelsDimIndex( 200 int num_dims, FilterTensorFormat format) { 201 DCHECK_EQ(format, FORMAT_OIHW_VECT_I); 202 return num_dims - 1; 203} 204 205// Returns the index of the input channels dimension. 206// If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the 207// outer input channel (i.e. 1), which holds num_input_channels / 4. 208inline int GetFilterTensorInputChannelsDimIndex(int num_dims, 209 FilterTensorFormat format) { 210 switch (format) { 211 case FORMAT_HWIO: 212 return num_dims - 2; 213 case FORMAT_OIHW: 214 case FORMAT_OIHW_VECT_I: 215 return 1; 216 default: 217 LOG(FATAL) << "Unknown format " << format; 218 return -1; // Avoid compiler warning about missing return value 219 } 220} 221 222// Returns the index of the output channels dimension. 223inline int GetFilterTensorOutputChannelsDimIndex(int num_dims, 224 FilterTensorFormat format) { 225 switch (format) { 226 case FORMAT_HWIO: 227 return num_dims - 1; 228 case FORMAT_OIHW: 229 case FORMAT_OIHW_VECT_I: 230 return 0; 231 default: 232 LOG(FATAL) << "Unknown format " << format; 233 return -1; // Avoid compiler warning about missing return value 234 } 235} 236 237// TODO(pauldonnelly): Replace these tensor dimension index functions with 238// constant structs to improve performance and reduce code size in Compute() 239// functions. 240 241// Return the dimension index for the specified 'dimension' of the specified 242// data 'tensor_format'. 'dimension' is a char that can be 'N' (batch size), 243// 'C' (channels), 'H' (height), 'W' (width), or a numbered spatial dimension: 244// '0', .. (NUM_SPATIAL_DIMS-1).. 245// If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of 246// the outer channel dimension (i.e. 1). 247template <int NUM_SPATIAL_DIMS> 248inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { 249 if (format == FORMAT_NHWC) { 250 // clang-format off 251 switch (dimension) { 252 case 'N': return 0; 253 case '0': return 1; 254 case '1': return 2; 255 case '2': return 3; 256 case 'H': return NUM_SPATIAL_DIMS - 1; 257 case 'W': return NUM_SPATIAL_DIMS; 258 case 'C': return NUM_SPATIAL_DIMS + 1; 259 default: 260 LOG(FATAL) << "Invalid dimension: " << dimension; 261 return -1; // Avoid compiler warning about missing return value 262 } 263 } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) { 264 switch (dimension) { 265 case 'N': return 0; 266 case 'C': return 1; 267 case '0': return 2; 268 case '1': return 3; 269 case '2': return 4; 270 case 'H': return NUM_SPATIAL_DIMS; 271 case 'W': return NUM_SPATIAL_DIMS + 1; 272 default: 273 LOG(FATAL) << "Invalid dimension: " << dimension; 274 return -1; // Avoid compiler warning about missing return value 275 } 276 } else { 277 LOG(FATAL) << "Invalid format: " << static_cast<int>(format); 278 return -1; // Avoid compiler warning about missing return value 279 } 280 // clang-format on 281} 282 283// Return the dimension index for the specified 'dimension' of the specified 284// 'filter_tensor_format'. 'dimension' is a char that can be 'O' (num output 285// channels), 'I' (num input channels), 'H' (height), 'W' (width), or a 286// numbered spatial dimension: '0', .. (NUM_SPATIAL_DIMS-1). 287// If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the 288// outer input channels dimension (i.e. 1). 289template <int NUM_SPATIAL_DIMS> 290inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format, 291 char dimension) { 292 // clang-format off 293 if (filter_tensor_format == FORMAT_HWIO) { 294 switch (dimension) { 295 case '0': return 0; 296 case '1': return 1; 297 case '2': return 2; 298 case 'H': return NUM_SPATIAL_DIMS - 2; 299 case 'W': return NUM_SPATIAL_DIMS - 1; 300 case 'I': return NUM_SPATIAL_DIMS; 301 case 'O': return NUM_SPATIAL_DIMS + 1; 302 default: 303 LOG(FATAL) << "Invalid dimension: " << dimension; 304 return -1; // Avoid compiler warning about missing return value 305 } 306 } else if (filter_tensor_format == FORMAT_OIHW || 307 filter_tensor_format == FORMAT_OIHW_VECT_I) { 308 switch (dimension) { 309 case 'O': return 0; 310 case 'I': return 1; 311 case '0': return 2; 312 case '1': return 3; 313 case '2': return 4; 314 case 'H': return NUM_SPATIAL_DIMS; 315 case 'W': return NUM_SPATIAL_DIMS + 1; 316 default: 317 LOG(FATAL) << "Invalid dimension: " << dimension; 318 return -1; // Avoid compiler warning about missing return value 319 } 320 } else { 321 LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format); 322 return -1; // Avoid compiler warning about missing return value 323 } 324 // clang-format on 325} 326 327inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { 328 return GetTensorDimIndex<2>(format, dimension); 329} 330 331// Return the element from 'dimension_attributes' that corresponds to the 332// specified 'dimension' according to 'tensor_format'. 333template <typename T> 334T GetTensorDim(gtl::ArraySlice<T> dimension_attributes, 335 TensorFormat tensor_format, char dimension) { 336 int index = 337 (GetTensorSpatialDims(dimension_attributes.size(), tensor_format) == 3) 338 ? GetTensorDimIndex<3>(tensor_format, dimension) 339 : GetTensorDimIndex<2>(tensor_format, dimension); 340 CHECK(index >= 0 && index < dimension_attributes.size()) 341 << "Invalid index from the dimension: " << index << ", " << tensor_format 342 << ", " << dimension; 343 return dimension_attributes[index]; 344} 345 346// Return the element from 'dimension_attribute' that corresponds to the 347// specified 'dimension' according to 'filter_tensor_format'. 348template <typename T> 349T GetFilterDim(gtl::ArraySlice<T> dimension_attribute, 350 FilterTensorFormat filter_tensor_format, char dimension) { 351 int index = (GetFilterTensorSpatialDims(dimension_attribute.size(), 352 filter_tensor_format) == 3) 353 ? GetFilterDimIndex<3>(filter_tensor_format, dimension) 354 : GetFilterDimIndex<2>(filter_tensor_format, dimension); 355 CHECK(index >= 0 && index < dimension_attribute.size()) 356 << "Invalid index from the dimension: " << index << ", " 357 << filter_tensor_format << ", " << dimension; 358 return dimension_attribute[index]; 359} 360 361template <typename T> 362T GetTensorDim(const std::vector<T>& attributes, TensorFormat format, 363 char dimension) { 364 return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension); 365} 366 367// Return the size of the specified 'dimension' within 'tensor_shape' 368// according to 'tensor_format'. 369inline int64 GetTensorDim(const TensorShape& tensor_shape, 370 TensorFormat tensor_format, char dimension) { 371 return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()), 372 tensor_format, dimension); 373} 374 375// Return the size of the specified 'dimension' within 'tensor_shape' 376// according to 'tensor_filter_format'. 377inline int64 GetFilterDim(const TensorShape& tensor_shape, 378 FilterTensorFormat tensor_filter_format, 379 char dimension) { 380 return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()), 381 tensor_filter_format, dimension); 382} 383 384// Return the size of the specified 'dimension' of 'tensor' according to 385// 'tensor_format'. 386inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format, 387 char dimension) { 388 return GetTensorDim(tensor.shape(), tensor_format, dimension); 389} 390 391// Return the size of the specified 'dimension' of 'tensor' according to 392// 'filter_tensor_format'. 393inline int64 GetFilterDim(const Tensor& tensor, 394 FilterTensorFormat filter_tensor_format, 395 char dimension) { 396 return GetFilterDim(tensor.shape(), filter_tensor_format, dimension); 397} 398 399// Return the string that specifies the data format for convnet operations. 400string GetConvnetDataFormatAttrString(); 401string GetConvnet3dDataFormatAttrString(); 402 403// Return the string that specifies the filter format for convnet operations. 404string GetConvnetFilterFormatAttrString(); 405string GetConvnet3dFilterFormatAttrString(); 406 407// Return a tensor shape for the given format. Works for both 2D and 3D 408// operations. If format is FORMAT_NCHW_VECT_C, the output TensorShape has rank 409// spatial.size()+3 (N,C,spatial,InnerC); otherwise, it has rank 410// spatial.size()+2 (e.g. N,C,spatial or N,spatial,C). 411inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, 412 gtl::ArraySlice<int64> spatial, int64 C) { 413 const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format); 414 gtl::InlinedVector<int64, 6> dim_sizes(dims); 415 dim_sizes[GetTensorBatchDimIndex(dims, format)] = N; 416 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) { 417 dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = spatial[dim]; 418 } 419 420 int feature_index = GetTensorFeatureDimIndex(dims, format); 421 if (format == FORMAT_NCHW_VECT_C) { 422 CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C=" 423 << C; 424 dim_sizes[feature_index] = C / 4; 425 dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4; 426 } else { 427 dim_sizes[feature_index] = C; 428 } 429 return TensorShape(dim_sizes); 430} 431 432// Return a tensor shape of the specified 'format', and dimensions. 433// Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I, 434// the output TensorShape has spatial.size() + 3 dimensions, otherwise 435// it has spatial.size() + 2 dimensions. 436inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format, 437 gtl::ArraySlice<int64> spatial, 438 int64 I, int64 O) { 439 const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format); 440 gtl::InlinedVector<int64, 6> dim_sizes(dims); 441 dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O; 442 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) { 443 dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim]; 444 } 445 446 if (format == FORMAT_OIHW_VECT_I) { 447 CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I=" 448 << I; 449 I /= 4; 450 dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4; 451 } 452 dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I; 453 return TensorShape(dim_sizes); 454} 455 456// Return a tensor shape of the specified 'format', and dimensions. 457inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H, 458 int64 W, int64 C) { 459 return ShapeFromFormat(format, N, {H, W}, C); 460} 461 462// Return a filter tensor shape of the specified 'format', and dimensions. 463inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format, 464 int64 H, int64 W, int64 I, 465 int64 O) { 466 return ShapeFromFilterTensorFormat(format, {H, W}, I, O); 467} 468 469// Returns a copy of the specified tensor 'src_shape' converted from 470// 'src_format' to 'dst_format'. 471inline TensorShape ShapeFromFormat(TensorFormat dst_format, 472 const TensorShape& src_shape, 473 TensorFormat src_format) { 474 if (src_format == dst_format) { 475 return src_shape; 476 } 477 478 const int64 batch = GetTensorDim(src_shape, src_format, 'N'); 479 const int64 channels = GetTensorDim(src_shape, src_format, 'C') * 480 (src_format == FORMAT_NCHW_VECT_C ? 4 : 1); 481 482 if (GetTensorSpatialDims(src_shape.dims(), src_format) == 3) { 483 return ShapeFromFormat(dst_format, batch, 484 {{GetTensorDim(src_shape, src_format, '0'), 485 GetTensorDim(src_shape, src_format, '1'), 486 GetTensorDim(src_shape, src_format, '2')}}, 487 channels); 488 } 489 490 return ShapeFromFormat(dst_format, batch, 491 {{GetTensorDim(src_shape, src_format, 'H'), 492 GetTensorDim(src_shape, src_format, 'W')}}, 493 channels); 494} 495 496// Returns a copy of the specified filter tensor 'src_shape' converted from 497// 'src_filter_format' to 'dst_filter_format'. 498inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format, 499 const TensorShape& src_shape, 500 FilterTensorFormat src_filter_format) { 501 if (src_filter_format == dst_filter_format) { 502 return src_shape; 503 } 504 505 const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O'); 506 const int64 input_channels = 507 GetFilterDim(src_shape, src_filter_format, 'I') * 508 (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1); 509 510 if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) { 511 return ShapeFromFilterTensorFormat( 512 dst_filter_format, 513 {{GetFilterDim(src_shape, src_filter_format, '0'), 514 GetFilterDim(src_shape, src_filter_format, '1'), 515 GetFilterDim(src_shape, src_filter_format, '2')}}, 516 input_channels, output_channels); 517 } 518 519 return ShapeFromFilterTensorFormat( 520 dst_filter_format, 521 {{GetFilterDim(src_shape, src_filter_format, 'H'), 522 GetFilterDim(src_shape, src_filter_format, 'W')}}, 523 input_channels, output_channels); 524} 525 526} // namespace tensorflow 527 528#endif // TENSORFLOW_UTIL_TENSOR_FORMAT_H_ 529