1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/tuple_simplifier.h" 17 18#include <queue> 19 20#include "tensorflow/compiler/xla/service/hlo_computation.h" 21#include "tensorflow/compiler/xla/service/hlo_instruction.h" 22#include "tensorflow/compiler/xla/service/hlo_opcode.h" 23#include "tensorflow/compiler/xla/status_macros.h" 24#include "tensorflow/compiler/xla/types.h" 25#include "tensorflow/compiler/xla/util.h" 26#include "tensorflow/core/lib/core/errors.h" 27#include "tensorflow/core/lib/core/status.h" 28#include "tensorflow/core/platform/logging.h" 29#include "tensorflow/core/platform/types.h" 30 31namespace xla { 32 33StatusOr<bool> TupleSimplifier::Run(HloModule* module) { 34 // Initially add all GTE and Tuple instructions to the worklist. 35 std::queue<HloInstruction*> worklist; 36 for (auto* computation : module->computations()) { 37 for (auto* instruction : computation->instructions()) { 38 if (instruction->opcode() == HloOpcode::kTuple || 39 instruction->opcode() == HloOpcode::kGetTupleElement) { 40 worklist.push(instruction); 41 } 42 } 43 } 44 45 bool changed = false; 46 while (!worklist.empty()) { 47 HloInstruction* instruction = worklist.front(); 48 worklist.pop(); 49 50 if (instruction->user_count() == 0 && 51 instruction != instruction->parent()->root_instruction()) { 52 // Tuple simplification works by replacing users of optimized away 53 // instructions with a simpler form. If there is no user of the 54 // instruction (including being the root), then there is nothing to do. 55 continue; 56 } 57 58 if (instruction->opcode() == HloOpcode::kTuple) { 59 // Collapse the following structure into just 'Tuple-shaped Op': 60 // 61 // Tuple-shaped Op 62 // | 63 // +-----+-----+ 64 // | | | 65 // GTE GTE GTE 66 // | | | 67 // +-----+-----+ 68 // | 69 // Tuple 70 // 71 HloInstruction* top_tuple = nullptr; 72 bool can_simplify = true; 73 for (int64 operand_number = 0; 74 operand_number < instruction->operand_count(); ++operand_number) { 75 HloInstruction* operand = instruction->mutable_operand(operand_number); 76 if (operand->opcode() != HloOpcode::kGetTupleElement || 77 operand->tuple_index() != operand_number) { 78 can_simplify = false; 79 break; 80 } 81 82 if (top_tuple == nullptr) { 83 top_tuple = operand->mutable_operand(0); 84 if (!ShapeUtil::Compatible(top_tuple->shape(), 85 instruction->shape())) { 86 can_simplify = false; 87 break; 88 } 89 } else if (top_tuple != operand->operand(0)) { 90 can_simplify = false; 91 break; 92 } 93 } 94 if (can_simplify && top_tuple != nullptr) { 95 changed = true; 96 TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(top_tuple)); 97 // No need to add anything to the worklist. 98 } 99 } else { 100 CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); 101 // If possible replace a GTE with the operation which produces the 102 // element. For example, replace uses of GTE with below with just 'Op' 103 // (assuming 'Op' is at the index of the GTE instruction): 104 // 105 // ... Op ... 106 // \ | / 107 // Tuple 108 // | 109 // GTE 110 if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { 111 changed = true; 112 HloInstruction* element_source = 113 instruction->mutable_operand(0)->mutable_operand( 114 instruction->tuple_index()); 115 TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); 116 for (HloInstruction* user : element_source->users()) { 117 if (user->opcode() == HloOpcode::kTuple || 118 user->opcode() == HloOpcode::kGetTupleElement) { 119 worklist.push(user); 120 } 121 } 122 } 123 } 124 } 125 126 return changed; 127} 128 129} // namespace xla 130