liveness_util.cc revision 342d315566211a095a06acb1973b94937dadbc0c
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/liveness_util.h" 17 18#include <algorithm> 19#include <utility> 20#include <vector> 21 22#include "tensorflow/compiler/xla/service/hlo_instruction.h" 23#include "tensorflow/compiler/xla/service/logical_buffer.h" 24#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 25#include "tensorflow/compiler/xla/shape_util.h" 26#include "tensorflow/compiler/xla/types.h" 27#include "tensorflow/compiler/xla/util.h" 28 29namespace xla { 30 31bool DoesNotUseOperandBuffer(const HloInstruction* operand, 32 const ShapeIndex& index, 33 const HloInstruction* user, 34 const TuplePointsToAnalysis& points_to_analysis) { 35 CHECK(user->IsUserOf(operand)) 36 << "user: " << user->ToString() << " operand: " << operand->ToString(); 37 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { 38 // GetTupleElement instructions only access the top-level buffer of their 39 // operand. 40 return true; 41 } else if (user->opcode() == HloOpcode::kFusion && 42 user->fusion_kind() == HloInstruction::FusionKind::kLoop) { 43 // Find fusion parameter associated with 'operand'. 44 auto it = std::find_if( 45 user->fused_parameters().begin(), user->fused_parameters().end(), 46 [=](HloInstruction* fused_param) { 47 return user->operand(fused_param->parameter_number()) == operand; 48 }); 49 CHECK(it != user->fused_parameters().end()); 50 // Iterate through all users of all buffer aliases of the buffer in the 51 // points-to set of fusion parameter at 'index'. 52 // Return false if any uses are detected at 'index', returns true otherwise. 53 const LogicalBuffer* buffer = 54 points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); 55 for (const BufferAlias& alias : 56 points_to_analysis.GetBufferAliases(*buffer)) { 57 for (HloInstruction* alias_user : alias.instruction()->users()) { 58 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 59 alias_user, points_to_analysis)) { 60 continue; 61 } 62 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. 63 return false; 64 } 65 } 66 // Return true: found no uses of 'operand' at 'index' in 'user'. 67 return true; 68 } 69 return false; 70} 71 72namespace { 73 74// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. 75// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) 76// where 'user' is a user of an alias of 'intruction' at 'index', and 77// 'operand_index' is the operand index at which the alias appears in the 78// operand list of 'user'. 79std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( 80 HloInstruction* instruction, const ShapeIndex& index, 81 const TuplePointsToAnalysis& points_to_analysis) { 82 std::vector<std::pair<HloInstruction*, int64>> uses; 83 const std::vector<const LogicalBuffer*>& points_to = 84 points_to_analysis.GetPointsToSet(instruction).element(index); 85 for (const LogicalBuffer* buffer : points_to) { 86 for (const BufferAlias& alias : 87 points_to_analysis.GetBufferAliases(*buffer)) { 88 for (HloInstruction* alias_user : alias.instruction()->users()) { 89 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 90 alias_user, points_to_analysis)) { 91 continue; 92 } 93 for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { 94 uses.emplace_back(alias_user, op_idx); 95 } 96 } 97 } 98 } 99 return uses; 100} 101 102} // namespace 103 104// User and operand can share buffers iff both instructions emit the same shape 105// and layout, and 'user' meets one of the following qualifications: 106// *) Is element-wise. Or... 107// *) Is a loop fusion instruction where the only use of 'operand' at 'index' 108// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root 109// at operand 0. Or... 110// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. 111bool CanShareOperandBufferWithUser( 112 HloInstruction* operand, const ShapeIndex& operand_index, 113 HloInstruction* user, const ShapeIndex& user_index, 114 const TuplePointsToAnalysis& points_to_analysis) { 115 CHECK(user->IsUserOf(operand)) 116 << "user: " << user->ToString() << " operand: " << operand->ToString(); 117 Shape operand_subshape = 118 ShapeUtil::GetSubshape(operand->shape(), operand_index); 119 Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); 120 // Check that operand and user emit the same shape and layout. 121 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { 122 return false; 123 } 124 // Copy instructions are explicitly added by CopyInsertion to prevent liveness 125 // issues, so they should never re-use their operand buffer. 126 if (user->opcode() == HloOpcode::kCopy) { 127 return false; 128 } 129 // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice 130 // fused root instruction. 131 if (user->opcode() == HloOpcode::kFusion && 132 user->fusion_kind() == HloInstruction::FusionKind::kLoop && 133 user->fused_expression_root()->opcode() == 134 HloOpcode::kDynamicUpdateSlice) { 135 for (auto& fused_param : user->fused_parameters()) { 136 // Find fusion parameter associated with 'operand'. 137 if (user->operand(fused_param->parameter_number()) != operand) { 138 continue; 139 } 140 // Get all uses of 'operand' at 'index' from 'user.fused_instructions'. 141 auto fused_param_uses = GetAllUsesOfInstructionAtIndex( 142 fused_param, operand_index, points_to_analysis); 143 // Return true iff there is exactly one use of 'operand' at 'index', and 144 // this singleton use is the fused root at operand index 0. 145 if (fused_param_uses.size() == 1 && 146 fused_param_uses[0].first == user->fused_expression_root() && 147 fused_param_uses[0].second == 0) { 148 return true; 149 } 150 break; 151 } 152 return false; 153 } 154 if (user->opcode() == HloOpcode::kDynamicUpdateSlice || 155 user->opcode() == HloOpcode::kWhile) { 156 // We eliminated other users in BufferLiveness::live_range_strictly_before, 157 // so here we just need to check that the use is at operand index 0. 158 std::vector<int64> operand_indices = user->OperandIndices(operand); 159 return operand_indices.size() == 1 && operand_indices[0] == 0; 160 } 161 // Check if 'user' is element-wise. 162 return user->IsElementwise(); 163} 164 165} // namespace xla 166