1//===-- AArch64AddressTypePromotion.cpp --- Promote type for addr accesses -==//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass tries to promote the computations use to obtained a sign extended
11// value used into memory accesses.
12// E.g.
13// a = add nsw i32 b, 3
14// d = sext i32 a to i64
15// e = getelementptr ..., i64 d
16//
17// =>
18// f = sext i32 b to i64
19// a = add nsw i64 f, 3
20// e = getelementptr ..., i64 a
21//
22// This is legal to do so if the computations are markers with either nsw or nuw
23// markers.
24// Moreover, the current heuristic is simple: it does not create new sext
25// operations, i.e., it gives up when a sext would have forked (e.g., if
26// a = add i32 b, c, two sexts are required to promote the computation).
27//
28// FIXME: This pass may be useful for other targets too.
29// ===---------------------------------------------------------------------===//
30
31#include "AArch64.h"
32#include "llvm/ADT/DenseMap.h"
33#include "llvm/ADT/SmallPtrSet.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/Dominators.h"
37#include "llvm/IR/Function.h"
38#include "llvm/IR/Instructions.h"
39#include "llvm/IR/Module.h"
40#include "llvm/IR/Operator.h"
41#include "llvm/Pass.h"
42#include "llvm/Support/CommandLine.h"
43#include "llvm/Support/Debug.h"
44
45using namespace llvm;
46
47#define DEBUG_TYPE "aarch64-type-promotion"
48
49static cl::opt<bool>
50EnableAddressTypePromotion("aarch64-type-promotion", cl::Hidden,
51                           cl::desc("Enable the type promotion pass"),
52                           cl::init(true));
53static cl::opt<bool>
54EnableMerge("aarch64-type-promotion-merge", cl::Hidden,
55            cl::desc("Enable merging of redundant sexts when one is dominating"
56                     " the other."),
57            cl::init(true));
58
59//===----------------------------------------------------------------------===//
60//                       AArch64AddressTypePromotion
61//===----------------------------------------------------------------------===//
62
63namespace llvm {
64void initializeAArch64AddressTypePromotionPass(PassRegistry &);
65}
66
67namespace {
68class AArch64AddressTypePromotion : public FunctionPass {
69
70public:
71  static char ID;
72  AArch64AddressTypePromotion()
73      : FunctionPass(ID), Func(nullptr), ConsideredSExtType(nullptr) {
74    initializeAArch64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
75  }
76
77  const char *getPassName() const override {
78    return "AArch64 Address Type Promotion";
79  }
80
81  /// Iterate over the functions and promote the computation of interesting
82  // sext instructions.
83  bool runOnFunction(Function &F) override;
84
85private:
86  /// The current function.
87  Function *Func;
88  /// Filter out all sexts that does not have this type.
89  /// Currently initialized with Int64Ty.
90  Type *ConsideredSExtType;
91
92  // This transformation requires dominator info.
93  void getAnalysisUsage(AnalysisUsage &AU) const override {
94    AU.setPreservesCFG();
95    AU.addRequired<DominatorTreeWrapperPass>();
96    AU.addPreserved<DominatorTreeWrapperPass>();
97    FunctionPass::getAnalysisUsage(AU);
98  }
99
100  typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
101  typedef SmallVector<Instruction *, 16> Instructions;
102  typedef DenseMap<Value *, Instructions> ValueToInsts;
103
104  /// Check if it is profitable to move a sext through this instruction.
105  /// Currently, we consider it is profitable if:
106  /// - Inst is used only once (no need to insert truncate).
107  /// - Inst has only one operand that will require a sext operation (we do
108  ///   do not create new sext operation).
109  bool shouldGetThrough(const Instruction *Inst);
110
111  /// Check if it is possible and legal to move a sext through this
112  /// instruction.
113  /// Current heuristic considers that we can get through:
114  /// - Arithmetic operation marked with the nsw or nuw flag.
115  /// - Other sext operation.
116  /// - Truncate operation if it was just dropping sign extended bits.
117  bool canGetThrough(const Instruction *Inst);
118
119  /// Move sext operations through safe to sext instructions.
120  bool propagateSignExtension(Instructions &SExtInsts);
121
122  /// Is this sext should be considered for code motion.
123  /// We look for sext with ConsideredSExtType and uses in at least one
124  // GetElementPtrInst.
125  bool shouldConsiderSExt(const Instruction *SExt) const;
126
127  /// Collect all interesting sext operations, i.e., the ones with the right
128  /// type and used in memory accesses.
129  /// More precisely, a sext instruction is considered as interesting if it
130  /// is used in a "complex" getelementptr or it exits at least another
131  /// sext instruction that sign extended the same initial value.
132  /// A getelementptr is considered as "complex" if it has more than 2
133  // operands.
134  void analyzeSExtension(Instructions &SExtInsts);
135
136  /// Merge redundant sign extension operations in common dominator.
137  void mergeSExts(ValueToInsts &ValToSExtendedUses,
138                  SetOfInstructions &ToRemove);
139};
140} // end anonymous namespace.
141
142char AArch64AddressTypePromotion::ID = 0;
143
144INITIALIZE_PASS_BEGIN(AArch64AddressTypePromotion, "aarch64-type-promotion",
145                      "AArch64 Type Promotion Pass", false, false)
146INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
147INITIALIZE_PASS_END(AArch64AddressTypePromotion, "aarch64-type-promotion",
148                    "AArch64 Type Promotion Pass", false, false)
149
150FunctionPass *llvm::createAArch64AddressTypePromotionPass() {
151  return new AArch64AddressTypePromotion();
152}
153
154bool AArch64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
155  if (isa<SExtInst>(Inst))
156    return true;
157
158  const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
159  if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
160      (BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
161    return true;
162
163  // sext(trunc(sext)) --> sext
164  if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
165    const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
166    // Check that the truncate just drop sign extended bits.
167    if (Inst->getType()->getIntegerBitWidth() >=
168            Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
169        Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
170            ConsideredSExtType->getIntegerBitWidth())
171      return true;
172  }
173
174  return false;
175}
176
177bool AArch64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
178  // If the type of the sext is the same as the considered one, this sext
179  // will become useless.
180  // Otherwise, we will have to do something to preserve the original value,
181  // unless it is used once.
182  if (isa<SExtInst>(Inst) &&
183      (Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
184    return true;
185
186  // If the Inst is used more that once, we may need to insert truncate
187  // operations and we don't do that at the moment.
188  if (!Inst->hasOneUse())
189    return false;
190
191  // This truncate is used only once, thus if we can get thourgh, it will become
192  // useless.
193  if (isa<TruncInst>(Inst))
194    return true;
195
196  // If both operands are not constant, a new sext will be created here.
197  // Current heuristic is: each step should be profitable.
198  // Therefore we don't allow to increase the number of sext even if it may
199  // be profitable later on.
200  if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
201    return true;
202
203  return false;
204}
205
206static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
207  if (isa<SelectInst>(Inst) && OpIdx == 0)
208    return false;
209  return true;
210}
211
212bool
213AArch64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
214  if (SExt->getType() != ConsideredSExtType)
215    return false;
216
217  for (const User *U : SExt->users()) {
218    if (isa<GetElementPtrInst>(U))
219      return true;
220  }
221
222  return false;
223}
224
225// Input:
226// - SExtInsts contains all the sext instructions that are use direclty in
227//   GetElementPtrInst, i.e., access to memory.
228// Algorithm:
229// - For each sext operation in SExtInsts:
230//   Let var be the operand of sext.
231//   while it is profitable (see shouldGetThrough), legal, and safe
232//   (see canGetThrough) to move sext through var's definition:
233//   * promote the type of var's definition.
234//   * fold var into sext uses.
235//   * move sext above var's definition.
236//   * update sext operand to use the operand of var that should be sign
237//     extended (by construction there is only one).
238//
239//   E.g.,
240//   a = ... i32 c, 3
241//   b = sext i32 a to i64 <- is it legal/safe/profitable to get through 'a'
242//   ...
243//   = b
244// => Yes, update the code
245//   b = sext i32 c to i64
246//   a = ... i64 b, 3
247//   ...
248//   = a
249// Iterate on 'c'.
250bool
251AArch64AddressTypePromotion::propagateSignExtension(Instructions &SExtInsts) {
252  DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");
253
254  bool LocalChange = false;
255  SetOfInstructions ToRemove;
256  ValueToInsts ValToSExtendedUses;
257  while (!SExtInsts.empty()) {
258    // Get through simple chain.
259    Instruction *SExt = SExtInsts.pop_back_val();
260
261    DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');
262
263    // If this SExt has already been merged continue.
264    if (SExt->use_empty() && ToRemove.count(SExt)) {
265      DEBUG(dbgs() << "No uses => marked as delete\n");
266      continue;
267    }
268
269    // Now try to get through the chain of definitions.
270    while (auto *Inst = dyn_cast<Instruction>(SExt->getOperand(0))) {
271      DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
272      if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
273        // We cannot get through something that is not an Instruction
274        // or not safe to SExt.
275        DEBUG(dbgs() << "Cannot get through\n");
276        break;
277      }
278
279      LocalChange = true;
280      // If this is a sign extend, it becomes useless.
281      if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
282        DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
283        // We cannot use replaceAllUsesWith here because we may trigger some
284        // assertion on the type as all involved sext operation may have not
285        // been moved yet.
286        while (!Inst->use_empty()) {
287          Use &U = *Inst->use_begin();
288          Instruction *User = dyn_cast<Instruction>(U.getUser());
289          assert(User && "User of sext is not an Instruction!");
290          User->setOperand(U.getOperandNo(), SExt);
291        }
292        ToRemove.insert(Inst);
293        SExt->setOperand(0, Inst->getOperand(0));
294        SExt->moveBefore(Inst);
295        continue;
296      }
297
298      // Get through the Instruction:
299      // 1. Update its type.
300      // 2. Replace the uses of SExt by Inst.
301      // 3. Sign extend each operand that needs to be sign extended.
302
303      // Step #1.
304      Inst->mutateType(SExt->getType());
305      // Step #2.
306      SExt->replaceAllUsesWith(Inst);
307      // Step #3.
308      Instruction *SExtForOpnd = SExt;
309
310      DEBUG(dbgs() << "Propagate SExt to operands\n");
311      for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
312           ++OpIdx) {
313        DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
314        if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
315            !shouldSExtOperand(Inst, OpIdx)) {
316          DEBUG(dbgs() << "No need to propagate\n");
317          continue;
318        }
319        // Check if we can statically sign extend the operand.
320        Value *Opnd = Inst->getOperand(OpIdx);
321        if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
322          DEBUG(dbgs() << "Statically sign extend\n");
323          Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
324                                                         Cst->getSExtValue()));
325          continue;
326        }
327        // UndefValue are typed, so we have to statically sign extend them.
328        if (isa<UndefValue>(Opnd)) {
329          DEBUG(dbgs() << "Statically sign extend\n");
330          Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
331          continue;
332        }
333
334        // Otherwise we have to explicity sign extend it.
335        assert(SExtForOpnd &&
336               "Only one operand should have been sign extended");
337
338        SExtForOpnd->setOperand(0, Opnd);
339
340        DEBUG(dbgs() << "Move before:\n" << *Inst << "\nSign extend\n");
341        // Move the sign extension before the insertion point.
342        SExtForOpnd->moveBefore(Inst);
343        Inst->setOperand(OpIdx, SExtForOpnd);
344        // If more sext are required, new instructions will have to be created.
345        SExtForOpnd = nullptr;
346      }
347      if (SExtForOpnd == SExt) {
348        DEBUG(dbgs() << "Sign extension is useless now\n");
349        ToRemove.insert(SExt);
350        break;
351      }
352    }
353
354    // If the use is already of the right type, connect its uses to its argument
355    // and delete it.
356    // This can happen for an Instruction which all uses are sign extended.
357    if (!ToRemove.count(SExt) &&
358        SExt->getType() == SExt->getOperand(0)->getType()) {
359      DEBUG(dbgs() << "Sign extension is useless, attach its use to "
360                      "its argument\n");
361      SExt->replaceAllUsesWith(SExt->getOperand(0));
362      ToRemove.insert(SExt);
363    } else
364      ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
365  }
366
367  if (EnableMerge)
368    mergeSExts(ValToSExtendedUses, ToRemove);
369
370  // Remove all instructions marked as ToRemove.
371  for (Instruction *I: ToRemove)
372    I->eraseFromParent();
373  return LocalChange;
374}
375
376void AArch64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
377                                             SetOfInstructions &ToRemove) {
378  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
379
380  for (auto &Entry : ValToSExtendedUses) {
381    Instructions &Insts = Entry.second;
382    Instructions CurPts;
383    for (Instruction *Inst : Insts) {
384      if (ToRemove.count(Inst))
385        continue;
386      bool inserted = false;
387      for (auto &Pt : CurPts) {
388        if (DT.dominates(Inst, Pt)) {
389          DEBUG(dbgs() << "Replace all uses of:\n" << *Pt << "\nwith:\n"
390                       << *Inst << '\n');
391          Pt->replaceAllUsesWith(Inst);
392          ToRemove.insert(Pt);
393          Pt = Inst;
394          inserted = true;
395          break;
396        }
397        if (!DT.dominates(Pt, Inst))
398          // Give up if we need to merge in a common dominator as the
399          // expermients show it is not profitable.
400          continue;
401
402        DEBUG(dbgs() << "Replace all uses of:\n" << *Inst << "\nwith:\n"
403                     << *Pt << '\n');
404        Inst->replaceAllUsesWith(Pt);
405        ToRemove.insert(Inst);
406        inserted = true;
407        break;
408      }
409      if (!inserted)
410        CurPts.push_back(Inst);
411    }
412  }
413}
414
415void AArch64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
416  DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");
417
418  DenseMap<Value *, Instruction *> SeenChains;
419
420  for (auto &BB : *Func) {
421    for (auto &II : BB) {
422      Instruction *SExt = &II;
423
424      // Collect all sext operation per type.
425      if (!isa<SExtInst>(SExt) || !shouldConsiderSExt(SExt))
426        continue;
427
428      DEBUG(dbgs() << "Found:\n" << (*SExt) << '\n');
429
430      // Cases where we actually perform the optimization:
431      // 1. SExt is used in a getelementptr with more than 2 operand =>
432      //    likely we can merge some computation if they are done on 64 bits.
433      // 2. The beginning of the SExt chain is SExt several time. =>
434      //    code sharing is possible.
435
436      bool insert = false;
437      // #1.
438      for (const User *U : SExt->users()) {
439        const Instruction *Inst = dyn_cast<GetElementPtrInst>(U);
440        if (Inst && Inst->getNumOperands() > 2) {
441          DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" << *Inst
442                       << '\n');
443          insert = true;
444          break;
445        }
446      }
447
448      // #2.
449      // Check the head of the chain.
450      Instruction *Inst = SExt;
451      Value *Last;
452      do {
453        int OpdIdx = 0;
454        const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
455        if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
456          OpdIdx = 1;
457        Last = Inst->getOperand(OpdIdx);
458        Inst = dyn_cast<Instruction>(Last);
459      } while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));
460
461      DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
462      DenseMap<Value *, Instruction *>::iterator AlreadySeen =
463          SeenChains.find(Last);
464      if (insert || AlreadySeen != SeenChains.end()) {
465        DEBUG(dbgs() << "Insert\n");
466        SExtInsts.push_back(SExt);
467        if (AlreadySeen != SeenChains.end() && AlreadySeen->second != nullptr) {
468          DEBUG(dbgs() << "Insert chain member\n");
469          SExtInsts.push_back(AlreadySeen->second);
470          SeenChains[Last] = nullptr;
471        }
472      } else {
473        DEBUG(dbgs() << "Record its chain membership\n");
474        SeenChains[Last] = SExt;
475      }
476    }
477  }
478}
479
480bool AArch64AddressTypePromotion::runOnFunction(Function &F) {
481  if (!EnableAddressTypePromotion || F.isDeclaration())
482    return false;
483  Func = &F;
484  ConsideredSExtType = Type::getInt64Ty(Func->getContext());
485
486  DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n');
487
488  Instructions SExtInsts;
489  analyzeSExtension(SExtInsts);
490  return propagateSignExtension(SExtInsts);
491}
492