1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/layout_util.h" 17 18#include <sstream> 19 20#include "tensorflow/compiler/xla/shape_util.h" 21#include "tensorflow/compiler/xla/test.h" 22#include "tensorflow/compiler/xla/test_helpers.h" 23 24namespace xla { 25namespace { 26 27class LayoutUtilTest : public ::testing::Test { 28 protected: 29 Shape MakeShapeWithLayout(PrimitiveType element_type, 30 tensorflow::gtl::ArraySlice<int64> dimensions, 31 tensorflow::gtl::ArraySlice<int64> minor_to_major) { 32 Shape shape = ShapeUtil::MakeShape(element_type, dimensions); 33 *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); 34 return shape; 35 } 36 37 Shape MakeShapeWithSparseLayout(PrimitiveType element_type, 38 tensorflow::gtl::ArraySlice<int64> dimensions, 39 int64 max_sparse_elements) { 40 Shape shape = ShapeUtil::MakeShape(element_type, dimensions); 41 *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); 42 return shape; 43 } 44}; 45 46TEST_F(LayoutUtilTest, TupleLayoutComparison) { 47 Shape shape = 48 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1})}); 49 Shape other_shape = 50 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); 51 52 Shape tuple0 = ShapeUtil::MakeTupleShape({}); 53 Shape tuple1 = ShapeUtil::MakeTupleShape({shape}); 54 Shape tuple2 = ShapeUtil::MakeTupleShape({shape, shape}); 55 56 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple0)); 57 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple1)); 58 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple2)); 59 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple0)); 60 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple0)); 61 62 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple1)); 63 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple2)); 64 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple1)); 65 66 Shape other_tuple2 = ShapeUtil::MakeTupleShape({shape, other_shape}); 67 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple2)); 68 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, other_tuple2)); 69 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(other_tuple2, tuple2)); 70} 71 72TEST_F(LayoutUtilTest, CopyLayoutArray) { 73 Shape src = MakeShapeWithLayout(F32, {2, 3}, {0, 1}); 74 Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); 75 76 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 77 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 78 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 79 80 // Should work if destination has no layout. 81 dst.clear_layout(); 82 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 83 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 84 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 85 86 // If source is cleared, then destination should be cleared. 87 src.clear_layout(); 88 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 89 EXPECT_TRUE(dst.has_layout()); 90 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 91 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 92 EXPECT_FALSE(dst.has_layout()); 93} 94 95TEST_F(LayoutUtilTest, CopyLayoutSparse) { 96 Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2); 97 Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); 98 99 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 100 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 101 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 102 103 // Should work if destination has no layout. 104 dst.clear_layout(); 105 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 106 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 107 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 108 109 // If source is cleared, then destination should be cleared. 110 src.clear_layout(); 111 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 112 EXPECT_TRUE(dst.has_layout()); 113 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 114 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 115 EXPECT_FALSE(dst.has_layout()); 116} 117 118TEST_F(LayoutUtilTest, CopyLayoutTuple) { 119 Shape src = ShapeUtil::MakeTupleShape( 120 {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), 121 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), 122 ShapeUtil::MakeTupleShape( 123 {MakeShapeWithLayout(F32, {}, {}), 124 MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); 125 Shape dst = ShapeUtil::MakeTupleShape( 126 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), 127 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), 128 ShapeUtil::MakeTupleShape( 129 {MakeShapeWithLayout(F32, {}, {}), 130 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); 131 132 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 133 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 134 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 135} 136 137TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) { 138 Shape src = ShapeUtil::MakeTupleShape( 139 {MakeShapeWithSparseLayout(F32, {2, 3}, 4), 140 MakeShapeWithSparseLayout(F32, {42, 123}, 4), 141 ShapeUtil::MakeTupleShape( 142 {MakeShapeWithLayout(F32, {}, {}), 143 MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})}); 144 Shape dst = ShapeUtil::MakeTupleShape( 145 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), 146 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), 147 ShapeUtil::MakeTupleShape( 148 {MakeShapeWithLayout(F32, {}, {}), 149 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); 150 151 EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 152 EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 153 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 154} 155 156TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { 157 Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); 158 Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); 159 ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 160 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 161} 162 163TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) { 164 Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6); 165 Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); 166 ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); 167 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); 168} 169 170TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { 171 Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); 172 Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); 173 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); 174 EXPECT_FALSE(status.ok()); 175 EXPECT_THAT(status.error_message(), 176 ::testing::ContainsRegex("cannot copy layout from shape")); 177} 178 179TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) { 180 Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); 181 Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4); 182 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); 183 EXPECT_FALSE(status.ok()); 184 EXPECT_THAT(status.error_message(), 185 ::testing::ContainsRegex("cannot copy layout from shape")); 186} 187 188TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { 189 Shape src = 190 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), 191 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), 192 ShapeUtil::MakeTupleShape({MakeShapeWithLayout( 193 F32, {1, 2, 3}, {0, 2, 1})})}); 194 Shape dst = ShapeUtil::MakeTupleShape( 195 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), 196 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), 197 ShapeUtil::MakeTupleShape( 198 {MakeShapeWithLayout(F32, {}, {}), 199 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); 200 201 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); 202 EXPECT_FALSE(status.ok()); 203 EXPECT_THAT(status.error_message(), 204 ::testing::ContainsRegex("cannot copy layout from shape")); 205} 206 207TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { 208 Shape src = ShapeUtil::MakeShape(F32, {2, 3}); 209 Shape dst = ShapeUtil::MakeShape(F32, {2, 3}); 210 // Set layout to invalid value. 211 *src.mutable_layout() = LayoutUtil::MakeLayout({1, 2, 3, 4}); 212 213 auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); 214 EXPECT_FALSE(status.ok()); 215 EXPECT_THAT( 216 status.error_message(), 217 ::testing::ContainsRegex("layout minor_to_major field contains .* " 218 "elements, but shape is rank")); 219} 220 221TEST_F(LayoutUtilTest, ClearLayoutTuple) { 222 Shape shape = ShapeUtil::MakeTupleShape( 223 {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), 224 MakeShapeWithLayout(F32, {42, 123}, {1, 0}), 225 ShapeUtil::MakeTupleShape( 226 {MakeShapeWithLayout(F32, {}, {}), 227 MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); 228 EXPECT_TRUE(LayoutUtil::HasLayout(shape)); 229 EXPECT_TRUE(shape.tuple_shapes(0).has_layout()); 230 EXPECT_TRUE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); 231 232 LayoutUtil::ClearLayout(&shape); 233 234 EXPECT_FALSE(LayoutUtil::HasLayout(shape)); 235 EXPECT_FALSE(shape.tuple_shapes(0).has_layout()); 236 EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); 237} 238 239TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { 240 Shape shape = ShapeUtil::MakeTupleShape( 241 {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), 242 MakeShapeWithLayout(F32, {42, 123, 7}, {1, 2, 0}), 243 ShapeUtil::MakeTupleShape( 244 {MakeShapeWithLayout(F32, {}, {}), 245 MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 1, 2, 0})})}); 246 EXPECT_FALSE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(), 247 shape.tuple_shapes(1).layout())); 248 LayoutUtil::SetToDefaultLayout(&shape); 249 EXPECT_TRUE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(), 250 shape.tuple_shapes(1).layout())); 251 EXPECT_TRUE(LayoutUtil::Equal( 252 LayoutUtil::GetDefaultLayoutForShape(shape.tuple_shapes(0)), 253 shape.tuple_shapes(1).layout())); 254} 255 256TEST_F(LayoutUtilTest, IsPadded) { 257 Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); 258 LayoutUtil::ClearLayout(&shape_without_layout); 259 EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout)); 260 261 Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); 262 LayoutUtil::SetToDefaultLayout(&shape_with_layout); 263 EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout)); 264 265 // Add padding equal to the dimension sizes. In this case the padding is a 266 // nop. 267 Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); 268 shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2); 269 shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3); 270 shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4); 271 EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding)); 272 273 Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); 274 shape_with_padding.mutable_layout()->add_padded_dimensions(2); 275 shape_with_padding.mutable_layout()->add_padded_dimensions(14); 276 shape_with_padding.mutable_layout()->add_padded_dimensions(42); 277 EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding)); 278} 279 280TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { 281 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 282 LayoutUtil::GetDefaultLayoutForR2())); 283 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}), 284 LayoutUtil::GetDefaultLayoutForR3())); 285 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({3, 2, 1, 0}), 286 LayoutUtil::GetDefaultLayoutForR4())); 287 EXPECT_TRUE( 288 LayoutUtil::Equal(LayoutUtil::MakeLayout({4, 3, 2, 1, 0}), 289 LayoutUtil::GetDefaultLayoutForShape( 290 ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); 291} 292 293TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { 294 EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), 295 101); 296} 297 298TEST_F(LayoutUtilTest, StreamOut) { 299 std::ostringstream oss; 300 oss << LayoutUtil::MakeLayout({0, 1, 2}); 301 EXPECT_EQ(oss.str(), "{0,1,2}"); 302} 303 304} // namespace 305} // namespace xla 306