1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/defuser.h"
17
18#include "tensorflow/compiler/xla/literal_util.h"
19#include "tensorflow/compiler/xla/service/hlo_matchers.h"
20#include "tensorflow/compiler/xla/shape_util.h"
21#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
22
23namespace op = xla::testing::opcode_matchers;
24
25namespace xla {
26namespace {
27
28class DefuserTest : public HloVerifiedTestBase {
29 protected:
30  // Returns the number of fusion instructions in the module.
31  int FusionCount() {
32    int count = 0;
33    for (HloComputation* computation : module().computations()) {
34      if (computation->IsFusionComputation()) {
35        count++;
36      }
37    }
38    return count;
39  }
40
41  Defuser defuser_;
42  const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2});
43};
44
45TEST_F(DefuserTest, NoFusionInstruction) {
46  auto builder = HloComputation::Builder(TestName());
47  auto param0 =
48      builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
49  auto param1 =
50      builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
51  builder.AddInstruction(
52      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
53
54  module().AddEntryComputation(builder.Build());
55  EXPECT_EQ(0, FusionCount());
56
57  EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie());
58}
59
60TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) {
61  auto builder = HloComputation::Builder(TestName());
62  auto param0 =
63      builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
64  auto param1 =
65      builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
66  auto add = builder.AddInstruction(
67      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
68
69  auto computation = module().AddEntryComputation(builder.Build());
70  computation->CreateFusionInstruction({add},
71                                       HloInstruction::FusionKind::kLoop);
72
73  EXPECT_THAT(computation->root_instruction(), op::Fusion());
74
75  EXPECT_EQ(1, FusionCount());
76  EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
77  EXPECT_EQ(0, FusionCount());
78
79  EXPECT_THAT(computation->root_instruction(),
80              op::Add(op::Parameter(), op::Parameter()));
81}
82
83TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) {
84  auto builder = HloComputation::Builder(TestName());
85  auto param0 =
86      builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
87  auto param1 =
88      builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
89  auto add = builder.AddInstruction(
90      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
91  builder.AddInstruction(
92      HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
93
94  auto computation = module().AddEntryComputation(builder.Build());
95  computation->CreateFusionInstruction({add},
96                                       HloInstruction::FusionKind::kLoop);
97
98  EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion()));
99
100  EXPECT_EQ(1, FusionCount());
101  EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
102  EXPECT_EQ(0, FusionCount());
103
104  EXPECT_THAT(computation->root_instruction(),
105              op::Negate(op::Add(op::Parameter(), op::Parameter())));
106}
107
108TEST_F(DefuserTest, NonTrivialFusionInstruction) {
109  auto builder = HloComputation::Builder(TestName());
110  auto param0 =
111      builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
112  auto param1 =
113      builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
114  auto param3 =
115      builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
116  auto add = builder.AddInstruction(
117      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
118  auto negate = builder.AddInstruction(
119      HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
120  auto sub = builder.AddInstruction(
121      HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
122  auto mul = builder.AddInstruction(
123      HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
124  auto div = builder.AddInstruction(
125      HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
126  auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
127      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
128  auto add2 = builder.AddInstruction(
129      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
130
131  auto computation = module().AddEntryComputation(builder.Build());
132  computation->CreateFusionInstruction(
133      {add2, constant, div, mul, sub, negate, add},
134      HloInstruction::FusionKind::kLoop);
135
136  EXPECT_THAT(computation->root_instruction(), op::Fusion());
137
138  EXPECT_EQ(1, FusionCount());
139  EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
140  EXPECT_EQ(0, FusionCount());
141
142  EXPECT_THAT(computation->root_instruction(),
143              op::Add(op::Constant(), op::Divide()));
144}
145
146TEST_F(DefuserTest, MultipleFusionInstructions) {
147  auto builder = HloComputation::Builder(TestName());
148  auto param0 =
149      builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
150  auto param1 =
151      builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
152  auto param3 =
153      builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2"));
154  auto add = builder.AddInstruction(
155      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
156  auto negate = builder.AddInstruction(
157      HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
158  auto sub = builder.AddInstruction(
159      HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate));
160  auto mul = builder.AddInstruction(
161      HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3));
162  auto div = builder.AddInstruction(
163      HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
164  auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
165      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
166  auto add2 = builder.AddInstruction(
167      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
168
169  auto computation = module().AddEntryComputation(builder.Build());
170  computation->CreateFusionInstruction({add2, constant, div, mul},
171                                       HloInstruction::FusionKind::kLoop);
172  computation->CreateFusionInstruction({sub, negate, add},
173                                       HloInstruction::FusionKind::kLoop);
174
175  EXPECT_THAT(computation->root_instruction(), op::Fusion());
176
177  EXPECT_EQ(2, FusionCount());
178  EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
179  EXPECT_EQ(0, FusionCount());
180
181  EXPECT_THAT(computation->root_instruction(),
182              op::Add(op::Constant(), op::Divide()));
183}
184
185TEST_F(DefuserTest, NestedFusionInstructions) {
186  auto builder = HloComputation::Builder(TestName());
187  auto param0 =
188      builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0"));
189  auto param1 =
190      builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1"));
191  auto add = builder.AddInstruction(
192      HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1));
193  auto negate = builder.AddInstruction(
194      HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add));
195
196  auto computation = module().AddEntryComputation(builder.Build());
197  auto outer_fusion = computation->CreateFusionInstruction(
198      {negate, add}, HloInstruction::FusionKind::kLoop);
199  HloInstruction* fused_negate = outer_fusion->fused_expression_root();
200  ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate);
201  outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
202      {fused_negate}, HloInstruction::FusionKind::kLoop);
203
204  EXPECT_THAT(computation->root_instruction(), op::Fusion());
205
206  EXPECT_EQ(2, FusionCount());
207  EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie());
208  EXPECT_EQ(0, FusionCount());
209
210  EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add()));
211}
212
213}  // namespace
214}  // namespace xla
215