1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Utility functions related to layouts of Shapes. 17 18#ifndef TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_ 19#define TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_ 20 21#include <vector> 22 23#include "tensorflow/compiler/xla/types.h" 24#include "tensorflow/compiler/xla/xla_data.pb.h" 25#include "tensorflow/core/lib/gtl/array_slice.h" 26#include "tensorflow/core/platform/macros.h" 27 28namespace xla { 29 30// Namespaced collection of (static) utilities related to indexing into 31// multidimensional arrays. 32class IndexUtil { 33 public: 34 // Converts a multidimensional index (eg {x, y, z}) into a linear index based 35 // on the shape and its layout. The first index in the multi_index is 36 // dimension 0. 37 static int64 MultidimensionalIndexToLinearIndex( 38 const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index); 39 40 // Converts a linear index into multidimensional index (eg {x, y, z}) based on 41 // the shape and its layout. The first index in the returned multidimensional 42 // index is dimension 0. 43 static std::vector<int64> LinearIndexToMultidimensionalIndex( 44 const Shape& shape, int64 linear_index); 45 46 // Bumps a sequence of indices; e.g. {0,0,0,0} up by one index value; e.g. to 47 // {0,0,0,1}. This is akin to std::next_permutation. If the index hits a limit 48 // for the provided shape, the next most significant index is bumped, in a 49 // counting-up process. 50 // 51 // E.g. for shape f32[2,3] 52 // {0,0}=>{0,1} 53 // {0,1}=>{0,2} 54 // {0,2}=>{1,0} 55 // etc. 56 // 57 // This is useful for traversing the indices in a literal. 58 // 59 // Returns true iff the indices were successfully bumped; false if we've hit 60 // the limit where it can no longer be bumped in-bounds. 61 static bool BumpIndices(const Shape& shape, 62 tensorflow::gtl::MutableArraySlice<int64> indices); 63 64 // Calculates the stride size (in number of elements, not byte size) of a 65 // given logical shape dimension (from 0 to rank-1). If available, padded 66 // dimensions are used. 67 // Example: 68 // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == 69 // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 70 static int64 GetDimensionStride(const Shape& shape, int64 dimension); 71 72 // Returns true iff the given multi-index is contained in the bounds for the 73 // shape. 74 static bool IndexInBounds(const Shape& shape, 75 tensorflow::gtl::ArraySlice<int64> index); 76 77 // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are 78 // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, 79 // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is 80 // returned. 81 static int CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs, 82 tensorflow::gtl::ArraySlice<int64> rhs); 83 84 private: 85 TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); 86}; 87 88} // namespace xla 89 90#endif // TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_ 91