parallel_task_assignment.cc revision d0de8738e3401bbc5fd142846b4fc124951e5e07
1943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
3943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieLicensed under the Apache License, Version 2.0 (the "License");
4943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieyou may not use this file except in compliance with the License.
5943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieYou may obtain a copy of the License at
6943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
7943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    http://www.apache.org/licenses/LICENSE-2.0
8943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
9943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieUnless required by applicable law or agreed to in writing, software
10943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xiedistributed under the License is distributed on an "AS IS" BASIS,
11943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieSee the License for the specific language governing permissions and
13943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xielimitations under the License.
14943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie==============================================================================*/
15943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
16943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
17943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
18943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
19943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
20943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_computation.h"
21943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_instruction.h"
22943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_opcode.h"
23943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
24943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xienamespace xla {
25943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xienamespace cpu {
26943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
27943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieclass SimpleCostModel : public ParallelCostModel {
28943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie public:
29943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  SimpleCostModel(const int64 max_parallelism,
30943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                  const HloCostAnalysis::ShapeSizeFunction& shape_size)
31943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
32943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  ~SimpleCostModel() override {}
33943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
34943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  int64 GetParallelTaskCount(HloInstruction* instruction) override {
35943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Simple cost model based on hlo size and typical L2 cache size.
36943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    const int64 instruction_cost = shape_size_(instruction->shape());
37943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    const int64 min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
38943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Return target parallel task count in [1, max_parallelism_].
39943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    return std::min(max_parallelism_,
40943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                    std::max(1LL, instruction_cost / min_cost_per_thread));
41943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
42943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
43943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie private:
44943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const int64 max_parallelism_;
45943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const HloCostAnalysis::ShapeSizeFunction shape_size_;
46943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie};
47943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
48943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieclass DefaultCostModel : public ParallelCostModel {
49943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie public:
50943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  DefaultCostModel(const int64 max_parallelism,
51a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                   const HloCostAnalysis::ShapeSizeFunction& shape_size,
52943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                   std::unique_ptr<HloCostAnalysis> cost_analysis)
53943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      : max_parallelism_(max_parallelism),
54a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        shape_size_(shape_size),
55943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie        cost_analysis_(std::move(cost_analysis)) {}
56943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  ~DefaultCostModel() override {}
57943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
58943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  int64 GetParallelTaskCount(HloInstruction* instruction) override {
59a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Parameters for parallel task count computation.
60a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    int64 instruction_cost;
61a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    int64 min_cost_per_thread;
62a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    int64 max_parallelism;
63a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Calculate flops-to-bytes-ratio for 'instruction'.
64a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const int64 bytes_accessed =
65a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        std::max(1LL, cost_analysis_->bytes_accessed(*instruction));
66a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const float flops_to_bytes_ratio =
67a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        cost_analysis_->flop_count(*instruction) /
68a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        static_cast<float>(bytes_accessed);
69a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Check for I/O bound instructions.
70a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (flops_to_bytes_ratio <= 1.0) {
71a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Limit max parallelism for I/O bound instructions by assuming a
72a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // sub-linear scaling function (fit based on empirical benchmark results).
73a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // TODO(29630486) Develop system bandwidth model.
74a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      max_parallelism =
75a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs()));
76a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Use shape size instruction cost and L2 cache size min per-thread cost.
77a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      instruction_cost = shape_size_(instruction->shape());
78a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      min_cost_per_thread = 256LL << 10;  // 256KB L2 Cache size.
79a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    } else {
80a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Use max parallelism for compute bound instructions.
81a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      max_parallelism = max_parallelism_;
82a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Calculate the instruction cost in cycles.
83a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // TODO(29630486) Improve on this linear cost model.
84a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Consider making 'min_cost_per_thread' be a function of the target
85a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // bandwidth limit for instructions with low arithmetic complexity.
86a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      instruction_cost =
87a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          1 * cost_analysis_->flop_count(*instruction) +
88a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          2 * cost_analysis_->transcendental_count(*instruction) +
89a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower          10 * cost_analysis_->bytes_accessed(*instruction);
90a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Minimum per-thread cost is 100us of work on a 2GHz core.
91a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      min_cost_per_thread = 100000;
92a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
93943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Return target parallel task count in [1, max_parallelism_].
94a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    return std::min(max_parallelism,
95943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                    std::max(1LL, instruction_cost / min_cost_per_thread));
96943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
97943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
98943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie private:
99943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const int64 max_parallelism_;
100a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  const HloCostAnalysis::ShapeSizeFunction shape_size_;
101943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  const std::unique_ptr<HloCostAnalysis> cost_analysis_;
102943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie};
103943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
104943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieParallelTaskAssignment::ParallelTaskAssignment(
105943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    const int64 max_parallelism,
106d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das    const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) {
107943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
108943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // Run cost analysis on 'module'.
109943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
110943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  HloComputation* computation = module->entry_computation();
111943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  Status status = computation->root_instruction()->Accept(cost_analysis.get());
112943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  if (status.ok()) {
113943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Set default cost model based on 'cost_analysis'.
114a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
115943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie                                           std::move(cost_analysis)));
116943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  } else {
117943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Fall back to a simple cost model based on hlo size and L2 cache size.
118943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // Note that HloCostAnalysis can returns an error status (likely because
119943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
120943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
121943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
122943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}
123943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
124943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieint64 ParallelTaskAssignment::GetTargetParallelTaskCount(
125943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    HloInstruction* instruction) {
126943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // Currently, we do not assign parallel tasks to instructions with at least
127943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // one of the following properties:
128943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // *) Internal threading (library calls to kConv, kDot, and kCustomCall).
129943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
130943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // *) Tuple-shaped.
131943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // TODO(b/27458679) Parallelize instructions which are skipped here.
132943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  if (instruction->opcode() == HloOpcode::kParameter ||
133943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kConstant ||
134943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kCall ||
135943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kCustomCall ||
136943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      instruction->opcode() == HloOpcode::kSelectAndScatter ||
137017a5021a7fdc713357fceecf31068ae5090afafA. Unique TensorFlower      instruction->opcode() == HloOpcode::kGetTupleElement ||
138017a5021a7fdc713357fceecf31068ae5090afafA. Unique TensorFlower      instruction->opcode() == HloOpcode::kBitcast ||
139943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      (instruction->opcode() == HloOpcode::kConvolution &&
140943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie       PotentiallyImplementedAsEigenConvolution(*instruction)) ||
141943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      PotentiallyImplementedAsEigenDot(*instruction) ||
142943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      (instruction->opcode() == HloOpcode::kFusion &&
143943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie       instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
144943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie      ShapeUtil::IsTuple(instruction->shape())) {
145943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie    return 1;
146943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  }
147943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  // Consult 'cost_model_' to compute target parallel task count.
148943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie  return cost_model_->GetParallelTaskCount(instruction);
149943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}
150943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie
151a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerStatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
152a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
153a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(3, module->ToString());
154a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Compute target parallel task counts for all instructions in 'module'.
155a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  HloToParallelTasks hlo_to_parallel_tasks;
156a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
157a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
158a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Assign parallel tasks to target specific instructions in 'module'.
159a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // TODO(b/27458679) Support inter-op parallelism.
160a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
161a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
162a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
163a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  XLA_VLOG_LINES(3, module->ToString());
164a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  return changed;
165a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
166a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
167a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerbool ParallelTaskAssigner::AssignParallelTasks(
168a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
169a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  return AssignParallelTasksHelper(module, module->entry_computation(),
170a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                   hlo_to_parallel_tasks);
171a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
172a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
173a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerbool ParallelTaskAssigner::AssignParallelTasksHelper(
174a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    HloModule* module, HloComputation* computation,
175a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const HloToParallelTasks& hlo_to_parallel_tasks) {
176a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  bool changed = false;
177a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Snapshot set of instructions because outlining modifies the set below.
178a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  std::vector<HloInstruction*> instructions(computation->instructions().begin(),
179a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                            computation->instructions().end());
180a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  for (auto* instruction : instructions) {
181a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Assign parallel tasks to sub-computations for While and Call HLOs.
182a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement,
183a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // and support other callable computations like reduce.
184a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (instruction->opcode() == HloOpcode::kWhile) {
185a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      changed |= AssignParallelTasksHelper(module, instruction->while_body(),
186a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                           hlo_to_parallel_tasks);
187a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
188a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    } else if (instruction->opcode() == HloOpcode::kCall) {
189a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
190a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                           hlo_to_parallel_tasks);
191a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
192a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
193a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Skip if no parallel tasks were computed in first pass.
194a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto it = hlo_to_parallel_tasks.find(instruction);
195a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (it == hlo_to_parallel_tasks.end()) {
196a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
197a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
198a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Get target parallel task count computed for 'instruction'.
199a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const int64 target_parallel_task_count = (*it).second;
200a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Assign feasible dimension partitions (based on actual dimension sizes).
201a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
202a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower                                    .Run(target_parallel_task_count);
203a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    const int64 total_partition_count =
204a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
205a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (total_partition_count <= 1) {
206a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Feasible partition calculation resulting in no partitioning, so skip.
207a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
208a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
209a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
210a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Outline 'instruction' in 'computation' for parallel task assignment.
211a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto* call = module->OutlineExpressionFromComputation(
212a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        {instruction},
213a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        tensorflow::strings::StrCat("parallel_", instruction->name()),
214a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        computation);
215a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
216a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    // Set assigned dimension partitioning to 'instruction'.
217a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    auto* new_root = call->to_apply()->root_instruction();
218a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    new_root->set_outer_dimension_partitions(dim_partition_counts);
219a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
220a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    VLOG(2) << "Assigned parallel task count: " << total_partition_count
221a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower            << " to instruction: " << new_root->name()
222a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower            << " parent: " << new_root->parent()->name();
223a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    changed = true;
224a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  }
225a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  return changed;
226a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
227a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
228a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowervoid ParallelTaskAssigner::ComputeTargetParallelTasks(
229a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
230d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das  ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
231d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das                                                  shape_size_function_, module);
232d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das
233a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  // Compute parallel task counts for all instructions in 'module'.
234a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  for (auto* computation : module->computations()) {
235a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    if (computation->IsFusionComputation()) {
236a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      continue;
237a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
238a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    for (auto* instruction : computation->instructions()) {
239a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      // Query ParallelTaskAssignment for target parallel task count.
240a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      const int64 target_parallel_task_count =
241d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das          parallel_task_assignment.GetTargetParallelTaskCount(instruction);
242a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      if (target_parallel_task_count > 1) {
243a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower        hlo_to_parallel_tasks->insert(
244a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower            {instruction, target_parallel_task_count});
245a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower      }
246a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower    }
247a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower  }
248a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower}
249a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower
250943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}  // namespace cpu
251943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}  // namespace xla
252