1//===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9/// \file 10/// This transformation combines adjacent loads. 11/// 12//===----------------------------------------------------------------------===// 13 14#include "llvm/Transforms/Scalar.h" 15#include "llvm/ADT/DenseMap.h" 16#include "llvm/ADT/Statistic.h" 17#include "llvm/Analysis/AliasAnalysis.h" 18#include "llvm/Analysis/AliasSetTracker.h" 19#include "llvm/Analysis/GlobalsModRef.h" 20#include "llvm/Analysis/TargetFolder.h" 21#include "llvm/IR/DataLayout.h" 22#include "llvm/IR/Function.h" 23#include "llvm/IR/IRBuilder.h" 24#include "llvm/IR/Instructions.h" 25#include "llvm/IR/Module.h" 26#include "llvm/Pass.h" 27#include "llvm/Support/Debug.h" 28#include "llvm/Support/MathExtras.h" 29#include "llvm/Support/raw_ostream.h" 30 31using namespace llvm; 32 33#define DEBUG_TYPE "load-combine" 34 35STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining"); 36STATISTIC(NumLoadsCombined, "Number of loads combined"); 37 38#define LDCOMBINE_NAME "Combine Adjacent Loads" 39 40namespace { 41struct PointerOffsetPair { 42 Value *Pointer; 43 APInt Offset; 44}; 45 46struct LoadPOPPair { 47 LoadPOPPair() = default; 48 LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O) 49 : Load(L), POP(P), InsertOrder(O) {} 50 LoadInst *Load; 51 PointerOffsetPair POP; 52 /// \brief The new load needs to be created before the first load in IR order. 53 unsigned InsertOrder; 54}; 55 56class LoadCombine : public BasicBlockPass { 57 LLVMContext *C; 58 AliasAnalysis *AA; 59 60public: 61 LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) { 62 initializeLoadCombinePass(*PassRegistry::getPassRegistry()); 63 } 64 65 using llvm::Pass::doInitialization; 66 bool doInitialization(Function &) override; 67 bool runOnBasicBlock(BasicBlock &BB) override; 68 void getAnalysisUsage(AnalysisUsage &AU) const override { 69 AU.setPreservesCFG(); 70 AU.addRequired<AAResultsWrapperPass>(); 71 AU.addPreserved<GlobalsAAWrapperPass>(); 72 } 73 74 const char *getPassName() const override { return LDCOMBINE_NAME; } 75 static char ID; 76 77 typedef IRBuilder<TargetFolder> BuilderTy; 78 79private: 80 BuilderTy *Builder; 81 82 PointerOffsetPair getPointerOffsetPair(LoadInst &); 83 bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &); 84 bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &); 85 bool combineLoads(SmallVectorImpl<LoadPOPPair> &); 86}; 87} 88 89bool LoadCombine::doInitialization(Function &F) { 90 DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n"); 91 C = &F.getContext(); 92 return true; 93} 94 95PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) { 96 auto &DL = LI.getModule()->getDataLayout(); 97 98 PointerOffsetPair POP; 99 POP.Pointer = LI.getPointerOperand(); 100 unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace()); 101 POP.Offset = APInt(BitWidth, 0); 102 103 while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) { 104 if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) { 105 APInt LastOffset = POP.Offset; 106 if (!GEP->accumulateConstantOffset(DL, POP.Offset)) { 107 // Can't handle GEPs with variable indices. 108 POP.Offset = LastOffset; 109 return POP; 110 } 111 POP.Pointer = GEP->getPointerOperand(); 112 } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) { 113 POP.Pointer = BC->getOperand(0); 114 } 115 } 116 return POP; 117} 118 119bool LoadCombine::combineLoads( 120 DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) { 121 bool Combined = false; 122 for (auto &Loads : LoadMap) { 123 if (Loads.second.size() < 2) 124 continue; 125 std::sort(Loads.second.begin(), Loads.second.end(), 126 [](const LoadPOPPair &A, const LoadPOPPair &B) { 127 return A.POP.Offset.slt(B.POP.Offset); 128 }); 129 if (aggregateLoads(Loads.second)) 130 Combined = true; 131 } 132 return Combined; 133} 134 135/// \brief Try to aggregate loads from a sorted list of loads to be combined. 136/// 137/// It is guaranteed that no writes occur between any of the loads. All loads 138/// have the same base pointer. There are at least two loads. 139bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) { 140 assert(Loads.size() >= 2 && "Insufficient loads!"); 141 LoadInst *BaseLoad = nullptr; 142 SmallVector<LoadPOPPair, 8> AggregateLoads; 143 bool Combined = false; 144 bool ValidPrevOffset = false; 145 APInt PrevOffset; 146 uint64_t PrevSize = 0; 147 for (auto &L : Loads) { 148 if (ValidPrevOffset == false) { 149 BaseLoad = L.Load; 150 PrevOffset = L.POP.Offset; 151 PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize( 152 L.Load->getType()); 153 AggregateLoads.push_back(L); 154 ValidPrevOffset = true; 155 continue; 156 } 157 if (L.Load->getAlignment() > BaseLoad->getAlignment()) 158 continue; 159 APInt PrevEnd = PrevOffset + PrevSize; 160 if (L.POP.Offset.sgt(PrevEnd)) { 161 // No other load will be combinable 162 if (combineLoads(AggregateLoads)) 163 Combined = true; 164 AggregateLoads.clear(); 165 ValidPrevOffset = false; 166 continue; 167 } 168 if (L.POP.Offset != PrevEnd) 169 // This load is offset less than the size of the last load. 170 // FIXME: We may want to handle this case. 171 continue; 172 PrevOffset = L.POP.Offset; 173 PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize( 174 L.Load->getType()); 175 AggregateLoads.push_back(L); 176 } 177 if (combineLoads(AggregateLoads)) 178 Combined = true; 179 return Combined; 180} 181 182/// \brief Given a list of combinable load. Combine the maximum number of them. 183bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) { 184 // Remove loads from the end while the size is not a power of 2. 185 unsigned TotalSize = 0; 186 for (const auto &L : Loads) 187 TotalSize += L.Load->getType()->getPrimitiveSizeInBits(); 188 while (TotalSize != 0 && !isPowerOf2_32(TotalSize)) 189 TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits(); 190 if (Loads.size() < 2) 191 return false; 192 193 DEBUG({ 194 dbgs() << "***** Combining Loads ******\n"; 195 for (const auto &L : Loads) { 196 dbgs() << L.POP.Offset << ": " << *L.Load << "\n"; 197 } 198 }); 199 200 // Find first load. This is where we put the new load. 201 LoadPOPPair FirstLP; 202 FirstLP.InsertOrder = -1u; 203 for (const auto &L : Loads) 204 if (L.InsertOrder < FirstLP.InsertOrder) 205 FirstLP = L; 206 207 unsigned AddressSpace = 208 FirstLP.POP.Pointer->getType()->getPointerAddressSpace(); 209 210 Builder->SetInsertPoint(FirstLP.Load); 211 Value *Ptr = Builder->CreateConstGEP1_64( 212 Builder->CreatePointerCast(Loads[0].POP.Pointer, 213 Builder->getInt8PtrTy(AddressSpace)), 214 Loads[0].POP.Offset.getSExtValue()); 215 LoadInst *NewLoad = new LoadInst( 216 Builder->CreatePointerCast( 217 Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize), 218 Ptr->getType()->getPointerAddressSpace())), 219 Twine(Loads[0].Load->getName()) + ".combined", false, 220 Loads[0].Load->getAlignment(), FirstLP.Load); 221 222 for (const auto &L : Loads) { 223 Builder->SetInsertPoint(L.Load); 224 Value *V = Builder->CreateExtractInteger( 225 L.Load->getModule()->getDataLayout(), NewLoad, 226 cast<IntegerType>(L.Load->getType()), 227 (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract"); 228 L.Load->replaceAllUsesWith(V); 229 } 230 231 NumLoadsCombined = NumLoadsCombined + Loads.size(); 232 return true; 233} 234 235bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { 236 if (skipBasicBlock(BB)) 237 return false; 238 239 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 240 241 IRBuilder<TargetFolder> TheBuilder( 242 BB.getContext(), TargetFolder(BB.getModule()->getDataLayout())); 243 Builder = &TheBuilder; 244 245 DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap; 246 AliasSetTracker AST(*AA); 247 248 bool Combined = false; 249 unsigned Index = 0; 250 for (auto &I : BB) { 251 if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) { 252 if (combineLoads(LoadMap)) 253 Combined = true; 254 LoadMap.clear(); 255 AST.clear(); 256 continue; 257 } 258 LoadInst *LI = dyn_cast<LoadInst>(&I); 259 if (!LI) 260 continue; 261 ++NumLoadsAnalyzed; 262 if (!LI->isSimple() || !LI->getType()->isIntegerTy()) 263 continue; 264 auto POP = getPointerOffsetPair(*LI); 265 if (!POP.Pointer) 266 continue; 267 LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++)); 268 AST.add(LI); 269 } 270 if (combineLoads(LoadMap)) 271 Combined = true; 272 return Combined; 273} 274 275char LoadCombine::ID = 0; 276 277BasicBlockPass *llvm::createLoadCombinePass() { 278 return new LoadCombine(); 279} 280 281INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false) 282INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 283INITIALIZE_PASS_END(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false) 284