1//===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when 10// the size is large or is not a compile-time constant. 11// 12//===----------------------------------------------------------------------===// 13 14#include "NVPTXLowerAggrCopies.h" 15#include "llvm/CodeGen/MachineFunctionAnalysis.h" 16#include "llvm/CodeGen/StackProtector.h" 17#include "llvm/IR/Constants.h" 18#include "llvm/IR/DataLayout.h" 19#include "llvm/IR/Function.h" 20#include "llvm/IR/IRBuilder.h" 21#include "llvm/IR/InstIterator.h" 22#include "llvm/IR/Instructions.h" 23#include "llvm/IR/IntrinsicInst.h" 24#include "llvm/IR/Intrinsics.h" 25#include "llvm/IR/LLVMContext.h" 26#include "llvm/IR/Module.h" 27#include "llvm/Support/Debug.h" 28 29#define DEBUG_TYPE "nvptx" 30 31using namespace llvm; 32 33namespace { 34// actual analysis class, which is a functionpass 35struct NVPTXLowerAggrCopies : public FunctionPass { 36 static char ID; 37 38 NVPTXLowerAggrCopies() : FunctionPass(ID) {} 39 40 void getAnalysisUsage(AnalysisUsage &AU) const override { 41 AU.addPreserved<MachineFunctionAnalysis>(); 42 AU.addPreserved<StackProtector>(); 43 } 44 45 bool runOnFunction(Function &F) override; 46 47 static const unsigned MaxAggrCopySize = 128; 48 49 const char *getPassName() const override { 50 return "Lower aggregate copies/intrinsics into loops"; 51 } 52}; 53} // namespace 54 55char NVPTXLowerAggrCopies::ID = 0; 56 57// Lower MemTransferInst or load-store pair to loop 58static void convertTransferToLoop( 59 Instruction *splitAt, Value *srcAddr, Value *dstAddr, Value *len, 60 //unsigned numLoads, 61 bool srcVolatile, bool dstVolatile, LLVMContext &Context, Function &F) { 62 Type *indType = len->getType(); 63 64 BasicBlock *origBB = splitAt->getParent(); 65 BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split"); 66 BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB); 67 68 origBB->getTerminator()->setSuccessor(0, loopBB); 69 IRBuilder<> builder(origBB, origBB->getTerminator()); 70 71 // srcAddr and dstAddr are expected to be pointer types, 72 // so no check is made here. 73 unsigned srcAS = cast<PointerType>(srcAddr->getType())->getAddressSpace(); 74 unsigned dstAS = cast<PointerType>(dstAddr->getType())->getAddressSpace(); 75 76 // Cast pointers to (char *) 77 srcAddr = builder.CreateBitCast(srcAddr, Type::getInt8PtrTy(Context, srcAS)); 78 dstAddr = builder.CreateBitCast(dstAddr, Type::getInt8PtrTy(Context, dstAS)); 79 80 IRBuilder<> loop(loopBB); 81 // The loop index (ind) is a phi node. 82 PHINode *ind = loop.CreatePHI(indType, 0); 83 // Incoming value for ind is 0 84 ind->addIncoming(ConstantInt::get(indType, 0), origBB); 85 86 // load from srcAddr+ind 87 Value *val = loop.CreateLoad(loop.CreateGEP(loop.getInt8Ty(), srcAddr, ind), 88 srcVolatile); 89 // store at dstAddr+ind 90 loop.CreateStore(val, loop.CreateGEP(loop.getInt8Ty(), dstAddr, ind), 91 dstVolatile); 92 93 // The value for ind coming from backedge is (ind + 1) 94 Value *newind = loop.CreateAdd(ind, ConstantInt::get(indType, 1)); 95 ind->addIncoming(newind, loopBB); 96 97 loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB); 98} 99 100// Lower MemSetInst to loop 101static void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr, 102 Value *len, Value *val, LLVMContext &Context, 103 Function &F) { 104 BasicBlock *origBB = splitAt->getParent(); 105 BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split"); 106 BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB); 107 108 origBB->getTerminator()->setSuccessor(0, loopBB); 109 IRBuilder<> builder(origBB, origBB->getTerminator()); 110 111 unsigned dstAS = cast<PointerType>(dstAddr->getType())->getAddressSpace(); 112 113 // Cast pointer to the type of value getting stored 114 dstAddr = 115 builder.CreateBitCast(dstAddr, PointerType::get(val->getType(), dstAS)); 116 117 IRBuilder<> loop(loopBB); 118 PHINode *ind = loop.CreatePHI(len->getType(), 0); 119 ind->addIncoming(ConstantInt::get(len->getType(), 0), origBB); 120 121 loop.CreateStore(val, loop.CreateGEP(val->getType(), dstAddr, ind), false); 122 123 Value *newind = loop.CreateAdd(ind, ConstantInt::get(len->getType(), 1)); 124 ind->addIncoming(newind, loopBB); 125 126 loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB); 127} 128 129bool NVPTXLowerAggrCopies::runOnFunction(Function &F) { 130 SmallVector<LoadInst *, 4> aggrLoads; 131 SmallVector<MemTransferInst *, 4> aggrMemcpys; 132 SmallVector<MemSetInst *, 4> aggrMemsets; 133 134 const DataLayout &DL = F.getParent()->getDataLayout(); 135 LLVMContext &Context = F.getParent()->getContext(); 136 137 // 138 // Collect all the aggrLoads, aggrMemcpys and addrMemsets. 139 // 140 //const BasicBlock *firstBB = &F.front(); // first BB in F 141 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { 142 //BasicBlock *bb = BI; 143 for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE; 144 ++II) { 145 if (LoadInst *load = dyn_cast<LoadInst>(II)) { 146 147 if (!load->hasOneUse()) 148 continue; 149 150 if (DL.getTypeStoreSize(load->getType()) < MaxAggrCopySize) 151 continue; 152 153 User *use = load->user_back(); 154 if (StoreInst *store = dyn_cast<StoreInst>(use)) { 155 if (store->getOperand(0) != load) //getValueOperand 156 continue; 157 aggrLoads.push_back(load); 158 } 159 } else if (MemTransferInst *intr = dyn_cast<MemTransferInst>(II)) { 160 Value *len = intr->getLength(); 161 // If the number of elements being copied is greater 162 // than MaxAggrCopySize, lower it to a loop 163 if (ConstantInt *len_int = dyn_cast<ConstantInt>(len)) { 164 if (len_int->getZExtValue() >= MaxAggrCopySize) { 165 aggrMemcpys.push_back(intr); 166 } 167 } else { 168 // turn variable length memcpy/memmov into loop 169 aggrMemcpys.push_back(intr); 170 } 171 } else if (MemSetInst *memsetintr = dyn_cast<MemSetInst>(II)) { 172 Value *len = memsetintr->getLength(); 173 if (ConstantInt *len_int = dyn_cast<ConstantInt>(len)) { 174 if (len_int->getZExtValue() >= MaxAggrCopySize) { 175 aggrMemsets.push_back(memsetintr); 176 } 177 } else { 178 // turn variable length memset into loop 179 aggrMemsets.push_back(memsetintr); 180 } 181 } 182 } 183 } 184 if ((aggrLoads.size() == 0) && (aggrMemcpys.size() == 0) && 185 (aggrMemsets.size() == 0)) 186 return false; 187 188 // 189 // Do the transformation of an aggr load/copy/set to a loop 190 // 191 for (unsigned i = 0, e = aggrLoads.size(); i != e; ++i) { 192 LoadInst *load = aggrLoads[i]; 193 StoreInst *store = dyn_cast<StoreInst>(*load->user_begin()); 194 Value *srcAddr = load->getOperand(0); 195 Value *dstAddr = store->getOperand(1); 196 unsigned numLoads = DL.getTypeStoreSize(load->getType()); 197 Value *len = ConstantInt::get(Type::getInt32Ty(Context), numLoads); 198 199 convertTransferToLoop(store, srcAddr, dstAddr, len, load->isVolatile(), 200 store->isVolatile(), Context, F); 201 202 store->eraseFromParent(); 203 load->eraseFromParent(); 204 } 205 206 for (unsigned i = 0, e = aggrMemcpys.size(); i != e; ++i) { 207 MemTransferInst *cpy = aggrMemcpys[i]; 208 Value *len = cpy->getLength(); 209 // llvm 2.7 version of memcpy does not have volatile 210 // operand yet. So always making it non-volatile 211 // optimistically, so that we don't see unnecessary 212 // st.volatile in ptx 213 convertTransferToLoop(cpy, cpy->getSource(), cpy->getDest(), len, false, 214 false, Context, F); 215 cpy->eraseFromParent(); 216 } 217 218 for (unsigned i = 0, e = aggrMemsets.size(); i != e; ++i) { 219 MemSetInst *memsetinst = aggrMemsets[i]; 220 Value *len = memsetinst->getLength(); 221 Value *val = memsetinst->getValue(); 222 convertMemSetToLoop(memsetinst, memsetinst->getDest(), len, val, Context, 223 F); 224 memsetinst->eraseFromParent(); 225 } 226 227 return true; 228} 229 230FunctionPass *llvm::createLowerAggrCopies() { 231 return new NVPTXLowerAggrCopies(); 232} 233