1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" 17 18#include <algorithm> 19#include <unordered_set> 20 21#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" 22#include "tensorflow/compiler/xla/service/hlo_computation.h" 23#include "tensorflow/compiler/xla/service/hlo_instruction.h" 24#include "tensorflow/compiler/xla/service/hlo_opcode.h" 25#include "tensorflow/compiler/xla/test_helpers.h" 26#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 27#include "tensorflow/compiler/xla/types.h" 28 29namespace xla { 30namespace gpu { 31 32class HloScheduleTest : public HloTestBase { 33 protected: 34 using HloVec = std::vector<const HloInstruction*>; 35 36 // Pre-canned shapes. 37 Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); 38 39 static std::unique_ptr<HloSchedule> BuildHloSchedule( 40 const HloModule& module, const StreamAssignment& streams) { 41 return HloSchedule::Build(module, streams, /*pointer_size=*/8) 42 .ConsumeValueOrDie(); 43 } 44 45 HloVec RemoveHlo(const HloVec& input, 46 const std::unordered_set<const HloInstruction*>& remove) { 47 HloVec result(input); 48 result.erase(std::remove_if(result.begin(), result.end(), 49 [&remove](const HloInstruction* x) { 50 return remove.count(x) > 0; 51 }), 52 result.end()); 53 return result; 54 } 55}; 56 57// Test of a single stream, where data dependencies fully determine the 58// execution order. 59TEST_F(HloScheduleTest, SequentialMatMul) { 60 HloComputation::Builder builder("entry_computation"); 61 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 62 /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); 63 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 64 /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); 65 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 66 /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); 67 HloInstruction* dot1 = builder.AddInstruction( 68 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); 69 HloInstruction* dot2 = builder.AddInstruction( 70 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); 71 72 auto module = CreateNewModule(); 73 module->AddEntryComputation(builder.Build(dot2)); 74 75 std::unique_ptr<StreamAssignment> streams = AssignStreams(*module); 76 EXPECT_EQ(streams->StreamNumberForHlo(*dot1), 77 streams->StreamNumberForHlo(*dot2)); 78 79 auto schedule = BuildHloSchedule(*module, *streams); 80 // Remove parameters, which are unordered. 81 EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), 82 HloVec({dot1, dot2})); 83 84 // Parameters x,y,z are mutually unordered, while dot1 and dot2 are 85 // transitively ordered by operands. 86 auto order = schedule->ConsumeHloOrdering(); 87 EXPECT_TRUE(order->ExecutesBefore(x, dot1)); 88 EXPECT_TRUE(order->ExecutesBefore(x, dot2)); 89 EXPECT_TRUE(order->ExecutesBefore(y, dot1)); 90 EXPECT_TRUE(order->ExecutesBefore(y, dot2)); 91 EXPECT_TRUE(order->ExecutesBefore(z, dot2)); 92 EXPECT_TRUE(order->ExecutesBefore(dot1, dot2)); 93 94 EXPECT_FALSE(order->ExecutesBefore(x, x)); 95 EXPECT_FALSE(order->ExecutesBefore(x, y)); 96 EXPECT_FALSE(order->ExecutesBefore(x, z)); 97 EXPECT_FALSE(order->ExecutesBefore(y, x)); 98 EXPECT_FALSE(order->ExecutesBefore(y, y)); 99 EXPECT_FALSE(order->ExecutesBefore(y, z)); 100 EXPECT_FALSE(order->ExecutesBefore(z, x)); 101 EXPECT_FALSE(order->ExecutesBefore(z, y)); 102 EXPECT_FALSE(order->ExecutesBefore(z, z)); 103 EXPECT_FALSE(order->ExecutesBefore(z, dot1)); 104 EXPECT_FALSE(order->ExecutesBefore(dot1, x)); 105 EXPECT_FALSE(order->ExecutesBefore(dot1, y)); 106 EXPECT_FALSE(order->ExecutesBefore(dot1, z)); 107 EXPECT_FALSE(order->ExecutesBefore(dot1, dot1)); 108 EXPECT_FALSE(order->ExecutesBefore(dot2, x)); 109 EXPECT_FALSE(order->ExecutesBefore(dot2, y)); 110 EXPECT_FALSE(order->ExecutesBefore(dot2, z)); 111 EXPECT_FALSE(order->ExecutesBefore(dot2, dot1)); 112 EXPECT_FALSE(order->ExecutesBefore(dot2, dot2)); 113} 114 115// Test of a single stream, where data dependencies do not fully determine the 116// execution order, but the stream assignment does. 117TEST_F(HloScheduleTest, SequentialAdd) { 118 HloComputation::Builder builder("entry_computation"); 119 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 120 /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); 121 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 122 /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); 123 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 124 /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); 125 HloInstruction* add1 = builder.AddInstruction( 126 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y)); 127 HloInstruction* add2 = builder.AddInstruction( 128 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z)); 129 HloInstruction* add3 = builder.AddInstruction( 130 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); 131 132 auto module = CreateNewModule(); 133 module->AddEntryComputation(builder.Build(add3)); 134 135 std::unique_ptr<StreamAssignment> streams = AssignStreams(*module); 136 EXPECT_EQ(streams->StreamNumberForHlo(*add1), 137 streams->StreamNumberForHlo(*add2)); 138 EXPECT_EQ(streams->StreamNumberForHlo(*add1), 139 streams->StreamNumberForHlo(*add3)); 140 141 auto schedule = BuildHloSchedule(*module, *streams); 142 // Remove parameters, which are unordered. 143 EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), 144 HloVec({add1, add2, add3})); 145 146 // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are 147 // transitively ordered by operands. 148 auto order = schedule->ConsumeHloOrdering(); 149 EXPECT_TRUE(order->ExecutesBefore(x, add1)); 150 EXPECT_TRUE(order->ExecutesBefore(x, add2)); 151 EXPECT_TRUE(order->ExecutesBefore(x, add3)); 152 EXPECT_TRUE(order->ExecutesBefore(y, add1)); 153 EXPECT_TRUE(order->ExecutesBefore(y, add2)); 154 EXPECT_TRUE(order->ExecutesBefore(y, add3)); 155 EXPECT_TRUE(order->ExecutesBefore(z, add2)); 156 EXPECT_TRUE(order->ExecutesBefore(z, add3)); 157 EXPECT_TRUE(order->ExecutesBefore(add1, add3)); 158 EXPECT_TRUE(order->ExecutesBefore(add2, add3)); 159 // The HLO graph does not define an ordering for add1 and add2, but their 160 // assignment onto the same stream does define an ordering. 161 if (order->ExecutesBefore(add1, add2)) { 162 EXPECT_FALSE(order->ExecutesBefore(add2, add1)); 163 } else { 164 EXPECT_TRUE(order->ExecutesBefore(add2, add1)); 165 EXPECT_FALSE(order->ExecutesBefore(add1, add2)); 166 } 167 168 EXPECT_FALSE(order->ExecutesBefore(x, x)); 169 EXPECT_FALSE(order->ExecutesBefore(x, y)); 170 EXPECT_FALSE(order->ExecutesBefore(x, z)); 171 EXPECT_FALSE(order->ExecutesBefore(y, x)); 172 EXPECT_FALSE(order->ExecutesBefore(y, y)); 173 EXPECT_FALSE(order->ExecutesBefore(y, z)); 174 EXPECT_FALSE(order->ExecutesBefore(z, x)); 175 EXPECT_FALSE(order->ExecutesBefore(z, y)); 176 EXPECT_FALSE(order->ExecutesBefore(z, z)); 177 EXPECT_FALSE(order->ExecutesBefore(z, add1)); 178 EXPECT_FALSE(order->ExecutesBefore(add1, x)); 179 EXPECT_FALSE(order->ExecutesBefore(add1, y)); 180 EXPECT_FALSE(order->ExecutesBefore(add1, z)); 181 EXPECT_FALSE(order->ExecutesBefore(add1, add1)); 182 EXPECT_FALSE(order->ExecutesBefore(add2, x)); 183 EXPECT_FALSE(order->ExecutesBefore(add2, y)); 184 EXPECT_FALSE(order->ExecutesBefore(add2, z)); 185 EXPECT_FALSE(order->ExecutesBefore(add2, add2)); 186} 187 188// Test of two streams. 189TEST_F(HloScheduleTest, ConcurrentMatMul) { 190 HloComputation::Builder builder("entry_computation"); 191 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 192 /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); 193 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 194 /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); 195 HloInstruction* dot1 = builder.AddInstruction( 196 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); 197 HloInstruction* dot2 = builder.AddInstruction( 198 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); 199 HloInstruction* add = builder.AddInstruction( 200 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); 201 202 auto module = CreateNewModule(); 203 module->AddEntryComputation(builder.Build(add)); 204 205 std::unique_ptr<StreamAssignment> streams = AssignStreams(*module); 206 EXPECT_NE(streams->StreamNumberForHlo(*dot1), 207 streams->StreamNumberForHlo(*dot2)); 208 209 auto schedule = BuildHloSchedule(*module, *streams); 210 // Remove parameters, which are unordered. 211 HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y}); 212 EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) || 213 thunk_launch_order == HloVec({dot2, dot1, add})); 214 215 // Parameters x,y are mutually unordered, while dot1, dot2 and add are 216 // transitively ordered by operands. 217 auto order = schedule->ConsumeHloOrdering(); 218 EXPECT_TRUE(order->ExecutesBefore(x, dot1)); 219 EXPECT_TRUE(order->ExecutesBefore(x, dot2)); 220 EXPECT_TRUE(order->ExecutesBefore(y, dot1)); 221 EXPECT_TRUE(order->ExecutesBefore(y, dot2)); 222 EXPECT_TRUE(order->ExecutesBefore(dot1, add)); 223 EXPECT_TRUE(order->ExecutesBefore(dot2, add)); 224 225 EXPECT_FALSE(order->ExecutesBefore(x, x)); 226 EXPECT_FALSE(order->ExecutesBefore(x, y)); 227 EXPECT_FALSE(order->ExecutesBefore(y, x)); 228 EXPECT_FALSE(order->ExecutesBefore(y, y)); 229 EXPECT_FALSE(order->ExecutesBefore(dot1, x)); 230 EXPECT_FALSE(order->ExecutesBefore(dot1, y)); 231 EXPECT_FALSE(order->ExecutesBefore(dot1, dot1)); 232 EXPECT_FALSE(order->ExecutesBefore(dot1, dot2)); 233 EXPECT_FALSE(order->ExecutesBefore(dot2, x)); 234 EXPECT_FALSE(order->ExecutesBefore(dot2, y)); 235 EXPECT_FALSE(order->ExecutesBefore(dot2, dot1)); 236 EXPECT_FALSE(order->ExecutesBefore(dot2, dot2)); 237 EXPECT_FALSE(order->ExecutesBefore(add, x)); 238 EXPECT_FALSE(order->ExecutesBefore(add, y)); 239 EXPECT_FALSE(order->ExecutesBefore(add, dot1)); 240 EXPECT_FALSE(order->ExecutesBefore(add, dot2)); 241 EXPECT_FALSE(order->ExecutesBefore(add, add)); 242} 243 244// Test of multiple streams. 245TEST_F(HloScheduleTest, LatticeMatMul) { 246 // d00 -- layer 0 247 // / \ 248 // d10 d11 -- layer 1 249 // / \ / \ 250 // d20 d21 d22 -- layer 2 251 // \ / \ / 252 // d30 d31 -- layer 3 253 // \ / 254 // d40 -- layer 4 255 HloComputation::Builder builder("entry_computation"); 256 std::vector<HloInstruction*> params; 257 params.reserve(6); 258 for (int i = 0; i < 6; ++i) { 259 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( 260 i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); 261 } 262 HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( 263 f32_2x2_, HloOpcode::kDot, params[2], params[3])); 264 HloInstruction* d10 = builder.AddInstruction( 265 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); 266 HloInstruction* d11 = builder.AddInstruction( 267 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); 268 HloInstruction* d20 = builder.AddInstruction( 269 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); 270 HloInstruction* d21 = builder.AddInstruction( 271 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); 272 HloInstruction* d22 = builder.AddInstruction( 273 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); 274 HloInstruction* d30 = builder.AddInstruction( 275 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); 276 HloInstruction* d31 = builder.AddInstruction( 277 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); 278 HloInstruction* d40 = builder.AddInstruction( 279 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); 280 281 auto module = CreateNewModule(); 282 module->AddEntryComputation(builder.Build(d40)); 283 284 std::unique_ptr<StreamAssignment> streams = AssignStreams(*module); 285 // The two dots on layer 1 are concurrent. 286 EXPECT_NE(streams->StreamNumberForHlo(*d10), 287 streams->StreamNumberForHlo(*d11)); 288 // The three dots on layer 2 are concurrent. 289 EXPECT_NE(streams->StreamNumberForHlo(*d20), 290 streams->StreamNumberForHlo(*d21)); 291 EXPECT_NE(streams->StreamNumberForHlo(*d20), 292 streams->StreamNumberForHlo(*d22)); 293 EXPECT_NE(streams->StreamNumberForHlo(*d21), 294 streams->StreamNumberForHlo(*d22)); 295 // The two dots on layer 3 are concurrent. 296 EXPECT_NE(streams->StreamNumberForHlo(*d30), 297 streams->StreamNumberForHlo(*d31)); 298 299 // We don't check the thunk launch order, since there are many valid total 300 // orders, and it's annoying to express. 301 auto schedule = BuildHloSchedule(*module, *streams); 302 303 auto order = schedule->ConsumeHloOrdering(); 304 const HloVec all_params( 305 {params[0], params[1], params[2], params[3], params[4], params[5]}); 306 const HloVec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40}); 307 308 // Parameters are mutually unordered, and never execute before ops. 309 for (const HloInstruction* param : all_params) { 310 for (const HloInstruction* param2 : all_params) { 311 EXPECT_FALSE(order->ExecutesBefore(param, param2)); 312 } 313 for (const HloInstruction* op : all_ops) { 314 EXPECT_FALSE(order->ExecutesBefore(op, param)); 315 } 316 } 317 318 // Check ordering of params before ops. 319 for (const HloInstruction* op : all_ops) { 320 if (op == d20 || op == d30 || op == d40) { 321 EXPECT_TRUE(order->ExecutesBefore(params[0], op)); 322 } else { 323 EXPECT_FALSE(order->ExecutesBefore(params[0], op)); 324 } 325 if (op != d00 && op != d11 && op != d22) { 326 EXPECT_TRUE(order->ExecutesBefore(params[1], op)); 327 } else { 328 EXPECT_FALSE(order->ExecutesBefore(params[1], op)); 329 } 330 EXPECT_TRUE(order->ExecutesBefore(params[2], op)); 331 EXPECT_TRUE(order->ExecutesBefore(params[3], op)); 332 if (op != d00 && op != d10 && op != d20) { 333 EXPECT_TRUE(order->ExecutesBefore(params[4], op)); 334 } else { 335 EXPECT_FALSE(order->ExecutesBefore(params[4], op)); 336 } 337 if (op == d22 || op == d31 || op == d40) { 338 EXPECT_TRUE(order->ExecutesBefore(params[5], op)); 339 } else { 340 EXPECT_FALSE(order->ExecutesBefore(params[5], op)); 341 } 342 } 343 344 // Check ordering of ops before ops. 345 for (const HloInstruction* op : all_ops) { 346 if (op != d00) { 347 EXPECT_TRUE(order->ExecutesBefore(d00, op)); 348 } else { 349 EXPECT_FALSE(order->ExecutesBefore(d00, op)); 350 } 351 352 if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) { 353 EXPECT_TRUE(order->ExecutesBefore(d10, op)); 354 } else { 355 EXPECT_FALSE(order->ExecutesBefore(d10, op)); 356 } 357 358 if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) { 359 EXPECT_TRUE(order->ExecutesBefore(d11, op)); 360 } else { 361 EXPECT_FALSE(order->ExecutesBefore(d11, op)); 362 } 363 364 if (op == d30 || op == d40) { 365 EXPECT_TRUE(order->ExecutesBefore(d20, op)); 366 } else { 367 EXPECT_FALSE(order->ExecutesBefore(d20, op)); 368 } 369 370 if (op == d30 || op == d31 || op == d40) { 371 EXPECT_TRUE(order->ExecutesBefore(d21, op)); 372 } else { 373 EXPECT_FALSE(order->ExecutesBefore(d21, op)); 374 } 375 376 if (op == d31 || op == d40) { 377 EXPECT_TRUE(order->ExecutesBefore(d22, op)); 378 } else { 379 EXPECT_FALSE(order->ExecutesBefore(d22, op)); 380 } 381 382 if (op == d40) { 383 EXPECT_TRUE(order->ExecutesBefore(d30, op)); 384 EXPECT_TRUE(order->ExecutesBefore(d31, op)); 385 } else { 386 EXPECT_FALSE(order->ExecutesBefore(d30, op)); 387 EXPECT_FALSE(order->ExecutesBefore(d31, op)); 388 } 389 390 EXPECT_FALSE(order->ExecutesBefore(d40, op)); 391 } 392} 393 394} // namespace gpu 395} // namespace xla 396