1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ 17#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ 18 19#include <map> 20#include <unordered_map> 21 22#include "llvm/IR/IRBuilder.h" 23#include "llvm/IR/Value.h" 24#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 25#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 26#include "tensorflow/compiler/xla/service/hlo_instruction.h" 27#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 28#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 29#include "tensorflow/compiler/xla/statusor.h" 30#include "tensorflow/compiler/xla/xla_data.pb.h" 31#include "tensorflow/core/lib/gtl/array_slice.h" 32 33namespace xla { 34 35// FusedIrEmitter is used to generate code for fusion nodes. 36// 37// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM 38// Module, FusedIrEmitter is better understood as "IR generator generator". 39// FusedIrEmitter recursively creates a generator (a host function) which the 40// compiler can invoke at a later time. Invoking the generator emits LLVM IR 41// that, when run, produces the value at a particular index of the output. 42// 43// After building this generator, the compiler creates a loop (or its moral 44// equivalent, e.g. a GPU kernel) and calls the generator from within the loop. 45// This generates code that produces each element of the output. 46// 47// This class handles both vanilla fusion and multi-output fusion. In the MOF 48// case, the fusion node ends with a kTuple instruction, and the generator 49// created produces an LLVM struct with N elements, one for each element of the 50// arrays in the tuple. It follows that the arrays in the tuple must have the 51// same length. 52class FusedIrEmitter : public DfsHloVisitorWithDefault { 53 public: 54 using Generator = llvm_ir::ElementGenerator; 55 56 FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays, 57 ElementalIrEmitter* elemental_emitter) 58 : parameter_arrays_(parameter_arrays), 59 elemental_emitter_(elemental_emitter), 60 ir_builder_(elemental_emitter->ir_builder()), 61 module_(elemental_emitter->module()) {} 62 63 Status DefaultAction(HloInstruction* hlo) override; 64 65 Status HandleConstant(HloInstruction* constant) override; 66 67 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 68 69 Status HandleParameter(HloInstruction* parameter) override; 70 71 // Emits the ir value for each element in the tuple. 72 Status HandleTuple(HloInstruction* tuple) override; 73 74 Status FinishVisit(HloInstruction* root) override; 75 76 // Returns the generator function for the root of the fused computation. 77 Generator GetRootGenerator() const; 78 79 // Returns the generator function for the given instruction. 80 Generator GetGenerator(const HloInstruction* instruction) const; 81 82 // Returns the ir value for instruction 'hlo'. 83 llvm::Value* GetIrValueForGTE(const HloInstruction* hlo) const { 84 auto it = gte_values_.find(hlo); 85 CHECK(it != gte_values_.end()); 86 return it->second; 87 } 88 89 private: 90 // Arrays of parameters of fusion instruction 91 tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_; 92 93 ElementalIrEmitter* elemental_emitter_; 94 95 // This member will be set by FinishVisit and used in GetRootGenerator. 96 const HloInstruction* fused_root_ = nullptr; 97 98 // Borrowed 99 llvm::IRBuilder<>* ir_builder_; 100 llvm::Module* module_; 101 102 // Map from instruction pointers to functions to generate elements of their 103 // outputs 104 std::unordered_map<const HloInstruction*, Generator> generators_; 105 106 // Cache of generated values, lest we regenerate an element of a node with 107 // multiple outgoing edges 108 std::unordered_map<const HloInstruction*, 109 std::map<std::vector<llvm::Value*>, llvm::Value*>> 110 generated_value_cache_; 111 112 // Stores ir values required to emit fused (and possibly nested) 113 // GetTupleElement instructions. 114 std::unordered_map<const HloInstruction*, llvm::Value*> gte_values_; 115}; 116 117} // namespace xla 118 119#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_ 120