1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17
18#include <functional>
19
20#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Value.h"
22#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
23#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
24#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
26#include "tensorflow/compiler/xla/status_macros.h"
27#include "tensorflow/compiler/xla/statusor.h"
28#include "tensorflow/compiler/xla/util.h"
29#include "tensorflow/core/platform/logging.h"
30
31namespace xla {
32
33using llvm_ir::IrArray;
34
35Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
36  generators_[hlo] =
37      [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
38    if (generated_value_cache_[hlo].count(index.multidim()) > 0) {
39      llvm::Value* generated_value =
40          generated_value_cache_[hlo][index.multidim()];
41      llvm::BasicBlock* generated_value_bb = nullptr;
42      if (auto* generated_instruction =
43              llvm::dyn_cast<llvm::Instruction>(generated_value)) {
44        generated_value_bb = generated_instruction->getParent();
45      }
46      // Ideally, we should be able to reuse the cached generated value if it
47      // dominates the current insertion block. However, the check for dominance
48      // can be expensive and unreliable when the function is being constructed.
49      //
50      // It's also worth experimenting what if we don't do caching at all.
51      // LLVM's CSE or GVN should be able to easily merge common subexpressions
52      // that would be regenerated without caching. But this might increase the
53      // JIT compilation time.
54      if (generated_value_bb == nullptr ||
55          generated_value_bb == ir_builder_->GetInsertBlock()) {
56        VLOG(3) << "The cached generated value is reused.";
57        return generated_value;
58      }
59      VLOG(3) << "The cached generated value can't be reused, because it is in "
60                 "a different BB ("
61              << llvm_ir::AsString(generated_value_bb->getName())
62              << ") from the current insertion block ("
63              << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName())
64              << ").";
65    }
66
67    TF_ASSIGN_OR_RETURN(
68        generated_value_cache_[hlo][index.multidim()],
69        elemental_emitter_->MakeElementGenerator(hlo, generators_)(index));
70    return generated_value_cache_[hlo][index.multidim()];
71  };
72  return Status::OK();
73}
74
75Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
76  const Literal& literal = constant->literal();
77  llvm::Constant* initializer =
78      llvm_ir::ConvertLiteralToIrConstant(literal, module_);
79  llvm::GlobalVariable* global = new llvm::GlobalVariable(
80      *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
81      /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
82      /*Name=*/"");
83  generators_[constant] = [=](const IrArray::Index& index) {
84    return IrArray(global, constant->shape())
85        .EmitReadArrayElement(index, ir_builder_);
86  };
87
88  return Status::OK();
89}
90
91Status FusedIrEmitter::HandleGetTupleElement(
92    HloInstruction* get_tuple_element) {
93  // Lookup ir value for 'operand'.
94  auto operand = get_tuple_element->operand(0);
95  auto it = gte_values_.find(operand);
96  if (it == gte_values_.end()) {
97    return Unimplemented(
98        "GetTupleElement fusion currently only supports"
99        " parameter operands, but found operand: %s",
100        operand->name().c_str());
101  }
102  // Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
103  llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
104      get_tuple_element->shape(), get_tuple_element->tuple_index(),
105      /*alignment=*/1, it->second, ir_builder_, module_);
106  gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
107  // Emit code to read base tuple element array (if non-tuple shaped).
108  if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
109    generators_[get_tuple_element] =
110        [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
111      // TODO(b/34080002) Add aliasing information to tuple element IrArray.
112      return IrArray(tuple_element_ptr, get_tuple_element->shape())
113          .EmitReadArrayElement(index, ir_builder_);
114    };
115  }
116  return Status::OK();
117}
118
119Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
120  generators_[parameter] = [=](const IrArray::Index& index) {
121    return parameter_arrays_[parameter->parameter_number()]
122        .EmitReadArrayElement(index, ir_builder_);
123  };
124  // Store ir value for fusion operand associated with fusion parameter to be
125  // accessed by subsequent fused GetTupleElement instructions.
126  gte_values_.insert(std::make_pair(
127      parameter,
128      parameter_arrays_[parameter->parameter_number()].GetBasePointer()));
129  return Status::OK();
130}
131
132Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
133  tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
134  std::vector<llvm::Type*> operand_elemental_ir_types;
135  for (HloInstruction* operand : operands) {
136    operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
137        operand->shape().element_type(), module_));
138  }
139  generators_[tuple] =
140      [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
141    llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get(
142        ir_builder_->getContext(), operand_elemental_ir_types));
143    for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
144      TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index));
145      ret = ir_builder_->CreateInsertValue(ret, val_i, i);
146    }
147    return ret;
148  };
149  return Status::OK();
150}
151
152Status FusedIrEmitter::FinishVisit(HloInstruction* root) {
153  fused_root_ = root;
154  return tensorflow::Status::OK();
155}
156
157FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const {
158  CHECK_NE(nullptr, fused_root_)
159      << "GetRootGenerator should be called after Accept.";
160  return generators_.at(fused_root_);
161}
162
163FusedIrEmitter::Generator FusedIrEmitter::GetGenerator(
164    const HloInstruction* instruction) const {
165  return generators_.at(instruction);
166}
167
168}  // namespace xla
169