1943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
3943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieLicensed under the Apache License, Version 2.0 (the "License");
4943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieyou may not use this file except in compliance with the License.
5943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieYou may obtain a copy of the License at
6943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
7943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    http://www.apache.org/licenses/LICENSE-2.0
8943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
9943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieUnless required by applicable law or agreed to in writing, software
10943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xiedistributed under the License is distributed on an "AS IS" BASIS,
11943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieSee the License for the specific language governing permissions and
13943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xielimitations under the License.
14943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie==============================================================================*/
15943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
16943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
18289af8e7460e69edc106e834b7fbeee17811f1eaSanjoy Das#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
19943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
20943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
21943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_computation.h"
22943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_opcode.h"
24943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
25943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xienamespace xla {
26943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xienamespace cpu {
27943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
28943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieclass SimpleCostModel : public ParallelCostModel {
29943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie public:
30943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  SimpleCostModel(const int64 max_parallelism,
31943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                  const HloCostAnalysis::ShapeSizeFunction& shape_size)
32943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
33943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  ~SimpleCostModel() override {}
34943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
35943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  int64 GetParallelTaskCount(HloInstruction* instruction) override {
36943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Simple cost model based on hlo size and typical L2 cache size.
37943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    const int64 instruction_cost = shape_size_(instruction->shape());
38943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    const int64 min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
39943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Return target parallel task count in [1, max_parallelism_].
40943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    return std::min(max_parallelism_,
41943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                    std::max(1LL, instruction_cost / min_cost_per_thread));
42943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
43943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
44943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie private:
45943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const int64 max_parallelism_;
46943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const HloCostAnalysis::ShapeSizeFunction shape_size_;
47943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie};
48943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
49943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieclass DefaultCostModel : public ParallelCostModel {
50943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie public:
51943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  DefaultCostModel(const int64 max_parallelism,
52a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                   const HloCostAnalysis::ShapeSizeFunction& shape_size,
53943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                   std::unique_ptr<HloCostAnalysis> cost_analysis)
54943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      : max_parallelism_(max_parallelism),
55a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        shape_size_(shape_size),
56943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie        cost_analysis_(std::move(cost_analysis)) {}
57943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  ~DefaultCostModel() override {}
58943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
59943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  int64 GetParallelTaskCount(HloInstruction* instruction) override {
60a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Parameters for parallel task count computation.
61a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    int64 instruction_cost;
62a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    int64 min_cost_per_thread;
63a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    int64 max_parallelism;
64a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Calculate flops-to-bytes-ratio for 'instruction'.
65a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const int64 bytes_accessed =
66a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        std::max(1LL, cost_analysis_->bytes_accessed(*instruction));
67a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const float flops_to_bytes_ratio =
68a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        cost_analysis_->flop_count(*instruction) /
69a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        static_cast<float>(bytes_accessed);
70a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Check for I/O bound instructions.
71a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (flops_to_bytes_ratio <= 1.0) {
72a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Limit max parallelism for I/O bound instructions by assuming a
73a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // sub-linear scaling function (fit based on empirical benchmark results).
74a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // TODO(29630486) Develop system bandwidth model.
75a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      max_parallelism =
76a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs()));
77a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Use shape size instruction cost and L2 cache size min per-thread cost.
78a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      instruction_cost = shape_size_(instruction->shape());
79a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
80a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    } else {
81a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Use max parallelism for compute bound instructions.
82a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      max_parallelism = max_parallelism_;
83a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Calculate the instruction cost in cycles.
84a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // TODO(29630486) Improve on this linear cost model.
85a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Consider making 'min_cost_per_thread' be a function of the target
86a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // bandwidth limit for instructions with low arithmetic complexity.
87a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      instruction_cost =
88a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          1 * cost_analysis_->flop_count(*instruction) +
89a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          2 * cost_analysis_->transcendental_count(*instruction) +
90a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          10 * cost_analysis_->bytes_accessed(*instruction);
91a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Minimum per-thread cost is 100us of work on a 2GHz core.
92a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      min_cost_per_thread = 100000;
93a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
94943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Return target parallel task count in [1, max_parallelism_].
95a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    return std::min(max_parallelism,
96943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                    std::max(1LL, instruction_cost / min_cost_per_thread));
97943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
98943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
99943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie private:
100943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const int64 max_parallelism_;
101a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const HloCostAnalysis::ShapeSizeFunction shape_size_;
102943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const std::unique_ptr<HloCostAnalysis> cost_analysis_;
103943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie};
104943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
105943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieParallelTaskAssignment::ParallelTaskAssignment(
106943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    const int64 max_parallelism,
107d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das    const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) {
108943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
109943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // Run cost analysis on 'module'.
110943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
111943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  HloComputation* computation = module->entry_computation();
112943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  Status status = computation->root_instruction()->Accept(cost_analysis.get());
113943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  if (status.ok()) {
114943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Set default cost model based on 'cost_analysis'.
115a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
116943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                                           std::move(cost_analysis)));
117943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  } else {
118943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Fall back to a simple cost model based on hlo size and L2 cache size.
119943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Note that HloCostAnalysis can returns an error status (likely because
120943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
121943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
122943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
123943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}
124943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
125943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieint64 ParallelTaskAssignment::GetTargetParallelTaskCount(
126943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    HloInstruction* instruction) {
127943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // Currently, we do not assign parallel tasks to instructions with at least
128943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // one of the following properties:
1295bf26acd87d3d44183fc28cb9576cda10c0255caA. Unique TensorFlower  // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
130943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
131943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // *) Tuple-shaped.
132943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // TODO(b/27458679) Parallelize instructions which are skipped here.
133943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  if (instruction->opcode() == HloOpcode::kParameter ||
134943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kConstant ||
135943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kCall ||
136943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kCustomCall ||
137943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kSelectAndScatter ||
138017a5021a7fdc713357fceecf31068ae5090afafA. Unique TensorFlower      instruction->opcode() == HloOpcode::kGetTupleElement ||
139017a5021a7fdc713357fceecf31068ae5090afafA. Unique TensorFlower      instruction->opcode() == HloOpcode::kBitcast ||
1405bf26acd87d3d44183fc28cb9576cda10c0255caA. Unique TensorFlower      instruction->opcode() == HloOpcode::kFft ||
141943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      (instruction->opcode() == HloOpcode::kConvolution &&
142943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie       PotentiallyImplementedAsEigenConvolution(*instruction)) ||
143943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      PotentiallyImplementedAsEigenDot(*instruction) ||
144943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      (instruction->opcode() == HloOpcode::kFusion &&
145943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie       instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
146943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      ShapeUtil::IsTuple(instruction->shape())) {
147943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    return 1;
148943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
149943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // Consult 'cost_model_' to compute target parallel task count.
150943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  return cost_model_->GetParallelTaskCount(instruction);
151943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}
152943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
153a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerStatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
154a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
155a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(3, module->ToString());
156a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Compute target parallel task counts for all instructions in 'module'.
157a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  HloToParallelTasks hlo_to_parallel_tasks;
158a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
159a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
160a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Assign parallel tasks to target specific instructions in 'module'.
161a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // TODO(b/27458679) Support inter-op parallelism.
162a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
163a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
164a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
165a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(3, module->ToString());
166a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  return changed;
167a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
168a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
169a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerbool ParallelTaskAssigner::AssignParallelTasks(
170a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
171a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  return AssignParallelTasksHelper(module, module->entry_computation(),
172a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                   hlo_to_parallel_tasks);
173a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
174a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
175a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerbool ParallelTaskAssigner::AssignParallelTasksHelper(
176a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    HloModule* module, HloComputation* computation,
177a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const HloToParallelTasks& hlo_to_parallel_tasks) {
178a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  bool changed = false;
179a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Snapshot set of instructions because outlining modifies the set below.
180a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  std::vector<HloInstruction*> instructions(computation->instructions().begin(),
181a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                            computation->instructions().end());
182a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  for (auto* instruction : instructions) {
183a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Assign parallel tasks to sub-computations for While and Call HLOs.
184a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement,
185a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // and support other callable computations like reduce.
186a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (instruction->opcode() == HloOpcode::kWhile) {
187a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      changed |= AssignParallelTasksHelper(module, instruction->while_body(),
188a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                           hlo_to_parallel_tasks);
189a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
190a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    } else if (instruction->opcode() == HloOpcode::kCall) {
191a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
192a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                           hlo_to_parallel_tasks);
193a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
194a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
195a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Skip if no parallel tasks were computed in first pass.
196a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto it = hlo_to_parallel_tasks.find(instruction);
197a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (it == hlo_to_parallel_tasks.end()) {
198a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
199a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
200a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Get target parallel task count computed for 'instruction'.
201a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const int64 target_parallel_task_count = (*it).second;
202a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Assign feasible dimension partitions (based on actual dimension sizes).
203a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
204a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                    .Run(target_parallel_task_count);
205a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const int64 total_partition_count =
206a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
207a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (total_partition_count <= 1) {
208a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Feasible partition calculation resulting in no partitioning, so skip.
209a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
210a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
211a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
212a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Outline 'instruction' in 'computation' for parallel task assignment.
213a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto* call = module->OutlineExpressionFromComputation(
214a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        {instruction},
215a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        tensorflow::strings::StrCat("parallel_", instruction->name()),
216a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        computation);
217a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
218a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Set assigned dimension partitioning to 'instruction'.
219a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto* new_root = call->to_apply()->root_instruction();
220a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    new_root->set_outer_dimension_partitions(dim_partition_counts);
221a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
222a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    VLOG(2) << "Assigned parallel task count: " << total_partition_count
223a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower            << " to instruction: " << new_root->name()
224a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower            << " parent: " << new_root->parent()->name();
225a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    changed = true;
226a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  }
227a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  return changed;
228a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
229a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
230a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowervoid ParallelTaskAssigner::ComputeTargetParallelTasks(
231a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
232d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das  ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
233d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das                                                  shape_size_function_, module);
234d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das
235a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Compute parallel task counts for all instructions in 'module'.
236a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  for (auto* computation : module->computations()) {
237a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (computation->IsFusionComputation()) {
238a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
239a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
240a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    for (auto* instruction : computation->instructions()) {
241a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Query ParallelTaskAssignment for target parallel task count.
242a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      const int64 target_parallel_task_count =
243d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das          parallel_task_assignment.GetTargetParallelTaskCount(instruction);
244a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      if (target_parallel_task_count > 1) {
245a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        hlo_to_parallel_tasks->insert(
246a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower            {instruction, target_parallel_task_count});
247a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      }
248a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
249a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  }
250a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
251a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
252943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}  // namespace cpu
253943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}  // namespace xla
254