1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 17 18#include <functional> 19 20#include "llvm/IR/BasicBlock.h" 21#include "llvm/IR/Value.h" 22#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 23#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 24#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 25#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" 26#include "tensorflow/compiler/xla/status_macros.h" 27#include "tensorflow/compiler/xla/statusor.h" 28#include "tensorflow/compiler/xla/util.h" 29#include "tensorflow/core/platform/logging.h" 30 31namespace xla { 32 33using llvm_ir::IrArray; 34 35Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { 36 generators_[hlo] = 37 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> { 38 if (generated_value_cache_[hlo].count(index.multidim()) > 0) { 39 llvm::Value* generated_value = 40 generated_value_cache_[hlo][index.multidim()]; 41 llvm::BasicBlock* generated_value_bb = nullptr; 42 if (auto* generated_instruction = 43 llvm::dyn_cast<llvm::Instruction>(generated_value)) { 44 generated_value_bb = generated_instruction->getParent(); 45 } 46 // Ideally, we should be able to reuse the cached generated value if it 47 // dominates the current insertion block. However, the check for dominance 48 // can be expensive and unreliable when the function is being constructed. 49 // 50 // It's also worth experimenting what if we don't do caching at all. 51 // LLVM's CSE or GVN should be able to easily merge common subexpressions 52 // that would be regenerated without caching. But this might increase the 53 // JIT compilation time. 54 if (generated_value_bb == nullptr || 55 generated_value_bb == ir_builder_->GetInsertBlock()) { 56 VLOG(3) << "The cached generated value is reused."; 57 return generated_value; 58 } 59 VLOG(3) << "The cached generated value can't be reused, because it is in " 60 "a different BB (" 61 << llvm_ir::AsString(generated_value_bb->getName()) 62 << ") from the current insertion block (" 63 << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName()) 64 << ")."; 65 } 66 67 TF_ASSIGN_OR_RETURN( 68 generated_value_cache_[hlo][index.multidim()], 69 elemental_emitter_->MakeElementGenerator(hlo, generators_)(index)); 70 return generated_value_cache_[hlo][index.multidim()]; 71 }; 72 return Status::OK(); 73} 74 75Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { 76 const Literal& literal = constant->literal(); 77 llvm::Constant* initializer = 78 llvm_ir::ConvertLiteralToIrConstant(literal, module_); 79 llvm::GlobalVariable* global = new llvm::GlobalVariable( 80 *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), 81 /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, 82 /*Name=*/""); 83 generators_[constant] = [=](const IrArray::Index& index) { 84 return IrArray(global, constant->shape()) 85 .EmitReadArrayElement(index, ir_builder_); 86 }; 87 88 return Status::OK(); 89} 90 91Status FusedIrEmitter::HandleGetTupleElement( 92 HloInstruction* get_tuple_element) { 93 // Lookup ir value for 'operand'. 94 auto operand = get_tuple_element->operand(0); 95 auto it = gte_values_.find(operand); 96 if (it == gte_values_.end()) { 97 return Unimplemented( 98 "GetTupleElement fusion currently only supports" 99 " parameter operands, but found operand: %s", 100 operand->name().c_str()); 101 } 102 // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. 103 llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( 104 get_tuple_element->shape(), get_tuple_element->tuple_index(), 105 /*alignment=*/1, it->second, ir_builder_, module_); 106 gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); 107 // Emit code to read base tuple element array (if non-tuple shaped). 108 if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { 109 generators_[get_tuple_element] = 110 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> { 111 // TODO(b/34080002) Add aliasing information to tuple element IrArray. 112 return IrArray(tuple_element_ptr, get_tuple_element->shape()) 113 .EmitReadArrayElement(index, ir_builder_); 114 }; 115 } 116 return Status::OK(); 117} 118 119Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { 120 generators_[parameter] = [=](const IrArray::Index& index) { 121 return parameter_arrays_[parameter->parameter_number()] 122 .EmitReadArrayElement(index, ir_builder_); 123 }; 124 // Store ir value for fusion operand associated with fusion parameter to be 125 // accessed by subsequent fused GetTupleElement instructions. 126 gte_values_.insert(std::make_pair( 127 parameter, 128 parameter_arrays_[parameter->parameter_number()].GetBasePointer())); 129 return Status::OK(); 130} 131 132Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { 133 tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands()); 134 std::vector<llvm::Type*> operand_elemental_ir_types; 135 for (HloInstruction* operand : operands) { 136 operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( 137 operand->shape().element_type(), module_)); 138 } 139 generators_[tuple] = 140 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> { 141 llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get( 142 ir_builder_->getContext(), operand_elemental_ir_types)); 143 for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) { 144 TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index)); 145 ret = ir_builder_->CreateInsertValue(ret, val_i, i); 146 } 147 return ret; 148 }; 149 return Status::OK(); 150} 151 152Status FusedIrEmitter::FinishVisit(HloInstruction* root) { 153 fused_root_ = root; 154 return tensorflow::Status::OK(); 155} 156 157FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { 158 CHECK_NE(nullptr, fused_root_) 159 << "GetRootGenerator should be called after Accept."; 160 return generators_.at(fused_root_); 161} 162 163FusedIrEmitter::Generator FusedIrEmitter::GetGenerator( 164 const HloInstruction* instruction) const { 165 return generators_.at(instruction); 166} 167 168} // namespace xla 169