1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_UTIL_TENSOR_FORMAT_H_
17#define TENSORFLOW_UTIL_TENSOR_FORMAT_H_
18
19#include <array>
20#include <vector>
21
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/lib/gtl/inlined_vector.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27
28// Tensor format for input/output activations used in convolution operations.
29// The mnemonics specify the meaning of each tensor dimension sorted from
30// largest to smallest memory stride.
31// N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
32enum TensorFormat {
33  // FORMAT_NHWC is the default format in TensorFlow.
34  FORMAT_NHWC = 0,
35
36  // FORMAT_NCHW often improves performance on GPUs.
37  FORMAT_NCHW = 1,
38
39  // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
40  // int8 convolution and fused convolution. It is laid out in the same order
41  // as NCHW, except that the size of the Channels dimension is divided by 4,
42  // and a new dimension of size 4 is appended, which packs 4 adjacent channel
43  // activations for the same pixel into an int32. Thus an NCHW format tensor
44  // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
45  // NCHW_VECT_C format.
46  // A pre-condition of this format is that C must be a multiple of 4.
47  FORMAT_NCHW_VECT_C = 2,
48};
49
50// Tensor format for convolutional filters.
51// The mnemonics specify the meaning of each tensor dimension sorted
52// from largest to smallest memory stride.
53// H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
54// Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
55enum FilterTensorFormat {
56  // FORMAT_HWIO is the default filter format in TensorFlow.
57  // Ops that do not have a 'filter_format' attribute will assume this format.
58  FORMAT_HWIO = 0,
59
60  // FORMAT_OIHW often improves performance on GPUs.
61  FORMAT_OIHW = 1,
62
63  // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
64  // int8 convolution and fused convolution. It is analagous to the NCHW_VECT_C
65  // data format. It is laid out in the same order as OIHW, except that the size
66  // of the Input Channels dimension is divided by 4, and a new dimension of
67  // size 4 is appended, which packs 4 adjacent input channel weights into an
68  // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
69  // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
70  // A pre-condition of this format is that I must be a multiple of 4.
71  FORMAT_OIHW_VECT_I = 2,
72};
73
74// Parse tensor format from the given string.
75// Return true if the parsing succeeds, and false if it fails.
76bool FormatFromString(const string& format_str, TensorFormat* format);
77
78// Parse tensor format from the given string.
79// Return true if the parsing succeeds, and false if it fails.
80bool FilterFormatFromString(const string& format_str,
81                            FilterTensorFormat* format);
82
83// Convert a tensor format into string.
84string ToString(TensorFormat format);
85
86// Convert a filter tensor format into string.
87string ToString(FilterTensorFormat format);
88
89// Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
90// format 'format'.
91inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
92  if (format == FORMAT_NCHW_VECT_C) {
93    return num_dims - 3;  // Exclude N,C,InnerC.
94  } else {
95    return num_dims - 2;  // Exclude N,C.
96  }
97}
98
99inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
100  if (format == FORMAT_OIHW_VECT_I) {
101    return num_dims - 3;  // Exclude O,I,InnerI.
102  } else {
103    return num_dims - 2;  // Exclude O,I.
104  }
105}
106
107// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
108// tensor format 'format'. This is the inverse of GetTensorSpatialDims.
109inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
110                                        TensorFormat format) {
111  if (format == FORMAT_NCHW_VECT_C) {
112    return num_spatial_dims + 3;  // Include N,C,InnerC.
113  } else {
114    return num_spatial_dims + 2;  // Include N,C.
115  }
116}
117
118// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
119// filter tensor format 'format'.
120inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
121                                              FilterTensorFormat format) {
122  if (format == FORMAT_OIHW_VECT_I) {
123    return num_spatial_dims + 3;  // Include O,I,InnerI.
124  } else {
125    return num_spatial_dims + 2;  // Include O,I.
126  }
127}
128
129// Returns the index of the batch dimension.
130inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
131  switch (format) {
132    case FORMAT_NHWC:
133    case FORMAT_NCHW:
134    case FORMAT_NCHW_VECT_C:
135      return 0;
136    default:
137      LOG(FATAL) << "Unknown format " << format;
138      return -1;  // Avoid compiler warning about missing return value
139  }
140}
141
142// Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
143// the index of the outer feature dimension (i.e. dimension 1, whose size would
144// be num_features / 4 in this case).
145inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
146  switch (format) {
147    case FORMAT_NHWC:
148      return num_dims - 1;
149    case FORMAT_NCHW:
150    case FORMAT_NCHW_VECT_C:
151      return 1;
152    default:
153      LOG(FATAL) << "Unknown format " << format;
154      return -1;  // Avoid compiler warning about missing return value
155  }
156}
157
158// Returns the index of the inner feature dimension.
159inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
160  DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
161  return num_dims - 1;
162}
163
164// Returns the index of the `dim`-th spatial dimension.
165inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
166                                    int dim) {
167  CHECK(dim >= 0 && dim < GetTensorSpatialDims(num_dims, format))
168      << dim << " " << num_dims << " " << ToString(format);
169  switch (format) {
170    case FORMAT_NHWC:
171      return dim + 1;
172    case FORMAT_NCHW:
173    case FORMAT_NCHW_VECT_C:
174      return dim + 2;
175    default:
176      LOG(FATAL) << "Unknown format " << format;
177      return -1;  // Avoid compiler warning about missing return value
178  }
179}
180
181// Returns the index of the `dim`-th spatial dimension.
182inline int GetFilterTensorSpatialDimIndex(int num_dims,
183                                          FilterTensorFormat format, int dim) {
184  CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
185      << dim << " " << num_dims << " " << ToString(format);
186  switch (format) {
187    case FORMAT_HWIO:
188      return dim;
189    case FORMAT_OIHW:
190    case FORMAT_OIHW_VECT_I:
191      return dim + 2;
192    default:
193      LOG(FATAL) << "Unknown format " << format;
194      return -1;  // Avoid compiler warning about missing return value
195  }
196}
197
198// Returns the index of the inner input channels dimension.
199inline int GetFilterTensorInnerInputChannelsDimIndex(
200    int num_dims, FilterTensorFormat format) {
201  DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
202  return num_dims - 1;
203}
204
205// Returns the index of the input channels dimension.
206// If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
207// outer input channel (i.e. 1), which holds num_input_channels / 4.
208inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
209                                                FilterTensorFormat format) {
210  switch (format) {
211    case FORMAT_HWIO:
212      return num_dims - 2;
213    case FORMAT_OIHW:
214    case FORMAT_OIHW_VECT_I:
215      return 1;
216    default:
217      LOG(FATAL) << "Unknown format " << format;
218      return -1;  // Avoid compiler warning about missing return value
219  }
220}
221
222// Returns the index of the output channels dimension.
223inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
224                                                 FilterTensorFormat format) {
225  switch (format) {
226    case FORMAT_HWIO:
227      return num_dims - 1;
228    case FORMAT_OIHW:
229    case FORMAT_OIHW_VECT_I:
230      return 0;
231    default:
232      LOG(FATAL) << "Unknown format " << format;
233      return -1;  // Avoid compiler warning about missing return value
234  }
235}
236
237// TODO(pauldonnelly): Replace these tensor dimension index functions with
238// constant structs to improve performance and reduce code size in Compute()
239// functions.
240
241// Return the dimension index for the specified 'dimension' of the specified
242// data 'tensor_format'.  'dimension' is a char that can be 'N' (batch size),
243// 'C' (channels), 'H' (height), 'W' (width),  or a numbered spatial dimension:
244// '0',  .. (NUM_SPATIAL_DIMS-1)..
245// If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
246// the outer channel dimension (i.e. 1).
247template <int NUM_SPATIAL_DIMS>
248inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
249  if (format == FORMAT_NHWC) {
250    // clang-format off
251    switch (dimension) {
252      case 'N': return 0;
253      case '0': return 1;
254      case '1': return 2;
255      case '2': return 3;
256      case 'H': return NUM_SPATIAL_DIMS - 1;
257      case 'W': return NUM_SPATIAL_DIMS;
258      case 'C': return NUM_SPATIAL_DIMS + 1;
259      default:
260        LOG(FATAL) << "Invalid dimension: " << dimension;
261        return -1;  // Avoid compiler warning about missing return value
262    }
263  } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
264    switch (dimension) {
265      case 'N': return 0;
266      case 'C': return 1;
267      case '0': return 2;
268      case '1': return 3;
269      case '2': return 4;
270      case 'H': return NUM_SPATIAL_DIMS;
271      case 'W': return NUM_SPATIAL_DIMS + 1;
272      default:
273        LOG(FATAL) << "Invalid dimension: " << dimension;
274        return -1;  // Avoid compiler warning about missing return value
275    }
276  } else {
277    LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
278    return -1;  // Avoid compiler warning about missing return value
279  }
280  // clang-format on
281}
282
283// Return the dimension index for the specified 'dimension' of the specified
284// 'filter_tensor_format'.  'dimension' is a char that can be 'O' (num output
285// channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
286// numbered spatial dimension: '0',  .. (NUM_SPATIAL_DIMS-1).
287// If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
288// outer input channels dimension (i.e. 1).
289template <int NUM_SPATIAL_DIMS>
290inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
291                             char dimension) {
292  // clang-format off
293  if (filter_tensor_format == FORMAT_HWIO) {
294    switch (dimension) {
295      case '0': return 0;
296      case '1': return 1;
297      case '2': return 2;
298      case 'H': return NUM_SPATIAL_DIMS - 2;
299      case 'W': return NUM_SPATIAL_DIMS - 1;
300      case 'I': return NUM_SPATIAL_DIMS;
301      case 'O': return NUM_SPATIAL_DIMS + 1;
302      default:
303        LOG(FATAL) << "Invalid dimension: " << dimension;
304        return -1;  // Avoid compiler warning about missing return value
305    }
306  } else if (filter_tensor_format == FORMAT_OIHW ||
307             filter_tensor_format == FORMAT_OIHW_VECT_I) {
308    switch (dimension) {
309      case 'O': return 0;
310      case 'I': return 1;
311      case '0': return 2;
312      case '1': return 3;
313      case '2': return 4;
314      case 'H': return NUM_SPATIAL_DIMS;
315      case 'W': return NUM_SPATIAL_DIMS + 1;
316      default:
317        LOG(FATAL) << "Invalid dimension: " << dimension;
318        return -1;  // Avoid compiler warning about missing return value
319    }
320  } else {
321    LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
322    return -1;  // Avoid compiler warning about missing return value
323  }
324  // clang-format on
325}
326
327inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
328  return GetTensorDimIndex<2>(format, dimension);
329}
330
331// Return the element from 'dimension_attributes' that corresponds to the
332// specified 'dimension' according to 'tensor_format'.
333template <typename T>
334T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
335               TensorFormat tensor_format, char dimension) {
336  int index =
337      (GetTensorSpatialDims(dimension_attributes.size(), tensor_format) == 3)
338          ? GetTensorDimIndex<3>(tensor_format, dimension)
339          : GetTensorDimIndex<2>(tensor_format, dimension);
340  CHECK(index >= 0 && index < dimension_attributes.size())
341      << "Invalid index from the dimension: " << index << ", " << tensor_format
342      << ", " << dimension;
343  return dimension_attributes[index];
344}
345
346// Return the element from 'dimension_attribute' that corresponds to the
347// specified 'dimension' according to 'filter_tensor_format'.
348template <typename T>
349T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
350               FilterTensorFormat filter_tensor_format, char dimension) {
351  int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
352                                          filter_tensor_format) == 3)
353                  ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
354                  : GetFilterDimIndex<2>(filter_tensor_format, dimension);
355  CHECK(index >= 0 && index < dimension_attribute.size())
356      << "Invalid index from the dimension: " << index << ", "
357      << filter_tensor_format << ", " << dimension;
358  return dimension_attribute[index];
359}
360
361template <typename T>
362T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
363               char dimension) {
364  return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
365}
366
367// Return the size of the specified 'dimension' within 'tensor_shape'
368// according to 'tensor_format'.
369inline int64 GetTensorDim(const TensorShape& tensor_shape,
370                          TensorFormat tensor_format, char dimension) {
371  return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
372                      tensor_format, dimension);
373}
374
375// Return the size of the specified 'dimension' within 'tensor_shape'
376// according to 'tensor_filter_format'.
377inline int64 GetFilterDim(const TensorShape& tensor_shape,
378                          FilterTensorFormat tensor_filter_format,
379                          char dimension) {
380  return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
381                      tensor_filter_format, dimension);
382}
383
384// Return the size of the specified 'dimension' of 'tensor' according to
385// 'tensor_format'.
386inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
387                          char dimension) {
388  return GetTensorDim(tensor.shape(), tensor_format, dimension);
389}
390
391// Return the size of the specified 'dimension' of 'tensor' according to
392// 'filter_tensor_format'.
393inline int64 GetFilterDim(const Tensor& tensor,
394                          FilterTensorFormat filter_tensor_format,
395                          char dimension) {
396  return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
397}
398
399// Return the string that specifies the data format for convnet operations.
400string GetConvnetDataFormatAttrString();
401string GetConvnet3dDataFormatAttrString();
402
403// Return the string that specifies the filter format for convnet operations.
404string GetConvnetFilterFormatAttrString();
405string GetConvnet3dFilterFormatAttrString();
406
407// Return a tensor shape for the given format. Works for both 2D and 3D
408// operations. If format is FORMAT_NCHW_VECT_C, the output TensorShape has rank
409// spatial.size()+3 (N,C,spatial,InnerC); otherwise, it has rank
410// spatial.size()+2 (e.g. N,C,spatial or N,spatial,C).
411inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
412                                   gtl::ArraySlice<int64> spatial, int64 C) {
413  const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
414  gtl::InlinedVector<int64, 6> dim_sizes(dims);
415  dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
416  for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
417    dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
418  }
419
420  int feature_index = GetTensorFeatureDimIndex(dims, format);
421  if (format == FORMAT_NCHW_VECT_C) {
422    CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
423                       << C;
424    dim_sizes[feature_index] = C / 4;
425    dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
426  } else {
427    dim_sizes[feature_index] = C;
428  }
429  return TensorShape(dim_sizes);
430}
431
432// Return a tensor shape of the specified 'format', and dimensions.
433// Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
434// the output TensorShape has spatial.size() + 3 dimensions, otherwise
435// it has spatial.size() + 2 dimensions.
436inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
437                                               gtl::ArraySlice<int64> spatial,
438                                               int64 I, int64 O) {
439  const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
440  gtl::InlinedVector<int64, 6> dim_sizes(dims);
441  dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
442  for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
443    dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
444  }
445
446  if (format == FORMAT_OIHW_VECT_I) {
447    CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
448                       << I;
449    I /= 4;
450    dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
451  }
452  dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
453  return TensorShape(dim_sizes);
454}
455
456// Return a tensor shape of the specified 'format', and dimensions.
457inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
458                                   int64 W, int64 C) {
459  return ShapeFromFormat(format, N, {H, W}, C);
460}
461
462// Return a filter tensor shape of the specified 'format', and dimensions.
463inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
464                                               int64 H, int64 W, int64 I,
465                                               int64 O) {
466  return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
467}
468
469// Returns a copy of the specified tensor 'src_shape' converted from
470// 'src_format' to 'dst_format'.
471inline TensorShape ShapeFromFormat(TensorFormat dst_format,
472                                   const TensorShape& src_shape,
473                                   TensorFormat src_format) {
474  if (src_format == dst_format) {
475    return src_shape;
476  }
477
478  const int64 batch = GetTensorDim(src_shape, src_format, 'N');
479  const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
480                         (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
481
482  if (GetTensorSpatialDims(src_shape.dims(), src_format) == 3) {
483    return ShapeFromFormat(dst_format, batch,
484                           {{GetTensorDim(src_shape, src_format, '0'),
485                             GetTensorDim(src_shape, src_format, '1'),
486                             GetTensorDim(src_shape, src_format, '2')}},
487                           channels);
488  }
489
490  return ShapeFromFormat(dst_format, batch,
491                         {{GetTensorDim(src_shape, src_format, 'H'),
492                           GetTensorDim(src_shape, src_format, 'W')}},
493                         channels);
494}
495
496// Returns a copy of the specified filter tensor 'src_shape' converted from
497// 'src_filter_format' to 'dst_filter_format'.
498inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
499                                         const TensorShape& src_shape,
500                                         FilterTensorFormat src_filter_format) {
501  if (src_filter_format == dst_filter_format) {
502    return src_shape;
503  }
504
505  const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
506  const int64 input_channels =
507      GetFilterDim(src_shape, src_filter_format, 'I') *
508      (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
509
510  if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
511    return ShapeFromFilterTensorFormat(
512        dst_filter_format,
513        {{GetFilterDim(src_shape, src_filter_format, '0'),
514          GetFilterDim(src_shape, src_filter_format, '1'),
515          GetFilterDim(src_shape, src_filter_format, '2')}},
516        input_channels, output_channels);
517  }
518
519  return ShapeFromFilterTensorFormat(
520      dst_filter_format,
521      {{GetFilterDim(src_shape, src_filter_format, 'H'),
522        GetFilterDim(src_shape, src_filter_format, 'W')}},
523      input_channels, output_channels);
524}
525
526}  // namespace tensorflow
527
528#endif  // TENSORFLOW_UTIL_TENSOR_FORMAT_H_
529