1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/hlo_module.h" 17 18#include <iterator> 19#include <set> 20#include <sstream> 21#include <unordered_map> 22#include <unordered_set> 23#include <utility> 24 25#include "tensorflow/compiler/xla/map_util.h" 26#include "tensorflow/compiler/xla/ptr_util.h" 27#include "tensorflow/compiler/xla/shape_util.h" 28#include "tensorflow/compiler/xla/types.h" 29#include "tensorflow/core/lib/gtl/map_util.h" 30#include "tensorflow/core/lib/strings/strcat.h" 31#include "tensorflow/core/platform/types.h" 32 33namespace xla { 34 35HloModule::HloModule(const string& name, 36 const VersionedComputationHandle& entry_computation_handle, 37 const HloModuleConfig& config) 38 : name_(NameUniquer::GetSanitizedName(name)), 39 config_(config), 40 has_entry_computation_handle_(true), 41 entry_computation_handle_(entry_computation_handle), 42 unique_id_(next_unique_module_id_++) {} 43 44HloModule::HloModule(const string& name) 45 : name_(NameUniquer::GetSanitizedName(name)), 46 unique_id_(next_unique_module_id_++) {} 47HloModule::HloModule(const string& name, const HloModuleConfig& config) 48 : name_(NameUniquer::GetSanitizedName(name)), 49 config_(config), 50 unique_id_(next_unique_module_id_++) {} 51 52HloComputation* HloModule::AddComputationInternal( 53 std::unique_ptr<HloComputation> computation, bool is_entry, 54 bool uniquify_names) { 55 if (is_entry) { 56 CHECK_EQ(nullptr, entry_computation_); 57 entry_computation_ = computation.get(); 58 59 // If the module configuration has no entry layout computation set, create a 60 // default one based on the program shape. 61 if (!config_.has_entry_computation_layout()) { 62 config_.SetDefaultComputationLayout( 63 entry_computation_->ComputeProgramShape()); 64 } 65 } 66 67 if (uniquify_names) { 68 computation->UniquifyName(&computation_name_uniquer_); 69 for (auto* instruction : computation->instructions()) { 70 instruction->UniquifyName(&instruction_name_uniquer_); 71 } 72 } else { 73 // Don't uniquify the names of the computation or instruction, but we must 74 // run the names through the uniquifiers to prevent future name collisions 75 // for computations and instructions created later. 76 computation_name_uniquer_.GetUniqueName(computation->name()); 77 for (auto* instruction : computation->instructions()) { 78 instruction_name_uniquer_.GetUniqueName(instruction->name()); 79 } 80 } 81 82 // Pick unique IDs for each instruction. 83 for (auto* instruction : computation->instructions()) { 84 instruction->SetUniqueId(NewUniqueInstructionId()); 85 } 86 computation->set_parent(this); 87 computations_.push_back(std::move(computation)); 88 return computations_.back().get(); 89} 90 91HloComputation* HloModule::AddEntryComputation( 92 std::unique_ptr<HloComputation> computation) { 93 return AddComputationInternal(std::move(computation), /*is_entry=*/true, 94 /*uniquify_names=*/true); 95} 96 97Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { 98 auto it = 99 std::find_if(computations_.begin(), computations_.end(), 100 [&to_remove](const std::unique_ptr<HloComputation>& comp) { 101 return comp.get() == to_remove; 102 }); 103 TF_RET_CHECK(it->get() == to_remove); 104 computations_.erase(it); 105 return Status::OK(); 106} 107 108HloComputation* HloModule::AddEmbeddedComputation( 109 std::unique_ptr<HloComputation> computation) { 110 return AddComputationInternal(std::move(computation), /*is_entry=*/false, 111 /*uniquify_names=*/true); 112} 113 114void HloModule::ReplaceComputations( 115 const std::unordered_map<HloComputation*, HloComputation*>& replacements) { 116 // Replace all uses of non-canonical computations with their 117 // representatives. 118 std::vector<std::unique_ptr<HloComputation>> new_computations; 119 new_computations.reserve(computations_.size()); 120 121 for (std::unique_ptr<HloComputation>& computation : computations_) { 122 for (auto* instruction : computation->instructions()) { 123 switch (instruction->opcode()) { 124 case HloOpcode::kCall: 125 case HloOpcode::kMap: 126 case HloOpcode::kReduce: 127 case HloOpcode::kReduceWindow: { 128 HloComputation* new_arg = tensorflow::gtl::FindWithDefault( 129 replacements, instruction->to_apply(), nullptr); 130 if (new_arg != nullptr) { 131 instruction->set_to_apply(new_arg); 132 } 133 break; 134 } 135 case HloOpcode::kWhile: { 136 HloComputation* new_condition = tensorflow::gtl::FindWithDefault( 137 replacements, instruction->while_condition(), nullptr); 138 if (new_condition != nullptr) { 139 instruction->set_while_condition(new_condition); 140 } 141 HloComputation* new_body = tensorflow::gtl::FindWithDefault( 142 replacements, instruction->while_body(), nullptr); 143 if (new_body != nullptr) { 144 instruction->set_while_body(new_body); 145 } 146 break; 147 } 148 case HloOpcode::kConditional: { 149 HloComputation* new_true_computation = 150 tensorflow::gtl::FindWithDefault( 151 replacements, instruction->true_computation(), nullptr); 152 if (new_true_computation != nullptr) { 153 instruction->set_true_computation(new_true_computation); 154 } 155 HloComputation* new_false_computation = 156 tensorflow::gtl::FindWithDefault( 157 replacements, instruction->false_computation(), nullptr); 158 if (new_false_computation != nullptr) { 159 instruction->set_false_computation(new_false_computation); 160 } 161 break; 162 } 163 case HloOpcode::kSelectAndScatter: { 164 HloComputation* new_select = tensorflow::gtl::FindWithDefault( 165 replacements, instruction->select(), nullptr); 166 if (new_select != nullptr) { 167 instruction->set_select(new_select); 168 } 169 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault( 170 replacements, instruction->scatter(), nullptr); 171 if (new_scatter != nullptr) { 172 instruction->set_scatter(new_scatter); 173 } 174 break; 175 } 176 default: 177 break; 178 } 179 } 180 181 if (replacements.find(computation.get()) == replacements.end()) { 182 new_computations.push_back(std::move(computation)); 183 } 184 } 185 186 // Replace entry_computation if necessary. 187 entry_computation_ = tensorflow::gtl::FindWithDefault( 188 replacements, entry_computation_, entry_computation_); 189 190 computations_ = std::move(new_computations); 191} 192 193string HloModule::ToString(const HloPrintOptions& options) const { 194 std::ostringstream s; 195 s << "HloModule " << name() << "\n\n"; 196 for (const HloComputation* computation : MakeComputationPostOrder()) { 197 if (computation == entry_computation()) { 198 s << "ENTRY "; 199 } 200 s << computation->ToString(options) << "\n\n"; 201 } 202 return s.str(); 203} 204 205HloModuleProto HloModule::ToProto() const { 206 HloModuleProto proto; 207 proto.set_name(name_); 208 proto.set_entry_computation_name(entry_computation_->name()); 209 for (const HloComputation* computation : MakeComputationPostOrder()) { 210 // Fusion computations are added when the fusion instructions are created by 211 // HloInstruction::CreateFromProto. 212 if (computation->IsFusionComputation()) { 213 continue; 214 } 215 HloComputationProto computation_proto = computation->ToProto(); 216 proto.add_computations()->Swap(&computation_proto); 217 } 218 return proto; 219} 220 221namespace { 222 223// Construct a ProgramShape matching the shape of the parameters and root of the 224// given module's entry computation. 225StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) { 226 const HloComputationProto* entry_computation = nullptr; 227 for (const HloComputationProto& computation : module.computations()) { 228 if (computation.name() == module.entry_computation_name()) { 229 entry_computation = &computation; 230 break; 231 } 232 } 233 TF_RET_CHECK(entry_computation != nullptr) 234 << "No computation with entry computation name" 235 << module.entry_computation_name(); 236 237 tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters; 238 const HloInstructionProto* root = nullptr; 239 for (const HloInstructionProto& instruction : 240 entry_computation->instructions()) { 241 if (instruction.name() == entry_computation->root_name()) { 242 TF_RET_CHECK(root == nullptr) << "Entry computation has more than " 243 "one instruction with (root) name " 244 << instruction.name(); 245 root = &instruction; 246 } 247 if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { 248 TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) 249 << "Entry computation has more than one parameter instruction " 250 "with parameter number " 251 << instruction.parameter_number(); 252 parameters[instruction.parameter_number()] = {instruction.name(), 253 &instruction.shape()}; 254 } 255 } 256 TF_RET_CHECK(root != nullptr) 257 << "Entry computation is missing root instruction named " 258 << entry_computation->root_name(); 259 260 ProgramShape program_shape; 261 *program_shape.mutable_result() = root->shape(); 262 for (int64 i = 0; i < parameters.size(); ++i) { 263 TF_RET_CHECK(ContainsKey(parameters, i)) 264 << "Entry computation missing parameter number " << i; 265 const string& name = parameters.at(i).first; 266 const Shape& shape = *parameters.at(i).second; 267 *program_shape.add_parameters() = shape; 268 program_shape.add_parameter_names(name); 269 } 270 271 return std::move(program_shape); 272} 273 274} // namespace 275 276/* static */ 277StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( 278 const HloModuleProto& proto, const HloModuleConfig& module_config, 279 const VersionedComputationHandle& entry_computation_handle) { 280 // The ProgramShape in the passed in module config must match the shapes of 281 // the entry parameters and root. 282 TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, 283 ProgramShapeFromProto(proto)); 284 TF_RET_CHECK(expected_program_shape.parameters_size() == 285 module_config.entry_computation_layout().parameter_count()); 286 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { 287 const Shape& parameter_shape = 288 module_config.entry_computation_layout().parameter_layout(i).shape(); 289 TF_RET_CHECK( 290 ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) 291 << "HloModuleConfig has different shape for parameter " << i 292 << " than the HLO module. Expected: " 293 << ShapeUtil::HumanStringWithLayout( 294 expected_program_shape.parameters(i)) 295 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); 296 } 297 const Shape& result_shape = 298 module_config.entry_computation_layout().result_layout().shape(); 299 TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) 300 << "HloModuleConfig has different result shape than the HLO module. " 301 "Expected: " 302 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) 303 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); 304 305 auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle, 306 module_config); 307 308 tensorflow::gtl::FlatMap<string, HloComputation*> computation_map; 309 for (const HloComputationProto& computation_proto : proto.computations()) { 310 TF_ASSIGN_OR_RETURN( 311 std::unique_ptr<HloComputation> computation, 312 HloComputation::CreateFromProto( 313 module.get(), computation_proto, computation_map, 314 /*add_fused_computation=*/ 315 [&module](std::unique_ptr<HloComputation> fused_computation) { 316 module->AddComputationInternal(std::move(fused_computation), 317 /*is_entry=*/false, 318 /*uniquify_names=*/false); 319 })); 320 CHECK_NE(computation.get(), nullptr); 321 TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); 322 string computation_name = computation->name(); 323 // Don't uniquify names because we want names to be stable across 324 // serialization and deserialization. 325 computation_map[computation_name] = module->AddComputationInternal( 326 std::move(computation), 327 /*is_entry=*/proto.entry_computation_name() == computation_name, 328 /*uniquify_names=*/false); 329 } 330 TF_RET_CHECK(module->entry_computation_ != nullptr); 331 332 // Because we didn't uniquify the names, double-check that the instruction and 333 // computation names are unique from the proto. 334 tensorflow::gtl::FlatSet<string> computation_names; 335 tensorflow::gtl::FlatSet<string> instruction_names; 336 for (HloComputation* computation : module->computations()) { 337 if (computation->IsFusionComputation()) { 338 continue; 339 } 340 341 TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) 342 << "Computation name is not unique: " << computation->name(); 343 computation_names.insert(computation->name()); 344 for (HloInstruction* instruction : computation->instructions()) { 345 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) 346 << "Instruction name is not unique: " << instruction->name(); 347 instruction_names.insert(instruction->name()); 348 } 349 } 350 351 return std::move(module); 352} 353 354/* static */ 355StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto( 356 const HloModuleProto& module) { 357 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, 358 ProgramShapeFromProto(module)); 359 360 HloModuleConfig module_config(program_shape); 361 362 // The module config is constructed with default layouts regardless of what is 363 // passed in via the ProgramShape. Set the layouts to the appropriate values. 364 ComputationLayout* entry_layout = 365 module_config.mutable_entry_computation_layout(); 366 for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { 367 TF_RETURN_IF_ERROR( 368 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( 369 program_shape.parameters(i))); 370 } 371 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( 372 program_shape.result())); 373 374 return module_config; 375} 376 377namespace { 378// Returns whether `hlo` is used outside the given subcomputation. 379// `instructions_in_subcomputation` is the instruction set of the given 380// subcomputation. 381bool IsUsedOutsideSubcomputation( 382 const HloInstruction& hlo, 383 const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) { 384 for (HloInstruction* user : hlo.users()) { 385 if (!instructions_in_subcomputation.count(user)) { 386 return true; 387 } 388 } 389 return false; 390} 391} // anonymous namespace 392 393HloInstruction* HloModule::OutlineExpressionFromComputation( 394 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline, 395 const string& outlined_computation_name, HloComputation* computation) { 396 auto builder = HloComputation::Builder(outlined_computation_name); 397 398 // A map from original instructions to their counterparts in the new outlined 399 // function. 400 std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions; 401 // A set that contains all instructions to be outlined. 402 std::unordered_set<HloInstruction*> instruction_set_to_outline( 403 instructions_to_outline.begin(), instructions_to_outline.end()); 404 std::vector<HloInstruction*> arguments; 405 std::vector<HloInstruction*> outputs; 406 int64 parameter_count = 0; 407 for (HloInstruction* instruction_to_outline : instructions_to_outline) { 408 // Clone the original instruction. 409 HloInstruction* outlined_instruction = 410 builder.AddInstruction(instruction_to_outline->Clone()); 411 412 // Replace its operands to their counterparts in the new function. 413 for (int64 operand_num = 0; 414 operand_num < outlined_instruction->operand_count(); ++operand_num) { 415 HloInstruction* old_operand = 416 outlined_instruction->mutable_operand(operand_num); 417 418 HloInstruction** operand_slot = &(outlined_instructions[old_operand]); 419 if (*operand_slot == nullptr) { 420 // Because instructions_to_outline is in topological order, if 421 // old_operand is not in outlined_instructions, old_operand must be an 422 // input of the outlined subcomputation and thus should be represented 423 // as a parameter in the new function. 424 arguments.push_back(old_operand); 425 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( 426 parameter_count, old_operand->shape(), "")); 427 ++parameter_count; 428 } 429 TF_CHECK_OK( 430 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot)); 431 } 432 433 // Insert the new instruction into the outlined_instructions map. 434 InsertOrDie(&outlined_instructions, instruction_to_outline, 435 outlined_instruction); 436 437 // Mark instruction_to_outline an output if it is used outside the 438 // subcomputation or is the output of the original computation (i.e. used 439 // externally). 440 if (instruction_to_outline->user_count() == 0 || 441 IsUsedOutsideSubcomputation(*instruction_to_outline, 442 instruction_set_to_outline)) { 443 outputs.push_back(instruction_to_outline); 444 } 445 } 446 447 if (outputs.size() != 1) { 448 string error_message = 449 "The subcomputation to outline has multiple outputs:\n"; 450 for (HloInstruction* output : outputs) { 451 tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); 452 } 453 LOG(FATAL) << error_message; 454 } 455 HloInstruction* output = outputs[0]; 456 457 // Creates a call to the nested computation. 458 HloComputation* nested_computation = AddEmbeddedComputation( 459 builder.Build(FindOrDie(outlined_instructions, output))); 460 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall( 461 output->shape(), arguments, nested_computation)); 462 463 VLOG(2) << "Outlining the following instructions"; 464 for (auto* instruction_to_outline : instructions_to_outline) { 465 VLOG(2) << " " << instruction_to_outline->ToString(); 466 } 467 VLOG(2) << "as a call " << call->ToString(); 468 VLOG(2) << "to " << nested_computation->ToString(); 469 470 TF_CHECK_OK(output->ReplaceAllUsesWith(call)); 471 for (auto i = instructions_to_outline.rbegin(); 472 i != instructions_to_outline.rend(); ++i) { 473 TF_CHECK_OK(computation->RemoveInstruction(*i)); 474 } 475 476 return call; 477} 478 479int64 HloModule::instruction_count() const { 480 int64 n = 0; 481 for (const auto& computation : computations_) { 482 n += computation->instruction_count(); 483 } 484 return n; 485} 486 487std::list<HloComputation*> HloModule::MakeComputationPostOrder() const { 488 // First determine all root computations by building a set of nonroot 489 // computations (computations which are called by an instruction in the 490 // module). 491 std::set<HloComputation*> nonroot_computations; 492 for (auto& computation : computations_) { 493 for (auto* instruction : computation->instructions()) { 494 for (HloComputation* called_computation : 495 instruction->called_computations()) { 496 nonroot_computations.insert(called_computation); 497 } 498 } 499 } 500 501 // Keep track of computations which have already been added to the post 502 // order. This prevents duplication as an embedded computation may be called 503 // from two different root computations. 504 std::set<HloComputation*> added_computations; 505 std::list<HloComputation*> post_order; 506 for (auto& computation : computations_) { 507 if (nonroot_computations.count(computation.get()) == 0) { 508 for (HloComputation* embedded_computation : 509 computation->MakeEmbeddedComputationsList()) { 510 if (added_computations.count(embedded_computation) == 0) { 511 post_order.push_back(embedded_computation); 512 added_computations.insert(embedded_computation); 513 } 514 } 515 // Root computations should only be encountered once. 516 CHECK_EQ(0, added_computations.count(computation.get())); 517 post_order.push_back(computation.get()); 518 added_computations.insert(computation.get()); 519 } 520 } 521 CHECK_EQ(post_order.size(), computations_.size()); 522 return post_order; 523} 524 525std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const { 526 std::vector<HloComputation*> result; 527 for (auto* c : computations()) { 528 if (c->IsFusionComputation()) { 529 continue; 530 } 531 result.push_back(c); 532 } 533 return result; 534} 535 536std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const { 537 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; 538 auto module = MakeUnique<HloModule>(name_ + "-" + suffix); 539 module->config_ = config_; 540 module->entry_computation_handle_ = entry_computation_handle_; 541 module->has_entry_computation_handle_ = has_entry_computation_handle_; 542 543 std::unordered_map<HloComputation*, HloComputation*> clone_map; 544 for (auto& computation : computations_) { 545 if (computation->IsFusionComputation()) { 546 // Cloning of a fused computation is handled by its fusion instruction. 547 continue; 548 } 549 550 // When cloning a computation, pass in the new module, so that for any 551 // fusion instruction in this computation, the fused computation will be 552 // deep cloned to the new module. 553 auto cloned_computation = computation->Clone(suffix, module.get()); 554 InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); 555 556 if (entry_computation_ == computation.get()) { 557 module->AddEntryComputation(std::move(cloned_computation)); 558 } else { 559 module->AddEmbeddedComputation(std::move(cloned_computation)); 560 } 561 } 562 563 for (auto& cloned_computation : module->computations_) { 564 for (auto* instruction : cloned_computation->instructions()) { 565 // Rewrite instruction's called_computation to point to the cloned 566 // computations. 567 instruction->ReplaceCalledComputations([&](HloComputation* hlo) { 568 if (hlo->IsFusionComputation()) { 569 // Cloning of a fused computation has already been handled when its 570 // fusion instruction is cloned. So this hlo computation is already 571 // the cloned one. 572 return hlo; 573 } 574 return FindOrDie(clone_map, hlo); 575 }); 576 } 577 } 578 return module; 579} 580 581HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { 582 HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); 583 TF_CHECK_OK( 584 clone->root_instruction()->Accept([this](HloInstruction* instruction) { 585 instruction->ReplaceCalledComputations([this](HloComputation* callee) { 586 return DeepCloneComputation(callee); 587 }); 588 return Status::OK(); 589 })); 590 return clone; 591} 592 593uint64 HloModule::RandomNew64() const { 594 tensorflow::mutex_lock l(rng_mutex_); 595 return rng_(); 596} 597 598/* static */ std::atomic<int> HloModule::next_unique_module_id_(0); 599 600} // namespace xla 601