1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
17
18#include <map>
19#include <utility>
20#include <vector>
21
22#include "tensorflow/compiler/xla/service/heap_simulator.h"
23#include "tensorflow/compiler/xla/service/hlo_computation.h"
24#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/status_macros.h"
27#include "tensorflow/compiler/xla/statusor.h"
28#include "tensorflow/compiler/xla/types.h"
29#include "tensorflow/compiler/xla/util.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/strings/str_util.h"
32#include "tensorflow/core/lib/strings/stringprintf.h"
33#include "tensorflow/core/platform/logging.h"
34
35using ::tensorflow::strings::HumanReadableNumBytes;
36
37namespace xla {
38
39StatusOr<int64> MinimumMemoryForSequence(
40    const SequentialHloOrdering::HloModuleSequence& module_sequence,
41    const LogicalBuffer::SizeFunction& size_function) {
42  if (module_sequence.empty()) {
43    return 0;
44  }
45
46  const HloModule* module = module_sequence.begin()->first->parent();
47  TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
48                      TuplePointsToAnalysis::Run(module));
49
50  // The absolute minimum memory required for a given sequence of instructions
51  // is determined by the sequence of Alloc and Free calls on a simulated heap,
52  // ignoring fragmentation. We run the heap simulation on the whole module,
53  // rather than summing each computation, since it gives us a better lower
54  // bound, by minimizing the liveness of sub-computations.
55  TF_ASSIGN_OR_RETURN(
56      HeapSimulator::Result result,
57      HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
58                         module_sequence, *points_to_analysis, size_function));
59  return result.heap_size;
60}
61
62namespace {
63
64// Class implementing a list scheduler of HLO instructions which produces a
65// sequence which minimizes memory usage.
66class ListScheduler {
67 public:
68  // Construct and return a memory-minimizing sequence of HLO instructions
69  // containing the given HLO computation.
70  static StatusOr<std::vector<const HloInstruction*>> Run(
71      const HloComputation& computation,
72      const TuplePointsToAnalysis& points_to_analysis,
73      const LogicalBuffer::SizeFunction& size_function) {
74    ListScheduler scheduler(computation, points_to_analysis, size_function);
75    return scheduler.CreateSchedule();
76  }
77
78  // Returns whether the memory used by the given HLO should be ignored by the
79  // scheduling heuristic.
80  static bool IgnoreInstruction(const HloInstruction& instruction) {
81    return instruction.opcode() == HloOpcode::kParameter ||
82           instruction.opcode() == HloOpcode::kConstant;
83  }
84
85 private:
86  // The scheduling priority of an instruction is first the number of bytes
87  // freed by scheduling the instruction, and second (tie-breaker) by the number
88  // of users. This is represented as a std::pair containing these two values
89  // (first element is the bytes freed). std::pair provides the necessary
90  // comparison operators.
91  using Priority = std::pair<int64, int64>;
92
93  ListScheduler(const HloComputation& computation,
94                const TuplePointsToAnalysis& points_to_analysis,
95                const LogicalBuffer::SizeFunction& size_function)
96      : computation_(computation),
97        points_to_analysis_(points_to_analysis),
98        size_function_(size_function) {
99    // Create a map containing the LogicalBuffer uses for each HLO
100    // instruction. An HLO instruction "uses" a LogicalBuffer if the
101    // LogicalBuffer is in an operand of the instruction as indicated by
102    // points-to analysis.
103    for (auto* instruction : computation.instructions()) {
104      tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
105      for (auto* operand : instruction->operands()) {
106        for (const LogicalBuffer* buffer :
107             points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
108          instr_uses.insert(buffer);
109        }
110      }
111      buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
112          instr_uses.begin(), instr_uses.end());
113    }
114
115    // Create map containing the number of unscheduled uses (hlo instructions)
116    // of each logical buffer.
117    for (auto* instruction : computation.instructions()) {
118      for (auto* buffer :
119           points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
120        unscheduled_use_count_[buffer] = 0;
121      }
122    }
123    for (auto* instruction : computation.instructions()) {
124      for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
125        ++unscheduled_use_count_[buffer];
126      }
127    }
128
129    // Buffers live out of the computation have an implicit use at the end of
130    // the computation.
131    for (const LogicalBuffer* live_out_buffer :
132         points_to_analysis.GetPointsToSet(computation.root_instruction())
133             .CreateFlattenedSet()) {
134      ++unscheduled_use_count_[live_out_buffer];
135    }
136  }
137
138  // Returns whether the memory used by the given buffer should be ignored by
139  // the scheduling heuristic.
140  static bool IgnoreBuffer(const LogicalBuffer& buffer) {
141    return IgnoreInstruction(*buffer.instruction());
142  }
143
144  // An entry in the worklist used by CreateSchedule.  Corresponds to one
145  // HloInstruction, plus some cached metadata, saved for the purposes of making
146  // BytesFreedIfScheduled fast.
147  struct ReadyListEntry {
148    const HloInstruction* instruction;
149
150    // The total size of all buffers defined by this instruction.
151    int64 bytes_defined;
152
153    // For each buffer B used by this instruction, we keep a pair (B, U), where
154    // U is the number of uses of B that have not yet been scheduled. This pair
155    // is a pointer into the unscheduled_use_count_ map, so it gets updated for
156    // free when we update counts in the map.
157    std::vector<const std::pair<const LogicalBuffer* const, int64>*>
158        used_buffer_unscheduled_use_counts;
159  };
160
161  // Creates a ReadyListEntry for the given instruction.
162  ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) {
163    ReadyListEntry entry;
164    entry.instruction = instruction;
165
166    entry.bytes_defined = 0;
167    for (auto* buffer :
168         points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
169      if (!IgnoreBuffer(*buffer)) {
170        entry.bytes_defined += size_function_(*buffer);
171      }
172    }
173
174    for (auto* buffer : buffer_uses_.at(instruction)) {
175      if (IgnoreBuffer(*buffer)) {
176        continue;
177      }
178      auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
179      CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
180      entry.used_buffer_unscheduled_use_counts.push_back(
181          &*unscheduled_use_count_it);
182    }
183    return entry;
184  }
185
186  // Returns the number of bytes freed if the HLO instruction is scheduled.
187  int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
188    int64 freed_bytes = 0;
189    for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
190      auto buffer = kv->first;
191      auto use_count = kv->second;
192      if (use_count == 1) {
193        freed_bytes += size_function_(*buffer);
194      }
195    }
196    return freed_bytes - entry.bytes_defined;
197  }
198
199  // Constructs the scheduling priority of the given instruction.
200  Priority GetPriority(const ReadyListEntry& entry) {
201    return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
202  }
203
204  std::vector<const HloInstruction*> CreateSchedule() {
205    std::vector<const HloInstruction*> schedule;
206
207    // Populate the ready list with instructions which have no operands or
208    // control predecessors.
209    tensorflow::gtl::FlatMap<const HloInstruction*, int64>
210        unscheduled_pred_count;
211    for (auto* instruction : computation_.instructions()) {
212      // TODO(b/34466113): Replace this and above with successors() or
213      // predecessors() when these methods are added to HloInstruction.
214      for (const HloInstruction* user : instruction->users()) {
215        unscheduled_pred_count[user]++;
216      }
217      for (const HloInstruction* succ : instruction->control_successors()) {
218        unscheduled_pred_count[succ]++;
219      }
220    }
221
222    // Use a multimap to sort ReadyListEntry according to their priority.
223    std::multimap<Priority, ReadyListEntry> ready_queue;
224
225    // Map of ready instructions to their iterators in ready_queue.
226    tensorflow::gtl::FlatMap<const HloInstruction*,
227                             std::multimap<Priority, ReadyListEntry>::iterator>
228        ready_instructions;
229
230    auto add_to_ready_queue = [&](HloInstruction* inst) {
231      auto entry = MakeReadyListEntry(inst);
232      auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
233      ready_instructions[inst] = it;
234    };
235
236    for (auto* instruction : computation_.instructions()) {
237      // Instruction with no operands or control predecessors will
238      // not be in the map.
239      if (unscheduled_pred_count.count(instruction) == 0) {
240        add_to_ready_queue(instruction);
241      }
242    }
243
244    while (!ready_queue.empty()) {
245      // Remove the selected instruction from the ready list and add it to the
246      // schedule.
247      auto best_it = ready_queue.end();
248      --best_it;
249      const HloInstruction* best = best_it->second.instruction;
250      ready_queue.erase(best_it);
251      ready_instructions.erase(best);
252      schedule.push_back(best);
253      scheduled_instructions_.insert(best);
254
255      bool adjust_ready_queue = false;
256      // Update the unscheduled uses of the logical buffers.
257      for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
258        int64& count = unscheduled_use_count_[buffer];
259        CHECK_GT(count, 0);
260        --count;
261        if (count == 1) {
262          adjust_ready_queue = true;
263        }
264      }
265
266      // Add new instructions to ready list.
267      auto update_pred_count = [&](HloInstruction* inst) {
268        int64 pred_count = --unscheduled_pred_count.at(inst);
269        CHECK_GE(pred_count, 0);
270        if (pred_count == 0) {
271          add_to_ready_queue(inst);
272        }
273      };
274      // TODO(b/34466113): Replace this and above with successors() or
275      // predecessors() when these methods are added to HloInstruction.
276      for (HloInstruction* user : best->users()) {
277        update_pred_count(user);
278      }
279      for (HloInstruction* succ : best->control_successors()) {
280        update_pred_count(succ);
281      }
282      // The unscheduled use count for a buffer has changed to 1, so the
283      // priorities of some ready instructions may go up. We update them in the
284      // ready queue, so that they can appear earlier.
285      if (adjust_ready_queue) {
286        for (HloInstruction* operand : best->operands()) {
287          for (HloInstruction* operand_user : operand->users()) {
288            auto ready_instructions_it = ready_instructions.find(operand_user);
289            if (ready_instructions_it == ready_instructions.end()) {
290              continue;
291            }
292            auto ready_queue_it = ready_instructions_it->second;
293            auto& entry = ready_queue_it->second;
294            Priority new_priority = GetPriority(entry);
295            if (new_priority == ready_queue_it->first) {
296              continue;
297            }
298            // Create a new entry in ready_queue, then update
299            // ready_instructions[operand_user] to refer to the new entry.
300            ready_instructions_it->second =
301                ready_queue.emplace(new_priority, std::move(entry));
302            // Remove the old entry in ready_queue.
303            ready_queue.erase(ready_queue_it);
304          }
305        }
306      }
307    }
308    CHECK_EQ(schedule.size(), computation_.instruction_count());
309    CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
310
311    return schedule;
312  }
313
314  const HloComputation& computation_;
315  const TuplePointsToAnalysis& points_to_analysis_;
316  const LogicalBuffer::SizeFunction& size_function_;
317
318  // A map containing the LogicalBuffers that each instruction uses.
319  tensorflow::gtl::FlatMap<const HloInstruction*,
320                           std::vector<const LogicalBuffer*>>
321      buffer_uses_;
322
323  // A map containing the count of unscheduled HLOs which using a particular
324  // LogicalBuffer.  We rely on iterator stability in this map, and that the map
325  // entries are std::pair's.
326  std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
327
328  // Set of instructions which have been scheduled.
329  tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
330};
331
332int64 SumLogicalBufferSizes(
333    const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
334    const LogicalBuffer::SizeFunction& size_function) {
335  int64 size = 0;
336  for (const LogicalBuffer* buffer : buffers) {
337    size += size_function(*buffer);
338  }
339  return size;
340}
341
342StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
343    const HloComputation& computation,
344    const TuplePointsToAnalysis& points_to_analysis,
345    const LogicalBuffer::SizeFunction& size_function) {
346  // This ordering is based on DFS post-order, with a heuristic to decide which
347  // operand to visit first.  The heuristic is based on 'extra_users', which is
348  // simply users-1 for each instruction.  By subtracting 1, we're saying that
349  // instructions with no users or a single user don't count; instructions with
350  // lots of fan-out will be visited earlier.
351  tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
352  tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
353  for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
354    if (ListScheduler::IgnoreInstruction(*hlo)) {
355      extra_users[hlo] = 0;
356      total_sizes[hlo] = 0;
357      continue;
358    }
359    extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
360    total_sizes[hlo] = SumLogicalBufferSizes(
361        points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
362    tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
363        hlo->operands().begin(), hlo->operands().end());
364    for (const HloInstruction* operand : unique_operands) {
365      extra_users[hlo] += extra_users[operand];
366      total_sizes[hlo] += total_sizes[operand];
367    }
368  }
369  CHECK_EQ(extra_users.size(), computation.instruction_count());
370  CHECK_EQ(total_sizes.size(), computation.instruction_count());
371
372  // Construct a total order based on DFS post-order, visiting operands in
373  // decreasing cumulative extra user order, and next by cumulative size, with a
374  // tiebreaker by name for determinism.
375  std::vector<const HloInstruction*> sequence;
376  FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
377    sequence.push_back(hlo);
378    return Status::OK();
379  });
380  TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
381      &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
382                                             const HloInstruction* b) {
383        if (extra_users[a] != extra_users[b]) {
384          return extra_users[a] > extra_users[b];
385        }
386        if (total_sizes[a] != total_sizes[b]) {
387          return total_sizes[a] > total_sizes[b];
388        }
389        return a->name() < b->name();
390      }));
391  CHECK_EQ(sequence.size(), computation.instruction_count());
392  return sequence;
393}
394
395StatusOr<int64> MinimumMemoryForComputation(
396    const HloComputation& computation,
397    const std::vector<const HloInstruction*>& sequence,
398    const TuplePointsToAnalysis& points_to_analysis,
399    const LogicalBuffer::SizeFunction& size_function) {
400  TF_ASSIGN_OR_RETURN(
401      HeapSimulator::Result result,
402      HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
403                         sequence, points_to_analysis, size_function));
404  return result.heap_size;
405}
406
407StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
408    const HloComputation& computation,
409    const TuplePointsToAnalysis& points_to_analysis,
410    const LogicalBuffer::SizeFunction& size_function,
411    SchedulerAlgorithm algorithm) {
412  VLOG(2) << "Computation: " << computation.name();
413  if (algorithm == SchedulerAlgorithm::kListSchedule) {
414    return ListScheduler::Run(computation, points_to_analysis, size_function);
415  }
416  if (algorithm == SchedulerAlgorithm::kDfsSchedule) {
417    return RunDFSMemoryScheduler(computation, points_to_analysis,
418                                 size_function);
419  }
420
421  // We try both a list-scheduler based ordering and a DFS based ordering, and
422  // choose whichever returns a lower min-memory, not accounting for
423  // fragmentation.
424  //
425  // Note that this is just a heuristic. One obvious inaccuracy is that the
426  // memory required for sub-computations might be different when considered
427  // within the caller's context. But it's good enough for now.
428  TF_ASSIGN_OR_RETURN(
429      std::vector<const HloInstruction*> list_sequence,
430      ListScheduler::Run(computation, points_to_analysis, size_function));
431  TF_ASSIGN_OR_RETURN(
432      const int64 list_memory,
433      MinimumMemoryForComputation(computation, list_sequence,
434                                  points_to_analysis, size_function));
435  VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
436
437  TF_ASSIGN_OR_RETURN(
438      std::vector<const HloInstruction*> dfs_sequence,
439      RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
440  TF_ASSIGN_OR_RETURN(
441      const int64 dfs_memory,
442      MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
443                                  size_function));
444  VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
445
446  if (list_memory <= dfs_memory) {
447    VLOG(2) << "Chose min-memory list sequence: "
448            << HumanReadableNumBytes(list_memory);
449    return list_sequence;
450  } else {
451    VLOG(2) << "Chose min-memory dfs sequence: "
452            << HumanReadableNumBytes(dfs_memory);
453    return dfs_sequence;
454  }
455}
456
457}  // namespace
458
459StatusOr<SequentialHloOrdering::HloModuleSequence>
460CreateMemoryMinimizingSequence(const HloModule& module,
461                               const LogicalBuffer::SizeFunction& size_function,
462                               SchedulerAlgorithm algorithm) {
463  SequentialHloOrdering::HloModuleSequence sequence;
464  TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
465                      TuplePointsToAnalysis::Run(&module));
466  for (const auto* computation : module.MakeNonfusionComputations()) {
467    TF_ASSIGN_OR_RETURN(
468        sequence[computation],
469        CreateMemoryMinimizingSequence(*computation, *points_to_analysis,
470                                       size_function, algorithm));
471  }
472  return sequence;
473}
474
475StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
476    const HloComputation& computation,
477    const LogicalBuffer::SizeFunction& size_function,
478    SchedulerAlgorithm algorithm) {
479  CHECK(!computation.IsFusionComputation());
480  TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
481                      TuplePointsToAnalysis::Run(computation.parent()));
482  return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
483                                        size_function, algorithm);
484}
485
486}  // namespace xla
487