LoopExtractor.cpp revision 2ab36d350293c77fc8941ce1023e4899df7e3a82
1//===- LoopExtractor.cpp - Extract each loop into a new function ----------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// A pass wrapper around the ExtractLoop() scalar transformation to extract each 11// top-level loop into its own new function. If the loop is the ONLY loop in a 12// given function, it is not touched. This is a pass most useful for debugging 13// via bugpoint. 14// 15//===----------------------------------------------------------------------===// 16 17#define DEBUG_TYPE "loop-extract" 18#include "llvm/Transforms/IPO.h" 19#include "llvm/Instructions.h" 20#include "llvm/Module.h" 21#include "llvm/Pass.h" 22#include "llvm/Analysis/Dominators.h" 23#include "llvm/Analysis/LoopPass.h" 24#include "llvm/Support/CommandLine.h" 25#include "llvm/Transforms/Scalar.h" 26#include "llvm/Transforms/Utils/FunctionUtils.h" 27#include "llvm/ADT/Statistic.h" 28#include <fstream> 29#include <set> 30using namespace llvm; 31 32STATISTIC(NumExtracted, "Number of loops extracted"); 33 34namespace { 35 struct LoopExtractor : public LoopPass { 36 static char ID; // Pass identification, replacement for typeid 37 unsigned NumLoops; 38 39 explicit LoopExtractor(unsigned numLoops = ~0) 40 : LoopPass(ID), NumLoops(numLoops) {} 41 42 virtual bool runOnLoop(Loop *L, LPPassManager &LPM); 43 44 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 45 AU.addRequiredID(BreakCriticalEdgesID); 46 AU.addRequiredID(LoopSimplifyID); 47 AU.addRequired<DominatorTree>(); 48 } 49 }; 50} 51 52char LoopExtractor::ID = 0; 53INITIALIZE_PASS_BEGIN(LoopExtractor, "loop-extract", 54 "Extract loops into new functions", false, false) 55INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) 56INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 57INITIALIZE_PASS_DEPENDENCY(DominatorTree) 58INITIALIZE_PASS_END(LoopExtractor, "loop-extract", 59 "Extract loops into new functions", false, false) 60 61namespace { 62 /// SingleLoopExtractor - For bugpoint. 63 struct SingleLoopExtractor : public LoopExtractor { 64 static char ID; // Pass identification, replacement for typeid 65 SingleLoopExtractor() : LoopExtractor(1) {} 66 }; 67} // End anonymous namespace 68 69char SingleLoopExtractor::ID = 0; 70INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", 71 "Extract at most one loop into a new function", false, false) 72 73// createLoopExtractorPass - This pass extracts all natural loops from the 74// program into a function if it can. 75// 76Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } 77 78bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { 79 // Only visit top-level loops. 80 if (L->getParentLoop()) 81 return false; 82 83 // If LoopSimplify form is not available, stay out of trouble. 84 if (!L->isLoopSimplifyForm()) 85 return false; 86 87 DominatorTree &DT = getAnalysis<DominatorTree>(); 88 bool Changed = false; 89 90 // If there is more than one top-level loop in this function, extract all of 91 // the loops. Otherwise there is exactly one top-level loop; in this case if 92 // this function is more than a minimal wrapper around the loop, extract 93 // the loop. 94 bool ShouldExtractLoop = false; 95 96 // Extract the loop if the entry block doesn't branch to the loop header. 97 TerminatorInst *EntryTI = 98 L->getHeader()->getParent()->getEntryBlock().getTerminator(); 99 if (!isa<BranchInst>(EntryTI) || 100 !cast<BranchInst>(EntryTI)->isUnconditional() || 101 EntryTI->getSuccessor(0) != L->getHeader()) 102 ShouldExtractLoop = true; 103 else { 104 // Check to see if any exits from the loop are more than just return 105 // blocks. 106 SmallVector<BasicBlock*, 8> ExitBlocks; 107 L->getExitBlocks(ExitBlocks); 108 for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) 109 if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) { 110 ShouldExtractLoop = true; 111 break; 112 } 113 } 114 if (ShouldExtractLoop) { 115 if (NumLoops == 0) return Changed; 116 --NumLoops; 117 if (ExtractLoop(DT, L) != 0) { 118 Changed = true; 119 // After extraction, the loop is replaced by a function call, so 120 // we shouldn't try to run any more loop passes on it. 121 LPM.deleteLoopFromQueue(L); 122 } 123 ++NumExtracted; 124 } 125 126 return Changed; 127} 128 129// createSingleLoopExtractorPass - This pass extracts one natural loop from the 130// program into a function if it can. This is used by bugpoint. 131// 132Pass *llvm::createSingleLoopExtractorPass() { 133 return new SingleLoopExtractor(); 134} 135 136 137// BlockFile - A file which contains a list of blocks that should not be 138// extracted. 139static cl::opt<std::string> 140BlockFile("extract-blocks-file", cl::value_desc("filename"), 141 cl::desc("A file containing list of basic blocks to not extract"), 142 cl::Hidden); 143 144namespace { 145 /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks 146 /// from the module into their own functions except for those specified by the 147 /// BlocksToNotExtract list. 148 class BlockExtractorPass : public ModulePass { 149 void LoadFile(const char *Filename); 150 151 std::vector<BasicBlock*> BlocksToNotExtract; 152 std::vector<std::pair<std::string, std::string> > BlocksToNotExtractByName; 153 public: 154 static char ID; // Pass identification, replacement for typeid 155 BlockExtractorPass() : ModulePass(ID) { 156 if (!BlockFile.empty()) 157 LoadFile(BlockFile.c_str()); 158 } 159 160 bool runOnModule(Module &M); 161 }; 162} 163 164char BlockExtractorPass::ID = 0; 165INITIALIZE_PASS(BlockExtractorPass, "extract-blocks", 166 "Extract Basic Blocks From Module (for bugpoint use)", 167 false, false) 168 169// createBlockExtractorPass - This pass extracts all blocks (except those 170// specified in the argument list) from the functions in the module. 171// 172ModulePass *llvm::createBlockExtractorPass() 173{ 174 return new BlockExtractorPass(); 175} 176 177void BlockExtractorPass::LoadFile(const char *Filename) { 178 // Load the BlockFile... 179 std::ifstream In(Filename); 180 if (!In.good()) { 181 errs() << "WARNING: BlockExtractor couldn't load file '" << Filename 182 << "'!\n"; 183 return; 184 } 185 while (In) { 186 std::string FunctionName, BlockName; 187 In >> FunctionName; 188 In >> BlockName; 189 if (!BlockName.empty()) 190 BlocksToNotExtractByName.push_back( 191 std::make_pair(FunctionName, BlockName)); 192 } 193} 194 195bool BlockExtractorPass::runOnModule(Module &M) { 196 std::set<BasicBlock*> TranslatedBlocksToNotExtract; 197 for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { 198 BasicBlock *BB = BlocksToNotExtract[i]; 199 Function *F = BB->getParent(); 200 201 // Map the corresponding function in this module. 202 Function *MF = M.getFunction(F->getName()); 203 assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?"); 204 205 // Figure out which index the basic block is in its function. 206 Function::iterator BBI = MF->begin(); 207 std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); 208 TranslatedBlocksToNotExtract.insert(BBI); 209 } 210 211 while (!BlocksToNotExtractByName.empty()) { 212 // There's no way to find BBs by name without looking at every BB inside 213 // every Function. Fortunately, this is always empty except when used by 214 // bugpoint in which case correctness is more important than performance. 215 216 std::string &FuncName = BlocksToNotExtractByName.back().first; 217 std::string &BlockName = BlocksToNotExtractByName.back().second; 218 219 for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { 220 Function &F = *FI; 221 if (F.getName() != FuncName) continue; 222 223 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { 224 BasicBlock &BB = *BI; 225 if (BB.getName() != BlockName) continue; 226 227 TranslatedBlocksToNotExtract.insert(BI); 228 } 229 } 230 231 BlocksToNotExtractByName.pop_back(); 232 } 233 234 // Now that we know which blocks to not extract, figure out which ones we WANT 235 // to extract. 236 std::vector<BasicBlock*> BlocksToExtract; 237 for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) 238 for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) 239 if (!TranslatedBlocksToNotExtract.count(BB)) 240 BlocksToExtract.push_back(BB); 241 242 for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) 243 ExtractBasicBlock(BlocksToExtract[i]); 244 245 return !BlocksToExtract.empty(); 246} 247