1//===--- PartiallyInlineLibCalls.cpp - Partially inline libcalls ----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass tries to partially inline the fast path of well-known library
11// functions, such as using square-root instructions for cases where sqrt()
12// does not need to set errno.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/Analysis/TargetTransformInfo.h"
17#include "llvm/IR/IRBuilder.h"
18#include "llvm/IR/Intrinsics.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Target/TargetLibraryInfo.h"
22#include "llvm/Transforms/Scalar.h"
23#include "llvm/Transforms/Utils/BasicBlockUtils.h"
24
25using namespace llvm;
26
27#define DEBUG_TYPE "partially-inline-libcalls"
28
29namespace {
30  class PartiallyInlineLibCalls : public FunctionPass {
31  public:
32    static char ID;
33
34    PartiallyInlineLibCalls() :
35      FunctionPass(ID) {
36      initializePartiallyInlineLibCallsPass(*PassRegistry::getPassRegistry());
37    }
38
39    void getAnalysisUsage(AnalysisUsage &AU) const override;
40    bool runOnFunction(Function &F) override;
41
42  private:
43    /// Optimize calls to sqrt.
44    bool optimizeSQRT(CallInst *Call, Function *CalledFunc,
45                      BasicBlock &CurrBB, Function::iterator &BB);
46  };
47
48  char PartiallyInlineLibCalls::ID = 0;
49}
50
51INITIALIZE_PASS(PartiallyInlineLibCalls, "partially-inline-libcalls",
52                "Partially inline calls to library functions", false, false)
53
54void PartiallyInlineLibCalls::getAnalysisUsage(AnalysisUsage &AU) const {
55  AU.addRequired<TargetLibraryInfo>();
56  AU.addRequired<TargetTransformInfo>();
57  FunctionPass::getAnalysisUsage(AU);
58}
59
60bool PartiallyInlineLibCalls::runOnFunction(Function &F) {
61  bool Changed = false;
62  Function::iterator CurrBB;
63  TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>();
64  const TargetTransformInfo *TTI = &getAnalysis<TargetTransformInfo>();
65  for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) {
66    CurrBB = BB++;
67
68    for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end();
69         II != IE; ++II) {
70      CallInst *Call = dyn_cast<CallInst>(&*II);
71      Function *CalledFunc;
72
73      if (!Call || !(CalledFunc = Call->getCalledFunction()))
74        continue;
75
76      // Skip if function either has local linkage or is not a known library
77      // function.
78      LibFunc::Func LibFunc;
79      if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() ||
80          !TLI->getLibFunc(CalledFunc->getName(), LibFunc))
81        continue;
82
83      switch (LibFunc) {
84      case LibFunc::sqrtf:
85      case LibFunc::sqrt:
86        if (TTI->haveFastSqrt(Call->getType()) &&
87            optimizeSQRT(Call, CalledFunc, *CurrBB, BB))
88          break;
89        continue;
90      default:
91        continue;
92      }
93
94      Changed = true;
95      break;
96    }
97  }
98
99  return Changed;
100}
101
102bool PartiallyInlineLibCalls::optimizeSQRT(CallInst *Call,
103                                           Function *CalledFunc,
104                                           BasicBlock &CurrBB,
105                                           Function::iterator &BB) {
106  // There is no need to change the IR, since backend will emit sqrt
107  // instruction if the call has already been marked read-only.
108  if (Call->onlyReadsMemory())
109    return false;
110
111  // Do the following transformation:
112  //
113  // (before)
114  // dst = sqrt(src)
115  //
116  // (after)
117  // v0 = sqrt_noreadmem(src) # native sqrt instruction.
118  // if (v0 is a NaN)
119  //   v1 = sqrt(src)         # library call.
120  // dst = phi(v0, v1)
121  //
122
123  // Move all instructions following Call to newly created block JoinBB.
124  // Create phi and replace all uses.
125  BasicBlock *JoinBB = llvm::SplitBlock(&CurrBB, Call->getNextNode(), this);
126  IRBuilder<> Builder(JoinBB, JoinBB->begin());
127  PHINode *Phi = Builder.CreatePHI(Call->getType(), 2);
128  Call->replaceAllUsesWith(Phi);
129
130  // Create basic block LibCallBB and insert a call to library function sqrt.
131  BasicBlock *LibCallBB = BasicBlock::Create(CurrBB.getContext(), "call.sqrt",
132                                             CurrBB.getParent(), JoinBB);
133  Builder.SetInsertPoint(LibCallBB);
134  Instruction *LibCall = Call->clone();
135  Builder.Insert(LibCall);
136  Builder.CreateBr(JoinBB);
137
138  // Add attribute "readnone" so that backend can use a native sqrt instruction
139  // for this call. Insert a FP compare instruction and a conditional branch
140  // at the end of CurrBB.
141  Call->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
142  CurrBB.getTerminator()->eraseFromParent();
143  Builder.SetInsertPoint(&CurrBB);
144  Value *FCmp = Builder.CreateFCmpOEQ(Call, Call);
145  Builder.CreateCondBr(FCmp, JoinBB, LibCallBB);
146
147  // Add phi operands.
148  Phi->addIncoming(Call, &CurrBB);
149  Phi->addIncoming(LibCall, LibCallBB);
150
151  BB = JoinBB;
152  return true;
153}
154
155FunctionPass *llvm::createPartiallyInlineLibCallsPass() {
156  return new PartiallyInlineLibCalls();
157}
158