1//===--- TransBlockObjCVariable.cpp - Tranformations to ARC mode ----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// rewriteBlockObjCVariable:
11//
12// Adding __block to an obj-c variable could be either because the the variable
13// is used for output storage or the user wanted to break a retain cycle.
14// This transformation checks whether a reference of the variable for the block
15// is actually needed (it is assigned to or its address is taken) or not.
16// If the reference is not needed it will assume __block was added to break a
17// cycle so it will remove '__block' and add __weak/__unsafe_unretained.
18// e.g
19//
20//   __block Foo *x;
21//   bar(^ { [x cake]; });
22// ---->
23//   __weak Foo *x;
24//   bar(^ { [x cake]; });
25//
26//===----------------------------------------------------------------------===//
27
28#include "Transforms.h"
29#include "Internals.h"
30#include "clang/Basic/SourceManager.h"
31
32using namespace clang;
33using namespace arcmt;
34using namespace trans;
35
36namespace {
37
38class RootBlockObjCVarRewriter :
39                          public RecursiveASTVisitor<RootBlockObjCVarRewriter> {
40  MigrationPass &Pass;
41  llvm::DenseSet<VarDecl *> &VarsToChange;
42
43  class BlockVarChecker : public RecursiveASTVisitor<BlockVarChecker> {
44    VarDecl *Var;
45
46    typedef RecursiveASTVisitor<BlockVarChecker> base;
47  public:
48    BlockVarChecker(VarDecl *var) : Var(var) { }
49
50    bool TraverseImplicitCastExpr(ImplicitCastExpr *castE) {
51      if (DeclRefExpr *
52            ref = dyn_cast<DeclRefExpr>(castE->getSubExpr())) {
53        if (ref->getDecl() == Var) {
54          if (castE->getCastKind() == CK_LValueToRValue)
55            return true; // Using the value of the variable.
56          if (castE->getCastKind() == CK_NoOp && castE->isLValue() &&
57              Var->getASTContext().getLangOpts().CPlusPlus)
58            return true; // Binding to const C++ reference.
59        }
60      }
61
62      return base::TraverseImplicitCastExpr(castE);
63    }
64
65    bool VisitDeclRefExpr(DeclRefExpr *E) {
66      if (E->getDecl() == Var)
67        return false; // The reference of the variable, and not just its value,
68                      //  is needed.
69      return true;
70    }
71  };
72
73public:
74  RootBlockObjCVarRewriter(MigrationPass &pass,
75                           llvm::DenseSet<VarDecl *> &VarsToChange)
76    : Pass(pass), VarsToChange(VarsToChange) { }
77
78  bool VisitBlockDecl(BlockDecl *block) {
79    SmallVector<VarDecl *, 4> BlockVars;
80
81    for (BlockDecl::capture_iterator
82           I = block->capture_begin(), E = block->capture_end(); I != E; ++I) {
83      VarDecl *var = I->getVariable();
84      if (I->isByRef() &&
85          var->getType()->isObjCObjectPointerType() &&
86          isImplicitStrong(var->getType())) {
87        BlockVars.push_back(var);
88      }
89    }
90
91    for (unsigned i = 0, e = BlockVars.size(); i != e; ++i) {
92      VarDecl *var = BlockVars[i];
93
94      BlockVarChecker checker(var);
95      bool onlyValueOfVarIsNeeded = checker.TraverseStmt(block->getBody());
96      if (onlyValueOfVarIsNeeded)
97        VarsToChange.insert(var);
98      else
99        VarsToChange.erase(var);
100    }
101
102    return true;
103  }
104
105private:
106  bool isImplicitStrong(QualType ty) {
107    if (isa<AttributedType>(ty.getTypePtr()))
108      return false;
109    return ty.getLocalQualifiers().getObjCLifetime() == Qualifiers::OCL_Strong;
110  }
111};
112
113class BlockObjCVarRewriter : public RecursiveASTVisitor<BlockObjCVarRewriter> {
114  MigrationPass &Pass;
115  llvm::DenseSet<VarDecl *> &VarsToChange;
116
117public:
118  BlockObjCVarRewriter(MigrationPass &pass,
119                       llvm::DenseSet<VarDecl *> &VarsToChange)
120    : Pass(pass), VarsToChange(VarsToChange) { }
121
122  bool TraverseBlockDecl(BlockDecl *block) {
123    RootBlockObjCVarRewriter(Pass, VarsToChange).TraverseDecl(block);
124    return true;
125  }
126};
127
128} // anonymous namespace
129
130void BlockObjCVariableTraverser::traverseBody(BodyContext &BodyCtx) {
131  MigrationPass &Pass = BodyCtx.getMigrationContext().Pass;
132  llvm::DenseSet<VarDecl *> VarsToChange;
133
134  BlockObjCVarRewriter trans(Pass, VarsToChange);
135  trans.TraverseStmt(BodyCtx.getTopStmt());
136
137  for (llvm::DenseSet<VarDecl *>::iterator
138         I = VarsToChange.begin(), E = VarsToChange.end(); I != E; ++I) {
139    VarDecl *var = *I;
140    BlocksAttr *attr = var->getAttr<BlocksAttr>();
141    if(!attr)
142      continue;
143    bool useWeak = canApplyWeak(Pass.Ctx, var->getType());
144    SourceManager &SM = Pass.Ctx.getSourceManager();
145    Transaction Trans(Pass.TA);
146    Pass.TA.replaceText(SM.getExpansionLoc(attr->getLocation()),
147                        "__block",
148                        useWeak ? "__weak" : "__unsafe_unretained");
149  }
150}
151