1//===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9/// \file
10/// This transformation combines adjacent loads.
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Transforms/Scalar.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/Statistic.h"
17#include "llvm/Analysis/AliasAnalysis.h"
18#include "llvm/Analysis/AliasSetTracker.h"
19#include "llvm/Analysis/TargetFolder.h"
20#include "llvm/IR/DataLayout.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/Instructions.h"
24#include "llvm/IR/Module.h"
25#include "llvm/Pass.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/MathExtras.h"
28#include "llvm/Support/raw_ostream.h"
29
30using namespace llvm;
31
32#define DEBUG_TYPE "load-combine"
33
34STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
35STATISTIC(NumLoadsCombined, "Number of loads combined");
36
37namespace {
38struct PointerOffsetPair {
39  Value *Pointer;
40  uint64_t Offset;
41};
42
43struct LoadPOPPair {
44  LoadPOPPair() = default;
45  LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
46      : Load(L), POP(P), InsertOrder(O) {}
47  LoadInst *Load;
48  PointerOffsetPair POP;
49  /// \brief The new load needs to be created before the first load in IR order.
50  unsigned InsertOrder;
51};
52
53class LoadCombine : public BasicBlockPass {
54  LLVMContext *C;
55  AliasAnalysis *AA;
56
57public:
58  LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
59    initializeSROAPass(*PassRegistry::getPassRegistry());
60  }
61
62  using llvm::Pass::doInitialization;
63  bool doInitialization(Function &) override;
64  bool runOnBasicBlock(BasicBlock &BB) override;
65  void getAnalysisUsage(AnalysisUsage &AU) const override;
66
67  const char *getPassName() const override { return "LoadCombine"; }
68  static char ID;
69
70  typedef IRBuilder<true, TargetFolder> BuilderTy;
71
72private:
73  BuilderTy *Builder;
74
75  PointerOffsetPair getPointerOffsetPair(LoadInst &);
76  bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
77  bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
78  bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
79};
80}
81
82bool LoadCombine::doInitialization(Function &F) {
83  DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
84  C = &F.getContext();
85  return true;
86}
87
88PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
89  PointerOffsetPair POP;
90  POP.Pointer = LI.getPointerOperand();
91  POP.Offset = 0;
92  while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
93    if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
94      auto &DL = LI.getModule()->getDataLayout();
95      unsigned BitWidth = DL.getPointerTypeSizeInBits(GEP->getType());
96      APInt Offset(BitWidth, 0);
97      if (GEP->accumulateConstantOffset(DL, Offset))
98        POP.Offset += Offset.getZExtValue();
99      else
100        // Can't handle GEPs with variable indices.
101        return POP;
102      POP.Pointer = GEP->getPointerOperand();
103    } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer))
104      POP.Pointer = BC->getOperand(0);
105  }
106  return POP;
107}
108
109bool LoadCombine::combineLoads(
110    DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
111  bool Combined = false;
112  for (auto &Loads : LoadMap) {
113    if (Loads.second.size() < 2)
114      continue;
115    std::sort(Loads.second.begin(), Loads.second.end(),
116              [](const LoadPOPPair &A, const LoadPOPPair &B) {
117      return A.POP.Offset < B.POP.Offset;
118    });
119    if (aggregateLoads(Loads.second))
120      Combined = true;
121  }
122  return Combined;
123}
124
125/// \brief Try to aggregate loads from a sorted list of loads to be combined.
126///
127/// It is guaranteed that no writes occur between any of the loads. All loads
128/// have the same base pointer. There are at least two loads.
129bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
130  assert(Loads.size() >= 2 && "Insufficient loads!");
131  LoadInst *BaseLoad = nullptr;
132  SmallVector<LoadPOPPair, 8> AggregateLoads;
133  bool Combined = false;
134  uint64_t PrevOffset = -1ull;
135  uint64_t PrevSize = 0;
136  for (auto &L : Loads) {
137    if (PrevOffset == -1ull) {
138      BaseLoad = L.Load;
139      PrevOffset = L.POP.Offset;
140      PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
141          L.Load->getType());
142      AggregateLoads.push_back(L);
143      continue;
144    }
145    if (L.Load->getAlignment() > BaseLoad->getAlignment())
146      continue;
147    if (L.POP.Offset > PrevOffset + PrevSize) {
148      // No other load will be combinable
149      if (combineLoads(AggregateLoads))
150        Combined = true;
151      AggregateLoads.clear();
152      PrevOffset = -1;
153      continue;
154    }
155    if (L.POP.Offset != PrevOffset + PrevSize)
156      // This load is offset less than the size of the last load.
157      // FIXME: We may want to handle this case.
158      continue;
159    PrevOffset = L.POP.Offset;
160    PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
161        L.Load->getType());
162    AggregateLoads.push_back(L);
163  }
164  if (combineLoads(AggregateLoads))
165    Combined = true;
166  return Combined;
167}
168
169/// \brief Given a list of combinable load. Combine the maximum number of them.
170bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
171  // Remove loads from the end while the size is not a power of 2.
172  unsigned TotalSize = 0;
173  for (const auto &L : Loads)
174    TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
175  while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
176    TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
177  if (Loads.size() < 2)
178    return false;
179
180  DEBUG({
181    dbgs() << "***** Combining Loads ******\n";
182    for (const auto &L : Loads) {
183      dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
184    }
185  });
186
187  // Find first load. This is where we put the new load.
188  LoadPOPPair FirstLP;
189  FirstLP.InsertOrder = -1u;
190  for (const auto &L : Loads)
191    if (L.InsertOrder < FirstLP.InsertOrder)
192      FirstLP = L;
193
194  unsigned AddressSpace =
195      FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
196
197  Builder->SetInsertPoint(FirstLP.Load);
198  Value *Ptr = Builder->CreateConstGEP1_64(
199      Builder->CreatePointerCast(Loads[0].POP.Pointer,
200                                 Builder->getInt8PtrTy(AddressSpace)),
201      Loads[0].POP.Offset);
202  LoadInst *NewLoad = new LoadInst(
203      Builder->CreatePointerCast(
204          Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
205                                Ptr->getType()->getPointerAddressSpace())),
206      Twine(Loads[0].Load->getName()) + ".combined", false,
207      Loads[0].Load->getAlignment(), FirstLP.Load);
208
209  for (const auto &L : Loads) {
210    Builder->SetInsertPoint(L.Load);
211    Value *V = Builder->CreateExtractInteger(
212        L.Load->getModule()->getDataLayout(), NewLoad,
213        cast<IntegerType>(L.Load->getType()),
214        L.POP.Offset - Loads[0].POP.Offset, "combine.extract");
215    L.Load->replaceAllUsesWith(V);
216  }
217
218  NumLoadsCombined = NumLoadsCombined + Loads.size();
219  return true;
220}
221
222bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
223  if (skipOptnoneFunction(BB))
224    return false;
225
226  AA = &getAnalysis<AliasAnalysis>();
227
228  IRBuilder<true, TargetFolder> TheBuilder(
229      BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
230  Builder = &TheBuilder;
231
232  DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
233  AliasSetTracker AST(*AA);
234
235  bool Combined = false;
236  unsigned Index = 0;
237  for (auto &I : BB) {
238    if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
239      if (combineLoads(LoadMap))
240        Combined = true;
241      LoadMap.clear();
242      AST.clear();
243      continue;
244    }
245    LoadInst *LI = dyn_cast<LoadInst>(&I);
246    if (!LI)
247      continue;
248    ++NumLoadsAnalyzed;
249    if (!LI->isSimple() || !LI->getType()->isIntegerTy())
250      continue;
251    auto POP = getPointerOffsetPair(*LI);
252    if (!POP.Pointer)
253      continue;
254    LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
255    AST.add(LI);
256  }
257  if (combineLoads(LoadMap))
258    Combined = true;
259  return Combined;
260}
261
262void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const {
263  AU.setPreservesCFG();
264
265  AU.addRequired<AliasAnalysis>();
266  AU.addPreserved<AliasAnalysis>();
267}
268
269char LoadCombine::ID = 0;
270
271BasicBlockPass *llvm::createLoadCombinePass() {
272  return new LoadCombine();
273}
274
275INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", "Combine Adjacent Loads",
276                      false, false)
277INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
278INITIALIZE_PASS_END(LoadCombine, "load-combine", "Combine Adjacent Loads",
279                    false, false)
280
281