1//===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// \file 11// Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when 12// the size is large or is not a compile-time constant. 13// 14//===----------------------------------------------------------------------===// 15 16#include "NVPTXLowerAggrCopies.h" 17#include "llvm/CodeGen/MachineFunctionAnalysis.h" 18#include "llvm/CodeGen/StackProtector.h" 19#include "llvm/IR/Constants.h" 20#include "llvm/IR/DataLayout.h" 21#include "llvm/IR/Function.h" 22#include "llvm/IR/IRBuilder.h" 23#include "llvm/IR/Instructions.h" 24#include "llvm/IR/IntrinsicInst.h" 25#include "llvm/IR/Intrinsics.h" 26#include "llvm/IR/LLVMContext.h" 27#include "llvm/IR/Module.h" 28#include "llvm/Support/Debug.h" 29#include "llvm/Transforms/Utils/BasicBlockUtils.h" 30 31#define DEBUG_TYPE "nvptx" 32 33using namespace llvm; 34 35namespace { 36 37// actual analysis class, which is a functionpass 38struct NVPTXLowerAggrCopies : public FunctionPass { 39 static char ID; 40 41 NVPTXLowerAggrCopies() : FunctionPass(ID) {} 42 43 void getAnalysisUsage(AnalysisUsage &AU) const override { 44 AU.addPreserved<MachineFunctionAnalysis>(); 45 AU.addPreserved<StackProtector>(); 46 } 47 48 bool runOnFunction(Function &F) override; 49 50 static const unsigned MaxAggrCopySize = 128; 51 52 const char *getPassName() const override { 53 return "Lower aggregate copies/intrinsics into loops"; 54 } 55}; 56 57char NVPTXLowerAggrCopies::ID = 0; 58 59// Lower memcpy to loop. 60void convertMemCpyToLoop(Instruction *ConvertedInst, Value *SrcAddr, 61 Value *DstAddr, Value *CopyLen, bool SrcIsVolatile, 62 bool DstIsVolatile, LLVMContext &Context, 63 Function &F) { 64 Type *TypeOfCopyLen = CopyLen->getType(); 65 66 BasicBlock *OrigBB = ConvertedInst->getParent(); 67 BasicBlock *NewBB = 68 ConvertedInst->getParent()->splitBasicBlock(ConvertedInst, "split"); 69 BasicBlock *LoopBB = BasicBlock::Create(Context, "loadstoreloop", &F, NewBB); 70 71 OrigBB->getTerminator()->setSuccessor(0, LoopBB); 72 IRBuilder<> Builder(OrigBB->getTerminator()); 73 74 // SrcAddr and DstAddr are expected to be pointer types, 75 // so no check is made here. 76 unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); 77 unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); 78 79 // Cast pointers to (char *) 80 SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS)); 81 DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS)); 82 83 IRBuilder<> LoopBuilder(LoopBB); 84 PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); 85 LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB); 86 87 // load from SrcAddr+LoopIndex 88 // TODO: we can leverage the align parameter of llvm.memcpy for more efficient 89 // word-sized loads and stores. 90 Value *Element = 91 LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP( 92 LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex), 93 SrcIsVolatile); 94 // store at DstAddr+LoopIndex 95 LoopBuilder.CreateStore(Element, 96 LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(), 97 DstAddr, LoopIndex), 98 DstIsVolatile); 99 100 // The value for LoopIndex coming from backedge is (LoopIndex + 1) 101 Value *NewIndex = 102 LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1)); 103 LoopIndex->addIncoming(NewIndex, LoopBB); 104 105 LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, 106 NewBB); 107} 108 109// Lower memmove to IR. memmove is required to correctly copy overlapping memory 110// regions; therefore, it has to check the relative positions of the source and 111// destination pointers and choose the copy direction accordingly. 112// 113// The code below is an IR rendition of this C function: 114// 115// void* memmove(void* dst, const void* src, size_t n) { 116// unsigned char* d = dst; 117// const unsigned char* s = src; 118// if (s < d) { 119// // copy backwards 120// while (n--) { 121// d[n] = s[n]; 122// } 123// } else { 124// // copy forward 125// for (size_t i = 0; i < n; ++i) { 126// d[i] = s[i]; 127// } 128// } 129// return dst; 130// } 131void convertMemMoveToLoop(Instruction *ConvertedInst, Value *SrcAddr, 132 Value *DstAddr, Value *CopyLen, bool SrcIsVolatile, 133 bool DstIsVolatile, LLVMContext &Context, 134 Function &F) { 135 Type *TypeOfCopyLen = CopyLen->getType(); 136 BasicBlock *OrigBB = ConvertedInst->getParent(); 137 138 // Create the a comparison of src and dst, based on which we jump to either 139 // the forward-copy part of the function (if src >= dst) or the backwards-copy 140 // part (if src < dst). 141 // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else 142 // structure. Its block terminators (unconditional branches) are replaced by 143 // the appropriate conditional branches when the loop is built. 144 ICmpInst *PtrCompare = new ICmpInst(ConvertedInst, ICmpInst::ICMP_ULT, 145 SrcAddr, DstAddr, "compare_src_dst"); 146 TerminatorInst *ThenTerm, *ElseTerm; 147 SplitBlockAndInsertIfThenElse(PtrCompare, ConvertedInst, &ThenTerm, 148 &ElseTerm); 149 150 // Each part of the function consists of two blocks: 151 // copy_backwards: used to skip the loop when n == 0 152 // copy_backwards_loop: the actual backwards loop BB 153 // copy_forward: used to skip the loop when n == 0 154 // copy_forward_loop: the actual forward loop BB 155 BasicBlock *CopyBackwardsBB = ThenTerm->getParent(); 156 CopyBackwardsBB->setName("copy_backwards"); 157 BasicBlock *CopyForwardBB = ElseTerm->getParent(); 158 CopyForwardBB->setName("copy_forward"); 159 BasicBlock *ExitBB = ConvertedInst->getParent(); 160 ExitBB->setName("memmove_done"); 161 162 // Initial comparison of n == 0 that lets us skip the loops altogether. Shared 163 // between both backwards and forward copy clauses. 164 ICmpInst *CompareN = 165 new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen, 166 ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0"); 167 168 // Copying backwards. 169 BasicBlock *LoopBB = 170 BasicBlock::Create(Context, "copy_backwards_loop", &F, CopyForwardBB); 171 IRBuilder<> LoopBuilder(LoopBB); 172 PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); 173 Value *IndexPtr = LoopBuilder.CreateSub( 174 LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr"); 175 Value *Element = LoopBuilder.CreateLoad( 176 LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element"); 177 LoopBuilder.CreateStore(Element, 178 LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr)); 179 LoopBuilder.CreateCondBr( 180 LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)), 181 ExitBB, LoopBB); 182 LoopPhi->addIncoming(IndexPtr, LoopBB); 183 LoopPhi->addIncoming(CopyLen, CopyBackwardsBB); 184 BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm); 185 ThenTerm->eraseFromParent(); 186 187 // Copying forward. 188 BasicBlock *FwdLoopBB = 189 BasicBlock::Create(Context, "copy_forward_loop", &F, ExitBB); 190 IRBuilder<> FwdLoopBuilder(FwdLoopBB); 191 PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr"); 192 Value *FwdElement = FwdLoopBuilder.CreateLoad( 193 FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element"); 194 FwdLoopBuilder.CreateStore( 195 FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi)); 196 Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd( 197 FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment"); 198 FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen), 199 ExitBB, FwdLoopBB); 200 FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB); 201 FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB); 202 203 BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm); 204 ElseTerm->eraseFromParent(); 205} 206 207// Lower memset to loop. 208void convertMemSetToLoop(Instruction *ConvertedInst, Value *DstAddr, 209 Value *CopyLen, Value *SetValue, LLVMContext &Context, 210 Function &F) { 211 BasicBlock *OrigBB = ConvertedInst->getParent(); 212 BasicBlock *NewBB = 213 ConvertedInst->getParent()->splitBasicBlock(ConvertedInst, "split"); 214 BasicBlock *LoopBB = BasicBlock::Create(Context, "loadstoreloop", &F, NewBB); 215 216 OrigBB->getTerminator()->setSuccessor(0, LoopBB); 217 IRBuilder<> Builder(OrigBB->getTerminator()); 218 219 // Cast pointer to the type of value getting stored 220 unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); 221 DstAddr = Builder.CreateBitCast(DstAddr, 222 PointerType::get(SetValue->getType(), dstAS)); 223 224 IRBuilder<> LoopBuilder(LoopBB); 225 PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLen->getType(), 0); 226 LoopIndex->addIncoming(ConstantInt::get(CopyLen->getType(), 0), OrigBB); 227 228 LoopBuilder.CreateStore( 229 SetValue, 230 LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex), 231 false); 232 233 Value *NewIndex = 234 LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLen->getType(), 1)); 235 LoopIndex->addIncoming(NewIndex, LoopBB); 236 237 LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, 238 NewBB); 239} 240 241bool NVPTXLowerAggrCopies::runOnFunction(Function &F) { 242 SmallVector<LoadInst *, 4> AggrLoads; 243 SmallVector<MemIntrinsic *, 4> MemCalls; 244 245 const DataLayout &DL = F.getParent()->getDataLayout(); 246 LLVMContext &Context = F.getParent()->getContext(); 247 248 // Collect all aggregate loads and mem* calls. 249 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { 250 for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE; 251 ++II) { 252 if (LoadInst *LI = dyn_cast<LoadInst>(II)) { 253 if (!LI->hasOneUse()) 254 continue; 255 256 if (DL.getTypeStoreSize(LI->getType()) < MaxAggrCopySize) 257 continue; 258 259 if (StoreInst *SI = dyn_cast<StoreInst>(LI->user_back())) { 260 if (SI->getOperand(0) != LI) 261 continue; 262 AggrLoads.push_back(LI); 263 } 264 } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(II)) { 265 // Convert intrinsic calls with variable size or with constant size 266 // larger than the MaxAggrCopySize threshold. 267 if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) { 268 if (LenCI->getZExtValue() >= MaxAggrCopySize) { 269 MemCalls.push_back(IntrCall); 270 } 271 } else { 272 MemCalls.push_back(IntrCall); 273 } 274 } 275 } 276 } 277 278 if (AggrLoads.size() == 0 && MemCalls.size() == 0) { 279 return false; 280 } 281 282 // 283 // Do the transformation of an aggr load/copy/set to a loop 284 // 285 for (LoadInst *LI : AggrLoads) { 286 StoreInst *SI = dyn_cast<StoreInst>(*LI->user_begin()); 287 Value *SrcAddr = LI->getOperand(0); 288 Value *DstAddr = SI->getOperand(1); 289 unsigned NumLoads = DL.getTypeStoreSize(LI->getType()); 290 Value *CopyLen = ConstantInt::get(Type::getInt32Ty(Context), NumLoads); 291 292 convertMemCpyToLoop(/* ConvertedInst */ SI, 293 /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr, 294 /* CopyLen */ CopyLen, 295 /* SrcIsVolatile */ LI->isVolatile(), 296 /* DstIsVolatile */ SI->isVolatile(), 297 /* Context */ Context, 298 /* Function F */ F); 299 300 SI->eraseFromParent(); 301 LI->eraseFromParent(); 302 } 303 304 // Transform mem* intrinsic calls. 305 for (MemIntrinsic *MemCall : MemCalls) { 306 if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) { 307 convertMemCpyToLoop(/* ConvertedInst */ Memcpy, 308 /* SrcAddr */ Memcpy->getRawSource(), 309 /* DstAddr */ Memcpy->getRawDest(), 310 /* CopyLen */ Memcpy->getLength(), 311 /* SrcIsVolatile */ Memcpy->isVolatile(), 312 /* DstIsVolatile */ Memcpy->isVolatile(), 313 /* Context */ Context, 314 /* Function F */ F); 315 } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) { 316 convertMemMoveToLoop(/* ConvertedInst */ Memmove, 317 /* SrcAddr */ Memmove->getRawSource(), 318 /* DstAddr */ Memmove->getRawDest(), 319 /* CopyLen */ Memmove->getLength(), 320 /* SrcIsVolatile */ Memmove->isVolatile(), 321 /* DstIsVolatile */ Memmove->isVolatile(), 322 /* Context */ Context, 323 /* Function F */ F); 324 325 } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) { 326 convertMemSetToLoop(/* ConvertedInst */ Memset, 327 /* DstAddr */ Memset->getRawDest(), 328 /* CopyLen */ Memset->getLength(), 329 /* SetValue */ Memset->getValue(), 330 /* Context */ Context, 331 /* Function F */ F); 332 } 333 MemCall->eraseFromParent(); 334 } 335 336 return true; 337} 338 339} // namespace 340 341namespace llvm { 342void initializeNVPTXLowerAggrCopiesPass(PassRegistry &); 343} 344 345INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies", 346 "Lower aggregate copies, and llvm.mem* intrinsics into loops", 347 false, false) 348 349FunctionPass *llvm::createLowerAggrCopies() { 350 return new NVPTXLowerAggrCopies(); 351} 352