liveness_util.cc revision 342d315566211a095a06acb1973b94937dadbc0c
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/liveness_util.h"
17
18#include <algorithm>
19#include <utility>
20#include <vector>
21
22#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23#include "tensorflow/compiler/xla/service/logical_buffer.h"
24#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/types.h"
27#include "tensorflow/compiler/xla/util.h"
28
29namespace xla {
30
31bool DoesNotUseOperandBuffer(const HloInstruction* operand,
32                             const ShapeIndex& index,
33                             const HloInstruction* user,
34                             const TuplePointsToAnalysis& points_to_analysis) {
35  CHECK(user->IsUserOf(operand))
36      << "user: " << user->ToString() << " operand: " << operand->ToString();
37  if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
38    // GetTupleElement instructions only access the top-level buffer of their
39    // operand.
40    return true;
41  } else if (user->opcode() == HloOpcode::kFusion &&
42             user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
43    // Find fusion parameter associated with 'operand'.
44    auto it = std::find_if(
45        user->fused_parameters().begin(), user->fused_parameters().end(),
46        [=](HloInstruction* fused_param) {
47          return user->operand(fused_param->parameter_number()) == operand;
48        });
49    CHECK(it != user->fused_parameters().end());
50    // Iterate through all users of all buffer aliases of the buffer in the
51    // points-to set of fusion parameter at 'index'.
52    // Return false if any uses are detected at 'index', returns true otherwise.
53    const LogicalBuffer* buffer =
54        points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
55    for (const BufferAlias& alias :
56         points_to_analysis.GetBufferAliases(*buffer)) {
57      for (HloInstruction* alias_user : alias.instruction()->users()) {
58        if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
59                                    alias_user, points_to_analysis)) {
60          continue;
61        }
62        // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
63        return false;
64      }
65    }
66    // Return true: found no uses of 'operand' at 'index' in 'user'.
67    return true;
68  }
69  return false;
70}
71
72namespace {
73
74// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
75// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
76// where 'user' is a user of an alias of 'intruction' at 'index', and
77// 'operand_index' is the operand index at which the alias appears in the
78// operand list of 'user'.
79std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
80    HloInstruction* instruction, const ShapeIndex& index,
81    const TuplePointsToAnalysis& points_to_analysis) {
82  std::vector<std::pair<HloInstruction*, int64>> uses;
83  const std::vector<const LogicalBuffer*>& points_to =
84      points_to_analysis.GetPointsToSet(instruction).element(index);
85  for (const LogicalBuffer* buffer : points_to) {
86    for (const BufferAlias& alias :
87         points_to_analysis.GetBufferAliases(*buffer)) {
88      for (HloInstruction* alias_user : alias.instruction()->users()) {
89        if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
90                                    alias_user, points_to_analysis)) {
91          continue;
92        }
93        for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
94          uses.emplace_back(alias_user, op_idx);
95        }
96      }
97    }
98  }
99  return uses;
100}
101
102}  // namespace
103
104// User and operand can share buffers iff both instructions emit the same shape
105// and layout, and 'user' meets one of the following qualifications:
106// *) Is element-wise. Or...
107// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
108//    in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
109//    at operand 0. Or...
110// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0.
111bool CanShareOperandBufferWithUser(
112    HloInstruction* operand, const ShapeIndex& operand_index,
113    HloInstruction* user, const ShapeIndex& user_index,
114    const TuplePointsToAnalysis& points_to_analysis) {
115  CHECK(user->IsUserOf(operand))
116      << "user: " << user->ToString() << " operand: " << operand->ToString();
117  Shape operand_subshape =
118      ShapeUtil::GetSubshape(operand->shape(), operand_index);
119  Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
120  // Check that operand and user emit the same shape and layout.
121  if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
122    return false;
123  }
124  // Copy instructions are explicitly added by CopyInsertion to prevent liveness
125  // issues, so they should never re-use their operand buffer.
126  if (user->opcode() == HloOpcode::kCopy) {
127    return false;
128  }
129  // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
130  // fused root instruction.
131  if (user->opcode() == HloOpcode::kFusion &&
132      user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
133      user->fused_expression_root()->opcode() ==
134          HloOpcode::kDynamicUpdateSlice) {
135    for (auto& fused_param : user->fused_parameters()) {
136      // Find fusion parameter associated with 'operand'.
137      if (user->operand(fused_param->parameter_number()) != operand) {
138        continue;
139      }
140      // Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
141      auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
142          fused_param, operand_index, points_to_analysis);
143      // Return true iff there is exactly one use of 'operand' at 'index', and
144      // this singleton use is the fused root at operand index 0.
145      if (fused_param_uses.size() == 1 &&
146          fused_param_uses[0].first == user->fused_expression_root() &&
147          fused_param_uses[0].second == 0) {
148        return true;
149      }
150      break;
151    }
152    return false;
153  }
154  if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
155      user->opcode() == HloOpcode::kWhile) {
156    // We eliminated other users in BufferLiveness::live_range_strictly_before,
157    // so here we just need to check that the use is at operand index 0.
158    std::vector<int64> operand_indices = user->OperandIndices(operand);
159    return operand_indices.size() == 1 && operand_indices[0] == 0;
160  }
161  // Check if 'user' is element-wise.
162  return user->IsElementwise();
163}
164
165}  // namespace xla
166