1e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 3e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFloweryou may not use this file except in compliance with the License. 5e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerYou may obtain a copy of the License at 6e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 7e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 9e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerSee the License for the specific language governing permissions and 13e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerlimitations under the License. 14e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower==============================================================================*/ 15e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 16e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/liveness_util.h" 17e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 18e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include <algorithm> 19e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include <utility> 20e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include <vector> 21e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 22e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h" 23e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/logical_buffer.h" 24e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 25e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/shape_util.h" 26e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/types.h" 27e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h" 28e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 29e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowernamespace xla { 30e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 31342d315566211a095a06acb1973b94937dadbc0cMark Heffernanbool DoesNotUseOperandBuffer(const HloInstruction* operand, 32342d315566211a095a06acb1973b94937dadbc0cMark Heffernan const ShapeIndex& index, 33342d315566211a095a06acb1973b94937dadbc0cMark Heffernan const HloInstruction* user, 34e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower const TuplePointsToAnalysis& points_to_analysis) { 35e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower CHECK(user->IsUserOf(operand)) 36e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower << "user: " << user->ToString() << " operand: " << operand->ToString(); 37e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { 38e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // GetTupleElement instructions only access the top-level buffer of their 39e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // operand. 40e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return true; 41e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } else if (user->opcode() == HloOpcode::kFusion && 42e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower user->fusion_kind() == HloInstruction::FusionKind::kLoop) { 43e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Find fusion parameter associated with 'operand'. 44e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower auto it = std::find_if( 45e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower user->fused_parameters().begin(), user->fused_parameters().end(), 46e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower [=](HloInstruction* fused_param) { 47e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return user->operand(fused_param->parameter_number()) == operand; 48e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower }); 49e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower CHECK(it != user->fused_parameters().end()); 50e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Iterate through all users of all buffer aliases of the buffer in the 51e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // points-to set of fusion parameter at 'index'. 52e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Return false if any uses are detected at 'index', returns true otherwise. 53e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower const LogicalBuffer* buffer = 54e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); 55e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower for (const BufferAlias& alias : 56e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower points_to_analysis.GetBufferAliases(*buffer)) { 57e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower for (HloInstruction* alias_user : alias.instruction()->users()) { 58e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 59e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower alias_user, points_to_analysis)) { 60e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower continue; 61e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 62e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. 63e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return false; 64e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 65e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 66e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Return true: found no uses of 'operand' at 'index' in 'user'. 67e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return true; 68e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 69e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return false; 70e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower} 71e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 7223da21150d988f7cf5780488f24adbb116675586Mark Heffernanbool DoesNotUseOperandBuffer(const HloInstruction* operand, 7323da21150d988f7cf5780488f24adbb116675586Mark Heffernan const ShapeIndex& index, 7423da21150d988f7cf5780488f24adbb116675586Mark Heffernan const HloInstruction* user, 7523da21150d988f7cf5780488f24adbb116675586Mark Heffernan const HloDataflowAnalysis& dataflow) { 7623da21150d988f7cf5780488f24adbb116675586Mark Heffernan CHECK(user->IsUserOf(operand)) 7723da21150d988f7cf5780488f24adbb116675586Mark Heffernan << "user: " << user->ToString() << " operand: " << operand->ToString(); 7823da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (user->opcode() == HloOpcode::kFusion && 7923da21150d988f7cf5780488f24adbb116675586Mark Heffernan user->fusion_kind() == HloInstruction::FusionKind::kLoop) { 8023da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Find fusion parameter associated with 'operand'. 8123da21150d988f7cf5780488f24adbb116675586Mark Heffernan HloInstruction* fusion_param = 8223da21150d988f7cf5780488f24adbb116675586Mark Heffernan user->fused_parameter(user->operand_index(operand)); 8323da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Iterate through all users of all uses of the fusion parameter value. 8423da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Return false if any uses are detected, returns true otherwise. 8523da21150d988f7cf5780488f24adbb116675586Mark Heffernan const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); 8623da21150d988f7cf5780488f24adbb116675586Mark Heffernan return value.uses().empty(); 8723da21150d988f7cf5780488f24adbb116675586Mark Heffernan } else { 8823da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Return false if no value at 'operand' and 'index' is used at 'user'. 8923da21150d988f7cf5780488f24adbb116675586Mark Heffernan for (const HloValue* value : 9023da21150d988f7cf5780488f24adbb116675586Mark Heffernan dataflow.GetValueSet(operand, index).values()) { 9123da21150d988f7cf5780488f24adbb116675586Mark Heffernan for (const HloUse& use : value->uses()) { 9223da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (use.instruction == user) { 9323da21150d988f7cf5780488f24adbb116675586Mark Heffernan return false; 9423da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 9523da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 9623da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 9723da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 9823da21150d988f7cf5780488f24adbb116675586Mark Heffernan 9923da21150d988f7cf5780488f24adbb116675586Mark Heffernan return true; 10023da21150d988f7cf5780488f24adbb116675586Mark Heffernan} 10123da21150d988f7cf5780488f24adbb116675586Mark Heffernan 102e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowernamespace { 103e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 104e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. 105e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) 106724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower// where 'user' is a user of an alias of 'instruction' at 'index', and 107e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// 'operand_index' is the operand index at which the alias appears in the 108e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// operand list of 'user'. 109e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerstd::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( 110e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower HloInstruction* instruction, const ShapeIndex& index, 111e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower const TuplePointsToAnalysis& points_to_analysis) { 112e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower std::vector<std::pair<HloInstruction*, int64>> uses; 1135ead76420dee762a5f710fda6893075f1292d5d3A. Unique TensorFlower const PointsToSet::BufferList& points_to = 114e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower points_to_analysis.GetPointsToSet(instruction).element(index); 115e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower for (const LogicalBuffer* buffer : points_to) { 116e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower for (const BufferAlias& alias : 117e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower points_to_analysis.GetBufferAliases(*buffer)) { 118e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower for (HloInstruction* alias_user : alias.instruction()->users()) { 119e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), 120e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower alias_user, points_to_analysis)) { 121e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower continue; 122e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 123e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { 124e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower uses.emplace_back(alias_user, op_idx); 125e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 126e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 127e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 128e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 129e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return uses; 130e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower} 131e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 13209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// Returns true if there is exactly one use of 'operand' at 'operand_index' 13309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// in 'fusion.fused_instructions', where the singleton use is the fused 13409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// root at operand index 'use_operand_index'. Returns false otherwise. 13509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// 13609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// REQUIRES: 'fusion' opcode is a kFusion instruction. 13709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlowerbool HasUniqueFusedUseOfOperandAt( 13809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower HloInstruction* operand, const ShapeIndex& operand_index, 13909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower HloInstruction* fusion, const int64 use_operand_index, 14009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower const TuplePointsToAnalysis& points_to_analysis) { 14109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); 14209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Check that 'operand' is unique in the operand list of 'fusion'. 14309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower if (fusion->OperandIndices(operand).size() > 1) { 14409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return false; 14509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower } 14609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Find fusion parameter associated with 'operand'. 14709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower const auto& fused_params = fusion->fused_parameters(); 14809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower auto fused_param_it = std::find_if( 14909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower fused_params.begin(), fused_params.end(), 15009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower [&](HloInstruction* fused_param) { 15109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return fusion->operand(fused_param->parameter_number()) == operand; 15209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower }); 15309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower if (fused_param_it == fused_params.end()) { 15409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return false; 15509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower } 15609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower auto* fused_param = *fused_param_it; 15709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. 15809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower auto fused_param_uses = GetAllUsesOfInstructionAtIndex( 15909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower fused_param, operand_index, points_to_analysis); 16009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Return true iff there is exactly one use of 'operand' at 'index', and 16109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // this singleton use is the fused root (at index in 'use_operand_indices'). 16209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return fused_param_uses.size() == 1 && 16309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower fused_param_uses[0].first == fusion->fused_expression_root() && 16409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower fused_param_uses[0].second == use_operand_index; 16509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower} 16609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower 167e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower} // namespace 168e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 169e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// User and operand can share buffers iff both instructions emit the same shape 170971b11dcca8942c76b601a1b418ccd98a9b25f4aA. Unique TensorFlower// and layout, and 'user' meets one of the following qualifications: 1715b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// 1725b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (1) Is element-wise. Or... 1735b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' 1745b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root 1755b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// at operand 0. Or... 1765b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion 1775b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// instruction where the only use of 'operand' at 'index' in the set 1785b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... 1795b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 1805b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// 0. 1815b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// 1825b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (2) and (3) can only be determined if points-to analysis is available. 183e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerbool CanShareOperandBufferWithUser( 184e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower HloInstruction* operand, const ShapeIndex& operand_index, 185e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower HloInstruction* user, const ShapeIndex& user_index, 18623da21150d988f7cf5780488f24adbb116675586Mark Heffernan const TuplePointsToAnalysis& points_to_analysis) { 187e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower CHECK(user->IsUserOf(operand)) 188e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower << "user: " << user->ToString() << " operand: " << operand->ToString(); 1897e3d54903037ad44f17c362a510c07cb5190c778Jeffrey A. Dean const Shape& operand_subshape = 190e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower ShapeUtil::GetSubshape(operand->shape(), operand_index); 1917e3d54903037ad44f17c362a510c07cb5190c778Jeffrey A. Dean const Shape& user_subshape = 1927e3d54903037ad44f17c362a510c07cb5190c778Jeffrey A. Dean ShapeUtil::GetSubshape(user->shape(), user_index); 193e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Check that operand and user emit the same shape and layout. 194e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { 195e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return false; 196e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 19723da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (user->opcode() == HloOpcode::kFusion) { 19809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && 19909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower user->fused_expression_root()->opcode() == 20009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower HloOpcode::kDynamicUpdateSlice) { 20109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Loop fusion with kDynamicUpdateSlice fused root. 20209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // 20309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Returns true iff there is exactly one use of 'operand' at shape index 20409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // 'operand_index', and this singleton use is the fused root at operand 20509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // index 0. 20609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, 20723da21150d988f7cf5780488f24adbb116675586Mark Heffernan points_to_analysis); 20809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && 20909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower user->fused_expression_root()->opcode() == HloOpcode::kAdd) { 21009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Output fusion with kAdd fused root. 21109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower 21209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Check if one operand of kAdd fused root is either kDot, or nested 21309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // kFusion of kind kTransposeDot. 21409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower auto* add = user->fused_expression_root(); 21509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower auto add_operand_it = 21609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower std::find_if(add->operands().begin(), add->operands().end(), 21709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower [&](HloInstruction* operand) { 21810d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer return operand->opcode() == HloOpcode::kConvolution || 21910d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer operand->opcode() == HloOpcode::kDot || 22009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower (operand->opcode() == HloOpcode::kFusion && 22109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower operand->fusion_kind() == 22209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower HloInstruction::FusionKind::kTransposeDot); 22309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower }); 22409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower if (add_operand_it == add->operands().end()) { 22509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return false; 226e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 22709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower auto* matched_add_operand = *add_operand_it; 22809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Calculate operand index of 'add' operand which was not matched above. 22909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower const int64 other_add_operand_index = 23009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower matched_add_operand == add->operand(0) ? 1 : 0; 23109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // Returns true iff there is exactly one use of 'operand' at shape index 23209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // 'operand_index', and this singleton use is the fused root (at operand 23309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower // index 'other_add_operand_index'). 23409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 23509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower other_add_operand_index, 23623da21150d988f7cf5780488f24adbb116675586Mark Heffernan points_to_analysis); 23723da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 23823da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 23923da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (user->opcode() == HloOpcode::kDynamicUpdateSlice || 24023da21150d988f7cf5780488f24adbb116675586Mark Heffernan user->opcode() == HloOpcode::kWhile) { 24123da21150d988f7cf5780488f24adbb116675586Mark Heffernan // We eliminated other users in BufferLiveness::live_range_strictly_before, 24223da21150d988f7cf5780488f24adbb116675586Mark Heffernan // so here we just need to check that the use is at operand index 0. 24323da21150d988f7cf5780488f24adbb116675586Mark Heffernan std::vector<int64> operand_indices = user->OperandIndices(operand); 24423da21150d988f7cf5780488f24adbb116675586Mark Heffernan return operand_indices.size() == 1 && operand_indices[0] == 0; 24523da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 246724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower if (user->opcode() == HloOpcode::kCall) { 247724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // TODO(b/62548313): Remove when buffer assignment is module scoped and 248724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // does not assign buffers to calls. 249724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // Find called computation parameter associated with 'operand'. 250724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower const std::vector<int64> operand_indices = user->OperandIndices(operand); 251724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower if (operand_indices.size() > 1) { 252724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower return false; 253724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower } 254724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower CHECK_EQ(1, operand_indices.size()); 255724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); 256724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // Get all uses of 'operand' at 'index' in called computation. 257724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, 258724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower points_to_analysis); 259724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower 260724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // Return true iff: 261724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) There exists exactly one use of 'operand' in called computation. 262724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) The unique use is by the root instruction of called computation. 263724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // (Note: we check the root of the called computation, because the 264724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // root result buffer is required to alias with the Call result buffer). 265724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) The root instruction of the called computation is element-wise on 266724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // 'operand'. 267724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower auto* callee_root = user->to_apply()->root_instruction(); 268724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower return param_uses.size() == 1 && param_uses[0].first == callee_root && 269724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower callee_root->IsElementwiseOnOperand(param_uses[0].second); 270724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower } 27123da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Check if 'user' is element-wise. 27223da21150d988f7cf5780488f24adbb116675586Mark Heffernan return user->IsElementwise(); 27323da21150d988f7cf5780488f24adbb116675586Mark Heffernan} 27423da21150d988f7cf5780488f24adbb116675586Mark Heffernan 27523da21150d988f7cf5780488f24adbb116675586Mark Heffernanbool CanShareOperandBufferWithUser(HloInstruction* operand, 27623da21150d988f7cf5780488f24adbb116675586Mark Heffernan const ShapeIndex& operand_index, 27723da21150d988f7cf5780488f24adbb116675586Mark Heffernan HloInstruction* user, 27823da21150d988f7cf5780488f24adbb116675586Mark Heffernan const ShapeIndex& user_index, 27923da21150d988f7cf5780488f24adbb116675586Mark Heffernan const HloDataflowAnalysis& dataflow) { 28023da21150d988f7cf5780488f24adbb116675586Mark Heffernan CHECK(user->IsUserOf(operand)) 28123da21150d988f7cf5780488f24adbb116675586Mark Heffernan << "user: " << user->ToString() << " operand: " << operand->ToString(); 28223da21150d988f7cf5780488f24adbb116675586Mark Heffernan const Shape& operand_subshape = 28323da21150d988f7cf5780488f24adbb116675586Mark Heffernan ShapeUtil::GetSubshape(operand->shape(), operand_index); 28423da21150d988f7cf5780488f24adbb116675586Mark Heffernan const Shape& user_subshape = 28523da21150d988f7cf5780488f24adbb116675586Mark Heffernan ShapeUtil::GetSubshape(user->shape(), user_index); 28623da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Check that operand and user emit the same shape and layout. 28723da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { 28823da21150d988f7cf5780488f24adbb116675586Mark Heffernan return false; 28923da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 29023da21150d988f7cf5780488f24adbb116675586Mark Heffernan 29123da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (user->opcode() == HloOpcode::kFusion) { 29223da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Get the parameter associated with 'operand'; 29323da21150d988f7cf5780488f24adbb116675586Mark Heffernan HloInstruction* fusion_param = 29423da21150d988f7cf5780488f24adbb116675586Mark Heffernan user->fused_parameter(user->operand_index(operand)); 29523da21150d988f7cf5780488f24adbb116675586Mark Heffernan 29623da21150d988f7cf5780488f24adbb116675586Mark Heffernan const HloValue& value = 29723da21150d988f7cf5780488f24adbb116675586Mark Heffernan dataflow.GetValueDefinedAt(fusion_param, operand_index); 29823da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (value.uses().size() != 1) { 29923da21150d988f7cf5780488f24adbb116675586Mark Heffernan return false; 30023da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 30123da21150d988f7cf5780488f24adbb116675586Mark Heffernan const HloUse& use = value.uses()[0]; 30223da21150d988f7cf5780488f24adbb116675586Mark Heffernan 30323da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && 30423da21150d988f7cf5780488f24adbb116675586Mark Heffernan user->fused_expression_root()->opcode() == 30523da21150d988f7cf5780488f24adbb116675586Mark Heffernan HloOpcode::kDynamicUpdateSlice) { 30623da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Loop fusion with kDynamicUpdateSlice fused root. 30723da21150d988f7cf5780488f24adbb116675586Mark Heffernan // 30823da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Returns true iff there is exactly one use of 'operand' at shape index 30923da21150d988f7cf5780488f24adbb116675586Mark Heffernan // 'operand_index', and this singleton use is the fused root at operand 31023da21150d988f7cf5780488f24adbb116675586Mark Heffernan // index 0. 31123da21150d988f7cf5780488f24adbb116675586Mark Heffernan return use.instruction == user->fused_expression_root() && 31223da21150d988f7cf5780488f24adbb116675586Mark Heffernan use.operand_number == 0; 31323da21150d988f7cf5780488f24adbb116675586Mark Heffernan } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && 31423da21150d988f7cf5780488f24adbb116675586Mark Heffernan user->fused_expression_root()->opcode() == HloOpcode::kAdd) { 31523da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Output fusion with kAdd fused root. 31623da21150d988f7cf5780488f24adbb116675586Mark Heffernan 31723da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Check if one operand of kAdd fused root is either kDot, or nested 31823da21150d988f7cf5780488f24adbb116675586Mark Heffernan // kFusion of kind kTransposeDot. 31923da21150d988f7cf5780488f24adbb116675586Mark Heffernan auto* add = user->fused_expression_root(); 32023da21150d988f7cf5780488f24adbb116675586Mark Heffernan auto add_operand_it = 32123da21150d988f7cf5780488f24adbb116675586Mark Heffernan std::find_if(add->operands().begin(), add->operands().end(), 32223da21150d988f7cf5780488f24adbb116675586Mark Heffernan [&](HloInstruction* operand) { 32310d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer return operand->opcode() == HloOpcode::kConvolution || 32410d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer operand->opcode() == HloOpcode::kDot || 32523da21150d988f7cf5780488f24adbb116675586Mark Heffernan (operand->opcode() == HloOpcode::kFusion && 32623da21150d988f7cf5780488f24adbb116675586Mark Heffernan operand->fusion_kind() == 32723da21150d988f7cf5780488f24adbb116675586Mark Heffernan HloInstruction::FusionKind::kTransposeDot); 32823da21150d988f7cf5780488f24adbb116675586Mark Heffernan }); 32923da21150d988f7cf5780488f24adbb116675586Mark Heffernan if (add_operand_it == add->operands().end()) { 33023da21150d988f7cf5780488f24adbb116675586Mark Heffernan return false; 33123da21150d988f7cf5780488f24adbb116675586Mark Heffernan } 33223da21150d988f7cf5780488f24adbb116675586Mark Heffernan auto* matched_add_operand = *add_operand_it; 33323da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Calculate operand index of 'add' operand which was not matched above. 33423da21150d988f7cf5780488f24adbb116675586Mark Heffernan const int64 other_add_operand_index = 33523da21150d988f7cf5780488f24adbb116675586Mark Heffernan matched_add_operand == add->operand(0) ? 1 : 0; 33623da21150d988f7cf5780488f24adbb116675586Mark Heffernan // Returns true iff there is exactly one use of 'operand' at shape index 33723da21150d988f7cf5780488f24adbb116675586Mark Heffernan // 'operand_index', and this singleton use is the fused root (at operand 33823da21150d988f7cf5780488f24adbb116675586Mark Heffernan // index 'other_add_operand_index'). 33923da21150d988f7cf5780488f24adbb116675586Mark Heffernan return use.instruction == user->fused_expression_root() && 34023da21150d988f7cf5780488f24adbb116675586Mark Heffernan use.operand_number == other_add_operand_index; 341e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 34205349e6fe7b2395f541e11a9eacbeff04270e4c6A. Unique TensorFlower } 34305349e6fe7b2395f541e11a9eacbeff04270e4c6A. Unique TensorFlower if (user->opcode() == HloOpcode::kDynamicUpdateSlice || 34405349e6fe7b2395f541e11a9eacbeff04270e4c6A. Unique TensorFlower user->opcode() == HloOpcode::kWhile) { 3454718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower // We eliminated other users in BufferLiveness::live_range_strictly_before, 3464718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower // so here we just need to check that the use is at operand index 0. 3474718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower std::vector<int64> operand_indices = user->OperandIndices(operand); 3484718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower return operand_indices.size() == 1 && operand_indices[0] == 0; 349e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower } 350724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower if (user->opcode() == HloOpcode::kCall) { 351724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // Get all uses of value defined by 'operand' at 'operand_index'. 352724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower const auto& uses = 353724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower dataflow.GetValueDefinedAt(operand, operand_index).uses(); 354724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // Return true iff: 355724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) There exists two uses of 'operand'. 356724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) One use is by 'user' (caller). 357724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) One use is by root instruction of called computation (callee root). 358724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // (Note: we check the root of the called computation, because the 359724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // root result buffer is required to alias with the Call result buffer). 360724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // *) The root instruction of the called computation is element-wise on 361724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower // 'operand'. 362724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower const bool found_caller_use = 363724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { 364724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower return use.instruction == user; 365724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower }) != uses.end(); 366724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower auto* callee_root = user->to_apply()->root_instruction(); 367724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower const bool found_elementwise_callee_use = 368724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower std::find_if( 369724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower uses.begin(), uses.end(), [callee_root](const HloUse& use) { 370724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower return use.instruction == callee_root && 371724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower callee_root->IsElementwiseOnOperand(use.operand_number); 372724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower }) != uses.end(); 373724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; 374724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower } 375e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower // Check if 'user' is element-wise. 376e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower return user->IsElementwise(); 377e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower} 378e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower 379e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower} // namespace xla 380