1//===---- DemandedBits.cpp - Determine demanded bits ----------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass implements a demanded bits analysis. A demanded bit is one that
11// contributes to a result; bits that are not demanded can be either zero or
12// one without affecting control or data flow. For example in this sequence:
13//
14//   %1 = add i32 %x, %y
15//   %2 = trunc i32 %1 to i16
16//
17// Only the lowest 16 bits of %1 are demanded; the rest are removed by the
18// trunc.
19//
20//===----------------------------------------------------------------------===//
21
22#include "llvm/Analysis/DemandedBits.h"
23#include "llvm/Transforms/Scalar.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/DepthFirstIterator.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/StringExtras.h"
29#include "llvm/Analysis/AssumptionCache.h"
30#include "llvm/Analysis/ValueTracking.h"
31#include "llvm/IR/BasicBlock.h"
32#include "llvm/IR/CFG.h"
33#include "llvm/IR/DataLayout.h"
34#include "llvm/IR/Dominators.h"
35#include "llvm/IR/InstIterator.h"
36#include "llvm/IR/Instructions.h"
37#include "llvm/IR/IntrinsicInst.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/Operator.h"
40#include "llvm/Pass.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/raw_ostream.h"
43using namespace llvm;
44
45#define DEBUG_TYPE "demanded-bits"
46
47char DemandedBits::ID = 0;
48INITIALIZE_PASS_BEGIN(DemandedBits, "demanded-bits", "Demanded bits analysis",
49                      false, false)
50INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
51INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
52INITIALIZE_PASS_END(DemandedBits, "demanded-bits", "Demanded bits analysis",
53                    false, false)
54
55DemandedBits::DemandedBits() : FunctionPass(ID), F(nullptr), Analyzed(false) {
56  initializeDemandedBitsPass(*PassRegistry::getPassRegistry());
57}
58
59void DemandedBits::getAnalysisUsage(AnalysisUsage &AU) const {
60  AU.setPreservesCFG();
61  AU.addRequired<AssumptionCacheTracker>();
62  AU.addRequired<DominatorTreeWrapperPass>();
63  AU.setPreservesAll();
64}
65
66static bool isAlwaysLive(Instruction *I) {
67  return isa<TerminatorInst>(I) || isa<DbgInfoIntrinsic>(I) ||
68      I->isEHPad() || I->mayHaveSideEffects();
69}
70
71void DemandedBits::determineLiveOperandBits(
72    const Instruction *UserI, const Instruction *I, unsigned OperandNo,
73    const APInt &AOut, APInt &AB, APInt &KnownZero, APInt &KnownOne,
74    APInt &KnownZero2, APInt &KnownOne2) {
75  unsigned BitWidth = AB.getBitWidth();
76
77  // We're called once per operand, but for some instructions, we need to
78  // compute known bits of both operands in order to determine the live bits of
79  // either (when both operands are instructions themselves). We don't,
80  // however, want to do this twice, so we cache the result in APInts that live
81  // in the caller. For the two-relevant-operands case, both operand values are
82  // provided here.
83  auto ComputeKnownBits =
84      [&](unsigned BitWidth, const Value *V1, const Value *V2) {
85        const DataLayout &DL = I->getModule()->getDataLayout();
86        KnownZero = APInt(BitWidth, 0);
87        KnownOne = APInt(BitWidth, 0);
88        computeKnownBits(const_cast<Value *>(V1), KnownZero, KnownOne, DL, 0,
89                         AC, UserI, DT);
90
91        if (V2) {
92          KnownZero2 = APInt(BitWidth, 0);
93          KnownOne2 = APInt(BitWidth, 0);
94          computeKnownBits(const_cast<Value *>(V2), KnownZero2, KnownOne2, DL,
95                           0, AC, UserI, DT);
96        }
97      };
98
99  switch (UserI->getOpcode()) {
100  default: break;
101  case Instruction::Call:
102  case Instruction::Invoke:
103    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI))
104      switch (II->getIntrinsicID()) {
105      default: break;
106      case Intrinsic::bswap:
107        // The alive bits of the input are the swapped alive bits of
108        // the output.
109        AB = AOut.byteSwap();
110        break;
111      case Intrinsic::ctlz:
112        if (OperandNo == 0) {
113          // We need some output bits, so we need all bits of the
114          // input to the left of, and including, the leftmost bit
115          // known to be one.
116          ComputeKnownBits(BitWidth, I, nullptr);
117          AB = APInt::getHighBitsSet(BitWidth,
118                 std::min(BitWidth, KnownOne.countLeadingZeros()+1));
119        }
120        break;
121      case Intrinsic::cttz:
122        if (OperandNo == 0) {
123          // We need some output bits, so we need all bits of the
124          // input to the right of, and including, the rightmost bit
125          // known to be one.
126          ComputeKnownBits(BitWidth, I, nullptr);
127          AB = APInt::getLowBitsSet(BitWidth,
128                 std::min(BitWidth, KnownOne.countTrailingZeros()+1));
129        }
130        break;
131      }
132    break;
133  case Instruction::Add:
134  case Instruction::Sub:
135  case Instruction::Mul:
136    // Find the highest live output bit. We don't need any more input
137    // bits than that (adds, and thus subtracts, ripple only to the
138    // left).
139    AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
140    break;
141  case Instruction::Shl:
142    if (OperandNo == 0)
143      if (ConstantInt *CI =
144            dyn_cast<ConstantInt>(UserI->getOperand(1))) {
145        uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1);
146        AB = AOut.lshr(ShiftAmt);
147
148        // If the shift is nuw/nsw, then the high bits are not dead
149        // (because we've promised that they *must* be zero).
150        const ShlOperator *S = cast<ShlOperator>(UserI);
151        if (S->hasNoSignedWrap())
152          AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
153        else if (S->hasNoUnsignedWrap())
154          AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
155      }
156    break;
157  case Instruction::LShr:
158    if (OperandNo == 0)
159      if (ConstantInt *CI =
160            dyn_cast<ConstantInt>(UserI->getOperand(1))) {
161        uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1);
162        AB = AOut.shl(ShiftAmt);
163
164        // If the shift is exact, then the low bits are not dead
165        // (they must be zero).
166        if (cast<LShrOperator>(UserI)->isExact())
167          AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
168      }
169    break;
170  case Instruction::AShr:
171    if (OperandNo == 0)
172      if (ConstantInt *CI =
173            dyn_cast<ConstantInt>(UserI->getOperand(1))) {
174        uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1);
175        AB = AOut.shl(ShiftAmt);
176        // Because the high input bit is replicated into the
177        // high-order bits of the result, if we need any of those
178        // bits, then we must keep the highest input bit.
179        if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
180            .getBoolValue())
181          AB.setBit(BitWidth-1);
182
183        // If the shift is exact, then the low bits are not dead
184        // (they must be zero).
185        if (cast<AShrOperator>(UserI)->isExact())
186          AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
187      }
188    break;
189  case Instruction::And:
190    AB = AOut;
191
192    // For bits that are known zero, the corresponding bits in the
193    // other operand are dead (unless they're both zero, in which
194    // case they can't both be dead, so just mark the LHS bits as
195    // dead).
196    if (OperandNo == 0) {
197      ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
198      AB &= ~KnownZero2;
199    } else {
200      if (!isa<Instruction>(UserI->getOperand(0)))
201        ComputeKnownBits(BitWidth, UserI->getOperand(0), I);
202      AB &= ~(KnownZero & ~KnownZero2);
203    }
204    break;
205  case Instruction::Or:
206    AB = AOut;
207
208    // For bits that are known one, the corresponding bits in the
209    // other operand are dead (unless they're both one, in which
210    // case they can't both be dead, so just mark the LHS bits as
211    // dead).
212    if (OperandNo == 0) {
213      ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
214      AB &= ~KnownOne2;
215    } else {
216      if (!isa<Instruction>(UserI->getOperand(0)))
217        ComputeKnownBits(BitWidth, UserI->getOperand(0), I);
218      AB &= ~(KnownOne & ~KnownOne2);
219    }
220    break;
221  case Instruction::Xor:
222  case Instruction::PHI:
223    AB = AOut;
224    break;
225  case Instruction::Trunc:
226    AB = AOut.zext(BitWidth);
227    break;
228  case Instruction::ZExt:
229    AB = AOut.trunc(BitWidth);
230    break;
231  case Instruction::SExt:
232    AB = AOut.trunc(BitWidth);
233    // Because the high input bit is replicated into the
234    // high-order bits of the result, if we need any of those
235    // bits, then we must keep the highest input bit.
236    if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
237                                      AOut.getBitWidth() - BitWidth))
238        .getBoolValue())
239      AB.setBit(BitWidth-1);
240    break;
241  case Instruction::Select:
242    if (OperandNo != 0)
243      AB = AOut;
244    break;
245  case Instruction::ICmp:
246    // Count the number of leading zeroes in each operand.
247    ComputeKnownBits(BitWidth, I, UserI->getOperand(1));
248    auto NumLeadingZeroes = std::min(KnownZero.countLeadingOnes(),
249                                     KnownZero2.countLeadingOnes());
250    AB = ~APInt::getHighBitsSet(BitWidth, NumLeadingZeroes);
251    break;
252  }
253}
254
255bool DemandedBits::runOnFunction(Function& Fn) {
256  F = &Fn;
257  Analyzed = false;
258  return false;
259}
260
261void DemandedBits::performAnalysis() {
262  if (Analyzed)
263    // Analysis already completed for this function.
264    return;
265  Analyzed = true;
266  AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F);
267  DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
268
269  Visited.clear();
270  AliveBits.clear();
271
272  SmallVector<Instruction*, 128> Worklist;
273
274  // Collect the set of "root" instructions that are known live.
275  for (Instruction &I : instructions(*F)) {
276    if (!isAlwaysLive(&I))
277      continue;
278
279    DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
280    // For integer-valued instructions, set up an initial empty set of alive
281    // bits and add the instruction to the work list. For other instructions
282    // add their operands to the work list (for integer values operands, mark
283    // all bits as live).
284    if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
285      if (!AliveBits.count(&I)) {
286        AliveBits[&I] = APInt(IT->getBitWidth(), 0);
287        Worklist.push_back(&I);
288      }
289
290      continue;
291    }
292
293    // Non-integer-typed instructions...
294    for (Use &OI : I.operands()) {
295      if (Instruction *J = dyn_cast<Instruction>(OI)) {
296        if (IntegerType *IT = dyn_cast<IntegerType>(J->getType()))
297          AliveBits[J] = APInt::getAllOnesValue(IT->getBitWidth());
298        Worklist.push_back(J);
299      }
300    }
301    // To save memory, we don't add I to the Visited set here. Instead, we
302    // check isAlwaysLive on every instruction when searching for dead
303    // instructions later (we need to check isAlwaysLive for the
304    // integer-typed instructions anyway).
305  }
306
307  // Propagate liveness backwards to operands.
308  while (!Worklist.empty()) {
309    Instruction *UserI = Worklist.pop_back_val();
310
311    DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
312    APInt AOut;
313    if (UserI->getType()->isIntegerTy()) {
314      AOut = AliveBits[UserI];
315      DEBUG(dbgs() << " Alive Out: " << AOut);
316    }
317    DEBUG(dbgs() << "\n");
318
319    if (!UserI->getType()->isIntegerTy())
320      Visited.insert(UserI);
321
322    APInt KnownZero, KnownOne, KnownZero2, KnownOne2;
323    // Compute the set of alive bits for each operand. These are anded into the
324    // existing set, if any, and if that changes the set of alive bits, the
325    // operand is added to the work-list.
326    for (Use &OI : UserI->operands()) {
327      if (Instruction *I = dyn_cast<Instruction>(OI)) {
328        if (IntegerType *IT = dyn_cast<IntegerType>(I->getType())) {
329          unsigned BitWidth = IT->getBitWidth();
330          APInt AB = APInt::getAllOnesValue(BitWidth);
331          if (UserI->getType()->isIntegerTy() && !AOut &&
332              !isAlwaysLive(UserI)) {
333            AB = APInt(BitWidth, 0);
334          } else {
335            // If all bits of the output are dead, then all bits of the input
336            // Bits of each operand that are used to compute alive bits of the
337            // output are alive, all others are dead.
338            determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB,
339                                     KnownZero, KnownOne,
340                                     KnownZero2, KnownOne2);
341          }
342
343          // If we've added to the set of alive bits (or the operand has not
344          // been previously visited), then re-queue the operand to be visited
345          // again.
346          APInt ABPrev(BitWidth, 0);
347          auto ABI = AliveBits.find(I);
348          if (ABI != AliveBits.end())
349            ABPrev = ABI->second;
350
351          APInt ABNew = AB | ABPrev;
352          if (ABNew != ABPrev || ABI == AliveBits.end()) {
353            AliveBits[I] = std::move(ABNew);
354            Worklist.push_back(I);
355          }
356        } else if (!Visited.count(I)) {
357          Worklist.push_back(I);
358        }
359      }
360    }
361  }
362}
363
364APInt DemandedBits::getDemandedBits(Instruction *I) {
365  performAnalysis();
366
367  const DataLayout &DL = I->getParent()->getModule()->getDataLayout();
368  if (AliveBits.count(I))
369    return AliveBits[I];
370  return APInt::getAllOnesValue(DL.getTypeSizeInBits(I->getType()));
371}
372
373bool DemandedBits::isInstructionDead(Instruction *I) {
374  performAnalysis();
375
376  return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
377    !isAlwaysLive(I);
378}
379
380void DemandedBits::print(raw_ostream &OS, const Module *M) const {
381  // This is gross. But the alternative is making all the state mutable
382  // just because of this one debugging method.
383  const_cast<DemandedBits*>(this)->performAnalysis();
384  for (auto &KV : AliveBits) {
385    OS << "DemandedBits: 0x" << utohexstr(KV.second.getLimitedValue()) << " for "
386       << *KV.first << "\n";
387  }
388}
389
390FunctionPass *llvm::createDemandedBitsPass() {
391  return new DemandedBits();
392}
393