1943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 3943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieLicensed under the Apache License, Version 2.0 (the "License"); 4943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieyou may not use this file except in compliance with the License. 5943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieYou may obtain a copy of the License at 6943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 7943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie http://www.apache.org/licenses/LICENSE-2.0 8943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 9943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieUnless required by applicable law or agreed to in writing, software 10943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xiedistributed under the License is distributed on an "AS IS" BASIS, 11943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieSee the License for the specific language governing permissions and 13943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xielimitations under the License. 14943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie==============================================================================*/ 15943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 16943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" 17943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 18289af8e7460e69edc106e834b7fbeee17811f1eaSanjoy Das#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" 19943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" 20943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" 21943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_computation.h" 22943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_instruction.h" 23943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie#include "tensorflow/compiler/xla/service/hlo_opcode.h" 24943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 25943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xienamespace xla { 26943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xienamespace cpu { 27943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 28943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieclass SimpleCostModel : public ParallelCostModel { 29943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie public: 30943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie SimpleCostModel(const int64 max_parallelism, 31943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const HloCostAnalysis::ShapeSizeFunction& shape_size) 32943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie : max_parallelism_(max_parallelism), shape_size_(shape_size) {} 33943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie ~SimpleCostModel() override {} 34943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 35943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie int64 GetParallelTaskCount(HloInstruction* instruction) override { 36943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Simple cost model based on hlo size and typical L2 cache size. 37943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const int64 instruction_cost = shape_size_(instruction->shape()); 38943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. 39943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Return target parallel task count in [1, max_parallelism_]. 40943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie return std::min(max_parallelism_, 41943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie std::max(1LL, instruction_cost / min_cost_per_thread)); 42943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie } 43943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 44943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie private: 45943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const int64 max_parallelism_; 46943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const HloCostAnalysis::ShapeSizeFunction shape_size_; 47943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}; 48943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 49943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieclass DefaultCostModel : public ParallelCostModel { 50943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie public: 51943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie DefaultCostModel(const int64 max_parallelism, 52a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const HloCostAnalysis::ShapeSizeFunction& shape_size, 53943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie std::unique_ptr<HloCostAnalysis> cost_analysis) 54943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie : max_parallelism_(max_parallelism), 55a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower shape_size_(shape_size), 56943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie cost_analysis_(std::move(cost_analysis)) {} 57943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie ~DefaultCostModel() override {} 58943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 59943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie int64 GetParallelTaskCount(HloInstruction* instruction) override { 60a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Parameters for parallel task count computation. 61a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower int64 instruction_cost; 62a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower int64 min_cost_per_thread; 63a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower int64 max_parallelism; 64a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Calculate flops-to-bytes-ratio for 'instruction'. 65a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 bytes_accessed = 66a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); 67a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const float flops_to_bytes_ratio = 68a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower cost_analysis_->flop_count(*instruction) / 69a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower static_cast<float>(bytes_accessed); 70a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Check for I/O bound instructions. 71a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower if (flops_to_bytes_ratio <= 1.0) { 72a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Limit max parallelism for I/O bound instructions by assuming a 73a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // sub-linear scaling function (fit based on empirical benchmark results). 74a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // TODO(29630486) Develop system bandwidth model. 75a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower max_parallelism = 76a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); 77a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Use shape size instruction cost and L2 cache size min per-thread cost. 78a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower instruction_cost = shape_size_(instruction->shape()); 79a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. 80a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } else { 81a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Use max parallelism for compute bound instructions. 82a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower max_parallelism = max_parallelism_; 83a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Calculate the instruction cost in cycles. 84a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // TODO(29630486) Improve on this linear cost model. 85a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Consider making 'min_cost_per_thread' be a function of the target 86a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // bandwidth limit for instructions with low arithmetic complexity. 87a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower instruction_cost = 88a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 1 * cost_analysis_->flop_count(*instruction) + 89a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 2 * cost_analysis_->transcendental_count(*instruction) + 90a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 10 * cost_analysis_->bytes_accessed(*instruction); 91a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Minimum per-thread cost is 100us of work on a 2GHz core. 92a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower min_cost_per_thread = 100000; 93a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 94943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Return target parallel task count in [1, max_parallelism_]. 95a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower return std::min(max_parallelism, 96943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie std::max(1LL, instruction_cost / min_cost_per_thread)); 97943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie } 98943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 99943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie private: 100943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const int64 max_parallelism_; 101a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const HloCostAnalysis::ShapeSizeFunction shape_size_; 102943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const std::unique_ptr<HloCostAnalysis> cost_analysis_; 103943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie}; 104943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 105943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei XieParallelTaskAssignment::ParallelTaskAssignment( 106943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie const int64 max_parallelism, 107d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { 108943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; 109943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Run cost analysis on 'module'. 110943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size); 111943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie HloComputation* computation = module->entry_computation(); 112943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie Status status = computation->root_instruction()->Accept(cost_analysis.get()); 113943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie if (status.ok()) { 114943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Set default cost model based on 'cost_analysis'. 115a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size, 116943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie std::move(cost_analysis))); 117943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie } else { 118943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Fall back to a simple cost model based on hlo size and L2 cache size. 119943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Note that HloCostAnalysis can returns an error status (likely because 120943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // HLOs like CustomCall are not yet implemented in the HloCostAnalysis). 121943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size)); 122943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie } 123943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie} 124943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 125943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xieint64 ParallelTaskAssignment::GetTargetParallelTaskCount( 126943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie HloInstruction* instruction) { 127943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Currently, we do not assign parallel tasks to instructions with at least 128943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // one of the following properties: 1295bf26acd87d3d44183fc28cb9576cda10c0255caA. Unique TensorFlower // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). 130943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). 131943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // *) Tuple-shaped. 132943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // TODO(b/27458679) Parallelize instructions which are skipped here. 133943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie if (instruction->opcode() == HloOpcode::kParameter || 134943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie instruction->opcode() == HloOpcode::kConstant || 135943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie instruction->opcode() == HloOpcode::kCall || 136943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie instruction->opcode() == HloOpcode::kCustomCall || 137943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie instruction->opcode() == HloOpcode::kSelectAndScatter || 138017a5021a7fdc713357fceecf31068ae5090afafA. Unique TensorFlower instruction->opcode() == HloOpcode::kGetTupleElement || 139017a5021a7fdc713357fceecf31068ae5090afafA. Unique TensorFlower instruction->opcode() == HloOpcode::kBitcast || 1405bf26acd87d3d44183fc28cb9576cda10c0255caA. Unique TensorFlower instruction->opcode() == HloOpcode::kFft || 141943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie (instruction->opcode() == HloOpcode::kConvolution && 142943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie PotentiallyImplementedAsEigenConvolution(*instruction)) || 143943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie PotentiallyImplementedAsEigenDot(*instruction) || 144943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie (instruction->opcode() == HloOpcode::kFusion && 145943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || 146943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie ShapeUtil::IsTuple(instruction->shape())) { 147943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie return 1; 148943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie } 149943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie // Consult 'cost_model_' to compute target parallel task count. 150943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie return cost_model_->GetParallelTaskCount(instruction); 151943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie} 152943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie 153a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerStatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) { 154a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); 155a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower XLA_VLOG_LINES(3, module->ToString()); 156a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Compute target parallel task counts for all instructions in 'module'. 157a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloToParallelTasks hlo_to_parallel_tasks; 158a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); 159a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 160a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Assign parallel tasks to target specific instructions in 'module'. 161a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // TODO(b/27458679) Support inter-op parallelism. 162a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks); 163a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 164a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT"); 165a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower XLA_VLOG_LINES(3, module->ToString()); 166a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower return changed; 167a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower} 168a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 169a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerbool ParallelTaskAssigner::AssignParallelTasks( 170a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) { 171a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower return AssignParallelTasksHelper(module, module->entry_computation(), 172a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower hlo_to_parallel_tasks); 173a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower} 174a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 175a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowerbool ParallelTaskAssigner::AssignParallelTasksHelper( 176a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloModule* module, HloComputation* computation, 177a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const HloToParallelTasks& hlo_to_parallel_tasks) { 178a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower bool changed = false; 179a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Snapshot set of instructions because outlining modifies the set below. 180a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower std::vector<HloInstruction*> instructions(computation->instructions().begin(), 181a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower computation->instructions().end()); 182a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower for (auto* instruction : instructions) { 183a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Assign parallel tasks to sub-computations for While and Call HLOs. 184a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement, 185a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // and support other callable computations like reduce. 186a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower if (instruction->opcode() == HloOpcode::kWhile) { 187a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower changed |= AssignParallelTasksHelper(module, instruction->while_body(), 188a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower hlo_to_parallel_tasks); 189a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower continue; 190a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } else if (instruction->opcode() == HloOpcode::kCall) { 191a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower changed |= AssignParallelTasksHelper(module, instruction->to_apply(), 192a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower hlo_to_parallel_tasks); 193a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower continue; 194a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 195a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Skip if no parallel tasks were computed in first pass. 196a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto it = hlo_to_parallel_tasks.find(instruction); 197a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower if (it == hlo_to_parallel_tasks.end()) { 198a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower continue; 199a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 200a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Get target parallel task count computed for 'instruction'. 201a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 target_parallel_task_count = (*it).second; 202a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Assign feasible dimension partitions (based on actual dimension sizes). 203a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) 204a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower .Run(target_parallel_task_count); 205a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 total_partition_count = 206a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); 207a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower if (total_partition_count <= 1) { 208a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Feasible partition calculation resulting in no partitioning, so skip. 209a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower continue; 210a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 211a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 212a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Outline 'instruction' in 'computation' for parallel task assignment. 213a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto* call = module->OutlineExpressionFromComputation( 214a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower {instruction}, 215a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower tensorflow::strings::StrCat("parallel_", instruction->name()), 216a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower computation); 217a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 218a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Set assigned dimension partitioning to 'instruction'. 219a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower auto* new_root = call->to_apply()->root_instruction(); 220a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower new_root->set_outer_dimension_partitions(dim_partition_counts); 221a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 222a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower VLOG(2) << "Assigned parallel task count: " << total_partition_count 223a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower << " to instruction: " << new_root->name() 224a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower << " parent: " << new_root->parent()->name(); 225a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower changed = true; 226a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 227a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower return changed; 228a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower} 229a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 230a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlowervoid ParallelTaskAssigner::ComputeTargetParallelTasks( 231a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { 232d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das ParallelTaskAssignment parallel_task_assignment(max_parallelism_, 233d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das shape_size_function_, module); 234d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das 235a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Compute parallel task counts for all instructions in 'module'. 236a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower for (auto* computation : module->computations()) { 237a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower if (computation->IsFusionComputation()) { 238a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower continue; 239a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 240a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower for (auto* instruction : computation->instructions()) { 241a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower // Query ParallelTaskAssignment for target parallel task count. 242a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower const int64 target_parallel_task_count = 243d0de8738e3401bbc5fd142846b4fc124951e5e07Sanjoy Das parallel_task_assignment.GetTargetParallelTaskCount(instruction); 244a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower if (target_parallel_task_count > 1) { 245a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower hlo_to_parallel_tasks->insert( 246a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower {instruction, target_parallel_task_count}); 247a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 248a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 249a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower } 250a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower} 251a799ade213cecb3c1c1d19eca6a0bfa3fddf0113A. Unique TensorFlower 252943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie} // namespace cpu 253943c6d7af7a8ccd4f824a2c0f90b251587c63feaJianwei Xie} // namespace xla 254