TailRecursionElimination.cpp revision 108e4ab159b59a616b0868e396dc7ddc1fb48616
1//===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file was developed by the LLVM research group and is distributed under
6// the University of Illinois Open Source License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements tail recursion elimination.
11//
12// Caveats: The algorithm implemented is trivially simple.  There are several
13// improvements that could be made:
14//
15//  1. If the function has any alloca instructions, these instructions will not
16//     remain in the entry block of the function.  Doing this requires analysis
17//     to prove that the alloca is not reachable by the recursively invoked
18//     function call.
19//  2. Tail recursion is only performed if the call immediately preceeds the
20//     return instruction.  Would it be useful to generalize this somehow?
21//  3. TRE is only performed if the function returns void or if the return
22//     returns the result returned by the call.  It is possible, but unlikely,
23//     that the return returns something else (like constant 0), and can still
24//     be TRE'd.  It can be TRE'd if ALL OTHER return instructions in the
25//     function return the exact same value.
26//
27//===----------------------------------------------------------------------===//
28
29#include "llvm/Transforms/Scalar.h"
30#include "llvm/DerivedTypes.h"
31#include "llvm/Function.h"
32#include "llvm/Instructions.h"
33#include "llvm/Pass.h"
34#include "Support/Statistic.h"
35
36using namespace llvm;
37
38namespace {
39  Statistic<> NumEliminated("tailcallelim", "Number of tail calls removed");
40
41  struct TailCallElim : public FunctionPass {
42    virtual bool runOnFunction(Function &F);
43  };
44  RegisterOpt<TailCallElim> X("tailcallelim", "Tail Call Elimination");
45}
46
47// Public interface to the TailCallElimination pass
48FunctionPass *llvm::createTailCallEliminationPass() {
49  return new TailCallElim();
50}
51
52
53bool TailCallElim::runOnFunction(Function &F) {
54  // If this function is a varargs function, we won't be able to PHI the args
55  // right, so don't even try to convert it...
56  if (F.getFunctionType()->isVarArg()) return false;
57
58  BasicBlock *OldEntry = 0;
59  std::vector<PHINode*> ArgumentPHIs;
60  bool MadeChange = false;
61
62  // Loop over the function, looking for any returning blocks...
63  for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
64    if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator()))
65      if (Ret != BB->begin())  // Make sure there is something before the ret...
66        if (CallInst *CI = dyn_cast<CallInst>(Ret->getPrev()))
67          // Make sure the tail call is to the current function, and that the
68          // return either returns void or returns the value computed by the
69          // call.
70          if (CI->getCalledFunction() == &F &&
71              (Ret->getNumOperands() == 0 || Ret->getReturnValue() == CI)) {
72            // Ohh, it looks like we found a tail call, is this the first?
73            if (!OldEntry) {
74              // Ok, so this is the first tail call we have found in this
75              // function.  Insert a new entry block into the function, allowing
76              // us to branch back to the old entry block.
77              OldEntry = &F.getEntryBlock();
78              BasicBlock *NewEntry = new BasicBlock("tailrecurse", OldEntry);
79              new BranchInst(OldEntry, NewEntry);
80
81              // Now that we have created a new block, which jumps to the entry
82              // block, insert a PHI node for each argument of the function.
83              // For now, we initialize each PHI to only have the real arguments
84              // which are passed in.
85              Instruction *InsertPos = OldEntry->begin();
86              for (Function::aiterator I = F.abegin(), E = F.aend(); I!=E; ++I){
87                PHINode *PN = new PHINode(I->getType(), I->getName()+".tr",
88                                          InsertPos);
89                I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
90                PN->addIncoming(I, NewEntry);
91                ArgumentPHIs.push_back(PN);
92              }
93            }
94
95            // Ok, now that we know we have a pseudo-entry block WITH all of the
96            // required PHI nodes, add entries into the PHI node for the actual
97            // parameters passed into the tail-recursive call.
98            for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i)
99              ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB);
100
101            // Now that all of the PHI nodes are in place, remove the call and
102            // ret instructions, replacing them with an unconditional branch.
103            new BranchInst(OldEntry, CI);
104            BB->getInstList().pop_back();  // Remove return.
105            BB->getInstList().pop_back();  // Remove call.
106            MadeChange = true;
107            NumEliminated++;
108          }
109
110  return MadeChange;
111}
112