1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/layout_util.h"
17
18#include <sstream>
19
20#include "tensorflow/compiler/xla/shape_util.h"
21#include "tensorflow/compiler/xla/test.h"
22#include "tensorflow/compiler/xla/test_helpers.h"
23
24namespace xla {
25namespace {
26
27class LayoutUtilTest : public ::testing::Test {
28 protected:
29  Shape MakeShapeWithLayout(PrimitiveType element_type,
30                            tensorflow::gtl::ArraySlice<int64> dimensions,
31                            tensorflow::gtl::ArraySlice<int64> minor_to_major) {
32    Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
33    *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
34    return shape;
35  }
36
37  Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
38                                  tensorflow::gtl::ArraySlice<int64> dimensions,
39                                  int64 max_sparse_elements) {
40    Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
41    *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
42    return shape;
43  }
44};
45
46TEST_F(LayoutUtilTest, TupleLayoutComparison) {
47  Shape shape =
48      ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1})});
49  Shape other_shape =
50      ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
51
52  Shape tuple0 = ShapeUtil::MakeTupleShape({});
53  Shape tuple1 = ShapeUtil::MakeTupleShape({shape});
54  Shape tuple2 = ShapeUtil::MakeTupleShape({shape, shape});
55
56  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple0));
57  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple1));
58  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple2));
59  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple0));
60  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple0));
61
62  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple1));
63  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple2));
64  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple1));
65
66  Shape other_tuple2 = ShapeUtil::MakeTupleShape({shape, other_shape});
67  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple2));
68  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, other_tuple2));
69  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(other_tuple2, tuple2));
70}
71
72TEST_F(LayoutUtilTest, CopyLayoutArray) {
73  Shape src = MakeShapeWithLayout(F32, {2, 3}, {0, 1});
74  Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
75
76  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
77  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
78  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
79
80  // Should work if destination has no layout.
81  dst.clear_layout();
82  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
83  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
84  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
85
86  // If source is cleared, then destination should be cleared.
87  src.clear_layout();
88  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
89  EXPECT_TRUE(dst.has_layout());
90  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
91  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
92  EXPECT_FALSE(dst.has_layout());
93}
94
95TEST_F(LayoutUtilTest, CopyLayoutSparse) {
96  Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2);
97  Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
98
99  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
100  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
101  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
102
103  // Should work if destination has no layout.
104  dst.clear_layout();
105  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
106  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
107  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
108
109  // If source is cleared, then destination should be cleared.
110  src.clear_layout();
111  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
112  EXPECT_TRUE(dst.has_layout());
113  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
114  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
115  EXPECT_FALSE(dst.has_layout());
116}
117
118TEST_F(LayoutUtilTest, CopyLayoutTuple) {
119  Shape src = ShapeUtil::MakeTupleShape(
120      {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
121       MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
122       ShapeUtil::MakeTupleShape(
123           {MakeShapeWithLayout(F32, {}, {}),
124            MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
125  Shape dst = ShapeUtil::MakeTupleShape(
126      {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
127       MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
128       ShapeUtil::MakeTupleShape(
129           {MakeShapeWithLayout(F32, {}, {}),
130            MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
131
132  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
133  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
134  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
135}
136
137TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) {
138  Shape src = ShapeUtil::MakeTupleShape(
139      {MakeShapeWithSparseLayout(F32, {2, 3}, 4),
140       MakeShapeWithSparseLayout(F32, {42, 123}, 4),
141       ShapeUtil::MakeTupleShape(
142           {MakeShapeWithLayout(F32, {}, {}),
143            MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})});
144  Shape dst = ShapeUtil::MakeTupleShape(
145      {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
146       MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
147       ShapeUtil::MakeTupleShape(
148           {MakeShapeWithLayout(F32, {}, {}),
149            MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
150
151  EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
152  EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
153  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
154}
155
156TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
157  Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
158  Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
159  ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
160  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
161}
162
163TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) {
164  Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6);
165  Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
166  ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
167  EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
168}
169
170TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
171  Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
172  Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
173  auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
174  EXPECT_FALSE(status.ok());
175  EXPECT_THAT(status.error_message(),
176              ::testing::ContainsRegex("cannot copy layout from shape"));
177}
178
179TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) {
180  Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
181  Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4);
182  auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
183  EXPECT_FALSE(status.ok());
184  EXPECT_THAT(status.error_message(),
185              ::testing::ContainsRegex("cannot copy layout from shape"));
186}
187
188TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
189  Shape src =
190      ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
191                                 MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
192                                 ShapeUtil::MakeTupleShape({MakeShapeWithLayout(
193                                     F32, {1, 2, 3}, {0, 2, 1})})});
194  Shape dst = ShapeUtil::MakeTupleShape(
195      {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
196       MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
197       ShapeUtil::MakeTupleShape(
198           {MakeShapeWithLayout(F32, {}, {}),
199            MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
200
201  auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
202  EXPECT_FALSE(status.ok());
203  EXPECT_THAT(status.error_message(),
204              ::testing::ContainsRegex("cannot copy layout from shape"));
205}
206
207TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
208  Shape src = ShapeUtil::MakeShape(F32, {2, 3});
209  Shape dst = ShapeUtil::MakeShape(F32, {2, 3});
210  // Set layout to invalid value.
211  *src.mutable_layout() = LayoutUtil::MakeLayout({1, 2, 3, 4});
212
213  auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
214  EXPECT_FALSE(status.ok());
215  EXPECT_THAT(
216      status.error_message(),
217      ::testing::ContainsRegex("layout minor_to_major field contains .* "
218                               "elements, but shape is rank"));
219}
220
221TEST_F(LayoutUtilTest, ClearLayoutTuple) {
222  Shape shape = ShapeUtil::MakeTupleShape(
223      {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
224       MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
225       ShapeUtil::MakeTupleShape(
226           {MakeShapeWithLayout(F32, {}, {}),
227            MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
228  EXPECT_TRUE(LayoutUtil::HasLayout(shape));
229  EXPECT_TRUE(shape.tuple_shapes(0).has_layout());
230  EXPECT_TRUE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
231
232  LayoutUtil::ClearLayout(&shape);
233
234  EXPECT_FALSE(LayoutUtil::HasLayout(shape));
235  EXPECT_FALSE(shape.tuple_shapes(0).has_layout());
236  EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
237}
238
239TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
240  Shape shape = ShapeUtil::MakeTupleShape(
241      {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
242       MakeShapeWithLayout(F32, {42, 123, 7}, {1, 2, 0}),
243       ShapeUtil::MakeTupleShape(
244           {MakeShapeWithLayout(F32, {}, {}),
245            MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 1, 2, 0})})});
246  EXPECT_FALSE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
247                                 shape.tuple_shapes(1).layout()));
248  LayoutUtil::SetToDefaultLayout(&shape);
249  EXPECT_TRUE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
250                                shape.tuple_shapes(1).layout()));
251  EXPECT_TRUE(LayoutUtil::Equal(
252      LayoutUtil::GetDefaultLayoutForShape(shape.tuple_shapes(0)),
253      shape.tuple_shapes(1).layout()));
254}
255
256TEST_F(LayoutUtilTest, IsPadded) {
257  Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
258  LayoutUtil::ClearLayout(&shape_without_layout);
259  EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout));
260
261  Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
262  LayoutUtil::SetToDefaultLayout(&shape_with_layout);
263  EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout));
264
265  // Add padding equal to the dimension sizes. In this case the padding is a
266  // nop.
267  Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
268  shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2);
269  shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3);
270  shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4);
271  EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding));
272
273  Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
274  shape_with_padding.mutable_layout()->add_padded_dimensions(2);
275  shape_with_padding.mutable_layout()->add_padded_dimensions(14);
276  shape_with_padding.mutable_layout()->add_padded_dimensions(42);
277  EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding));
278}
279
280TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
281  EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
282                                LayoutUtil::GetDefaultLayoutForR2()));
283  EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}),
284                                LayoutUtil::GetDefaultLayoutForR3()));
285  EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({3, 2, 1, 0}),
286                                LayoutUtil::GetDefaultLayoutForR4()));
287  EXPECT_TRUE(
288      LayoutUtil::Equal(LayoutUtil::MakeLayout({4, 3, 2, 1, 0}),
289                        LayoutUtil::GetDefaultLayoutForShape(
290                            ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
291}
292
293TEST_F(LayoutUtilTest, SparseLayoutMaxElements) {
294  EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
295            101);
296}
297
298TEST_F(LayoutUtilTest, StreamOut) {
299  std::ostringstream oss;
300  oss << LayoutUtil::MakeLayout({0, 1, 2});
301  EXPECT_EQ(oss.str(), "{0,1,2}");
302}
303
304}  // namespace
305}  // namespace xla
306