MergeFunctions.cpp revision ff380f202ea6a168a5167cdf238a6a22a9ea3a71
1//===- MergeFunctions.cpp - Merge identical functions ---------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass looks for equivalent functions that are mergable and folds them.
11//
12// A hash is computed from the function, based on its type and number of
13// basic blocks.
14//
15// Once all hashes are computed, we perform an expensive equality comparison
16// on each function pair. This takes n^2/2 comparisons per bucket, so it's
17// important that the hash function be high quality. The equality comparison
18// iterates through each instruction in each basic block.
19//
20// When a match is found the functions are folded. If both functions are
21// overridable, we move the functionality into a new internal function and
22// leave two overridable thunks to it.
23//
24//===----------------------------------------------------------------------===//
25//
26// Future work:
27//
28// * virtual functions.
29//
30// Many functions have their address taken by the virtual function table for
31// the object they belong to. However, as long as it's only used for a lookup
32// and call, this is irrelevant, and we'd like to fold such functions.
33//
34// * switch from n^2 pair-wise comparisons to an n-way comparison for each
35// bucket.
36//
37// * be smarter about bitcasts.
38//
39// In order to fold functions, we will sometimes add either bitcast instructions
40// or bitcast constant expressions. Unfortunately, this can confound further
41// analysis since the two functions differ where one has a bitcast and the
42// other doesn't. We should learn to look through bitcasts.
43//
44//===----------------------------------------------------------------------===//
45
46#define DEBUG_TYPE "mergefunc"
47#include "llvm/Transforms/IPO.h"
48#include "llvm/ADT/DenseSet.h"
49#include "llvm/ADT/FoldingSet.h"
50#include "llvm/ADT/SmallSet.h"
51#include "llvm/ADT/Statistic.h"
52#include "llvm/ADT/STLExtras.h"
53#include "llvm/Constants.h"
54#include "llvm/InlineAsm.h"
55#include "llvm/Instructions.h"
56#include "llvm/LLVMContext.h"
57#include "llvm/Module.h"
58#include "llvm/Pass.h"
59#include "llvm/Support/CallSite.h"
60#include "llvm/Support/Debug.h"
61#include "llvm/Support/ErrorHandling.h"
62#include "llvm/Support/IRBuilder.h"
63#include "llvm/Support/ValueHandle.h"
64#include "llvm/Support/raw_ostream.h"
65#include "llvm/Target/TargetData.h"
66#include <vector>
67using namespace llvm;
68
69STATISTIC(NumFunctionsMerged, "Number of functions merged");
70STATISTIC(NumThunksWritten, "Number of thunks generated");
71STATISTIC(NumDoubleWeak, "Number of new functions created");
72
73/// ProfileFunction - Creates a hash-code for the function which is the same
74/// for any two functions that will compare equal, without looking at the
75/// instructions inside the function.
76static unsigned ProfileFunction(const Function *F) {
77  const FunctionType *FTy = F->getFunctionType();
78
79  FoldingSetNodeID ID;
80  ID.AddInteger(F->size());
81  ID.AddInteger(F->getCallingConv());
82  ID.AddBoolean(F->hasGC());
83  ID.AddBoolean(FTy->isVarArg());
84  ID.AddInteger(FTy->getReturnType()->getTypeID());
85  for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i)
86    ID.AddInteger(FTy->getParamType(i)->getTypeID());
87  return ID.ComputeHash();
88}
89
90namespace {
91
92class ComparableFunction {
93public:
94  static const ComparableFunction EmptyKey;
95  static const ComparableFunction TombstoneKey;
96
97  ComparableFunction(Function *Func, TargetData *TD)
98    : Func(Func), Hash(ProfileFunction(Func)), TD(TD) {}
99
100  Function *getFunc() const { return Func; }
101  unsigned getHash() const { return Hash; }
102  TargetData *getTD() const { return TD; }
103
104  // Drops AssertingVH reference to the function. Outside of debug mode, this
105  // does nothing.
106  void release() {
107    assert(Func &&
108           "Attempted to release function twice, or release empty/tombstone!");
109    Func = NULL;
110  }
111
112private:
113  explicit ComparableFunction(unsigned Hash)
114    : Func(NULL), Hash(Hash), TD(NULL) {}
115
116  AssertingVH<Function> Func;
117  unsigned Hash;
118  TargetData *TD;
119};
120
121const ComparableFunction ComparableFunction::EmptyKey = ComparableFunction(0);
122const ComparableFunction ComparableFunction::TombstoneKey =
123    ComparableFunction(1);
124
125}
126
127namespace llvm {
128  template <>
129  struct DenseMapInfo<ComparableFunction> {
130    static ComparableFunction getEmptyKey() {
131      return ComparableFunction::EmptyKey;
132    }
133    static ComparableFunction getTombstoneKey() {
134      return ComparableFunction::TombstoneKey;
135    }
136    static unsigned getHashValue(const ComparableFunction &CF) {
137      return CF.getHash();
138    }
139    static bool isEqual(const ComparableFunction &LHS,
140                        const ComparableFunction &RHS);
141  };
142}
143
144namespace {
145
146/// MergeFunctions finds functions which will generate identical machine code,
147/// by considering all pointer types to be equivalent. Once identified,
148/// MergeFunctions will fold them by replacing a call to one to a call to a
149/// bitcast of the other.
150///
151class MergeFunctions : public ModulePass {
152public:
153  static char ID;
154  MergeFunctions() : ModulePass(ID) {}
155
156  bool runOnModule(Module &M);
157
158private:
159  typedef DenseSet<ComparableFunction> FnSetType;
160
161
162  /// Insert a ComparableFunction into the FnSet, or merge it away if it's
163  /// equal to one that's already present.
164  bool Insert(FnSetType &FnSet, ComparableFunction &NewF);
165
166  /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, G
167  /// may be deleted, or may be converted into a thunk. In either case, it
168  /// should never be visited again.
169  void MergeTwoFunctions(Function *F, Function *G) const;
170
171  /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also
172  /// replace direct uses of G with bitcast(F). Deletes G.
173  void WriteThunk(Function *F, Function *G) const;
174
175  TargetData *TD;
176};
177
178}  // end anonymous namespace
179
180char MergeFunctions::ID = 0;
181INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false);
182
183ModulePass *llvm::createMergeFunctionsPass() {
184  return new MergeFunctions();
185}
186
187namespace {
188/// FunctionComparator - Compares two functions to determine whether or not
189/// they will generate machine code with the same behaviour. TargetData is
190/// used if available. The comparator always fails conservatively (erring on the
191/// side of claiming that two functions are different).
192class FunctionComparator {
193public:
194  FunctionComparator(const TargetData *TD, const Function *F1,
195                     const Function *F2)
196    : F1(F1), F2(F2), TD(TD), IDMap1Count(0), IDMap2Count(0) {}
197
198  /// Compare - test whether the two functions have equivalent behaviour.
199  bool Compare();
200
201private:
202  /// Compare - test whether two basic blocks have equivalent behaviour.
203  bool Compare(const BasicBlock *BB1, const BasicBlock *BB2);
204
205  /// Enumerate - Assign or look up previously assigned numbers for the two
206  /// values, and return whether the numbers are equal. Numbers are assigned in
207  /// the order visited.
208  bool Enumerate(const Value *V1, const Value *V2);
209
210  /// isEquivalentOperation - Compare two Instructions for equivalence, similar
211  /// to Instruction::isSameOperationAs but with modifications to the type
212  /// comparison.
213  bool isEquivalentOperation(const Instruction *I1,
214                             const Instruction *I2) const;
215
216  /// isEquivalentGEP - Compare two GEPs for equivalent pointer arithmetic.
217  bool isEquivalentGEP(const GEPOperator *GEP1, const GEPOperator *GEP2);
218  bool isEquivalentGEP(const GetElementPtrInst *GEP1,
219                       const GetElementPtrInst *GEP2) {
220    return isEquivalentGEP(cast<GEPOperator>(GEP1), cast<GEPOperator>(GEP2));
221  }
222
223  /// isEquivalentType - Compare two Types, treating all pointer types as equal.
224  bool isEquivalentType(const Type *Ty1, const Type *Ty2) const;
225
226  // The two functions undergoing comparison.
227  const Function *F1, *F2;
228
229  const TargetData *TD;
230
231  typedef DenseMap<const Value *, unsigned long> IDMap;
232  IDMap Map1, Map2;
233  unsigned long IDMap1Count, IDMap2Count;
234};
235}
236
237/// isEquivalentType - any two pointers in the same address space are
238/// equivalent. Otherwise, standard type equivalence rules apply.
239bool FunctionComparator::isEquivalentType(const Type *Ty1,
240                                          const Type *Ty2) const {
241  if (Ty1 == Ty2)
242    return true;
243  if (Ty1->getTypeID() != Ty2->getTypeID())
244    return false;
245
246  switch(Ty1->getTypeID()) {
247  default:
248    llvm_unreachable("Unknown type!");
249    // Fall through in Release mode.
250  case Type::IntegerTyID:
251  case Type::OpaqueTyID:
252    // Ty1 == Ty2 would have returned true earlier.
253    return false;
254
255  case Type::VoidTyID:
256  case Type::FloatTyID:
257  case Type::DoubleTyID:
258  case Type::X86_FP80TyID:
259  case Type::FP128TyID:
260  case Type::PPC_FP128TyID:
261  case Type::LabelTyID:
262  case Type::MetadataTyID:
263    return true;
264
265  case Type::PointerTyID: {
266    const PointerType *PTy1 = cast<PointerType>(Ty1);
267    const PointerType *PTy2 = cast<PointerType>(Ty2);
268    return PTy1->getAddressSpace() == PTy2->getAddressSpace();
269  }
270
271  case Type::StructTyID: {
272    const StructType *STy1 = cast<StructType>(Ty1);
273    const StructType *STy2 = cast<StructType>(Ty2);
274    if (STy1->getNumElements() != STy2->getNumElements())
275      return false;
276
277    if (STy1->isPacked() != STy2->isPacked())
278      return false;
279
280    for (unsigned i = 0, e = STy1->getNumElements(); i != e; ++i) {
281      if (!isEquivalentType(STy1->getElementType(i), STy2->getElementType(i)))
282        return false;
283    }
284    return true;
285  }
286
287  case Type::FunctionTyID: {
288    const FunctionType *FTy1 = cast<FunctionType>(Ty1);
289    const FunctionType *FTy2 = cast<FunctionType>(Ty2);
290    if (FTy1->getNumParams() != FTy2->getNumParams() ||
291        FTy1->isVarArg() != FTy2->isVarArg())
292      return false;
293
294    if (!isEquivalentType(FTy1->getReturnType(), FTy2->getReturnType()))
295      return false;
296
297    for (unsigned i = 0, e = FTy1->getNumParams(); i != e; ++i) {
298      if (!isEquivalentType(FTy1->getParamType(i), FTy2->getParamType(i)))
299        return false;
300    }
301    return true;
302  }
303
304  case Type::ArrayTyID: {
305    const ArrayType *ATy1 = cast<ArrayType>(Ty1);
306    const ArrayType *ATy2 = cast<ArrayType>(Ty2);
307    return ATy1->getNumElements() == ATy2->getNumElements() &&
308           isEquivalentType(ATy1->getElementType(), ATy2->getElementType());
309  }
310
311  case Type::VectorTyID: {
312    const VectorType *VTy1 = cast<VectorType>(Ty1);
313    const VectorType *VTy2 = cast<VectorType>(Ty2);
314    return VTy1->getNumElements() == VTy2->getNumElements() &&
315           isEquivalentType(VTy1->getElementType(), VTy2->getElementType());
316  }
317  }
318}
319
320/// isEquivalentOperation - determine whether the two operations are the same
321/// except that pointer-to-A and pointer-to-B are equivalent. This should be
322/// kept in sync with Instruction::isSameOperationAs.
323bool FunctionComparator::isEquivalentOperation(const Instruction *I1,
324                                               const Instruction *I2) const {
325  if (I1->getOpcode() != I2->getOpcode() ||
326      I1->getNumOperands() != I2->getNumOperands() ||
327      !isEquivalentType(I1->getType(), I2->getType()) ||
328      !I1->hasSameSubclassOptionalData(I2))
329    return false;
330
331  // We have two instructions of identical opcode and #operands.  Check to see
332  // if all operands are the same type
333  for (unsigned i = 0, e = I1->getNumOperands(); i != e; ++i)
334    if (!isEquivalentType(I1->getOperand(i)->getType(),
335                          I2->getOperand(i)->getType()))
336      return false;
337
338  // Check special state that is a part of some instructions.
339  if (const LoadInst *LI = dyn_cast<LoadInst>(I1))
340    return LI->isVolatile() == cast<LoadInst>(I2)->isVolatile() &&
341           LI->getAlignment() == cast<LoadInst>(I2)->getAlignment();
342  if (const StoreInst *SI = dyn_cast<StoreInst>(I1))
343    return SI->isVolatile() == cast<StoreInst>(I2)->isVolatile() &&
344           SI->getAlignment() == cast<StoreInst>(I2)->getAlignment();
345  if (const CmpInst *CI = dyn_cast<CmpInst>(I1))
346    return CI->getPredicate() == cast<CmpInst>(I2)->getPredicate();
347  if (const CallInst *CI = dyn_cast<CallInst>(I1))
348    return CI->isTailCall() == cast<CallInst>(I2)->isTailCall() &&
349           CI->getCallingConv() == cast<CallInst>(I2)->getCallingConv() &&
350           CI->getAttributes().getRawPointer() ==
351             cast<CallInst>(I2)->getAttributes().getRawPointer();
352  if (const InvokeInst *CI = dyn_cast<InvokeInst>(I1))
353    return CI->getCallingConv() == cast<InvokeInst>(I2)->getCallingConv() &&
354           CI->getAttributes().getRawPointer() ==
355             cast<InvokeInst>(I2)->getAttributes().getRawPointer();
356  if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(I1)) {
357    if (IVI->getNumIndices() != cast<InsertValueInst>(I2)->getNumIndices())
358      return false;
359    for (unsigned i = 0, e = IVI->getNumIndices(); i != e; ++i)
360      if (IVI->idx_begin()[i] != cast<InsertValueInst>(I2)->idx_begin()[i])
361        return false;
362    return true;
363  }
364  if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(I1)) {
365    if (EVI->getNumIndices() != cast<ExtractValueInst>(I2)->getNumIndices())
366      return false;
367    for (unsigned i = 0, e = EVI->getNumIndices(); i != e; ++i)
368      if (EVI->idx_begin()[i] != cast<ExtractValueInst>(I2)->idx_begin()[i])
369        return false;
370    return true;
371  }
372
373  return true;
374}
375
376/// isEquivalentGEP - determine whether two GEP operations perform the same
377/// underlying arithmetic.
378bool FunctionComparator::isEquivalentGEP(const GEPOperator *GEP1,
379                                         const GEPOperator *GEP2) {
380  // When we have target data, we can reduce the GEP down to the value in bytes
381  // added to the address.
382  if (TD && GEP1->hasAllConstantIndices() && GEP2->hasAllConstantIndices()) {
383    SmallVector<Value *, 8> Indices1(GEP1->idx_begin(), GEP1->idx_end());
384    SmallVector<Value *, 8> Indices2(GEP2->idx_begin(), GEP2->idx_end());
385    uint64_t Offset1 = TD->getIndexedOffset(GEP1->getPointerOperandType(),
386                                            Indices1.data(), Indices1.size());
387    uint64_t Offset2 = TD->getIndexedOffset(GEP2->getPointerOperandType(),
388                                            Indices2.data(), Indices2.size());
389    return Offset1 == Offset2;
390  }
391
392  if (GEP1->getPointerOperand()->getType() !=
393      GEP2->getPointerOperand()->getType())
394    return false;
395
396  if (GEP1->getNumOperands() != GEP2->getNumOperands())
397    return false;
398
399  for (unsigned i = 0, e = GEP1->getNumOperands(); i != e; ++i) {
400    if (!Enumerate(GEP1->getOperand(i), GEP2->getOperand(i)))
401      return false;
402  }
403
404  return true;
405}
406
407/// Enumerate - Compare two values used by the two functions under pair-wise
408/// comparison. If this is the first time the values are seen, they're added to
409/// the mapping so that we will detect mismatches on next use.
410bool FunctionComparator::Enumerate(const Value *V1, const Value *V2) {
411  // Check for function @f1 referring to itself and function @f2 referring to
412  // itself, or referring to each other, or both referring to either of them.
413  // They're all equivalent if the two functions are otherwise equivalent.
414  if (V1 == F1 && V2 == F2)
415    return true;
416  if (V1 == F2 && V2 == F1)
417    return true;
418
419  // TODO: constant expressions with GEP or references to F1 or F2.
420  if (isa<Constant>(V1))
421    return V1 == V2;
422
423  if (isa<InlineAsm>(V1) && isa<InlineAsm>(V2)) {
424    const InlineAsm *IA1 = cast<InlineAsm>(V1);
425    const InlineAsm *IA2 = cast<InlineAsm>(V2);
426    return IA1->getAsmString() == IA2->getAsmString() &&
427           IA1->getConstraintString() == IA2->getConstraintString();
428  }
429
430  unsigned long &ID1 = Map1[V1];
431  if (!ID1)
432    ID1 = ++IDMap1Count;
433
434  unsigned long &ID2 = Map2[V2];
435  if (!ID2)
436    ID2 = ++IDMap2Count;
437
438  return ID1 == ID2;
439}
440
441/// Compare - test whether two basic blocks have equivalent behaviour.
442bool FunctionComparator::Compare(const BasicBlock *BB1, const BasicBlock *BB2) {
443  BasicBlock::const_iterator F1I = BB1->begin(), F1E = BB1->end();
444  BasicBlock::const_iterator F2I = BB2->begin(), F2E = BB2->end();
445
446  do {
447    if (!Enumerate(F1I, F2I))
448      return false;
449
450    if (const GetElementPtrInst *GEP1 = dyn_cast<GetElementPtrInst>(F1I)) {
451      const GetElementPtrInst *GEP2 = dyn_cast<GetElementPtrInst>(F2I);
452      if (!GEP2)
453        return false;
454
455      if (!Enumerate(GEP1->getPointerOperand(), GEP2->getPointerOperand()))
456        return false;
457
458      if (!isEquivalentGEP(GEP1, GEP2))
459        return false;
460    } else {
461      if (!isEquivalentOperation(F1I, F2I))
462        return false;
463
464      assert(F1I->getNumOperands() == F2I->getNumOperands());
465      for (unsigned i = 0, e = F1I->getNumOperands(); i != e; ++i) {
466        Value *OpF1 = F1I->getOperand(i);
467        Value *OpF2 = F2I->getOperand(i);
468
469        if (!Enumerate(OpF1, OpF2))
470          return false;
471
472        if (OpF1->getValueID() != OpF2->getValueID() ||
473            !isEquivalentType(OpF1->getType(), OpF2->getType()))
474          return false;
475      }
476    }
477
478    ++F1I, ++F2I;
479  } while (F1I != F1E && F2I != F2E);
480
481  return F1I == F1E && F2I == F2E;
482}
483
484/// Compare - test whether the two functions have equivalent behaviour.
485bool FunctionComparator::Compare() {
486  // We need to recheck everything, but check the things that weren't included
487  // in the hash first.
488
489  if (F1->getAttributes() != F2->getAttributes())
490    return false;
491
492  if (F1->hasGC() != F2->hasGC())
493    return false;
494
495  if (F1->hasGC() && F1->getGC() != F2->getGC())
496    return false;
497
498  if (F1->hasSection() != F2->hasSection())
499    return false;
500
501  if (F1->hasSection() && F1->getSection() != F2->getSection())
502    return false;
503
504  if (F1->isVarArg() != F2->isVarArg())
505    return false;
506
507  // TODO: if it's internal and only used in direct calls, we could handle this
508  // case too.
509  if (F1->getCallingConv() != F2->getCallingConv())
510    return false;
511
512  if (!isEquivalentType(F1->getFunctionType(), F2->getFunctionType()))
513    return false;
514
515  assert(F1->arg_size() == F2->arg_size() &&
516         "Identically typed functions have different numbers of args!");
517
518  // Visit the arguments so that they get enumerated in the order they're
519  // passed in.
520  for (Function::const_arg_iterator f1i = F1->arg_begin(),
521         f2i = F2->arg_begin(), f1e = F1->arg_end(); f1i != f1e; ++f1i, ++f2i) {
522    if (!Enumerate(f1i, f2i))
523      llvm_unreachable("Arguments repeat!");
524  }
525
526  // We do a CFG-ordered walk since the actual ordering of the blocks in the
527  // linked list is immaterial. Our walk starts at the entry block for both
528  // functions, then takes each block from each terminator in order. As an
529  // artifact, this also means that unreachable blocks are ignored.
530  SmallVector<const BasicBlock *, 8> F1BBs, F2BBs;
531  SmallSet<const BasicBlock *, 128> VisitedBBs; // in terms of F1.
532
533  F1BBs.push_back(&F1->getEntryBlock());
534  F2BBs.push_back(&F2->getEntryBlock());
535
536  VisitedBBs.insert(F1BBs[0]);
537  while (!F1BBs.empty()) {
538    const BasicBlock *F1BB = F1BBs.pop_back_val();
539    const BasicBlock *F2BB = F2BBs.pop_back_val();
540
541    if (!Enumerate(F1BB, F2BB) || !Compare(F1BB, F2BB))
542      return false;
543
544    const TerminatorInst *F1TI = F1BB->getTerminator();
545    const TerminatorInst *F2TI = F2BB->getTerminator();
546
547    assert(F1TI->getNumSuccessors() == F2TI->getNumSuccessors());
548    for (unsigned i = 0, e = F1TI->getNumSuccessors(); i != e; ++i) {
549      if (!VisitedBBs.insert(F1TI->getSuccessor(i)))
550        continue;
551
552      F1BBs.push_back(F1TI->getSuccessor(i));
553      F2BBs.push_back(F2TI->getSuccessor(i));
554    }
555  }
556  return true;
557}
558
559/// WriteThunk - Replace G with a simple tail call to bitcast(F). Also replace
560/// direct uses of G with bitcast(F). Deletes G.
561void MergeFunctions::WriteThunk(Function *F, Function *G) const {
562  if (!G->mayBeOverridden()) {
563    // Redirect direct callers of G to F.
564    Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
565    for (Value::use_iterator UI = G->use_begin(), UE = G->use_end();
566         UI != UE;) {
567      Value::use_iterator TheIter = UI;
568      ++UI;
569      CallSite CS(*TheIter);
570      if (CS && CS.isCallee(TheIter))
571        TheIter.getUse().set(BitcastF);
572    }
573  }
574
575  // If G was internal then we may have replaced all uses of G with F. If so,
576  // stop here and delete G. There's no need for a thunk.
577  if (G->hasLocalLinkage() && G->use_empty()) {
578    G->eraseFromParent();
579    return;
580  }
581
582  Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "",
583                                    G->getParent());
584  BasicBlock *BB = BasicBlock::Create(F->getContext(), "", NewG);
585  IRBuilder<false> Builder(BB);
586
587  SmallVector<Value *, 16> Args;
588  unsigned i = 0;
589  const FunctionType *FFTy = F->getFunctionType();
590  for (Function::arg_iterator AI = NewG->arg_begin(), AE = NewG->arg_end();
591       AI != AE; ++AI) {
592    Args.push_back(Builder.CreateBitCast(AI, FFTy->getParamType(i)));
593    ++i;
594  }
595
596  CallInst *CI = Builder.CreateCall(F, Args.begin(), Args.end());
597  CI->setTailCall();
598  CI->setCallingConv(F->getCallingConv());
599  if (NewG->getReturnType()->isVoidTy()) {
600    Builder.CreateRetVoid();
601  } else {
602    Builder.CreateRet(Builder.CreateBitCast(CI, NewG->getReturnType()));
603  }
604
605  NewG->copyAttributesFrom(G);
606  NewG->takeName(G);
607  G->replaceAllUsesWith(NewG);
608  G->eraseFromParent();
609
610  DEBUG(dbgs() << "WriteThunk: " << NewG->getName() << '\n');
611  ++NumThunksWritten;
612}
613
614/// MergeTwoFunctions - Merge two equivalent functions. Upon completion,
615/// Function G is deleted.
616void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const {
617  if (F->mayBeOverridden()) {
618    assert(G->mayBeOverridden());
619
620    // Make them both thunks to the same internal function.
621    Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "",
622                                   F->getParent());
623    H->copyAttributesFrom(F);
624    H->takeName(F);
625    F->replaceAllUsesWith(H);
626
627    unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment());
628
629    WriteThunk(F, G);
630    WriteThunk(F, H);
631
632    F->setAlignment(MaxAlignment);
633    F->setLinkage(GlobalValue::InternalLinkage);
634
635    ++NumDoubleWeak;
636  } else {
637    WriteThunk(F, G);
638  }
639
640  ++NumFunctionsMerged;
641}
642
643// Insert - Insert a ComparableFunction into the FnSet, or merge it away if
644// equal to one that's already inserted.
645bool MergeFunctions::Insert(FnSetType &FnSet, ComparableFunction &NewF) {
646  std::pair<FnSetType::iterator, bool> Result = FnSet.insert(NewF);
647  if (Result.second)
648    return false;
649
650  const ComparableFunction &OldF = *Result.first;
651
652  // Never thunk a strong function to a weak function.
653  assert(!OldF.getFunc()->mayBeOverridden() ||
654         NewF.getFunc()->mayBeOverridden());
655
656  DEBUG(dbgs() << "  " << OldF.getFunc()->getName() << " == "
657               << NewF.getFunc()->getName() << '\n');
658
659  Function *DeleteF = NewF.getFunc();
660  NewF.release();
661  MergeTwoFunctions(OldF.getFunc(), DeleteF);
662  return true;
663}
664
665// IsThunk - This method determines whether or not a given Function is a thunk\// like the ones emitted by this pass and therefore not subject to further
666// merging.
667static bool IsThunk(const Function *F) {
668  // The safe direction to fail is to return true. In that case, the function
669  // will be removed from merging analysis. If we failed to including functions
670  // then we may try to merge unmergable thing (ie., identical weak functions)
671  // which will push us into an infinite loop.
672
673  assert(!F->isDeclaration() && "Expected a function definition.");
674
675  const BasicBlock *BB = &F->front();
676  // A thunk is:
677  //   bitcast-inst*
678  //   optional-reg tail call @thunkee(args...*)
679  //   ret void|optional-reg
680  // where the args are in the same order as the arguments.
681
682  // Put this at the top since it triggers most often.
683  const ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator());
684  if (!RI) return false;
685
686  // Verify that the sequence of bitcast-inst's are all casts of arguments and
687  // that there aren't any extras (ie. no repeated casts).
688  int LastArgNo = -1;
689  BasicBlock::const_iterator I = BB->begin();
690  while (const BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
691    const Argument *A = dyn_cast<Argument>(BCI->getOperand(0));
692    if (!A) return false;
693    if ((int)A->getArgNo() <= LastArgNo) return false;
694    LastArgNo = A->getArgNo();
695    ++I;
696  }
697
698  // Verify that we have a direct tail call and that the calling conventions
699  // and number of arguments match.
700  const CallInst *CI = dyn_cast<CallInst>(I++);
701  if (!CI || !CI->isTailCall() || !CI->getCalledFunction() ||
702      CI->getCallingConv() != CI->getCalledFunction()->getCallingConv() ||
703      CI->getNumArgOperands() != F->arg_size())
704    return false;
705
706  // Verify that the call instruction has the same arguments as this function
707  // and that they're all either the incoming argument or a cast of the right
708  // argument.
709  for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) {
710    const Value *V = CI->getArgOperand(i);
711    const Argument *A = dyn_cast<Argument>(V);
712    if (!A) {
713      const BitCastInst *BCI = dyn_cast<BitCastInst>(V);
714      if (!BCI) return false;
715      A = cast<Argument>(BCI->getOperand(0));
716    }
717    if (A->getArgNo() != i) return false;
718  }
719
720  // Verify that the terminator is a ret void (if we're void) or a ret of the
721  // call's return, or a ret of a bitcast of the call's return.
722  const Value *RetOp = CI;
723  if (const BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
724    ++I;
725    if (BCI->getOperand(0) != CI) return false;
726    RetOp = BCI;
727  }
728  if (RI != I) return false;
729  if (RI->getNumOperands() == 0)
730    return CI->getType()->isVoidTy();
731  return RI->getReturnValue() == CI;
732}
733
734bool MergeFunctions::runOnModule(Module &M) {
735  bool Changed = false;
736  TD = getAnalysisIfAvailable<TargetData>();
737
738  bool LocalChanged;
739  do {
740    DEBUG(dbgs() << "size of module: " << M.size() << '\n');
741    LocalChanged = false;
742    FnSetType FnSet;
743
744    // Insert only strong functions and merge them. Strong function merging
745    // always deletes one of them.
746    for (Module::iterator I = M.begin(), E = M.end(); I != E;) {
747      Function *F = I++;
748      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() &&
749          !F->mayBeOverridden() && !IsThunk(F)) {
750        ComparableFunction CF = ComparableFunction(F, TD);
751        LocalChanged |= Insert(FnSet, CF);
752      }
753    }
754
755    // Insert only weak functions and merge them. By doing these second we
756    // create thunks to the strong function when possible. When two weak
757    // functions are identical, we create a new strong function with two weak
758    // weak thunks to it which are identical but not mergable.
759    for (Module::iterator I = M.begin(), E = M.end(); I != E;) {
760      Function *F = I++;
761      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() &&
762          F->mayBeOverridden() && !IsThunk(F)) {
763        ComparableFunction CF = ComparableFunction(F, TD);
764        LocalChanged |= Insert(FnSet, CF);
765      }
766    }
767    DEBUG(dbgs() << "size of FnSet: " << FnSet.size() << '\n');
768    Changed |= LocalChanged;
769  } while (LocalChanged);
770
771  return Changed;
772}
773
774bool DenseMapInfo<ComparableFunction>::isEqual(const ComparableFunction &LHS,
775                                               const ComparableFunction &RHS) {
776  if (LHS.getFunc() == RHS.getFunc() &&
777      LHS.getHash() == RHS.getHash())
778    return true;
779  if (!LHS.getFunc() || !RHS.getFunc())
780    return false;
781  assert(LHS.getTD() == RHS.getTD() &&
782         "Comparing functions for different targets");
783  return FunctionComparator(LHS.getTD(),
784                            LHS.getFunc(), RHS.getFunc()).Compare();
785}
786