1e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
3e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFloweryou may not use this file except in compliance with the License.
5e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerYou may obtain a copy of the License at
6e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
7e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
9e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerSee the License for the specific language governing permissions and
13e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerlimitations under the License.
14e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower==============================================================================*/
15e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
16e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/liveness_util.h"
17e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
18e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include <algorithm>
19e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include <utility>
20e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include <vector>
21e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
22e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/logical_buffer.h"
24e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
25e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/shape_util.h"
26e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/types.h"
27e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower#include "tensorflow/compiler/xla/util.h"
28e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
29e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowernamespace xla {
30e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
31342d315566211a095a06acb1973b94937dadbc0cMark Heffernanbool DoesNotUseOperandBuffer(const HloInstruction* operand,
32342d315566211a095a06acb1973b94937dadbc0cMark Heffernan                             const ShapeIndex& index,
33342d315566211a095a06acb1973b94937dadbc0cMark Heffernan                             const HloInstruction* user,
34e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower                             const TuplePointsToAnalysis& points_to_analysis) {
35e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  CHECK(user->IsUserOf(operand))
36e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      << "user: " << user->ToString() << " operand: " << operand->ToString();
37e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
38e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // GetTupleElement instructions only access the top-level buffer of their
39e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // operand.
40e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    return true;
41e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  } else if (user->opcode() == HloOpcode::kFusion &&
42e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower             user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
43e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // Find fusion parameter associated with 'operand'.
44e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    auto it = std::find_if(
45e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        user->fused_parameters().begin(), user->fused_parameters().end(),
46e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        [=](HloInstruction* fused_param) {
47e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower          return user->operand(fused_param->parameter_number()) == operand;
48e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        });
49e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    CHECK(it != user->fused_parameters().end());
50e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // Iterate through all users of all buffer aliases of the buffer in the
51e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // points-to set of fusion parameter at 'index'.
52e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // Return false if any uses are detected at 'index', returns true otherwise.
53e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    const LogicalBuffer* buffer =
54e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
55e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    for (const BufferAlias& alias :
56e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower         points_to_analysis.GetBufferAliases(*buffer)) {
57e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      for (HloInstruction* alias_user : alias.instruction()->users()) {
58e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
59e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower                                    alias_user, points_to_analysis)) {
60e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower          continue;
61e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        }
62e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
63e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        return false;
64e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      }
65e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    }
66e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    // Return true: found no uses of 'operand' at 'index' in 'user'.
67e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    return true;
68e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  }
69e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  return false;
70e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower}
71e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
7223da21150d988f7cf5780488f24adbb116675586Mark Heffernanbool DoesNotUseOperandBuffer(const HloInstruction* operand,
7323da21150d988f7cf5780488f24adbb116675586Mark Heffernan                             const ShapeIndex& index,
7423da21150d988f7cf5780488f24adbb116675586Mark Heffernan                             const HloInstruction* user,
7523da21150d988f7cf5780488f24adbb116675586Mark Heffernan                             const HloDataflowAnalysis& dataflow) {
7623da21150d988f7cf5780488f24adbb116675586Mark Heffernan  CHECK(user->IsUserOf(operand))
7723da21150d988f7cf5780488f24adbb116675586Mark Heffernan      << "user: " << user->ToString() << " operand: " << operand->ToString();
7823da21150d988f7cf5780488f24adbb116675586Mark Heffernan  if (user->opcode() == HloOpcode::kFusion &&
7923da21150d988f7cf5780488f24adbb116675586Mark Heffernan      user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
8023da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // Find fusion parameter associated with 'operand'.
8123da21150d988f7cf5780488f24adbb116675586Mark Heffernan    HloInstruction* fusion_param =
8223da21150d988f7cf5780488f24adbb116675586Mark Heffernan        user->fused_parameter(user->operand_index(operand));
8323da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // Iterate through all users of all uses of the fusion parameter value.
8423da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // Return false if any uses are detected, returns true otherwise.
8523da21150d988f7cf5780488f24adbb116675586Mark Heffernan    const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index);
8623da21150d988f7cf5780488f24adbb116675586Mark Heffernan    return value.uses().empty();
8723da21150d988f7cf5780488f24adbb116675586Mark Heffernan  } else {
8823da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // Return false if no value at 'operand' and 'index' is used at 'user'.
8923da21150d988f7cf5780488f24adbb116675586Mark Heffernan    for (const HloValue* value :
9023da21150d988f7cf5780488f24adbb116675586Mark Heffernan         dataflow.GetValueSet(operand, index).values()) {
9123da21150d988f7cf5780488f24adbb116675586Mark Heffernan      for (const HloUse& use : value->uses()) {
9223da21150d988f7cf5780488f24adbb116675586Mark Heffernan        if (use.instruction == user) {
9323da21150d988f7cf5780488f24adbb116675586Mark Heffernan          return false;
9423da21150d988f7cf5780488f24adbb116675586Mark Heffernan        }
9523da21150d988f7cf5780488f24adbb116675586Mark Heffernan      }
9623da21150d988f7cf5780488f24adbb116675586Mark Heffernan    }
9723da21150d988f7cf5780488f24adbb116675586Mark Heffernan  }
9823da21150d988f7cf5780488f24adbb116675586Mark Heffernan
9923da21150d988f7cf5780488f24adbb116675586Mark Heffernan  return true;
10023da21150d988f7cf5780488f24adbb116675586Mark Heffernan}
10123da21150d988f7cf5780488f24adbb116675586Mark Heffernan
102e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowernamespace {
103e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
104e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
105e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
106724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower// where 'user' is a user of an alias of 'instruction' at 'index', and
107e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// 'operand_index' is the operand index at which the alias appears in the
108e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// operand list of 'user'.
109e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerstd::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
110e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    HloInstruction* instruction, const ShapeIndex& index,
111e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    const TuplePointsToAnalysis& points_to_analysis) {
112e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  std::vector<std::pair<HloInstruction*, int64>> uses;
1135ead76420dee762a5f710fda6893075f1292d5d3A. Unique TensorFlower  const PointsToSet::BufferList& points_to =
114e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      points_to_analysis.GetPointsToSet(instruction).element(index);
115e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  for (const LogicalBuffer* buffer : points_to) {
116e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    for (const BufferAlias& alias :
117e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower         points_to_analysis.GetBufferAliases(*buffer)) {
118e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      for (HloInstruction* alias_user : alias.instruction()->users()) {
119e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
120e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower                                    alias_user, points_to_analysis)) {
121e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower          continue;
122e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        }
123e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
124e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower          uses.emplace_back(alias_user, op_idx);
125e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower        }
126e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      }
127e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    }
128e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  }
129e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  return uses;
130e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower}
131e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
13209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// Returns true if there is exactly one use of 'operand' at 'operand_index'
13309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// in 'fusion.fused_instructions', where the singleton use is the fused
13409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// root at operand index 'use_operand_index'. Returns false otherwise.
13509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower//
13609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower// REQUIRES: 'fusion' opcode is a kFusion instruction.
13709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlowerbool HasUniqueFusedUseOfOperandAt(
13809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    HloInstruction* operand, const ShapeIndex& operand_index,
13909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    HloInstruction* fusion, const int64 use_operand_index,
14009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    const TuplePointsToAnalysis& points_to_analysis) {
14109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
14209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  // Check that 'operand' is unique in the operand list of 'fusion'.
14309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  if (fusion->OperandIndices(operand).size() > 1) {
14409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    return false;
14509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  }
14609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  // Find fusion parameter associated with 'operand'.
14709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  const auto& fused_params = fusion->fused_parameters();
14809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  auto fused_param_it = std::find_if(
14909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      fused_params.begin(), fused_params.end(),
15009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      [&](HloInstruction* fused_param) {
15109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower        return fusion->operand(fused_param->parameter_number()) == operand;
15209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      });
15309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  if (fused_param_it == fused_params.end()) {
15409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    return false;
15509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  }
15609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  auto* fused_param = *fused_param_it;
15709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
15809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
15909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      fused_param, operand_index, points_to_analysis);
16009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  // Return true iff there is exactly one use of 'operand' at 'index', and
16109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  // this singleton use is the fused root (at index in 'use_operand_indices').
16209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower  return fused_param_uses.size() == 1 &&
16309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower         fused_param_uses[0].first == fusion->fused_expression_root() &&
16409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower         fused_param_uses[0].second == use_operand_index;
16509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower}
16609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower
167e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower}  // namespace
168e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
169e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower// User and operand can share buffers iff both instructions emit the same shape
170971b11dcca8942c76b601a1b418ccd98a9b25f4aA. Unique TensorFlower// and layout, and 'user' meets one of the following qualifications:
1715b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//
1725b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (1) Is element-wise. Or...
1735b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (2) Is a loop fusion instruction where the only use of 'operand' at 'index'
1745b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//     in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
1755b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//     at operand 0. Or...
1765b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion
1775b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//     instruction where the only use of 'operand' at 'index' in the set
1785b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//     'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or...
1795b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
1805b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//     0.
1815b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan//
1825b6a203c5c759656b2b7018271219916ddd85cb6Mark Heffernan// (2) and (3) can only be determined if points-to analysis is available.
183e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlowerbool CanShareOperandBufferWithUser(
184e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    HloInstruction* operand, const ShapeIndex& operand_index,
185e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    HloInstruction* user, const ShapeIndex& user_index,
18623da21150d988f7cf5780488f24adbb116675586Mark Heffernan    const TuplePointsToAnalysis& points_to_analysis) {
187e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  CHECK(user->IsUserOf(operand))
188e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      << "user: " << user->ToString() << " operand: " << operand->ToString();
1897e3d54903037ad44f17c362a510c07cb5190c778Jeffrey A. Dean  const Shape& operand_subshape =
190e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      ShapeUtil::GetSubshape(operand->shape(), operand_index);
1917e3d54903037ad44f17c362a510c07cb5190c778Jeffrey A. Dean  const Shape& user_subshape =
1927e3d54903037ad44f17c362a510c07cb5190c778Jeffrey A. Dean      ShapeUtil::GetSubshape(user->shape(), user_index);
193e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  // Check that operand and user emit the same shape and layout.
194e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
195e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    return false;
196e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  }
19723da21150d988f7cf5780488f24adbb116675586Mark Heffernan  if (user->opcode() == HloOpcode::kFusion) {
19809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
19909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower        user->fused_expression_root()->opcode() ==
20009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower            HloOpcode::kDynamicUpdateSlice) {
20109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // Loop fusion with kDynamicUpdateSlice fused root.
20209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      //
20309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // Returns true iff there is exactly one use of 'operand' at shape index
20409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // 'operand_index', and this singleton use is the fused root at operand
20509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // index 0.
20609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0,
20723da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                          points_to_analysis);
20809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower    } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
20909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower               user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
21009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // Output fusion with kAdd fused root.
21109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower
21209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // Check if one operand of kAdd fused root is either kDot, or nested
21309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // kFusion of kind kTransposeDot.
21409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      auto* add = user->fused_expression_root();
21509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      auto add_operand_it =
21609f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower          std::find_if(add->operands().begin(), add->operands().end(),
21709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower                       [&](HloInstruction* operand) {
21810d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer                         return operand->opcode() == HloOpcode::kConvolution ||
21910d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer                                operand->opcode() == HloOpcode::kDot ||
22009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower                                (operand->opcode() == HloOpcode::kFusion &&
22109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower                                 operand->fusion_kind() ==
22209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower                                     HloInstruction::FusionKind::kTransposeDot);
22309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower                       });
22409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      if (add_operand_it == add->operands().end()) {
22509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower        return false;
226e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower      }
22709f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      auto* matched_add_operand = *add_operand_it;
22809f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // Calculate operand index of 'add' operand which was not matched above.
22909f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      const int64 other_add_operand_index =
23009f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower          matched_add_operand == add->operand(0) ? 1 : 0;
23109f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // Returns true iff there is exactly one use of 'operand' at shape index
23209f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // 'operand_index', and this singleton use is the fused root (at operand
23309f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      // index 'other_add_operand_index').
23409f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower      return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
23509f3fb939c9b395a9bc747cf81d15b2dc2804c3eA. Unique TensorFlower                                          other_add_operand_index,
23623da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                          points_to_analysis);
23723da21150d988f7cf5780488f24adbb116675586Mark Heffernan    }
23823da21150d988f7cf5780488f24adbb116675586Mark Heffernan  }
23923da21150d988f7cf5780488f24adbb116675586Mark Heffernan  if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
24023da21150d988f7cf5780488f24adbb116675586Mark Heffernan      user->opcode() == HloOpcode::kWhile) {
24123da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // We eliminated other users in BufferLiveness::live_range_strictly_before,
24223da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // so here we just need to check that the use is at operand index 0.
24323da21150d988f7cf5780488f24adbb116675586Mark Heffernan    std::vector<int64> operand_indices = user->OperandIndices(operand);
24423da21150d988f7cf5780488f24adbb116675586Mark Heffernan    return operand_indices.size() == 1 && operand_indices[0] == 0;
24523da21150d988f7cf5780488f24adbb116675586Mark Heffernan  }
246724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower  if (user->opcode() == HloOpcode::kCall) {
247724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // TODO(b/62548313): Remove when buffer assignment is module scoped and
248724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // does not assign buffers to calls.
249724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // Find called computation parameter associated with 'operand'.
250724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    const std::vector<int64> operand_indices = user->OperandIndices(operand);
251724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    if (operand_indices.size() > 1) {
252724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower      return false;
253724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    }
254724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    CHECK_EQ(1, operand_indices.size());
255724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    auto* param = user->to_apply()->parameter_instruction(operand_indices[0]);
256724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // Get all uses of 'operand' at 'index' in called computation.
257724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index,
258724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower                                                     points_to_analysis);
259724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower
260724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // Return true iff:
261724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) There exists exactly one use of 'operand' in called computation.
262724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) The unique use is by the root instruction of called computation.
263724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    //    (Note: we check the root of the called computation, because the
264724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    //     root result buffer is required to alias with the Call result buffer).
265724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) The root instruction of the called computation is element-wise on
266724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    //    'operand'.
267724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    auto* callee_root = user->to_apply()->root_instruction();
268724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    return param_uses.size() == 1 && param_uses[0].first == callee_root &&
269724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower           callee_root->IsElementwiseOnOperand(param_uses[0].second);
270724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower  }
27123da21150d988f7cf5780488f24adbb116675586Mark Heffernan  // Check if 'user' is element-wise.
27223da21150d988f7cf5780488f24adbb116675586Mark Heffernan  return user->IsElementwise();
27323da21150d988f7cf5780488f24adbb116675586Mark Heffernan}
27423da21150d988f7cf5780488f24adbb116675586Mark Heffernan
27523da21150d988f7cf5780488f24adbb116675586Mark Heffernanbool CanShareOperandBufferWithUser(HloInstruction* operand,
27623da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                   const ShapeIndex& operand_index,
27723da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                   HloInstruction* user,
27823da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                   const ShapeIndex& user_index,
27923da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                   const HloDataflowAnalysis& dataflow) {
28023da21150d988f7cf5780488f24adbb116675586Mark Heffernan  CHECK(user->IsUserOf(operand))
28123da21150d988f7cf5780488f24adbb116675586Mark Heffernan      << "user: " << user->ToString() << " operand: " << operand->ToString();
28223da21150d988f7cf5780488f24adbb116675586Mark Heffernan  const Shape& operand_subshape =
28323da21150d988f7cf5780488f24adbb116675586Mark Heffernan      ShapeUtil::GetSubshape(operand->shape(), operand_index);
28423da21150d988f7cf5780488f24adbb116675586Mark Heffernan  const Shape& user_subshape =
28523da21150d988f7cf5780488f24adbb116675586Mark Heffernan      ShapeUtil::GetSubshape(user->shape(), user_index);
28623da21150d988f7cf5780488f24adbb116675586Mark Heffernan  // Check that operand and user emit the same shape and layout.
28723da21150d988f7cf5780488f24adbb116675586Mark Heffernan  if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
28823da21150d988f7cf5780488f24adbb116675586Mark Heffernan    return false;
28923da21150d988f7cf5780488f24adbb116675586Mark Heffernan  }
29023da21150d988f7cf5780488f24adbb116675586Mark Heffernan
29123da21150d988f7cf5780488f24adbb116675586Mark Heffernan  if (user->opcode() == HloOpcode::kFusion) {
29223da21150d988f7cf5780488f24adbb116675586Mark Heffernan    // Get the parameter associated with 'operand';
29323da21150d988f7cf5780488f24adbb116675586Mark Heffernan    HloInstruction* fusion_param =
29423da21150d988f7cf5780488f24adbb116675586Mark Heffernan        user->fused_parameter(user->operand_index(operand));
29523da21150d988f7cf5780488f24adbb116675586Mark Heffernan
29623da21150d988f7cf5780488f24adbb116675586Mark Heffernan    const HloValue& value =
29723da21150d988f7cf5780488f24adbb116675586Mark Heffernan        dataflow.GetValueDefinedAt(fusion_param, operand_index);
29823da21150d988f7cf5780488f24adbb116675586Mark Heffernan    if (value.uses().size() != 1) {
29923da21150d988f7cf5780488f24adbb116675586Mark Heffernan      return false;
30023da21150d988f7cf5780488f24adbb116675586Mark Heffernan    }
30123da21150d988f7cf5780488f24adbb116675586Mark Heffernan    const HloUse& use = value.uses()[0];
30223da21150d988f7cf5780488f24adbb116675586Mark Heffernan
30323da21150d988f7cf5780488f24adbb116675586Mark Heffernan    if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
30423da21150d988f7cf5780488f24adbb116675586Mark Heffernan        user->fused_expression_root()->opcode() ==
30523da21150d988f7cf5780488f24adbb116675586Mark Heffernan            HloOpcode::kDynamicUpdateSlice) {
30623da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // Loop fusion with kDynamicUpdateSlice fused root.
30723da21150d988f7cf5780488f24adbb116675586Mark Heffernan      //
30823da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // Returns true iff there is exactly one use of 'operand' at shape index
30923da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // 'operand_index', and this singleton use is the fused root at operand
31023da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // index 0.
31123da21150d988f7cf5780488f24adbb116675586Mark Heffernan      return use.instruction == user->fused_expression_root() &&
31223da21150d988f7cf5780488f24adbb116675586Mark Heffernan             use.operand_number == 0;
31323da21150d988f7cf5780488f24adbb116675586Mark Heffernan    } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
31423da21150d988f7cf5780488f24adbb116675586Mark Heffernan               user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
31523da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // Output fusion with kAdd fused root.
31623da21150d988f7cf5780488f24adbb116675586Mark Heffernan
31723da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // Check if one operand of kAdd fused root is either kDot, or nested
31823da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // kFusion of kind kTransposeDot.
31923da21150d988f7cf5780488f24adbb116675586Mark Heffernan      auto* add = user->fused_expression_root();
32023da21150d988f7cf5780488f24adbb116675586Mark Heffernan      auto add_operand_it =
32123da21150d988f7cf5780488f24adbb116675586Mark Heffernan          std::find_if(add->operands().begin(), add->operands().end(),
32223da21150d988f7cf5780488f24adbb116675586Mark Heffernan                       [&](HloInstruction* operand) {
32310d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer                         return operand->opcode() == HloOpcode::kConvolution ||
32410d1827987b0eca4d0e6f8f56506c93c67e03f83David Majnemer                                operand->opcode() == HloOpcode::kDot ||
32523da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                (operand->opcode() == HloOpcode::kFusion &&
32623da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                 operand->fusion_kind() ==
32723da21150d988f7cf5780488f24adbb116675586Mark Heffernan                                     HloInstruction::FusionKind::kTransposeDot);
32823da21150d988f7cf5780488f24adbb116675586Mark Heffernan                       });
32923da21150d988f7cf5780488f24adbb116675586Mark Heffernan      if (add_operand_it == add->operands().end()) {
33023da21150d988f7cf5780488f24adbb116675586Mark Heffernan        return false;
33123da21150d988f7cf5780488f24adbb116675586Mark Heffernan      }
33223da21150d988f7cf5780488f24adbb116675586Mark Heffernan      auto* matched_add_operand = *add_operand_it;
33323da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // Calculate operand index of 'add' operand which was not matched above.
33423da21150d988f7cf5780488f24adbb116675586Mark Heffernan      const int64 other_add_operand_index =
33523da21150d988f7cf5780488f24adbb116675586Mark Heffernan          matched_add_operand == add->operand(0) ? 1 : 0;
33623da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // Returns true iff there is exactly one use of 'operand' at shape index
33723da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // 'operand_index', and this singleton use is the fused root (at operand
33823da21150d988f7cf5780488f24adbb116675586Mark Heffernan      // index 'other_add_operand_index').
33923da21150d988f7cf5780488f24adbb116675586Mark Heffernan      return use.instruction == user->fused_expression_root() &&
34023da21150d988f7cf5780488f24adbb116675586Mark Heffernan             use.operand_number == other_add_operand_index;
341e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower    }
34205349e6fe7b2395f541e11a9eacbeff04270e4c6A. Unique TensorFlower  }
34305349e6fe7b2395f541e11a9eacbeff04270e4c6A. Unique TensorFlower  if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
34405349e6fe7b2395f541e11a9eacbeff04270e4c6A. Unique TensorFlower      user->opcode() == HloOpcode::kWhile) {
3454718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower    // We eliminated other users in BufferLiveness::live_range_strictly_before,
3464718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower    // so here we just need to check that the use is at operand index 0.
3474718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower    std::vector<int64> operand_indices = user->OperandIndices(operand);
3484718ac6b15cd5ed6f7da0692c97a79596e465580A. Unique TensorFlower    return operand_indices.size() == 1 && operand_indices[0] == 0;
349e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  }
350724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower  if (user->opcode() == HloOpcode::kCall) {
351724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // Get all uses of value defined by 'operand' at 'operand_index'.
352724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    const auto& uses =
353724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower        dataflow.GetValueDefinedAt(operand, operand_index).uses();
354724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // Return true iff:
355724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) There exists two uses of 'operand'.
356724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) One use is by 'user' (caller).
357724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) One use is by root instruction of called computation (callee root).
358724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    //    (Note: we check the root of the called computation, because the
359724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    //     root result buffer is required to alias with the Call result buffer).
360724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    // *) The root instruction of the called computation is element-wise on
361724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    //    'operand'.
362724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    const bool found_caller_use =
363724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower        std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
364724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower          return use.instruction == user;
365724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower        }) != uses.end();
366724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    auto* callee_root = user->to_apply()->root_instruction();
367724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    const bool found_elementwise_callee_use =
368724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower        std::find_if(
369724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower            uses.begin(), uses.end(), [callee_root](const HloUse& use) {
370724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower              return use.instruction == callee_root &&
371724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower                     callee_root->IsElementwiseOnOperand(use.operand_number);
372724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower            }) != uses.end();
373724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower    return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
374724ca9f1a5a7428e74b62c8e2e6061244af93aceA. Unique TensorFlower  }
375e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  // Check if 'user' is element-wise.
376e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower  return user->IsElementwise();
377e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower}
378e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower
379e0d0c676ec111c711099bf89eb51278bc4493678A. Unique TensorFlower}  // namespace xla
380