1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" 17 18#include "tensorflow/compiler/xla/service/hlo_computation.h" 19#include "tensorflow/compiler/xla/service/hlo_instruction.h" 20#include "tensorflow/compiler/xla/service/hlo_opcode.h" 21#include "tensorflow/compiler/xla/test_helpers.h" 22#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 23#include "tensorflow/compiler/xla/types.h" 24#include "tensorflow/core/lib/strings/stringprintf.h" 25 26namespace xla { 27namespace gpu { 28 29class StreamAssignmentTest : public HloTestBase { 30 protected: 31 // Pre-canned shapes. 32 Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); 33}; 34 35TEST_F(StreamAssignmentTest, SequentialMatMul) { 36 HloComputation::Builder builder("entry_computation"); 37 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 38 /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); 39 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 40 /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); 41 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 42 /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); 43 HloInstruction* dot1 = builder.AddInstruction( 44 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); 45 HloInstruction* dot2 = builder.AddInstruction( 46 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); 47 48 auto module = CreateNewModule(); 49 module->AddEntryComputation(builder.Build(dot2)); 50 51 std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module); 52 EXPECT_EQ(assignment->StreamNumberForHlo(*dot1), 53 assignment->StreamNumberForHlo(*dot2)); 54} 55 56TEST_F(StreamAssignmentTest, ConcurrentMatMul) { 57 HloComputation::Builder builder("entry_computation"); 58 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 59 /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); 60 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 61 /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); 62 HloInstruction* dot1 = builder.AddInstruction( 63 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); 64 HloInstruction* dot2 = builder.AddInstruction( 65 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); 66 HloInstruction* add = builder.AddInstruction( 67 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); 68 69 auto module = CreateNewModule(); 70 module->AddEntryComputation(builder.Build(add)); 71 72 std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module); 73 EXPECT_NE(assignment->StreamNumberForHlo(*dot1), 74 assignment->StreamNumberForHlo(*dot2)); 75} 76 77TEST_F(StreamAssignmentTest, LatticeMatMul) { 78 // d00 -- layer 0 79 // / \ 80 // d10 d11 -- layer 1 81 // / \ / \ 82 // d20 d21 d22 -- layer 2 83 // \ / \ / 84 // d30 d31 -- layer 3 85 // \ / 86 // d40 -- layer 4 87 HloComputation::Builder builder("entry_computation"); 88 std::vector<HloInstruction*> params; 89 params.reserve(6); 90 for (int i = 0; i < 6; ++i) { 91 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( 92 i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); 93 } 94 HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( 95 f32_2x2_, HloOpcode::kDot, params[2], params[3])); 96 HloInstruction* d10 = builder.AddInstruction( 97 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); 98 HloInstruction* d11 = builder.AddInstruction( 99 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); 100 HloInstruction* d20 = builder.AddInstruction( 101 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); 102 HloInstruction* d21 = builder.AddInstruction( 103 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); 104 HloInstruction* d22 = builder.AddInstruction( 105 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); 106 HloInstruction* d30 = builder.AddInstruction( 107 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); 108 HloInstruction* d31 = builder.AddInstruction( 109 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); 110 HloInstruction* d40 = builder.AddInstruction( 111 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); 112 113 auto module = CreateNewModule(); 114 module->AddEntryComputation(builder.Build(d40)); 115 116 std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module); 117 // The two dots on layer 1 are concurrent. 118 EXPECT_NE(assignment->StreamNumberForHlo(*d10), 119 assignment->StreamNumberForHlo(*d11)); 120 // The three dots on layer 2 are concurrent. 121 EXPECT_NE(assignment->StreamNumberForHlo(*d20), 122 assignment->StreamNumberForHlo(*d21)); 123 EXPECT_NE(assignment->StreamNumberForHlo(*d20), 124 assignment->StreamNumberForHlo(*d22)); 125 EXPECT_NE(assignment->StreamNumberForHlo(*d21), 126 assignment->StreamNumberForHlo(*d22)); 127 // The two dots on layer 3 are concurrent. 128 EXPECT_NE(assignment->StreamNumberForHlo(*d30), 129 assignment->StreamNumberForHlo(*d31)); 130} 131 132} // namespace gpu 133} // namespace xla 134