1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
17
18#include <algorithm>
19#include <unordered_set>
20
21#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
22#include "tensorflow/compiler/xla/service/hlo_computation.h"
23#include "tensorflow/compiler/xla/service/hlo_instruction.h"
24#include "tensorflow/compiler/xla/service/hlo_opcode.h"
25#include "tensorflow/compiler/xla/test_helpers.h"
26#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27#include "tensorflow/compiler/xla/types.h"
28
29namespace xla {
30namespace gpu {
31
32class HloScheduleTest : public HloTestBase {
33 protected:
34  using HloVec = std::vector<const HloInstruction*>;
35
36  // Pre-canned shapes.
37  Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
38
39  static std::unique_ptr<HloSchedule> BuildHloSchedule(
40      const HloModule& module, const StreamAssignment& streams) {
41    return HloSchedule::Build(module, streams, /*pointer_size=*/8)
42        .ConsumeValueOrDie();
43  }
44
45  HloVec RemoveHlo(const HloVec& input,
46                   const std::unordered_set<const HloInstruction*>& remove) {
47    HloVec result(input);
48    result.erase(std::remove_if(result.begin(), result.end(),
49                                [&remove](const HloInstruction* x) {
50                                  return remove.count(x) > 0;
51                                }),
52                 result.end());
53    return result;
54  }
55};
56
57// Test of a single stream, where data dependencies fully determine the
58// execution order.
59TEST_F(HloScheduleTest, SequentialMatMul) {
60  HloComputation::Builder builder("entry_computation");
61  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
62      /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
63  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
64      /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
65  HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
66      /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
67  HloInstruction* dot1 = builder.AddInstruction(
68      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
69  HloInstruction* dot2 = builder.AddInstruction(
70      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
71
72  auto module = CreateNewModule();
73  module->AddEntryComputation(builder.Build(dot2));
74
75  std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
76  EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
77            streams->StreamNumberForHlo(*dot2));
78
79  auto schedule = BuildHloSchedule(*module, *streams);
80  // Remove parameters, which are unordered.
81  EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
82            HloVec({dot1, dot2}));
83
84  // Parameters x,y,z are mutually unordered, while dot1 and dot2 are
85  // transitively ordered by operands.
86  auto order = schedule->ConsumeHloOrdering();
87  EXPECT_TRUE(order->ExecutesBefore(x, dot1));
88  EXPECT_TRUE(order->ExecutesBefore(x, dot2));
89  EXPECT_TRUE(order->ExecutesBefore(y, dot1));
90  EXPECT_TRUE(order->ExecutesBefore(y, dot2));
91  EXPECT_TRUE(order->ExecutesBefore(z, dot2));
92  EXPECT_TRUE(order->ExecutesBefore(dot1, dot2));
93
94  EXPECT_FALSE(order->ExecutesBefore(x, x));
95  EXPECT_FALSE(order->ExecutesBefore(x, y));
96  EXPECT_FALSE(order->ExecutesBefore(x, z));
97  EXPECT_FALSE(order->ExecutesBefore(y, x));
98  EXPECT_FALSE(order->ExecutesBefore(y, y));
99  EXPECT_FALSE(order->ExecutesBefore(y, z));
100  EXPECT_FALSE(order->ExecutesBefore(z, x));
101  EXPECT_FALSE(order->ExecutesBefore(z, y));
102  EXPECT_FALSE(order->ExecutesBefore(z, z));
103  EXPECT_FALSE(order->ExecutesBefore(z, dot1));
104  EXPECT_FALSE(order->ExecutesBefore(dot1, x));
105  EXPECT_FALSE(order->ExecutesBefore(dot1, y));
106  EXPECT_FALSE(order->ExecutesBefore(dot1, z));
107  EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
108  EXPECT_FALSE(order->ExecutesBefore(dot2, x));
109  EXPECT_FALSE(order->ExecutesBefore(dot2, y));
110  EXPECT_FALSE(order->ExecutesBefore(dot2, z));
111  EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
112  EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
113}
114
115// Test of a single stream, where data dependencies do not fully determine the
116// execution order, but the stream assignment does.
117TEST_F(HloScheduleTest, SequentialAdd) {
118  HloComputation::Builder builder("entry_computation");
119  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
120      /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
121  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
122      /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
123  HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
124      /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
125  HloInstruction* add1 = builder.AddInstruction(
126      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
127  HloInstruction* add2 = builder.AddInstruction(
128      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
129  HloInstruction* add3 = builder.AddInstruction(
130      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
131
132  auto module = CreateNewModule();
133  module->AddEntryComputation(builder.Build(add3));
134
135  std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
136  EXPECT_EQ(streams->StreamNumberForHlo(*add1),
137            streams->StreamNumberForHlo(*add2));
138  EXPECT_EQ(streams->StreamNumberForHlo(*add1),
139            streams->StreamNumberForHlo(*add3));
140
141  auto schedule = BuildHloSchedule(*module, *streams);
142  // Remove parameters, which are unordered.
143  EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
144            HloVec({add1, add2, add3}));
145
146  // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are
147  // transitively ordered by operands.
148  auto order = schedule->ConsumeHloOrdering();
149  EXPECT_TRUE(order->ExecutesBefore(x, add1));
150  EXPECT_TRUE(order->ExecutesBefore(x, add2));
151  EXPECT_TRUE(order->ExecutesBefore(x, add3));
152  EXPECT_TRUE(order->ExecutesBefore(y, add1));
153  EXPECT_TRUE(order->ExecutesBefore(y, add2));
154  EXPECT_TRUE(order->ExecutesBefore(y, add3));
155  EXPECT_TRUE(order->ExecutesBefore(z, add2));
156  EXPECT_TRUE(order->ExecutesBefore(z, add3));
157  EXPECT_TRUE(order->ExecutesBefore(add1, add3));
158  EXPECT_TRUE(order->ExecutesBefore(add2, add3));
159  // The HLO graph does not define an ordering for add1 and add2, but their
160  // assignment onto the same stream does define an ordering.
161  if (order->ExecutesBefore(add1, add2)) {
162    EXPECT_FALSE(order->ExecutesBefore(add2, add1));
163  } else {
164    EXPECT_TRUE(order->ExecutesBefore(add2, add1));
165    EXPECT_FALSE(order->ExecutesBefore(add1, add2));
166  }
167
168  EXPECT_FALSE(order->ExecutesBefore(x, x));
169  EXPECT_FALSE(order->ExecutesBefore(x, y));
170  EXPECT_FALSE(order->ExecutesBefore(x, z));
171  EXPECT_FALSE(order->ExecutesBefore(y, x));
172  EXPECT_FALSE(order->ExecutesBefore(y, y));
173  EXPECT_FALSE(order->ExecutesBefore(y, z));
174  EXPECT_FALSE(order->ExecutesBefore(z, x));
175  EXPECT_FALSE(order->ExecutesBefore(z, y));
176  EXPECT_FALSE(order->ExecutesBefore(z, z));
177  EXPECT_FALSE(order->ExecutesBefore(z, add1));
178  EXPECT_FALSE(order->ExecutesBefore(add1, x));
179  EXPECT_FALSE(order->ExecutesBefore(add1, y));
180  EXPECT_FALSE(order->ExecutesBefore(add1, z));
181  EXPECT_FALSE(order->ExecutesBefore(add1, add1));
182  EXPECT_FALSE(order->ExecutesBefore(add2, x));
183  EXPECT_FALSE(order->ExecutesBefore(add2, y));
184  EXPECT_FALSE(order->ExecutesBefore(add2, z));
185  EXPECT_FALSE(order->ExecutesBefore(add2, add2));
186}
187
188// Test of two streams.
189TEST_F(HloScheduleTest, ConcurrentMatMul) {
190  HloComputation::Builder builder("entry_computation");
191  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
192      /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
193  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
194      /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
195  HloInstruction* dot1 = builder.AddInstruction(
196      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
197  HloInstruction* dot2 = builder.AddInstruction(
198      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
199  HloInstruction* add = builder.AddInstruction(
200      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
201
202  auto module = CreateNewModule();
203  module->AddEntryComputation(builder.Build(add));
204
205  std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
206  EXPECT_NE(streams->StreamNumberForHlo(*dot1),
207            streams->StreamNumberForHlo(*dot2));
208
209  auto schedule = BuildHloSchedule(*module, *streams);
210  // Remove parameters, which are unordered.
211  HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y});
212  EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) ||
213              thunk_launch_order == HloVec({dot2, dot1, add}));
214
215  // Parameters x,y are mutually unordered, while dot1, dot2 and add are
216  // transitively ordered by operands.
217  auto order = schedule->ConsumeHloOrdering();
218  EXPECT_TRUE(order->ExecutesBefore(x, dot1));
219  EXPECT_TRUE(order->ExecutesBefore(x, dot2));
220  EXPECT_TRUE(order->ExecutesBefore(y, dot1));
221  EXPECT_TRUE(order->ExecutesBefore(y, dot2));
222  EXPECT_TRUE(order->ExecutesBefore(dot1, add));
223  EXPECT_TRUE(order->ExecutesBefore(dot2, add));
224
225  EXPECT_FALSE(order->ExecutesBefore(x, x));
226  EXPECT_FALSE(order->ExecutesBefore(x, y));
227  EXPECT_FALSE(order->ExecutesBefore(y, x));
228  EXPECT_FALSE(order->ExecutesBefore(y, y));
229  EXPECT_FALSE(order->ExecutesBefore(dot1, x));
230  EXPECT_FALSE(order->ExecutesBefore(dot1, y));
231  EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
232  EXPECT_FALSE(order->ExecutesBefore(dot1, dot2));
233  EXPECT_FALSE(order->ExecutesBefore(dot2, x));
234  EXPECT_FALSE(order->ExecutesBefore(dot2, y));
235  EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
236  EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
237  EXPECT_FALSE(order->ExecutesBefore(add, x));
238  EXPECT_FALSE(order->ExecutesBefore(add, y));
239  EXPECT_FALSE(order->ExecutesBefore(add, dot1));
240  EXPECT_FALSE(order->ExecutesBefore(add, dot2));
241  EXPECT_FALSE(order->ExecutesBefore(add, add));
242}
243
244// Test of multiple streams.
245TEST_F(HloScheduleTest, LatticeMatMul) {
246  //      d00      -- layer 0
247  //     /   \
248  //   d10   d11   -- layer 1
249  //  /   \ /   \
250  // d20  d21  d22 -- layer 2
251  //  \   / \   /
252  //   d30   d31   -- layer 3
253  //     \   /
254  //      d40      -- layer 4
255  HloComputation::Builder builder("entry_computation");
256  std::vector<HloInstruction*> params;
257  params.reserve(6);
258  for (int i = 0; i < 6; ++i) {
259    params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
260        i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
261  }
262  HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
263      f32_2x2_, HloOpcode::kDot, params[2], params[3]));
264  HloInstruction* d10 = builder.AddInstruction(
265      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
266  HloInstruction* d11 = builder.AddInstruction(
267      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
268  HloInstruction* d20 = builder.AddInstruction(
269      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
270  HloInstruction* d21 = builder.AddInstruction(
271      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
272  HloInstruction* d22 = builder.AddInstruction(
273      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
274  HloInstruction* d30 = builder.AddInstruction(
275      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
276  HloInstruction* d31 = builder.AddInstruction(
277      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
278  HloInstruction* d40 = builder.AddInstruction(
279      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
280
281  auto module = CreateNewModule();
282  module->AddEntryComputation(builder.Build(d40));
283
284  std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
285  // The two dots on layer 1 are concurrent.
286  EXPECT_NE(streams->StreamNumberForHlo(*d10),
287            streams->StreamNumberForHlo(*d11));
288  // The three dots on layer 2 are concurrent.
289  EXPECT_NE(streams->StreamNumberForHlo(*d20),
290            streams->StreamNumberForHlo(*d21));
291  EXPECT_NE(streams->StreamNumberForHlo(*d20),
292            streams->StreamNumberForHlo(*d22));
293  EXPECT_NE(streams->StreamNumberForHlo(*d21),
294            streams->StreamNumberForHlo(*d22));
295  // The two dots on layer 3 are concurrent.
296  EXPECT_NE(streams->StreamNumberForHlo(*d30),
297            streams->StreamNumberForHlo(*d31));
298
299  // We don't check the thunk launch order, since there are many valid total
300  // orders, and it's annoying to express.
301  auto schedule = BuildHloSchedule(*module, *streams);
302
303  auto order = schedule->ConsumeHloOrdering();
304  const HloVec all_params(
305      {params[0], params[1], params[2], params[3], params[4], params[5]});
306  const HloVec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40});
307
308  // Parameters are mutually unordered, and never execute before ops.
309  for (const HloInstruction* param : all_params) {
310    for (const HloInstruction* param2 : all_params) {
311      EXPECT_FALSE(order->ExecutesBefore(param, param2));
312    }
313    for (const HloInstruction* op : all_ops) {
314      EXPECT_FALSE(order->ExecutesBefore(op, param));
315    }
316  }
317
318  // Check ordering of params before ops.
319  for (const HloInstruction* op : all_ops) {
320    if (op == d20 || op == d30 || op == d40) {
321      EXPECT_TRUE(order->ExecutesBefore(params[0], op));
322    } else {
323      EXPECT_FALSE(order->ExecutesBefore(params[0], op));
324    }
325    if (op != d00 && op != d11 && op != d22) {
326      EXPECT_TRUE(order->ExecutesBefore(params[1], op));
327    } else {
328      EXPECT_FALSE(order->ExecutesBefore(params[1], op));
329    }
330    EXPECT_TRUE(order->ExecutesBefore(params[2], op));
331    EXPECT_TRUE(order->ExecutesBefore(params[3], op));
332    if (op != d00 && op != d10 && op != d20) {
333      EXPECT_TRUE(order->ExecutesBefore(params[4], op));
334    } else {
335      EXPECT_FALSE(order->ExecutesBefore(params[4], op));
336    }
337    if (op == d22 || op == d31 || op == d40) {
338      EXPECT_TRUE(order->ExecutesBefore(params[5], op));
339    } else {
340      EXPECT_FALSE(order->ExecutesBefore(params[5], op));
341    }
342  }
343
344  // Check ordering of ops before ops.
345  for (const HloInstruction* op : all_ops) {
346    if (op != d00) {
347      EXPECT_TRUE(order->ExecutesBefore(d00, op));
348    } else {
349      EXPECT_FALSE(order->ExecutesBefore(d00, op));
350    }
351
352    if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) {
353      EXPECT_TRUE(order->ExecutesBefore(d10, op));
354    } else {
355      EXPECT_FALSE(order->ExecutesBefore(d10, op));
356    }
357
358    if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) {
359      EXPECT_TRUE(order->ExecutesBefore(d11, op));
360    } else {
361      EXPECT_FALSE(order->ExecutesBefore(d11, op));
362    }
363
364    if (op == d30 || op == d40) {
365      EXPECT_TRUE(order->ExecutesBefore(d20, op));
366    } else {
367      EXPECT_FALSE(order->ExecutesBefore(d20, op));
368    }
369
370    if (op == d30 || op == d31 || op == d40) {
371      EXPECT_TRUE(order->ExecutesBefore(d21, op));
372    } else {
373      EXPECT_FALSE(order->ExecutesBefore(d21, op));
374    }
375
376    if (op == d31 || op == d40) {
377      EXPECT_TRUE(order->ExecutesBefore(d22, op));
378    } else {
379      EXPECT_FALSE(order->ExecutesBefore(d22, op));
380    }
381
382    if (op == d40) {
383      EXPECT_TRUE(order->ExecutesBefore(d30, op));
384      EXPECT_TRUE(order->ExecutesBefore(d31, op));
385    } else {
386      EXPECT_FALSE(order->ExecutesBefore(d30, op));
387      EXPECT_FALSE(order->ExecutesBefore(d31, op));
388    }
389
390    EXPECT_FALSE(order->ExecutesBefore(d40, op));
391  }
392}
393
394}  // namespace gpu
395}  // namespace xla
396