1//===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9/// \file
10/// This transformation combines adjacent loads.
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Transforms/Scalar.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/Statistic.h"
17#include "llvm/Analysis/AliasAnalysis.h"
18#include "llvm/Analysis/AliasSetTracker.h"
19#include "llvm/Analysis/GlobalsModRef.h"
20#include "llvm/Analysis/TargetFolder.h"
21#include "llvm/IR/DataLayout.h"
22#include "llvm/IR/Function.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/IR/Module.h"
26#include "llvm/Pass.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/MathExtras.h"
29#include "llvm/Support/raw_ostream.h"
30
31using namespace llvm;
32
33#define DEBUG_TYPE "load-combine"
34
35STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
36STATISTIC(NumLoadsCombined, "Number of loads combined");
37
38#define LDCOMBINE_NAME "Combine Adjacent Loads"
39
40namespace {
41struct PointerOffsetPair {
42  Value *Pointer;
43  APInt Offset;
44};
45
46struct LoadPOPPair {
47  LoadPOPPair() = default;
48  LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
49      : Load(L), POP(P), InsertOrder(O) {}
50  LoadInst *Load;
51  PointerOffsetPair POP;
52  /// \brief The new load needs to be created before the first load in IR order.
53  unsigned InsertOrder;
54};
55
56class LoadCombine : public BasicBlockPass {
57  LLVMContext *C;
58  AliasAnalysis *AA;
59
60public:
61  LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
62    initializeLoadCombinePass(*PassRegistry::getPassRegistry());
63  }
64
65  using llvm::Pass::doInitialization;
66  bool doInitialization(Function &) override;
67  bool runOnBasicBlock(BasicBlock &BB) override;
68  void getAnalysisUsage(AnalysisUsage &AU) const override {
69    AU.setPreservesCFG();
70    AU.addRequired<AAResultsWrapperPass>();
71    AU.addPreserved<GlobalsAAWrapperPass>();
72  }
73
74  const char *getPassName() const override { return LDCOMBINE_NAME; }
75  static char ID;
76
77  typedef IRBuilder<TargetFolder> BuilderTy;
78
79private:
80  BuilderTy *Builder;
81
82  PointerOffsetPair getPointerOffsetPair(LoadInst &);
83  bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
84  bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
85  bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
86};
87}
88
89bool LoadCombine::doInitialization(Function &F) {
90  DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
91  C = &F.getContext();
92  return true;
93}
94
95PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
96  auto &DL = LI.getModule()->getDataLayout();
97
98  PointerOffsetPair POP;
99  POP.Pointer = LI.getPointerOperand();
100  unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace());
101  POP.Offset = APInt(BitWidth, 0);
102
103  while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
104    if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
105      APInt LastOffset = POP.Offset;
106      if (!GEP->accumulateConstantOffset(DL, POP.Offset)) {
107        // Can't handle GEPs with variable indices.
108        POP.Offset = LastOffset;
109        return POP;
110      }
111      POP.Pointer = GEP->getPointerOperand();
112    } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) {
113      POP.Pointer = BC->getOperand(0);
114    }
115  }
116  return POP;
117}
118
119bool LoadCombine::combineLoads(
120    DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
121  bool Combined = false;
122  for (auto &Loads : LoadMap) {
123    if (Loads.second.size() < 2)
124      continue;
125    std::sort(Loads.second.begin(), Loads.second.end(),
126              [](const LoadPOPPair &A, const LoadPOPPair &B) {
127                return A.POP.Offset.slt(B.POP.Offset);
128              });
129    if (aggregateLoads(Loads.second))
130      Combined = true;
131  }
132  return Combined;
133}
134
135/// \brief Try to aggregate loads from a sorted list of loads to be combined.
136///
137/// It is guaranteed that no writes occur between any of the loads. All loads
138/// have the same base pointer. There are at least two loads.
139bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
140  assert(Loads.size() >= 2 && "Insufficient loads!");
141  LoadInst *BaseLoad = nullptr;
142  SmallVector<LoadPOPPair, 8> AggregateLoads;
143  bool Combined = false;
144  bool ValidPrevOffset = false;
145  APInt PrevOffset;
146  uint64_t PrevSize = 0;
147  for (auto &L : Loads) {
148    if (ValidPrevOffset == false) {
149      BaseLoad = L.Load;
150      PrevOffset = L.POP.Offset;
151      PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
152          L.Load->getType());
153      AggregateLoads.push_back(L);
154      ValidPrevOffset = true;
155      continue;
156    }
157    if (L.Load->getAlignment() > BaseLoad->getAlignment())
158      continue;
159    APInt PrevEnd = PrevOffset + PrevSize;
160    if (L.POP.Offset.sgt(PrevEnd)) {
161      // No other load will be combinable
162      if (combineLoads(AggregateLoads))
163        Combined = true;
164      AggregateLoads.clear();
165      ValidPrevOffset = false;
166      continue;
167    }
168    if (L.POP.Offset != PrevEnd)
169      // This load is offset less than the size of the last load.
170      // FIXME: We may want to handle this case.
171      continue;
172    PrevOffset = L.POP.Offset;
173    PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
174        L.Load->getType());
175    AggregateLoads.push_back(L);
176  }
177  if (combineLoads(AggregateLoads))
178    Combined = true;
179  return Combined;
180}
181
182/// \brief Given a list of combinable load. Combine the maximum number of them.
183bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
184  // Remove loads from the end while the size is not a power of 2.
185  unsigned TotalSize = 0;
186  for (const auto &L : Loads)
187    TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
188  while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
189    TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
190  if (Loads.size() < 2)
191    return false;
192
193  DEBUG({
194    dbgs() << "***** Combining Loads ******\n";
195    for (const auto &L : Loads) {
196      dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
197    }
198  });
199
200  // Find first load. This is where we put the new load.
201  LoadPOPPair FirstLP;
202  FirstLP.InsertOrder = -1u;
203  for (const auto &L : Loads)
204    if (L.InsertOrder < FirstLP.InsertOrder)
205      FirstLP = L;
206
207  unsigned AddressSpace =
208      FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
209
210  Builder->SetInsertPoint(FirstLP.Load);
211  Value *Ptr = Builder->CreateConstGEP1_64(
212      Builder->CreatePointerCast(Loads[0].POP.Pointer,
213                                 Builder->getInt8PtrTy(AddressSpace)),
214      Loads[0].POP.Offset.getSExtValue());
215  LoadInst *NewLoad = new LoadInst(
216      Builder->CreatePointerCast(
217          Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
218                                Ptr->getType()->getPointerAddressSpace())),
219      Twine(Loads[0].Load->getName()) + ".combined", false,
220      Loads[0].Load->getAlignment(), FirstLP.Load);
221
222  for (const auto &L : Loads) {
223    Builder->SetInsertPoint(L.Load);
224    Value *V = Builder->CreateExtractInteger(
225        L.Load->getModule()->getDataLayout(), NewLoad,
226        cast<IntegerType>(L.Load->getType()),
227        (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract");
228    L.Load->replaceAllUsesWith(V);
229  }
230
231  NumLoadsCombined = NumLoadsCombined + Loads.size();
232  return true;
233}
234
235bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
236  if (skipBasicBlock(BB))
237    return false;
238
239  AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
240
241  IRBuilder<TargetFolder> TheBuilder(
242      BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
243  Builder = &TheBuilder;
244
245  DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
246  AliasSetTracker AST(*AA);
247
248  bool Combined = false;
249  unsigned Index = 0;
250  for (auto &I : BB) {
251    if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
252      if (combineLoads(LoadMap))
253        Combined = true;
254      LoadMap.clear();
255      AST.clear();
256      continue;
257    }
258    LoadInst *LI = dyn_cast<LoadInst>(&I);
259    if (!LI)
260      continue;
261    ++NumLoadsAnalyzed;
262    if (!LI->isSimple() || !LI->getType()->isIntegerTy())
263      continue;
264    auto POP = getPointerOffsetPair(*LI);
265    if (!POP.Pointer)
266      continue;
267    LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
268    AST.add(LI);
269  }
270  if (combineLoads(LoadMap))
271    Combined = true;
272  return Combined;
273}
274
275char LoadCombine::ID = 0;
276
277BasicBlockPass *llvm::createLoadCombinePass() {
278  return new LoadCombine();
279}
280
281INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false)
282INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
283INITIALIZE_PASS_END(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false)
284