LoadCombine.cpp revision cd81d94322a39503e4a3e87b6ee03d4fcb3465fb
1//===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9/// \file
10/// This transformation combines adjacent loads.
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Transforms/Scalar.h"
15
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/Analysis/TargetFolder.h"
19#include "llvm/Pass.h"
20#include "llvm/IR/DataLayout.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include "llvm/Support/raw_ostream.h"
27
28using namespace llvm;
29
30#define DEBUG_TYPE "load-combine"
31
32STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
33STATISTIC(NumLoadsCombined, "Number of loads combined");
34
35namespace {
36struct PointerOffsetPair {
37  Value *Pointer;
38  uint64_t Offset;
39};
40
41struct LoadPOPPair {
42  LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
43      : Load(L), POP(P), InsertOrder(O) {}
44  LoadPOPPair() {}
45  LoadInst *Load;
46  PointerOffsetPair POP;
47  /// \brief The new load needs to be created before the first load in IR order.
48  unsigned InsertOrder;
49};
50
51class LoadCombine : public BasicBlockPass {
52  LLVMContext *C;
53  const DataLayout *DL;
54
55public:
56  LoadCombine()
57      : BasicBlockPass(ID),
58        C(nullptr), DL(nullptr) {
59    initializeSROAPass(*PassRegistry::getPassRegistry());
60  }
61  bool doInitialization(Function &) override;
62  bool runOnBasicBlock(BasicBlock &BB) override;
63  void getAnalysisUsage(AnalysisUsage &AU) const override;
64
65  const char *getPassName() const override { return "LoadCombine"; }
66  static char ID;
67
68  typedef IRBuilder<true, TargetFolder> BuilderTy;
69
70private:
71  BuilderTy *Builder;
72
73  PointerOffsetPair getPointerOffsetPair(LoadInst &);
74  bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
75  bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
76  bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
77};
78}
79
80bool LoadCombine::doInitialization(Function &F) {
81  DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
82  C = &F.getContext();
83  DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
84  if (!DLP) {
85    DEBUG(dbgs() << "  Skipping LoadCombine -- no target data!\n");
86    return false;
87  }
88  DL = &DLP->getDataLayout();
89  return true;
90}
91
92PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
93  PointerOffsetPair POP;
94  POP.Pointer = LI.getPointerOperand();
95  POP.Offset = 0;
96  while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
97    if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
98      unsigned BitWidth = DL->getPointerTypeSizeInBits(GEP->getType());
99      APInt Offset(BitWidth, 0);
100      if (GEP->accumulateConstantOffset(*DL, Offset))
101        POP.Offset += Offset.getZExtValue();
102      else
103        // Can't handle GEPs with variable indices.
104        return POP;
105      POP.Pointer = GEP->getPointerOperand();
106    } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer))
107      POP.Pointer = BC->getOperand(0);
108  }
109  return POP;
110}
111
112bool LoadCombine::combineLoads(
113    DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
114  bool Combined = false;
115  for (auto &Loads : LoadMap) {
116    if (Loads.second.size() < 2)
117      continue;
118    std::sort(Loads.second.begin(), Loads.second.end(),
119              [](const LoadPOPPair &A, const LoadPOPPair &B) {
120      return A.POP.Offset < B.POP.Offset;
121    });
122    if (aggregateLoads(Loads.second))
123      Combined = true;
124  }
125  return Combined;
126}
127
128/// \brief Try to aggregate loads from a sorted list of loads to be combined.
129///
130/// It is guaranteed that no writes occur between any of the loads. All loads
131/// have the same base pointer. There are at least two loads.
132bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
133  assert(Loads.size() >= 2 && "Insufficient loads!");
134  LoadInst *BaseLoad = nullptr;
135  SmallVector<LoadPOPPair, 8> AggregateLoads;
136  bool Combined = false;
137  uint64_t PrevOffset = -1ull;
138  uint64_t PrevSize = 0;
139  for (auto &L : Loads) {
140    if (PrevOffset == -1ull) {
141      BaseLoad = L.Load;
142      PrevOffset = L.POP.Offset;
143      PrevSize = DL->getTypeStoreSize(L.Load->getType());
144      AggregateLoads.push_back(L);
145      continue;
146    }
147    if (L.Load->getAlignment() > BaseLoad->getAlignment())
148      continue;
149    if (L.POP.Offset > PrevOffset + PrevSize) {
150      // No other load will be combinable
151      if (combineLoads(AggregateLoads))
152        Combined = true;
153      AggregateLoads.clear();
154      PrevOffset = -1;
155      continue;
156    }
157    if (L.POP.Offset != PrevOffset + PrevSize)
158      // This load is offset less than the size of the last load.
159      // FIXME: We may want to handle this case.
160      continue;
161    PrevOffset = L.POP.Offset;
162    PrevSize = DL->getTypeStoreSize(L.Load->getType());
163    AggregateLoads.push_back(L);
164  }
165  if (combineLoads(AggregateLoads))
166    Combined = true;
167  return Combined;
168}
169
170/// \brief Given a list of combinable load. Combine the maximum number of them.
171bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
172  // Remove loads from the end while the size is not a power of 2.
173  unsigned TotalSize = 0;
174  for (const auto &L : Loads)
175    TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
176  while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
177    TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
178  if (Loads.size() < 2)
179    return false;
180
181  DEBUG({
182    dbgs() << "***** Combining Loads ******\n";
183    for (const auto &L : Loads) {
184      dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
185    }
186  });
187
188  // Find first load. This is where we put the new load.
189  LoadPOPPair FirstLP;
190  FirstLP.InsertOrder = -1u;
191  for (const auto &L : Loads)
192    if (L.InsertOrder < FirstLP.InsertOrder)
193      FirstLP = L;
194
195  unsigned AddressSpace =
196      FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
197
198  Builder->SetInsertPoint(FirstLP.Load);
199  Value *Ptr = Builder->CreateConstGEP1_64(
200      Builder->CreatePointerCast(Loads[0].POP.Pointer,
201                                 Builder->getInt8PtrTy(AddressSpace)),
202      Loads[0].POP.Offset);
203  LoadInst *NewLoad = new LoadInst(
204      Builder->CreatePointerCast(
205          Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
206                                Ptr->getType()->getPointerAddressSpace())),
207      Twine(Loads[0].Load->getName()) + ".combined", false,
208      Loads[0].Load->getAlignment(), FirstLP.Load);
209
210  for (const auto &L : Loads) {
211    Builder->SetInsertPoint(L.Load);
212    Value *V = Builder->CreateExtractInteger(
213        *DL, NewLoad, cast<IntegerType>(L.Load->getType()),
214        L.POP.Offset - Loads[0].POP.Offset, "combine.extract");
215    L.Load->replaceAllUsesWith(V);
216  }
217
218  NumLoadsCombined = NumLoadsCombined + Loads.size();
219  return true;
220}
221
222bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
223  if (skipOptnoneFunction(BB) || !DL)
224    return false;
225
226  IRBuilder<true, TargetFolder>
227  TheBuilder(BB.getContext(), TargetFolder(DL));
228  Builder = &TheBuilder;
229
230  DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
231
232  bool Combined = false;
233  unsigned Index = 0;
234  for (auto &I : BB) {
235    if (I.mayWriteToMemory() || I.mayThrow()) {
236      if (combineLoads(LoadMap))
237        Combined = true;
238      LoadMap.clear();
239      continue;
240    }
241    LoadInst *LI = dyn_cast<LoadInst>(&I);
242    if (!LI)
243      continue;
244    ++NumLoadsAnalyzed;
245    if (!LI->isSimple() || !LI->getType()->isIntegerTy())
246      continue;
247    auto POP = getPointerOffsetPair(*LI);
248    if (!POP.Pointer)
249      continue;
250    LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
251  }
252  if (combineLoads(LoadMap))
253    Combined = true;
254  return Combined;
255}
256
257void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const {
258  AU.setPreservesCFG();
259}
260
261char LoadCombine::ID = 0;
262
263BasicBlockPass *llvm::createLoadCombinePass() {
264  return new LoadCombine();
265}
266
267INITIALIZE_PASS(LoadCombine, "load-combine", "Combine Adjacent Loads", false,
268                false)
269