1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 17 18#include <numeric> 19#include <vector> 20 21#include "llvm/IR/Constants.h" 22#include "llvm/IR/Function.h" 23#include "llvm/IR/Instructions.h" 24#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 25#include "tensorflow/compiler/xla/shape_util.h" 26#include "tensorflow/compiler/xla/types.h" 27#include "tensorflow/compiler/xla/xla_data.pb.h" 28#include "tensorflow/core/lib/strings/strcat.h" 29#include "tensorflow/core/lib/strings/stringprintf.h" 30#include "tensorflow/core/platform/logging.h" 31 32namespace xla { 33namespace llvm_ir { 34 35ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, 36 llvm::Value* start_index, llvm::Value* end_index, 37 llvm::Value* step, bool prevent_unrolling, 38 bool prevent_vectorization) 39 : prefix_(prefix.ToString()), 40 suffix_(suffix.ToString()), 41 start_index_(start_index), 42 end_index_(end_index), 43 step_(step), 44 insert_before_bb_(nullptr), 45 prevent_unrolling_(prevent_unrolling), 46 prevent_vectorization_(prevent_vectorization) {} 47 48/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop( 49 tensorflow::StringPiece prefix, llvm::Value* start_index, 50 llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, 51 bool prevent_unrolling, bool prevent_vectorization) { 52 std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index, 53 end_index, step, prevent_unrolling, 54 prevent_vectorization)); 55 loop->Emit(ir_builder); 56 return loop; 57} 58 59void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { 60 // The preheader block is the block the builder is currently emitting 61 // code into. 62 preheader_bb_ = ir_builder->GetInsertBlock(); 63 64 llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint(); 65 if (insert_point == preheader_bb_->end()) { 66 // We're emitting the loop at the end of a basic block. Verify there is no 67 // terminator (eg, branch) in the basic block. 68 CHECK_EQ(nullptr, preheader_bb_->getTerminator()); 69 70 exit_bb_ = CreateLoopBB("loop_exit", ir_builder); 71 } else { 72 // We're emitting the loop into the middle of a basic block. splitBasicBlock 73 // requires that this basic block be well-formed (have a terminator). 74 CHECK_NE(nullptr, preheader_bb_->getTerminator()); 75 76 // Split the preheader to create an exit basic block. The exit basic block 77 // will contain all instructions at or after insert_point. 78 exit_bb_ = preheader_bb_->splitBasicBlock( 79 insert_point, AsStringRef(GetQualifiedName("loop_exit"))); 80 81 // splitBasicBlock adds an unconditional branch between the split basic 82 // blocks. Remove it. An unconditional branch will be added below from the 83 // preheader to the header. 84 preheader_bb_->getTerminator()->eraseFromParent(); 85 } 86 insert_before_bb_ = exit_bb_; 87 88 // Create remaining basic block which form the inside of the loop. 89 header_bb_ = CreateLoopBB("loop_header", ir_builder); 90 body_bb_ = CreateLoopBB("loop_body", ir_builder); 91 92 // Function entry basic block. 93 // Emit alloca for the induction variable. We do this at the entry to the 94 // basic block to ensure the alloc only executes once per function (we could 95 // be emitting a nested loop). 96 llvm::Function* func = preheader_bb_->getParent(); 97 ir_builder->SetInsertPoint(&func->getEntryBlock(), 98 func->getEntryBlock().getFirstInsertionPt()); 99 llvm::Value* indvar_address = 100 ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr, 101 AsStringRef(GetQualifiedName("invar_address"))); 102 103 // Preheader basic block. 104 // Initialize induction variable starting index. Create branch to the header. 105 ir_builder->SetInsertPoint(preheader_bb_); 106 ir_builder->CreateStore(start_index_, indvar_address); 107 // The preheader should not have a branch yet. 108 CHECK_EQ(preheader_bb_->getTerminator(), nullptr); 109 ir_builder->CreateBr(header_bb_); 110 111 // Header basic block. 112 // Emit the loop conditional branch. Load and compare indvar with ending 113 // index and jump to loop exit if equal. Jump to body otherwise. 114 ir_builder->SetInsertPoint(header_bb_); 115 indvar_ = ir_builder->CreateLoad(indvar_address, 116 AsStringRef(GetQualifiedName("indvar"))); 117 llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_); 118 ir_builder->CreateCondBr(/*Cond=*/exit_cond, 119 /*True=*/exit_bb_, /*False=*/body_bb_); 120 121 // Body basic block. 122 // Increment indvar, store indvar, and jump to header. 123 ir_builder->SetInsertPoint(body_bb_); 124 llvm::Value* step = step_; 125 llvm::Value* indvar = indvar_; 126 127 llvm::Value* indvar_inc = 128 ir_builder->CreateAdd(indvar, step, "invar.inc", 129 /*HasNUW=*/true, /*HasNSW=*/true); 130 ir_builder->CreateStore(indvar_inc, indvar_address); 131 llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); 132 133 std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder); 134 if (!loop_metadata.empty()) { 135 llvm::LLVMContext* ctx = &start_index_->getContext(); 136 auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); 137 loop_metadata.insert(loop_metadata.begin(), temp_node.get()); 138 auto loop_id = llvm::MDNode::get(*ctx, loop_metadata); 139 loop_id->replaceOperandWith(0, loop_id); 140 back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); 141 } 142 143 // Re-point the IR builder to the loop exit block. 144 ir_builder->SetInsertPoint(exit_bb_); 145} 146 147std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata( 148 llvm::IRBuilder<>* ir_builder) { 149 const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; 150 const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; 151 llvm::LLVMContext* ctx = &start_index_->getContext(); 152 153 std::vector<llvm::Metadata*> result; 154 if (prevent_unrolling_) { 155 result.push_back(llvm::MDNode::get( 156 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); 157 } 158 159 if (prevent_vectorization_) { 160 result.push_back(llvm::MDNode::get( 161 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), 162 llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); 163 } 164 165 return result; 166} 167 168string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { 169 return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); 170} 171 172llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, 173 llvm::IRBuilder<>* ir_builder) { 174 return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), 175 ir_builder); 176} 177 178std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, 179 llvm::Value* start_index, 180 llvm::Value* end_index, 181 bool prevent_unrolling, 182 bool prevent_vectorization) { 183 return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), 184 prevent_unrolling, prevent_vectorization); 185} 186 187std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, 188 llvm::Value* start_index, 189 llvm::Value* end_index, 190 llvm::Value* stride, 191 bool prevent_unrolling, 192 bool prevent_vectorization) { 193 if (inner_loop_body_bb_ != nullptr) { 194 // Create this loop inside the previous one. 195 ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); 196 } 197 std::unique_ptr<ForLoop> loop(new ForLoop( 198 /*prefix=*/name_, suffix, start_index, end_index, stride, 199 prevent_unrolling, prevent_vectorization)); 200 loop->Emit(ir_builder_); 201 202 if (outer_loop_preheader_bb_ == nullptr) { 203 outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock(); 204 } 205 206 if (outer_loop_exit_bb_ == nullptr) { 207 outer_loop_exit_bb_ = loop->GetExitBasicBlock(); 208 } 209 210 inner_loop_body_bb_ = loop->GetBodyBasicBlock(); 211 212 return loop; 213} 214 215std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, 216 int64 end_index, 217 tensorflow::StringPiece suffix, 218 bool prevent_unrolling, 219 bool prevent_vectorization) { 220 CHECK_LE(start_index, end_index); 221 return AddLoop(suffix, ir_builder_->getInt64(start_index), 222 ir_builder_->getInt64(end_index), prevent_unrolling, 223 prevent_vectorization); 224} 225 226std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, 227 int64 end_index, int64 stride, 228 tensorflow::StringPiece suffix, 229 bool prevent_unrolling, 230 bool prevent_vectorization) { 231 CHECK_LE(start_index, end_index); 232 return AddLoop(suffix, ir_builder_->getInt64(start_index), 233 ir_builder_->getInt64(end_index), 234 ir_builder_->getInt64(stride), prevent_unrolling, 235 prevent_vectorization); 236} 237 238IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, 239 tensorflow::StringPiece suffix) { 240 std::vector<int64> dimensions(ShapeUtil::Rank(shape)); 241 std::iota(dimensions.begin(), dimensions.end(), 0); 242 return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); 243} 244 245IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( 246 const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions, 247 tensorflow::StringPiece suffix) { 248 llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr); 249 for (int64 dimension : dimensions) { 250 std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop( 251 /*start_index=*/0, 252 /*end_index=*/shape.dimensions(dimension), 253 /*suffix=*/ 254 llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension))); 255 index[dimension] = loop->GetIndVarValue(); 256 } 257 return index; 258} 259 260} // namespace llvm_ir 261} // namespace xla 262