1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
17
18#include "tensorflow/compiler/xla/service/hlo_computation.h"
19#include "tensorflow/compiler/xla/service/hlo_instruction.h"
20#include "tensorflow/compiler/xla/service/hlo_opcode.h"
21#include "tensorflow/compiler/xla/test_helpers.h"
22#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23#include "tensorflow/compiler/xla/types.h"
24#include "tensorflow/core/lib/strings/stringprintf.h"
25
26namespace xla {
27namespace gpu {
28
29class StreamAssignmentTest : public HloTestBase {
30 protected:
31  // Pre-canned shapes.
32  Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
33};
34
35TEST_F(StreamAssignmentTest, SequentialMatMul) {
36  HloComputation::Builder builder("entry_computation");
37  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
38      /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
39  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
40      /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
41  HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
42      /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
43  HloInstruction* dot1 = builder.AddInstruction(
44      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
45  HloInstruction* dot2 = builder.AddInstruction(
46      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
47
48  auto module = CreateNewModule();
49  module->AddEntryComputation(builder.Build(dot2));
50
51  std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module);
52  EXPECT_EQ(assignment->StreamNumberForHlo(*dot1),
53            assignment->StreamNumberForHlo(*dot2));
54}
55
56TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
57  HloComputation::Builder builder("entry_computation");
58  HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
59      /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
60  HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
61      /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
62  HloInstruction* dot1 = builder.AddInstruction(
63      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
64  HloInstruction* dot2 = builder.AddInstruction(
65      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
66  HloInstruction* add = builder.AddInstruction(
67      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
68
69  auto module = CreateNewModule();
70  module->AddEntryComputation(builder.Build(add));
71
72  std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module);
73  EXPECT_NE(assignment->StreamNumberForHlo(*dot1),
74            assignment->StreamNumberForHlo(*dot2));
75}
76
77TEST_F(StreamAssignmentTest, LatticeMatMul) {
78  //      d00      -- layer 0
79  //     /   \
80  //   d10   d11   -- layer 1
81  //  /   \ /   \
82  // d20  d21  d22 -- layer 2
83  //  \   / \   /
84  //   d30   d31   -- layer 3
85  //     \   /
86  //      d40      -- layer 4
87  HloComputation::Builder builder("entry_computation");
88  std::vector<HloInstruction*> params;
89  params.reserve(6);
90  for (int i = 0; i < 6; ++i) {
91    params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
92        i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
93  }
94  HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
95      f32_2x2_, HloOpcode::kDot, params[2], params[3]));
96  HloInstruction* d10 = builder.AddInstruction(
97      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
98  HloInstruction* d11 = builder.AddInstruction(
99      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
100  HloInstruction* d20 = builder.AddInstruction(
101      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
102  HloInstruction* d21 = builder.AddInstruction(
103      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
104  HloInstruction* d22 = builder.AddInstruction(
105      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
106  HloInstruction* d30 = builder.AddInstruction(
107      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
108  HloInstruction* d31 = builder.AddInstruction(
109      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
110  HloInstruction* d40 = builder.AddInstruction(
111      HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
112
113  auto module = CreateNewModule();
114  module->AddEntryComputation(builder.Build(d40));
115
116  std::unique_ptr<StreamAssignment> assignment = AssignStreams(*module);
117  // The two dots on layer 1 are concurrent.
118  EXPECT_NE(assignment->StreamNumberForHlo(*d10),
119            assignment->StreamNumberForHlo(*d11));
120  // The three dots on layer 2 are concurrent.
121  EXPECT_NE(assignment->StreamNumberForHlo(*d20),
122            assignment->StreamNumberForHlo(*d21));
123  EXPECT_NE(assignment->StreamNumberForHlo(*d20),
124            assignment->StreamNumberForHlo(*d22));
125  EXPECT_NE(assignment->StreamNumberForHlo(*d21),
126            assignment->StreamNumberForHlo(*d22));
127  // The two dots on layer 3 are concurrent.
128  EXPECT_NE(assignment->StreamNumberForHlo(*d30),
129            assignment->StreamNumberForHlo(*d31));
130}
131
132}  // namespace gpu
133}  // namespace xla
134