1//===- PartialInlining.cpp - Inline parts of functions --------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// This pass performs partial inlining, typically by inlining an if statement 11// that surrounds the body of the function. 12// 13//===----------------------------------------------------------------------===// 14 15#include "llvm/Transforms/IPO/PartialInlining.h" 16#include "llvm/ADT/Statistic.h" 17#include "llvm/IR/CFG.h" 18#include "llvm/IR/Dominators.h" 19#include "llvm/IR/Instructions.h" 20#include "llvm/IR/Module.h" 21#include "llvm/Pass.h" 22#include "llvm/Transforms/IPO.h" 23#include "llvm/Transforms/Utils/Cloning.h" 24#include "llvm/Transforms/Utils/CodeExtractor.h" 25using namespace llvm; 26 27#define DEBUG_TYPE "partialinlining" 28 29STATISTIC(NumPartialInlined, "Number of functions partially inlined"); 30 31namespace { 32struct PartialInlinerLegacyPass : public ModulePass { 33 static char ID; // Pass identification, replacement for typeid 34 PartialInlinerLegacyPass() : ModulePass(ID) { 35 initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); 36 } 37 38 bool runOnModule(Module &M) override { 39 if (skipModule(M)) 40 return false; 41 ModuleAnalysisManager DummyMAM; 42 auto PA = Impl.run(M, DummyMAM); 43 return !PA.areAllPreserved(); 44 } 45 46private: 47 PartialInlinerPass Impl; 48 }; 49} 50 51char PartialInlinerLegacyPass::ID = 0; 52INITIALIZE_PASS(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", 53 false, false) 54 55ModulePass *llvm::createPartialInliningPass() { 56 return new PartialInlinerLegacyPass(); 57} 58 59Function *PartialInlinerPass::unswitchFunction(Function *F) { 60 // First, verify that this function is an unswitching candidate... 61 BasicBlock *entryBlock = &F->front(); 62 BranchInst *BR = dyn_cast<BranchInst>(entryBlock->getTerminator()); 63 if (!BR || BR->isUnconditional()) 64 return nullptr; 65 66 BasicBlock* returnBlock = nullptr; 67 BasicBlock* nonReturnBlock = nullptr; 68 unsigned returnCount = 0; 69 for (BasicBlock *BB : successors(entryBlock)) { 70 if (isa<ReturnInst>(BB->getTerminator())) { 71 returnBlock = BB; 72 returnCount++; 73 } else 74 nonReturnBlock = BB; 75 } 76 77 if (returnCount != 1) 78 return nullptr; 79 80 // Clone the function, so that we can hack away on it. 81 ValueToValueMapTy VMap; 82 Function* duplicateFunction = CloneFunction(F, VMap); 83 duplicateFunction->setLinkage(GlobalValue::InternalLinkage); 84 BasicBlock* newEntryBlock = cast<BasicBlock>(VMap[entryBlock]); 85 BasicBlock* newReturnBlock = cast<BasicBlock>(VMap[returnBlock]); 86 BasicBlock* newNonReturnBlock = cast<BasicBlock>(VMap[nonReturnBlock]); 87 88 // Go ahead and update all uses to the duplicate, so that we can just 89 // use the inliner functionality when we're done hacking. 90 F->replaceAllUsesWith(duplicateFunction); 91 92 // Special hackery is needed with PHI nodes that have inputs from more than 93 // one extracted block. For simplicity, just split the PHIs into a two-level 94 // sequence of PHIs, some of which will go in the extracted region, and some 95 // of which will go outside. 96 BasicBlock* preReturn = newReturnBlock; 97 newReturnBlock = newReturnBlock->splitBasicBlock( 98 newReturnBlock->getFirstNonPHI()->getIterator()); 99 BasicBlock::iterator I = preReturn->begin(); 100 Instruction *Ins = &newReturnBlock->front(); 101 while (I != preReturn->end()) { 102 PHINode* OldPhi = dyn_cast<PHINode>(I); 103 if (!OldPhi) break; 104 105 PHINode *retPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins); 106 OldPhi->replaceAllUsesWith(retPhi); 107 Ins = newReturnBlock->getFirstNonPHI(); 108 109 retPhi->addIncoming(&*I, preReturn); 110 retPhi->addIncoming(OldPhi->getIncomingValueForBlock(newEntryBlock), 111 newEntryBlock); 112 OldPhi->removeIncomingValue(newEntryBlock); 113 114 ++I; 115 } 116 newEntryBlock->getTerminator()->replaceUsesOfWith(preReturn, newReturnBlock); 117 118 // Gather up the blocks that we're going to extract. 119 std::vector<BasicBlock*> toExtract; 120 toExtract.push_back(newNonReturnBlock); 121 for (BasicBlock &BB : *duplicateFunction) 122 if (&BB != newEntryBlock && &BB != newReturnBlock && 123 &BB != newNonReturnBlock) 124 toExtract.push_back(&BB); 125 126 // The CodeExtractor needs a dominator tree. 127 DominatorTree DT; 128 DT.recalculate(*duplicateFunction); 129 130 // Extract the body of the if. 131 Function* extractedFunction 132 = CodeExtractor(toExtract, &DT).extractCodeRegion(); 133 134 InlineFunctionInfo IFI; 135 136 // Inline the top-level if test into all callers. 137 std::vector<User *> Users(duplicateFunction->user_begin(), 138 duplicateFunction->user_end()); 139 for (User *User : Users) 140 if (CallInst *CI = dyn_cast<CallInst>(User)) 141 InlineFunction(CI, IFI); 142 else if (InvokeInst *II = dyn_cast<InvokeInst>(User)) 143 InlineFunction(II, IFI); 144 145 // Ditch the duplicate, since we're done with it, and rewrite all remaining 146 // users (function pointers, etc.) back to the original function. 147 duplicateFunction->replaceAllUsesWith(F); 148 duplicateFunction->eraseFromParent(); 149 150 ++NumPartialInlined; 151 152 return extractedFunction; 153} 154 155PreservedAnalyses PartialInlinerPass::run(Module &M, ModuleAnalysisManager &) { 156 std::vector<Function*> worklist; 157 worklist.reserve(M.size()); 158 for (Function &F : M) 159 if (!F.use_empty() && !F.isDeclaration()) 160 worklist.push_back(&F); 161 162 bool changed = false; 163 while (!worklist.empty()) { 164 Function* currFunc = worklist.back(); 165 worklist.pop_back(); 166 167 if (currFunc->use_empty()) continue; 168 169 bool recursive = false; 170 for (User *U : currFunc->users()) 171 if (Instruction* I = dyn_cast<Instruction>(U)) 172 if (I->getParent()->getParent() == currFunc) { 173 recursive = true; 174 break; 175 } 176 if (recursive) continue; 177 178 179 if (Function* newFunc = unswitchFunction(currFunc)) { 180 worklist.push_back(newFunc); 181 changed = true; 182 } 183 184 } 185 186 if (changed) 187 return PreservedAnalyses::none(); 188 return PreservedAnalyses::all(); 189} 190