1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17
18#include <memory>
19#include <utility>
20
21#include "tensorflow/compiler/xla/literal_util.h"
22#include "tensorflow/compiler/xla/service/hlo_computation.h"
23#include "tensorflow/compiler/xla/service/hlo_instruction.h"
24#include "tensorflow/compiler/xla/service/hlo_matchers.h"
25#include "tensorflow/compiler/xla/service/hlo_opcode.h"
26#include "tensorflow/compiler/xla/shape_util.h"
27#include "tensorflow/compiler/xla/test.h"
28#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29#include "tensorflow/compiler/xla/types.h"
30#include "tensorflow/core/lib/core/status_test_util.h"
31
32namespace op = xla::testing::opcode_matchers;
33
34namespace xla {
35namespace {
36
37class TupleSimplifierTest : public HloTestBase {
38 protected:
39  void Run(HloModule* module, bool change_expected) {
40    TupleSimplifier simplifier;
41    auto changed_status = simplifier.Run(module);
42    TF_ASSERT_OK(changed_status.status());
43    EXPECT_EQ(change_expected, changed_status.ValueOrDie());
44  }
45
46  const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
47  const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
48      {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
49       ShapeUtil::MakeShape(F32, {})});
50};
51
52TEST_F(TupleSimplifierTest, TupleOfParameters) {
53  // A Tuple constructed of a bunch of parameters should not be changed.
54  HloComputation::Builder builder(TestName());
55  HloInstruction* param0 = builder.AddInstruction(
56      HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
57  HloInstruction* param1 = builder.AddInstruction(
58      HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
59  HloInstruction* param2 = builder.AddInstruction(
60      HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
61  builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
62  auto module = CreateNewModule();
63  module->AddEntryComputation(builder.Build());
64
65  Run(module.get(), /*change_expected=*/false);
66}
67
68TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
69  // A GTE of a tuple parameter should not be changed.
70  HloComputation::Builder builder(TestName());
71  HloInstruction* param = builder.AddInstruction(
72      HloInstruction::CreateParameter(0, tuple_shape_, "param"));
73  builder.AddInstruction(
74      HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
75  auto module = CreateNewModule();
76  module->AddEntryComputation(builder.Build());
77
78  Run(module.get(), /*change_expected=*/false);
79}
80
81TEST_F(TupleSimplifierTest, GteOfTuple) {
82  // A GTE of a Tuple should be short-circuited.
83  HloComputation::Builder builder(TestName());
84  HloInstruction* param0 = builder.AddInstruction(
85      HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
86  HloInstruction* param1 = builder.AddInstruction(
87      HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
88  HloInstruction* param2 = builder.AddInstruction(
89      HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
90  HloInstruction* tuple = builder.AddInstruction(
91      HloInstruction::CreateTuple({param0, param1, param2}));
92  HloInstruction* gte = builder.AddInstruction(
93      HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
94
95  auto module = CreateNewModule();
96  auto computation = module->AddEntryComputation(builder.Build());
97
98  EXPECT_THAT(computation->root_instruction(), gte);
99
100  Run(module.get(), /*change_expected=*/true);
101
102  EXPECT_THAT(computation->root_instruction(), param1);
103}
104
105TEST_F(TupleSimplifierTest, GteOfTupleChain) {
106  // Verify a chain of GTE/Tuple instructions is collapsed.
107  HloComputation::Builder builder(TestName());
108  HloInstruction* param = builder.AddInstruction(
109      HloInstruction::CreateParameter(0, scalar_shape_, "param"));
110
111  const int kChainLength = 10;
112  HloInstruction* element = param;
113  for (int i = 0; i < kChainLength; ++i) {
114    HloInstruction* tuple = builder.AddInstruction(
115        HloInstruction::CreateTuple({element, element, element}));
116    element = builder.AddInstruction(
117        HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
118  }
119  builder.AddInstruction(
120      HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
121
122  auto module = CreateNewModule();
123  auto computation = module->AddEntryComputation(builder.Build());
124
125  EXPECT_THAT(computation->root_instruction(),
126              op::Negate(op::GetTupleElement(op::Tuple())));
127
128  Run(module.get(), /*change_expected=*/true);
129
130  EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
131}
132
133TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
134  // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
135  // to some depth with a chain of Tuple instructions, then extracted with a
136  // chain of GTE instructions.
137  HloComputation::Builder builder(TestName());
138  HloInstruction* param = builder.AddInstruction(
139      HloInstruction::CreateParameter(0, scalar_shape_, "param"));
140
141  const int kNestingDepth = 5;
142  HloInstruction* nested_tuple = param;
143  for (int i = 0; i < kNestingDepth; ++i) {
144    nested_tuple = builder.AddInstruction(
145        HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
146  }
147
148  HloInstruction* element = nested_tuple;
149  for (int i = 0; i < kNestingDepth; ++i) {
150    element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
151        ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
152  }
153
154  auto module = CreateNewModule();
155  auto computation = module->AddEntryComputation(builder.Build());
156
157  EXPECT_THAT(computation->root_instruction(), element);
158
159  Run(module.get(), /*change_expected=*/true);
160
161  EXPECT_THAT(computation->root_instruction(), param);
162}
163
164TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
165  // Verify that a tuple constructed of GTE instructions operating on the same
166  // tuple are collapsed.
167  HloComputation::Builder builder(TestName());
168  HloInstruction* tuple_param = builder.AddInstruction(
169      HloInstruction::CreateParameter(0, tuple_shape_, "param"));
170  HloInstruction* gte0 = builder.AddInstruction(
171      HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
172  HloInstruction* gte1 = builder.AddInstruction(
173      HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
174  HloInstruction* gte2 = builder.AddInstruction(
175      HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
176  HloInstruction* tuple =
177      builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
178
179  auto module = CreateNewModule();
180  auto computation = module->AddEntryComputation(builder.Build());
181
182  EXPECT_THAT(computation->root_instruction(), tuple);
183
184  Run(module.get(), /*change_expected=*/true);
185
186  EXPECT_THAT(computation->root_instruction(), tuple_param);
187}
188
189TEST_F(TupleSimplifierTest, IncompatibleTuples) {
190  // Verify that a tuple->GTE->tuple construct is not simplified if the input
191  // and output tuple are not compatible shapes.
192  HloComputation::Builder builder(TestName());
193  HloInstruction* tuple_param = builder.AddInstruction(
194      HloInstruction::CreateParameter(0, tuple_shape_, "param"));
195  HloInstruction* gte0 = builder.AddInstruction(
196      HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
197  HloInstruction* gte1 = builder.AddInstruction(
198      HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
199  // Output tuple has only two elements. Parameter tuple has three elements so
200  // simplification is not possible.
201  HloInstruction* tuple =
202      builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
203
204  auto module = CreateNewModule();
205  auto computation = module->AddEntryComputation(builder.Build());
206
207  EXPECT_THAT(computation->root_instruction(), tuple);
208
209  Run(module.get(), /*change_expected=*/false);
210
211  EXPECT_THAT(computation->root_instruction(), tuple);
212}
213
214}  // namespace
215}  // namespace xla
216