1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/flatten_call_graph.h" 17 18#include "tensorflow/compiler/xla/service/call_graph.h" 19#include "tensorflow/compiler/xla/service/hlo_computation.h" 20#include "tensorflow/compiler/xla/service/hlo_instruction.h" 21#include "tensorflow/compiler/xla/service/hlo_module.h" 22#include "tensorflow/compiler/xla/util.h" 23#include "tensorflow/core/lib/core/errors.h" 24 25namespace xla { 26 27namespace { 28 29// Helper to replace the called computation at a while-, call-, or 30// conditional-instruction. This function replaces exactly one instance of 31// 'computation' with 'new_computation' even if 'instruction' calls 32// 'computation' more than once. 33void ReplaceCalledComputation(HloInstruction* instruction, 34 HloComputation* computation, 35 HloComputation* new_computation) { 36 switch (instruction->opcode()) { 37 case HloOpcode::kWhile: { 38 if (computation == instruction->while_condition()) { 39 instruction->set_while_condition(new_computation); 40 } else { 41 CHECK_EQ(computation, instruction->while_body()); 42 instruction->set_while_body(new_computation); 43 } 44 break; 45 } 46 case HloOpcode::kCall: { 47 CHECK_EQ(instruction->to_apply(), computation); 48 instruction->set_to_apply(new_computation); 49 break; 50 } 51 case HloOpcode::kConditional: { 52 if (computation == instruction->true_computation()) { 53 instruction->set_true_computation(new_computation); 54 } else { 55 CHECK_EQ(computation, instruction->false_computation()); 56 instruction->set_false_computation(new_computation); 57 } 58 break; 59 } 60 default: 61 LOG(FATAL) << "unexpected opcode: " 62 << HloOpcodeString(instruction->opcode()); 63 } 64} 65 66// Flatten a single call graph node. Expects to visit nodes in postorder. 67Status FlattenNode(const CallGraphNode& node) { 68 HloComputation* computation = node.computation(); 69 HloModule* module = computation->parent(); 70 // Clone callee for all call-sites except the first one. 71 for (int i = 0; i < node.caller_callsites().size(); ++i) { 72 CallSite call_site = node.caller_callsites()[i]; 73 // Only consider sequential call contexts. 74 if (call_site.context() == CallContext::kParallel) { 75 continue; 76 } 77 CHECK_EQ(call_site.context(), CallContext::kSequential); 78 79 // Skip first element if this computation is only called from a sequential 80 // context. 81 if (node.context() != CallContext::kBoth && i == 0) { 82 continue; 83 } 84 85 // Clone computation for the remaining sequential context call sites. 86 HloComputation* clone = 87 module->AddEmbeddedComputation(computation->Clone()); 88 ReplaceCalledComputation(call_site.instruction(), computation, clone); 89 // Clone the sub-tree of all computations called from this node. 90 std::vector<HloComputation*> worklist; 91 worklist.push_back(clone); 92 while (!worklist.empty()) { 93 auto current = worklist.back(); 94 worklist.pop_back(); 95 for (auto* instruction : current->instructions()) { 96 if (GetInstructionCallContext(instruction) != 97 CallContext::kSequential) { 98 continue; 99 } 100 for (auto callee : instruction->called_computations()) { 101 HloComputation* callee_clone = 102 module->AddEmbeddedComputation(callee->Clone()); 103 ReplaceCalledComputation(instruction, callee, callee_clone); 104 worklist.push_back(callee_clone); 105 } 106 } 107 } 108 } 109 return Status::OK(); 110} 111 112} // namespace 113 114StatusOr<bool> FlattenCallGraph::Run(HloModule* module) { 115 XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString()); 116 117 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 118 TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode)); 119 120 XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString()); 121 return true; 122} 123 124} // namespace xla 125