1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_module.h"
17
18#include <iterator>
19#include <set>
20#include <sstream>
21#include <unordered_map>
22#include <unordered_set>
23#include <utility>
24
25#include "tensorflow/compiler/xla/map_util.h"
26#include "tensorflow/compiler/xla/ptr_util.h"
27#include "tensorflow/compiler/xla/shape_util.h"
28#include "tensorflow/compiler/xla/types.h"
29#include "tensorflow/core/lib/gtl/map_util.h"
30#include "tensorflow/core/lib/strings/strcat.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace xla {
34
35HloModule::HloModule(const string& name,
36                     const VersionedComputationHandle& entry_computation_handle,
37                     const HloModuleConfig& config)
38    : name_(NameUniquer::GetSanitizedName(name)),
39      config_(config),
40      has_entry_computation_handle_(true),
41      entry_computation_handle_(entry_computation_handle),
42      unique_id_(next_unique_module_id_++) {}
43
44HloModule::HloModule(const string& name)
45    : name_(NameUniquer::GetSanitizedName(name)),
46      unique_id_(next_unique_module_id_++) {}
47HloModule::HloModule(const string& name, const HloModuleConfig& config)
48    : name_(NameUniquer::GetSanitizedName(name)),
49      config_(config),
50      unique_id_(next_unique_module_id_++) {}
51
52HloComputation* HloModule::AddComputationInternal(
53    std::unique_ptr<HloComputation> computation, bool is_entry,
54    bool uniquify_names) {
55  if (is_entry) {
56    CHECK_EQ(nullptr, entry_computation_);
57    entry_computation_ = computation.get();
58
59    // If the module configuration has no entry layout computation set, create a
60    // default one based on the program shape.
61    if (!config_.has_entry_computation_layout()) {
62      config_.SetDefaultComputationLayout(
63          entry_computation_->ComputeProgramShape());
64    }
65  }
66
67  if (uniquify_names) {
68    computation->UniquifyName(&computation_name_uniquer_);
69    for (auto* instruction : computation->instructions()) {
70      instruction->UniquifyName(&instruction_name_uniquer_);
71    }
72  } else {
73    // Don't uniquify the names of the computation or instruction, but we must
74    // run the names through the uniquifiers to prevent future name collisions
75    // for computations and instructions created later.
76    computation_name_uniquer_.GetUniqueName(computation->name());
77    for (auto* instruction : computation->instructions()) {
78      instruction_name_uniquer_.GetUniqueName(instruction->name());
79    }
80  }
81
82  // Pick unique IDs for each instruction.
83  for (auto* instruction : computation->instructions()) {
84    instruction->SetUniqueId(NewUniqueInstructionId());
85  }
86  computation->set_parent(this);
87  computations_.push_back(std::move(computation));
88  return computations_.back().get();
89}
90
91HloComputation* HloModule::AddEntryComputation(
92    std::unique_ptr<HloComputation> computation) {
93  return AddComputationInternal(std::move(computation), /*is_entry=*/true,
94                                /*uniquify_names=*/true);
95}
96
97Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
98  auto it =
99      std::find_if(computations_.begin(), computations_.end(),
100                   [&to_remove](const std::unique_ptr<HloComputation>& comp) {
101                     return comp.get() == to_remove;
102                   });
103  TF_RET_CHECK(it->get() == to_remove);
104  computations_.erase(it);
105  return Status::OK();
106}
107
108HloComputation* HloModule::AddEmbeddedComputation(
109    std::unique_ptr<HloComputation> computation) {
110  return AddComputationInternal(std::move(computation), /*is_entry=*/false,
111                                /*uniquify_names=*/true);
112}
113
114void HloModule::ReplaceComputations(
115    const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
116  // Replace all uses of non-canonical computations with their
117  // representatives.
118  std::vector<std::unique_ptr<HloComputation>> new_computations;
119  new_computations.reserve(computations_.size());
120
121  for (std::unique_ptr<HloComputation>& computation : computations_) {
122    for (auto* instruction : computation->instructions()) {
123      switch (instruction->opcode()) {
124        case HloOpcode::kCall:
125        case HloOpcode::kMap:
126        case HloOpcode::kReduce:
127        case HloOpcode::kReduceWindow: {
128          HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
129              replacements, instruction->to_apply(), nullptr);
130          if (new_arg != nullptr) {
131            instruction->set_to_apply(new_arg);
132          }
133          break;
134        }
135        case HloOpcode::kWhile: {
136          HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
137              replacements, instruction->while_condition(), nullptr);
138          if (new_condition != nullptr) {
139            instruction->set_while_condition(new_condition);
140          }
141          HloComputation* new_body = tensorflow::gtl::FindWithDefault(
142              replacements, instruction->while_body(), nullptr);
143          if (new_body != nullptr) {
144            instruction->set_while_body(new_body);
145          }
146          break;
147        }
148        case HloOpcode::kConditional: {
149          HloComputation* new_true_computation =
150              tensorflow::gtl::FindWithDefault(
151                  replacements, instruction->true_computation(), nullptr);
152          if (new_true_computation != nullptr) {
153            instruction->set_true_computation(new_true_computation);
154          }
155          HloComputation* new_false_computation =
156              tensorflow::gtl::FindWithDefault(
157                  replacements, instruction->false_computation(), nullptr);
158          if (new_false_computation != nullptr) {
159            instruction->set_false_computation(new_false_computation);
160          }
161          break;
162        }
163        case HloOpcode::kSelectAndScatter: {
164          HloComputation* new_select = tensorflow::gtl::FindWithDefault(
165              replacements, instruction->select(), nullptr);
166          if (new_select != nullptr) {
167            instruction->set_select(new_select);
168          }
169          HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
170              replacements, instruction->scatter(), nullptr);
171          if (new_scatter != nullptr) {
172            instruction->set_scatter(new_scatter);
173          }
174          break;
175        }
176        default:
177          break;
178      }
179    }
180
181    if (replacements.find(computation.get()) == replacements.end()) {
182      new_computations.push_back(std::move(computation));
183    }
184  }
185
186  // Replace entry_computation if necessary.
187  entry_computation_ = tensorflow::gtl::FindWithDefault(
188      replacements, entry_computation_, entry_computation_);
189
190  computations_ = std::move(new_computations);
191}
192
193string HloModule::ToString(const HloPrintOptions& options) const {
194  std::ostringstream s;
195  s << "HloModule " << name() << "\n\n";
196  for (const HloComputation* computation : MakeComputationPostOrder()) {
197    if (computation == entry_computation()) {
198      s << "ENTRY ";
199    }
200    s << computation->ToString(options) << "\n\n";
201  }
202  return s.str();
203}
204
205HloModuleProto HloModule::ToProto() const {
206  HloModuleProto proto;
207  proto.set_name(name_);
208  proto.set_entry_computation_name(entry_computation_->name());
209  for (const HloComputation* computation : MakeComputationPostOrder()) {
210    // Fusion computations are added when the fusion instructions are created by
211    // HloInstruction::CreateFromProto.
212    if (computation->IsFusionComputation()) {
213      continue;
214    }
215    HloComputationProto computation_proto = computation->ToProto();
216    proto.add_computations()->Swap(&computation_proto);
217  }
218  return proto;
219}
220
221namespace {
222
223// Construct a ProgramShape matching the shape of the parameters and root of the
224// given module's entry computation.
225StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) {
226  const HloComputationProto* entry_computation = nullptr;
227  for (const HloComputationProto& computation : module.computations()) {
228    if (computation.name() == module.entry_computation_name()) {
229      entry_computation = &computation;
230      break;
231    }
232  }
233  TF_RET_CHECK(entry_computation != nullptr)
234      << "No computation with entry computation name"
235      << module.entry_computation_name();
236
237  tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters;
238  const HloInstructionProto* root = nullptr;
239  for (const HloInstructionProto& instruction :
240       entry_computation->instructions()) {
241    if (instruction.name() == entry_computation->root_name()) {
242      TF_RET_CHECK(root == nullptr) << "Entry computation has more than "
243                                       "one instruction with (root) name "
244                                    << instruction.name();
245      root = &instruction;
246    }
247    if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
248      TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number()))
249          << "Entry computation has more than one parameter instruction "
250             "with parameter number "
251          << instruction.parameter_number();
252      parameters[instruction.parameter_number()] = {instruction.name(),
253                                                    &instruction.shape()};
254    }
255  }
256  TF_RET_CHECK(root != nullptr)
257      << "Entry computation is missing root instruction named "
258      << entry_computation->root_name();
259
260  ProgramShape program_shape;
261  *program_shape.mutable_result() = root->shape();
262  for (int64 i = 0; i < parameters.size(); ++i) {
263    TF_RET_CHECK(ContainsKey(parameters, i))
264        << "Entry computation missing parameter number " << i;
265    const string& name = parameters.at(i).first;
266    const Shape& shape = *parameters.at(i).second;
267    *program_shape.add_parameters() = shape;
268    program_shape.add_parameter_names(name);
269  }
270
271  return std::move(program_shape);
272}
273
274}  // namespace
275
276/* static */
277StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
278    const HloModuleProto& proto, const HloModuleConfig& module_config,
279    const VersionedComputationHandle& entry_computation_handle) {
280  // The ProgramShape in the passed in module config must match the shapes of
281  // the entry parameters and root.
282  TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape,
283                      ProgramShapeFromProto(proto));
284  TF_RET_CHECK(expected_program_shape.parameters_size() ==
285               module_config.entry_computation_layout().parameter_count());
286  for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
287    const Shape& parameter_shape =
288        module_config.entry_computation_layout().parameter_layout(i).shape();
289    TF_RET_CHECK(
290        ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape))
291        << "HloModuleConfig has different shape for parameter " << i
292        << " than the HLO module. Expected: "
293        << ShapeUtil::HumanStringWithLayout(
294               expected_program_shape.parameters(i))
295        << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
296  }
297  const Shape& result_shape =
298      module_config.entry_computation_layout().result_layout().shape();
299  TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape))
300      << "HloModuleConfig has different result shape than the HLO module. "
301         "Expected: "
302      << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
303      << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
304
305  auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
306                                      module_config);
307
308  tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
309  for (const HloComputationProto& computation_proto : proto.computations()) {
310    TF_ASSIGN_OR_RETURN(
311        std::unique_ptr<HloComputation> computation,
312        HloComputation::CreateFromProto(
313            module.get(), computation_proto, computation_map,
314            /*add_fused_computation=*/
315            [&module](std::unique_ptr<HloComputation> fused_computation) {
316              module->AddComputationInternal(std::move(fused_computation),
317                                             /*is_entry=*/false,
318                                             /*uniquify_names=*/false);
319            }));
320    CHECK_NE(computation.get(), nullptr);
321    TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
322    string computation_name = computation->name();
323    // Don't uniquify names because we want names to be stable across
324    // serialization and deserialization.
325    computation_map[computation_name] = module->AddComputationInternal(
326        std::move(computation),
327        /*is_entry=*/proto.entry_computation_name() == computation_name,
328        /*uniquify_names=*/false);
329  }
330  TF_RET_CHECK(module->entry_computation_ != nullptr);
331
332  // Because we didn't uniquify the names, double-check that the instruction and
333  // computation names are unique from the proto.
334  tensorflow::gtl::FlatSet<string> computation_names;
335  tensorflow::gtl::FlatSet<string> instruction_names;
336  for (HloComputation* computation : module->computations()) {
337    if (computation->IsFusionComputation()) {
338      continue;
339    }
340
341    TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
342        << "Computation name is not unique: " << computation->name();
343    computation_names.insert(computation->name());
344    for (HloInstruction* instruction : computation->instructions()) {
345      TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
346          << "Instruction name is not unique: " << instruction->name();
347      instruction_names.insert(instruction->name());
348    }
349  }
350
351  return std::move(module);
352}
353
354/* static */
355StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
356    const HloModuleProto& module) {
357  TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
358                      ProgramShapeFromProto(module));
359
360  HloModuleConfig module_config(program_shape);
361
362  // The module config is constructed with default layouts regardless of what is
363  // passed in via the ProgramShape. Set the layouts to the appropriate values.
364  ComputationLayout* entry_layout =
365      module_config.mutable_entry_computation_layout();
366  for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
367    TF_RETURN_IF_ERROR(
368        entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
369            program_shape.parameters(i)));
370  }
371  TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
372      program_shape.result()));
373
374  return module_config;
375}
376
377namespace {
378// Returns whether `hlo` is used outside the given subcomputation.
379// `instructions_in_subcomputation` is the instruction set of the given
380// subcomputation.
381bool IsUsedOutsideSubcomputation(
382    const HloInstruction& hlo,
383    const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) {
384  for (HloInstruction* user : hlo.users()) {
385    if (!instructions_in_subcomputation.count(user)) {
386      return true;
387    }
388  }
389  return false;
390}
391}  // anonymous namespace
392
393HloInstruction* HloModule::OutlineExpressionFromComputation(
394    tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
395    const string& outlined_computation_name, HloComputation* computation) {
396  auto builder = HloComputation::Builder(outlined_computation_name);
397
398  // A map from original instructions to their counterparts in the new outlined
399  // function.
400  std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions;
401  // A set that contains all instructions to be outlined.
402  std::unordered_set<HloInstruction*> instruction_set_to_outline(
403      instructions_to_outline.begin(), instructions_to_outline.end());
404  std::vector<HloInstruction*> arguments;
405  std::vector<HloInstruction*> outputs;
406  int64 parameter_count = 0;
407  for (HloInstruction* instruction_to_outline : instructions_to_outline) {
408    // Clone the original instruction.
409    HloInstruction* outlined_instruction =
410        builder.AddInstruction(instruction_to_outline->Clone());
411
412    // Replace its operands to their counterparts in the new function.
413    for (int64 operand_num = 0;
414         operand_num < outlined_instruction->operand_count(); ++operand_num) {
415      HloInstruction* old_operand =
416          outlined_instruction->mutable_operand(operand_num);
417
418      HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
419      if (*operand_slot == nullptr) {
420        // Because instructions_to_outline is in topological order, if
421        // old_operand is not in outlined_instructions, old_operand must be an
422        // input of the outlined subcomputation and thus should be represented
423        // as a parameter in the new function.
424        arguments.push_back(old_operand);
425        *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
426            parameter_count, old_operand->shape(), ""));
427        ++parameter_count;
428      }
429      TF_CHECK_OK(
430          outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
431    }
432
433    // Insert the new instruction into the outlined_instructions map.
434    InsertOrDie(&outlined_instructions, instruction_to_outline,
435                outlined_instruction);
436
437    // Mark instruction_to_outline an output if it is used outside the
438    // subcomputation or is the output of the original computation (i.e. used
439    // externally).
440    if (instruction_to_outline->user_count() == 0 ||
441        IsUsedOutsideSubcomputation(*instruction_to_outline,
442                                    instruction_set_to_outline)) {
443      outputs.push_back(instruction_to_outline);
444    }
445  }
446
447  if (outputs.size() != 1) {
448    string error_message =
449        "The subcomputation to outline has multiple outputs:\n";
450    for (HloInstruction* output : outputs) {
451      tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
452    }
453    LOG(FATAL) << error_message;
454  }
455  HloInstruction* output = outputs[0];
456
457  // Creates a call to the nested computation.
458  HloComputation* nested_computation = AddEmbeddedComputation(
459      builder.Build(FindOrDie(outlined_instructions, output)));
460  HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
461      output->shape(), arguments, nested_computation));
462
463  VLOG(2) << "Outlining the following instructions";
464  for (auto* instruction_to_outline : instructions_to_outline) {
465    VLOG(2) << "  " << instruction_to_outline->ToString();
466  }
467  VLOG(2) << "as a call " << call->ToString();
468  VLOG(2) << "to " << nested_computation->ToString();
469
470  TF_CHECK_OK(output->ReplaceAllUsesWith(call));
471  for (auto i = instructions_to_outline.rbegin();
472       i != instructions_to_outline.rend(); ++i) {
473    TF_CHECK_OK(computation->RemoveInstruction(*i));
474  }
475
476  return call;
477}
478
479int64 HloModule::instruction_count() const {
480  int64 n = 0;
481  for (const auto& computation : computations_) {
482    n += computation->instruction_count();
483  }
484  return n;
485}
486
487std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
488  // First determine all root computations by building a set of nonroot
489  // computations (computations which are called by an instruction in the
490  // module).
491  std::set<HloComputation*> nonroot_computations;
492  for (auto& computation : computations_) {
493    for (auto* instruction : computation->instructions()) {
494      for (HloComputation* called_computation :
495           instruction->called_computations()) {
496        nonroot_computations.insert(called_computation);
497      }
498    }
499  }
500
501  // Keep track of computations which have already been added to the post
502  // order. This prevents duplication as an embedded computation may be called
503  // from two different root computations.
504  std::set<HloComputation*> added_computations;
505  std::list<HloComputation*> post_order;
506  for (auto& computation : computations_) {
507    if (nonroot_computations.count(computation.get()) == 0) {
508      for (HloComputation* embedded_computation :
509           computation->MakeEmbeddedComputationsList()) {
510        if (added_computations.count(embedded_computation) == 0) {
511          post_order.push_back(embedded_computation);
512          added_computations.insert(embedded_computation);
513        }
514      }
515      // Root computations should only be encountered once.
516      CHECK_EQ(0, added_computations.count(computation.get()));
517      post_order.push_back(computation.get());
518      added_computations.insert(computation.get());
519    }
520  }
521  CHECK_EQ(post_order.size(), computations_.size());
522  return post_order;
523}
524
525std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
526  std::vector<HloComputation*> result;
527  for (auto* c : computations()) {
528    if (c->IsFusionComputation()) {
529      continue;
530    }
531    result.push_back(c);
532  }
533  return result;
534}
535
536std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
537  VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
538  auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
539  module->config_ = config_;
540  module->entry_computation_handle_ = entry_computation_handle_;
541  module->has_entry_computation_handle_ = has_entry_computation_handle_;
542
543  std::unordered_map<HloComputation*, HloComputation*> clone_map;
544  for (auto& computation : computations_) {
545    if (computation->IsFusionComputation()) {
546      // Cloning of a fused computation is handled by its fusion instruction.
547      continue;
548    }
549
550    // When cloning a computation, pass in the new module, so that for any
551    // fusion instruction in this computation, the fused computation will be
552    // deep cloned to the new module.
553    auto cloned_computation = computation->Clone(suffix, module.get());
554    InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
555
556    if (entry_computation_ == computation.get()) {
557      module->AddEntryComputation(std::move(cloned_computation));
558    } else {
559      module->AddEmbeddedComputation(std::move(cloned_computation));
560    }
561  }
562
563  for (auto& cloned_computation : module->computations_) {
564    for (auto* instruction : cloned_computation->instructions()) {
565      // Rewrite instruction's called_computation to point to the cloned
566      // computations.
567      instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
568        if (hlo->IsFusionComputation()) {
569          // Cloning of a fused computation has already been handled when its
570          // fusion instruction is cloned. So this hlo computation is already
571          // the cloned one.
572          return hlo;
573        }
574        return FindOrDie(clone_map, hlo);
575      });
576    }
577  }
578  return module;
579}
580
581HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) {
582  HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this));
583  TF_CHECK_OK(
584      clone->root_instruction()->Accept([this](HloInstruction* instruction) {
585        instruction->ReplaceCalledComputations([this](HloComputation* callee) {
586          return DeepCloneComputation(callee);
587        });
588        return Status::OK();
589      }));
590  return clone;
591}
592
593uint64 HloModule::RandomNew64() const {
594  tensorflow::mutex_lock l(rng_mutex_);
595  return rng_();
596}
597
598/* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
599
600}  // namespace xla
601