LoopExtractor.cpp revision a65d6a686e6ad865c61aec70c5bdfb30bf6f5b22
1//===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// A pass wrapper around the ExtractLoop() scalar transformation to extract each
11// top-level loop into its own new function. If the loop is the ONLY loop in a
12// given function, it is not touched. This is a pass most useful for debugging
13// via bugpoint.
14//
15//===----------------------------------------------------------------------===//
16
17#define DEBUG_TYPE "loop-extract"
18#include "llvm/Transforms/IPO.h"
19#include "llvm/Instructions.h"
20#include "llvm/Module.h"
21#include "llvm/Pass.h"
22#include "llvm/Analysis/Dominators.h"
23#include "llvm/Analysis/LoopPass.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Transforms/Scalar.h"
26#include "llvm/Transforms/Utils/FunctionUtils.h"
27#include "llvm/ADT/Statistic.h"
28#include <fstream>
29#include <set>
30using namespace llvm;
31
32STATISTIC(NumExtracted, "Number of loops extracted");
33
34namespace {
35  struct LoopExtractor : public LoopPass {
36    static char ID; // Pass identification, replacement for typeid
37    unsigned NumLoops;
38
39    explicit LoopExtractor(unsigned numLoops = ~0)
40      : LoopPass(ID), NumLoops(numLoops) {
41        initializeLoopExtractorPass(*PassRegistry::getPassRegistry());
42      }
43
44    virtual bool runOnLoop(Loop *L, LPPassManager &LPM);
45
46    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
47      AU.addRequiredID(BreakCriticalEdgesID);
48      AU.addRequiredID(LoopSimplifyID);
49      AU.addRequired<DominatorTree>();
50    }
51  };
52}
53
54char LoopExtractor::ID = 0;
55INITIALIZE_PASS_BEGIN(LoopExtractor, "loop-extract",
56                "Extract loops into new functions", false, false)
57INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
58INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
59INITIALIZE_PASS_DEPENDENCY(DominatorTree)
60INITIALIZE_PASS_END(LoopExtractor, "loop-extract",
61                "Extract loops into new functions", false, false)
62
63namespace {
64  /// SingleLoopExtractor - For bugpoint.
65  struct SingleLoopExtractor : public LoopExtractor {
66    static char ID; // Pass identification, replacement for typeid
67    SingleLoopExtractor() : LoopExtractor(1) {}
68  };
69} // End anonymous namespace
70
71char SingleLoopExtractor::ID = 0;
72INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
73                "Extract at most one loop into a new function", false, false)
74
75// createLoopExtractorPass - This pass extracts all natural loops from the
76// program into a function if it can.
77//
78Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); }
79
80bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) {
81  // Only visit top-level loops.
82  if (L->getParentLoop())
83    return false;
84
85  // If LoopSimplify form is not available, stay out of trouble.
86  if (!L->isLoopSimplifyForm())
87    return false;
88
89  DominatorTree &DT = getAnalysis<DominatorTree>();
90  bool Changed = false;
91
92  // If there is more than one top-level loop in this function, extract all of
93  // the loops. Otherwise there is exactly one top-level loop; in this case if
94  // this function is more than a minimal wrapper around the loop, extract
95  // the loop.
96  bool ShouldExtractLoop = false;
97
98  // Extract the loop if the entry block doesn't branch to the loop header.
99  TerminatorInst *EntryTI =
100    L->getHeader()->getParent()->getEntryBlock().getTerminator();
101  if (!isa<BranchInst>(EntryTI) ||
102      !cast<BranchInst>(EntryTI)->isUnconditional() ||
103      EntryTI->getSuccessor(0) != L->getHeader())
104    ShouldExtractLoop = true;
105  else {
106    // Check to see if any exits from the loop are more than just return
107    // blocks.
108    SmallVector<BasicBlock*, 8> ExitBlocks;
109    L->getExitBlocks(ExitBlocks);
110    for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i)
111      if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) {
112        ShouldExtractLoop = true;
113        break;
114      }
115  }
116  if (ShouldExtractLoop) {
117    if (NumLoops == 0) return Changed;
118    --NumLoops;
119    if (ExtractLoop(DT, L) != 0) {
120      Changed = true;
121      // After extraction, the loop is replaced by a function call, so
122      // we shouldn't try to run any more loop passes on it.
123      LPM.deleteLoopFromQueue(L);
124    }
125    ++NumExtracted;
126  }
127
128  return Changed;
129}
130
131// createSingleLoopExtractorPass - This pass extracts one natural loop from the
132// program into a function if it can.  This is used by bugpoint.
133//
134Pass *llvm::createSingleLoopExtractorPass() {
135  return new SingleLoopExtractor();
136}
137
138
139// BlockFile - A file which contains a list of blocks that should not be
140// extracted.
141static cl::opt<std::string>
142BlockFile("extract-blocks-file", cl::value_desc("filename"),
143          cl::desc("A file containing list of basic blocks to not extract"),
144          cl::Hidden);
145
146namespace {
147  /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks
148  /// from the module into their own functions except for those specified by the
149  /// BlocksToNotExtract list.
150  class BlockExtractorPass : public ModulePass {
151    void LoadFile(const char *Filename);
152
153    std::vector<BasicBlock*> BlocksToNotExtract;
154    std::vector<std::pair<std::string, std::string> > BlocksToNotExtractByName;
155  public:
156    static char ID; // Pass identification, replacement for typeid
157    BlockExtractorPass() : ModulePass(ID) {
158      if (!BlockFile.empty())
159        LoadFile(BlockFile.c_str());
160    }
161
162    bool runOnModule(Module &M);
163  };
164}
165
166char BlockExtractorPass::ID = 0;
167INITIALIZE_PASS(BlockExtractorPass, "extract-blocks",
168                "Extract Basic Blocks From Module (for bugpoint use)",
169                false, false)
170
171// createBlockExtractorPass - This pass extracts all blocks (except those
172// specified in the argument list) from the functions in the module.
173//
174ModulePass *llvm::createBlockExtractorPass()
175{
176  return new BlockExtractorPass();
177}
178
179void BlockExtractorPass::LoadFile(const char *Filename) {
180  // Load the BlockFile...
181  std::ifstream In(Filename);
182  if (!In.good()) {
183    errs() << "WARNING: BlockExtractor couldn't load file '" << Filename
184           << "'!\n";
185    return;
186  }
187  while (In) {
188    std::string FunctionName, BlockName;
189    In >> FunctionName;
190    In >> BlockName;
191    if (!BlockName.empty())
192      BlocksToNotExtractByName.push_back(
193          std::make_pair(FunctionName, BlockName));
194  }
195}
196
197bool BlockExtractorPass::runOnModule(Module &M) {
198  std::set<BasicBlock*> TranslatedBlocksToNotExtract;
199  for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) {
200    BasicBlock *BB = BlocksToNotExtract[i];
201    Function *F = BB->getParent();
202
203    // Map the corresponding function in this module.
204    Function *MF = M.getFunction(F->getName());
205    assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?");
206
207    // Figure out which index the basic block is in its function.
208    Function::iterator BBI = MF->begin();
209    std::advance(BBI, std::distance(F->begin(), Function::iterator(BB)));
210    TranslatedBlocksToNotExtract.insert(BBI);
211  }
212
213  while (!BlocksToNotExtractByName.empty()) {
214    // There's no way to find BBs by name without looking at every BB inside
215    // every Function. Fortunately, this is always empty except when used by
216    // bugpoint in which case correctness is more important than performance.
217
218    std::string &FuncName  = BlocksToNotExtractByName.back().first;
219    std::string &BlockName = BlocksToNotExtractByName.back().second;
220
221    for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) {
222      Function &F = *FI;
223      if (F.getName() != FuncName) continue;
224
225      for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
226        BasicBlock &BB = *BI;
227        if (BB.getName() != BlockName) continue;
228
229        TranslatedBlocksToNotExtract.insert(BI);
230      }
231    }
232
233    BlocksToNotExtractByName.pop_back();
234  }
235
236  // Now that we know which blocks to not extract, figure out which ones we WANT
237  // to extract.
238  std::vector<BasicBlock*> BlocksToExtract;
239  for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F)
240    for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB)
241      if (!TranslatedBlocksToNotExtract.count(BB))
242        BlocksToExtract.push_back(BB);
243
244  for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i)
245    ExtractBasicBlock(BlocksToExtract[i]);
246
247  return !BlocksToExtract.empty();
248}
249