1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17
18#include <numeric>
19#include <vector>
20
21#include "llvm/IR/Constants.h"
22#include "llvm/IR/Function.h"
23#include "llvm/IR/Instructions.h"
24#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/types.h"
27#include "tensorflow/compiler/xla/xla_data.pb.h"
28#include "tensorflow/core/lib/strings/strcat.h"
29#include "tensorflow/core/lib/strings/stringprintf.h"
30#include "tensorflow/core/platform/logging.h"
31
32namespace xla {
33namespace llvm_ir {
34
35ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
36                 llvm::Value* start_index, llvm::Value* end_index,
37                 llvm::Value* step, bool prevent_unrolling,
38                 bool prevent_vectorization)
39    : prefix_(prefix.ToString()),
40      suffix_(suffix.ToString()),
41      start_index_(start_index),
42      end_index_(end_index),
43      step_(step),
44      insert_before_bb_(nullptr),
45      prevent_unrolling_(prevent_unrolling),
46      prevent_vectorization_(prevent_vectorization) {}
47
48/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
49    tensorflow::StringPiece prefix, llvm::Value* start_index,
50    llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
51    bool prevent_unrolling, bool prevent_vectorization) {
52  std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
53                                            end_index, step, prevent_unrolling,
54                                            prevent_vectorization));
55  loop->Emit(ir_builder);
56  return loop;
57}
58
59void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
60  // The preheader block is the block the builder is currently emitting
61  // code into.
62  preheader_bb_ = ir_builder->GetInsertBlock();
63
64  llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint();
65  if (insert_point == preheader_bb_->end()) {
66    // We're emitting the loop at the end of a basic block. Verify there is no
67    // terminator (eg, branch) in the basic block.
68    CHECK_EQ(nullptr, preheader_bb_->getTerminator());
69
70    exit_bb_ = CreateLoopBB("loop_exit", ir_builder);
71  } else {
72    // We're emitting the loop into the middle of a basic block. splitBasicBlock
73    // requires that this basic block be well-formed (have a terminator).
74    CHECK_NE(nullptr, preheader_bb_->getTerminator());
75
76    // Split the preheader to create an exit basic block. The exit basic block
77    // will contain all instructions at or after insert_point.
78    exit_bb_ = preheader_bb_->splitBasicBlock(
79        insert_point, AsStringRef(GetQualifiedName("loop_exit")));
80
81    // splitBasicBlock adds an unconditional branch between the split basic
82    // blocks. Remove it. An unconditional branch will be added below from the
83    // preheader to the header.
84    preheader_bb_->getTerminator()->eraseFromParent();
85  }
86  insert_before_bb_ = exit_bb_;
87
88  // Create remaining basic block which form the inside of the loop.
89  header_bb_ = CreateLoopBB("loop_header", ir_builder);
90  body_bb_ = CreateLoopBB("loop_body", ir_builder);
91
92  // Function entry basic block.
93  // Emit alloca for the induction variable. We do this at the entry to the
94  // basic block to ensure the alloc only executes once per function (we could
95  // be emitting a nested loop).
96  llvm::Function* func = preheader_bb_->getParent();
97  ir_builder->SetInsertPoint(&func->getEntryBlock(),
98                             func->getEntryBlock().getFirstInsertionPt());
99  llvm::Value* indvar_address =
100      ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr,
101                               AsStringRef(GetQualifiedName("invar_address")));
102
103  // Preheader basic block.
104  // Initialize induction variable starting index. Create branch to the header.
105  ir_builder->SetInsertPoint(preheader_bb_);
106  ir_builder->CreateStore(start_index_, indvar_address);
107  // The preheader should not have a branch yet.
108  CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
109  ir_builder->CreateBr(header_bb_);
110
111  // Header basic block.
112  // Emit the loop conditional branch. Load and compare indvar with ending
113  // index and jump to loop exit if equal. Jump to body otherwise.
114  ir_builder->SetInsertPoint(header_bb_);
115  indvar_ = ir_builder->CreateLoad(indvar_address,
116                                   AsStringRef(GetQualifiedName("indvar")));
117  llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_);
118  ir_builder->CreateCondBr(/*Cond=*/exit_cond,
119                           /*True=*/exit_bb_, /*False=*/body_bb_);
120
121  // Body basic block.
122  // Increment indvar, store indvar, and jump to header.
123  ir_builder->SetInsertPoint(body_bb_);
124  llvm::Value* step = step_;
125  llvm::Value* indvar = indvar_;
126
127  llvm::Value* indvar_inc =
128      ir_builder->CreateAdd(indvar, step, "invar.inc",
129                            /*HasNUW=*/true, /*HasNSW=*/true);
130  ir_builder->CreateStore(indvar_inc, indvar_address);
131  llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_);
132
133  std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder);
134  if (!loop_metadata.empty()) {
135    llvm::LLVMContext* ctx = &start_index_->getContext();
136    auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
137    loop_metadata.insert(loop_metadata.begin(), temp_node.get());
138    auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
139    loop_id->replaceOperandWith(0, loop_id);
140    back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
141  }
142
143  // Re-point the IR builder to the loop exit block.
144  ir_builder->SetInsertPoint(exit_bb_);
145}
146
147std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
148    llvm::IRBuilder<>* ir_builder) {
149  const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
150  const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
151  llvm::LLVMContext* ctx = &start_index_->getContext();
152
153  std::vector<llvm::Metadata*> result;
154  if (prevent_unrolling_) {
155    result.push_back(llvm::MDNode::get(
156        *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
157  }
158
159  if (prevent_vectorization_) {
160    result.push_back(llvm::MDNode::get(
161        *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
162               llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
163  }
164
165  return result;
166}
167
168string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
169  return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
170}
171
172llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
173                                        llvm::IRBuilder<>* ir_builder) {
174  return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name),
175                          ir_builder);
176}
177
178std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
179                                              llvm::Value* start_index,
180                                              llvm::Value* end_index,
181                                              bool prevent_unrolling,
182                                              bool prevent_vectorization) {
183  return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
184                 prevent_unrolling, prevent_vectorization);
185}
186
187std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
188                                              llvm::Value* start_index,
189                                              llvm::Value* end_index,
190                                              llvm::Value* stride,
191                                              bool prevent_unrolling,
192                                              bool prevent_vectorization) {
193  if (inner_loop_body_bb_ != nullptr) {
194    // Create this loop inside the previous one.
195    ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
196  }
197  std::unique_ptr<ForLoop> loop(new ForLoop(
198      /*prefix=*/name_, suffix, start_index, end_index, stride,
199      prevent_unrolling, prevent_vectorization));
200  loop->Emit(ir_builder_);
201
202  if (outer_loop_preheader_bb_ == nullptr) {
203    outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
204  }
205
206  if (outer_loop_exit_bb_ == nullptr) {
207    outer_loop_exit_bb_ = loop->GetExitBasicBlock();
208  }
209
210  inner_loop_body_bb_ = loop->GetBodyBasicBlock();
211
212  return loop;
213}
214
215std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
216                                              int64 end_index,
217                                              tensorflow::StringPiece suffix,
218                                              bool prevent_unrolling,
219                                              bool prevent_vectorization) {
220  CHECK_LE(start_index, end_index);
221  return AddLoop(suffix, ir_builder_->getInt64(start_index),
222                 ir_builder_->getInt64(end_index), prevent_unrolling,
223                 prevent_vectorization);
224}
225
226std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
227                                              int64 end_index, int64 stride,
228                                              tensorflow::StringPiece suffix,
229                                              bool prevent_unrolling,
230                                              bool prevent_vectorization) {
231  CHECK_LE(start_index, end_index);
232  return AddLoop(suffix, ir_builder_->getInt64(start_index),
233                 ir_builder_->getInt64(end_index),
234                 ir_builder_->getInt64(stride), prevent_unrolling,
235                 prevent_vectorization);
236}
237
238IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
239                                             tensorflow::StringPiece suffix) {
240  std::vector<int64> dimensions(ShapeUtil::Rank(shape));
241  std::iota(dimensions.begin(), dimensions.end(), 0);
242  return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
243}
244
245IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
246    const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
247    tensorflow::StringPiece suffix) {
248  llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr);
249  for (int64 dimension : dimensions) {
250    std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
251        /*start_index=*/0,
252        /*end_index=*/shape.dimensions(dimension),
253        /*suffix=*/
254        llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension)));
255    index[dimension] = loop->GetIndVarValue();
256  }
257  return index;
258}
259
260}  // namespace llvm_ir
261}  // namespace xla
262