1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17
18#include <queue>
19
20#include "tensorflow/compiler/xla/service/hlo_computation.h"
21#include "tensorflow/compiler/xla/service/hlo_instruction.h"
22#include "tensorflow/compiler/xla/service/hlo_opcode.h"
23#include "tensorflow/compiler/xla/status_macros.h"
24#include "tensorflow/compiler/xla/types.h"
25#include "tensorflow/compiler/xla/util.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace xla {
32
33StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
34  // Initially add all GTE and Tuple instructions to the worklist.
35  std::queue<HloInstruction*> worklist;
36  for (auto* computation : module->computations()) {
37    for (auto* instruction : computation->instructions()) {
38      if (instruction->opcode() == HloOpcode::kTuple ||
39          instruction->opcode() == HloOpcode::kGetTupleElement) {
40        worklist.push(instruction);
41      }
42    }
43  }
44
45  bool changed = false;
46  while (!worklist.empty()) {
47    HloInstruction* instruction = worklist.front();
48    worklist.pop();
49
50    if (instruction->user_count() == 0 &&
51        instruction != instruction->parent()->root_instruction()) {
52      // Tuple simplification works by replacing users of optimized away
53      // instructions with a simpler form. If there is no user of the
54      // instruction (including being the root), then there is nothing to do.
55      continue;
56    }
57
58    if (instruction->opcode() == HloOpcode::kTuple) {
59      // Collapse the following structure into just 'Tuple-shaped Op':
60      //
61      //   Tuple-shaped Op
62      //         |
63      //   +-----+-----+
64      //   |     |     |
65      //  GTE   GTE   GTE
66      //   |     |     |
67      //   +-----+-----+
68      //         |
69      //       Tuple
70      //
71      HloInstruction* top_tuple = nullptr;
72      bool can_simplify = true;
73      for (int64 operand_number = 0;
74           operand_number < instruction->operand_count(); ++operand_number) {
75        HloInstruction* operand = instruction->mutable_operand(operand_number);
76        if (operand->opcode() != HloOpcode::kGetTupleElement ||
77            operand->tuple_index() != operand_number) {
78          can_simplify = false;
79          break;
80        }
81
82        if (top_tuple == nullptr) {
83          top_tuple = operand->mutable_operand(0);
84          if (!ShapeUtil::Compatible(top_tuple->shape(),
85                                     instruction->shape())) {
86            can_simplify = false;
87            break;
88          }
89        } else if (top_tuple != operand->operand(0)) {
90          can_simplify = false;
91          break;
92        }
93      }
94      if (can_simplify && top_tuple != nullptr) {
95        changed = true;
96        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(top_tuple));
97        // No need to add anything to the worklist.
98      }
99    } else {
100      CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement);
101      // If possible replace a GTE with the operation which produces the
102      // element. For example, replace uses of GTE with below with just 'Op'
103      // (assuming 'Op' is at the index of the GTE instruction):
104      //
105      //     ...  Op ...
106      //       \  |   /
107      //        Tuple
108      //          |
109      //         GTE
110      if (instruction->operand(0)->opcode() == HloOpcode::kTuple) {
111        changed = true;
112        HloInstruction* element_source =
113            instruction->mutable_operand(0)->mutable_operand(
114                instruction->tuple_index());
115        TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
116        for (HloInstruction* user : element_source->users()) {
117          if (user->opcode() == HloOpcode::kTuple ||
118              user->opcode() == HloOpcode::kGetTupleElement) {
119            worklist.push(user);
120          }
121        }
122      }
123    }
124  }
125
126  return changed;
127}
128
129}  // namespace xla
130