LoopExtractor.cpp revision 081ce940e7351e90fff829320b7dc6738a6b3815
1//===- LoopExtractor.cpp - Extract each loop into a new function ----------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// A pass wrapper around the ExtractLoop() scalar transformation to extract each 11// top-level loop into its own new function. If the loop is the ONLY loop in a 12// given function, it is not touched. This is a pass most useful for debugging 13// via bugpoint. 14// 15//===----------------------------------------------------------------------===// 16 17#define DEBUG_TYPE "loop-extract" 18#include "llvm/Transforms/IPO.h" 19#include "llvm/Instructions.h" 20#include "llvm/Module.h" 21#include "llvm/Pass.h" 22#include "llvm/Analysis/Dominators.h" 23#include "llvm/Analysis/LoopInfo.h" 24#include "llvm/Support/CommandLine.h" 25#include "llvm/Support/Compiler.h" 26#include "llvm/Transforms/Scalar.h" 27#include "llvm/Transforms/Utils/FunctionUtils.h" 28#include "llvm/ADT/Statistic.h" 29#include <fstream> 30#include <set> 31using namespace llvm; 32 33STATISTIC(NumExtracted, "Number of loops extracted"); 34 35namespace { 36 // FIXME: This is not a function pass, but the PassManager doesn't allow 37 // Module passes to require FunctionPasses, so we can't get loop info if we're 38 // not a function pass. 39 struct VISIBILITY_HIDDEN LoopExtractor : public FunctionPass { 40 static char ID; // Pass identification, replacement for typeid 41 unsigned NumLoops; 42 43 explicit LoopExtractor(unsigned numLoops = ~0) 44 : FunctionPass((intptr_t)&ID), NumLoops(numLoops) {} 45 46 virtual bool runOnFunction(Function &F); 47 48 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 49 AU.addRequiredID(BreakCriticalEdgesID); 50 AU.addRequiredID(LoopSimplifyID); 51 AU.addRequired<DominatorTree>(); 52 AU.addRequired<LoopInfo>(); 53 } 54 }; 55 56 char LoopExtractor::ID = 0; 57 RegisterPass<LoopExtractor> 58 X("loop-extract", "Extract loops into new functions"); 59 60 /// SingleLoopExtractor - For bugpoint. 61 struct SingleLoopExtractor : public LoopExtractor { 62 static char ID; // Pass identification, replacement for typeid 63 SingleLoopExtractor() : LoopExtractor(1) {} 64 }; 65 66 char SingleLoopExtractor::ID = 0; 67 RegisterPass<SingleLoopExtractor> 68 Y("loop-extract-single", "Extract at most one loop into a new function"); 69} // End anonymous namespace 70 71// createLoopExtractorPass - This pass extracts all natural loops from the 72// program into a function if it can. 73// 74FunctionPass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } 75 76bool LoopExtractor::runOnFunction(Function &F) { 77 LoopInfo &LI = getAnalysis<LoopInfo>(); 78 79 // If this function has no loops, there is nothing to do. 80 if (LI.begin() == LI.end()) 81 return false; 82 83 DominatorTree &DT = getAnalysis<DominatorTree>(); 84 85 // If there is more than one top-level loop in this function, extract all of 86 // the loops. 87 bool Changed = false; 88 if (LI.end()-LI.begin() > 1) { 89 for (LoopInfo::iterator i = LI.begin(), e = LI.end(); i != e; ++i) { 90 if (NumLoops == 0) return Changed; 91 --NumLoops; 92 Changed |= ExtractLoop(DT, *i) != 0; 93 ++NumExtracted; 94 } 95 } else { 96 // Otherwise there is exactly one top-level loop. If this function is more 97 // than a minimal wrapper around the loop, extract the loop. 98 Loop *TLL = *LI.begin(); 99 bool ShouldExtractLoop = false; 100 101 // Extract the loop if the entry block doesn't branch to the loop header. 102 TerminatorInst *EntryTI = F.getEntryBlock().getTerminator(); 103 if (!isa<BranchInst>(EntryTI) || 104 !cast<BranchInst>(EntryTI)->isUnconditional() || 105 EntryTI->getSuccessor(0) != TLL->getHeader()) 106 ShouldExtractLoop = true; 107 else { 108 // Check to see if any exits from the loop are more than just return 109 // blocks. 110 SmallVector<BasicBlock*, 8> ExitBlocks; 111 TLL->getExitBlocks(ExitBlocks); 112 for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) 113 if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) { 114 ShouldExtractLoop = true; 115 break; 116 } 117 } 118 119 if (ShouldExtractLoop) { 120 if (NumLoops == 0) return Changed; 121 --NumLoops; 122 Changed |= ExtractLoop(DT, TLL) != 0; 123 ++NumExtracted; 124 } else { 125 // Okay, this function is a minimal container around the specified loop. 126 // If we extract the loop, we will continue to just keep extracting it 127 // infinitely... so don't extract it. However, if the loop contains any 128 // subloops, extract them. 129 for (Loop::iterator i = TLL->begin(), e = TLL->end(); i != e; ++i) { 130 if (NumLoops == 0) return Changed; 131 --NumLoops; 132 Changed |= ExtractLoop(DT, *i) != 0; 133 ++NumExtracted; 134 } 135 } 136 } 137 138 return Changed; 139} 140 141// createSingleLoopExtractorPass - This pass extracts one natural loop from the 142// program into a function if it can. This is used by bugpoint. 143// 144FunctionPass *llvm::createSingleLoopExtractorPass() { 145 return new SingleLoopExtractor(); 146} 147 148 149namespace { 150 // BlockFile - A file which contains a list of blocks that should not be 151 // extracted. 152 cl::opt<std::string> 153 BlockFile("extract-blocks-file", cl::value_desc("filename"), 154 cl::desc("A file containing list of basic blocks to not extract"), 155 cl::Hidden); 156 157 /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks 158 /// from the module into their own functions except for those specified by the 159 /// BlocksToNotExtract list. 160 class BlockExtractorPass : public ModulePass { 161 void LoadFile(const char *Filename); 162 163 std::vector<BasicBlock*> BlocksToNotExtract; 164 std::vector<std::pair<std::string, std::string> > BlocksToNotExtractByName; 165 public: 166 static char ID; // Pass identification, replacement for typeid 167 explicit BlockExtractorPass(const std::vector<BasicBlock*> &B) 168 : ModulePass((intptr_t)&ID), BlocksToNotExtract(B) { 169 if (!BlockFile.empty()) 170 LoadFile(BlockFile.c_str()); 171 } 172 BlockExtractorPass() : ModulePass((intptr_t)&ID) {} 173 174 bool runOnModule(Module &M); 175 }; 176 177 char BlockExtractorPass::ID = 0; 178 RegisterPass<BlockExtractorPass> 179 XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)"); 180} 181 182// createBlockExtractorPass - This pass extracts all blocks (except those 183// specified in the argument list) from the functions in the module. 184// 185ModulePass *llvm::createBlockExtractorPass(const std::vector<BasicBlock*> &BTNE) 186{ 187 return new BlockExtractorPass(BTNE); 188} 189 190void BlockExtractorPass::LoadFile(const char *Filename) { 191 // Load the BlockFile... 192 std::ifstream In(Filename); 193 if (!In.good()) { 194 cerr << "WARNING: BlockExtractor couldn't load file '" << Filename 195 << "'!\n"; 196 return; 197 } 198 while (In) { 199 std::string FunctionName, BlockName; 200 In >> FunctionName; 201 In >> BlockName; 202 if (!BlockName.empty()) 203 BlocksToNotExtractByName.push_back( 204 std::make_pair(FunctionName, BlockName)); 205 } 206} 207 208bool BlockExtractorPass::runOnModule(Module &M) { 209 std::set<BasicBlock*> TranslatedBlocksToNotExtract; 210 for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { 211 BasicBlock *BB = BlocksToNotExtract[i]; 212 Function *F = BB->getParent(); 213 214 // Map the corresponding function in this module. 215 Function *MF = M.getFunction(F->getName()); 216 assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?"); 217 218 // Figure out which index the basic block is in its function. 219 Function::iterator BBI = MF->begin(); 220 std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); 221 TranslatedBlocksToNotExtract.insert(BBI); 222 } 223 224 while (!BlocksToNotExtractByName.empty()) { 225 // There's no way to find BBs by name without looking at every BB inside 226 // every Function. Fortunately, this is always empty except when used by 227 // bugpoint in which case correctness is more important than performance. 228 229 std::string &FuncName = BlocksToNotExtractByName.back().first; 230 std::string &BlockName = BlocksToNotExtractByName.back().second; 231 232 for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { 233 Function &F = *FI; 234 if (F.getName() != FuncName) continue; 235 236 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { 237 BasicBlock &BB = *BI; 238 if (BB.getName() != BlockName) continue; 239 240 TranslatedBlocksToNotExtract.insert(BI); 241 } 242 } 243 244 BlocksToNotExtractByName.pop_back(); 245 } 246 247 // Now that we know which blocks to not extract, figure out which ones we WANT 248 // to extract. 249 std::vector<BasicBlock*> BlocksToExtract; 250 for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) 251 for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) 252 if (!TranslatedBlocksToNotExtract.count(BB)) 253 BlocksToExtract.push_back(BB); 254 255 for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) 256 ExtractBasicBlock(BlocksToExtract[i]); 257 258 return !BlocksToExtract.empty(); 259} 260