1#define DEBUG_TYPE "lower-expect-intrinsic"
2#include "llvm/Constants.h"
3#include "llvm/Function.h"
4#include "llvm/BasicBlock.h"
5#include "llvm/LLVMContext.h"
6#include "llvm/Instructions.h"
7#include "llvm/Intrinsics.h"
8#include "llvm/Metadata.h"
9#include "llvm/Pass.h"
10#include "llvm/Transforms/Scalar.h"
11#include "llvm/Support/CommandLine.h"
12#include "llvm/Support/Debug.h"
13#include "llvm/ADT/Statistic.h"
14#include <vector>
15
16using namespace llvm;
17
18STATISTIC(IfHandled, "Number of 'expect' intrinsic intructions handled");
19
20static cl::opt<uint32_t>
21LikelyBranchWeight("likely-branch-weight", cl::Hidden, cl::init(64),
22                   cl::desc("Weight of the branch likely to be taken (default = 64)"));
23static cl::opt<uint32_t>
24UnlikelyBranchWeight("unlikely-branch-weight", cl::Hidden, cl::init(4),
25                   cl::desc("Weight of the branch unlikely to be taken (default = 4)"));
26
27namespace {
28
29  class LowerExpectIntrinsic : public FunctionPass {
30
31    bool HandleSwitchExpect(SwitchInst *SI);
32
33    bool HandleIfExpect(BranchInst *BI);
34
35  public:
36    static char ID;
37    LowerExpectIntrinsic() : FunctionPass(ID) {
38      initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
39    }
40
41    bool runOnFunction(Function &F);
42  };
43}
44
45
46bool LowerExpectIntrinsic::HandleSwitchExpect(SwitchInst *SI) {
47  CallInst *CI = dyn_cast<CallInst>(SI->getCondition());
48  if (!CI)
49    return false;
50
51  Function *Fn = CI->getCalledFunction();
52  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
53    return false;
54
55  Value *ArgValue = CI->getArgOperand(0);
56  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
57  if (!ExpectedValue)
58    return false;
59
60  LLVMContext &Context = CI->getContext();
61  Type *Int32Ty = Type::getInt32Ty(Context);
62
63  unsigned caseNo = SI->findCaseValue(ExpectedValue);
64  std::vector<Value *> Vec;
65  unsigned n = SI->getNumCases();
66  Vec.resize(n + 1); // +1 for MDString
67
68  Vec[0] = MDString::get(Context, "branch_weights");
69  for (unsigned i = 0; i < n; ++i) {
70    Vec[i + 1] = ConstantInt::get(Int32Ty, i == caseNo ? LikelyBranchWeight : UnlikelyBranchWeight);
71  }
72
73  MDNode *WeightsNode = llvm::MDNode::get(Context, Vec);
74  SI->setMetadata(LLVMContext::MD_prof, WeightsNode);
75
76  SI->setCondition(ArgValue);
77  return true;
78}
79
80
81bool LowerExpectIntrinsic::HandleIfExpect(BranchInst *BI) {
82  if (BI->isUnconditional())
83    return false;
84
85  // Handle non-optimized IR code like:
86  //   %expval = call i64 @llvm.expect.i64.i64(i64 %conv1, i64 1)
87  //   %tobool = icmp ne i64 %expval, 0
88  //   br i1 %tobool, label %if.then, label %if.end
89
90  ICmpInst *CmpI = dyn_cast<ICmpInst>(BI->getCondition());
91  if (!CmpI || CmpI->getPredicate() != CmpInst::ICMP_NE)
92    return false;
93
94  CallInst *CI = dyn_cast<CallInst>(CmpI->getOperand(0));
95  if (!CI)
96    return false;
97
98  Function *Fn = CI->getCalledFunction();
99  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
100    return false;
101
102  Value *ArgValue = CI->getArgOperand(0);
103  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
104  if (!ExpectedValue)
105    return false;
106
107  LLVMContext &Context = CI->getContext();
108  Type *Int32Ty = Type::getInt32Ty(Context);
109  bool Likely = ExpectedValue->isOne();
110
111  // If expect value is equal to 1 it means that we are more likely to take
112  // branch 0, in other case more likely is branch 1.
113  Value *Ops[] = {
114    MDString::get(Context, "branch_weights"),
115    ConstantInt::get(Int32Ty, Likely ? LikelyBranchWeight : UnlikelyBranchWeight),
116    ConstantInt::get(Int32Ty, Likely ? UnlikelyBranchWeight : LikelyBranchWeight)
117  };
118
119  MDNode *WeightsNode = MDNode::get(Context, Ops);
120  BI->setMetadata(LLVMContext::MD_prof, WeightsNode);
121
122  CmpI->setOperand(0, ArgValue);
123  return true;
124}
125
126
127bool LowerExpectIntrinsic::runOnFunction(Function &F) {
128  for (Function::iterator I = F.begin(), E = F.end(); I != E;) {
129    BasicBlock *BB = I++;
130
131    // Create "block_weights" metadata.
132    if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
133      if (HandleIfExpect(BI))
134        IfHandled++;
135    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) {
136      if (HandleSwitchExpect(SI))
137        IfHandled++;
138    }
139
140    // remove llvm.expect intrinsics.
141    for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
142         BI != BE; ) {
143      CallInst *CI = dyn_cast<CallInst>(BI++);
144      if (!CI)
145        continue;
146
147      Function *Fn = CI->getCalledFunction();
148      if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) {
149        Value *Exp = CI->getArgOperand(0);
150        CI->replaceAllUsesWith(Exp);
151        CI->eraseFromParent();
152      }
153    }
154  }
155
156  return false;
157}
158
159
160char LowerExpectIntrinsic::ID = 0;
161INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", "Lower 'expect' "
162                "Intrinsics", false, false)
163
164FunctionPass *llvm::createLowerExpectIntrinsicPass() {
165  return new LowerExpectIntrinsic();
166}
167