1//===--- TransBlockObjCVariable.cpp - Transformations to ARC mode ---------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// rewriteBlockObjCVariable:
11//
12// Adding __block to an obj-c variable could be either because the variable
13// is used for output storage or the user wanted to break a retain cycle.
14// This transformation checks whether a reference of the variable for the block
15// is actually needed (it is assigned to or its address is taken) or not.
16// If the reference is not needed it will assume __block was added to break a
17// cycle so it will remove '__block' and add __weak/__unsafe_unretained.
18// e.g
19//
20//   __block Foo *x;
21//   bar(^ { [x cake]; });
22// ---->
23//   __weak Foo *x;
24//   bar(^ { [x cake]; });
25//
26//===----------------------------------------------------------------------===//
27
28#include "Transforms.h"
29#include "Internals.h"
30#include "clang/AST/ASTContext.h"
31#include "clang/AST/Attr.h"
32#include "clang/Basic/SourceManager.h"
33
34using namespace clang;
35using namespace arcmt;
36using namespace trans;
37
38namespace {
39
40class RootBlockObjCVarRewriter :
41                          public RecursiveASTVisitor<RootBlockObjCVarRewriter> {
42  llvm::DenseSet<VarDecl *> &VarsToChange;
43
44  class BlockVarChecker : public RecursiveASTVisitor<BlockVarChecker> {
45    VarDecl *Var;
46
47    typedef RecursiveASTVisitor<BlockVarChecker> base;
48  public:
49    BlockVarChecker(VarDecl *var) : Var(var) { }
50
51    bool TraverseImplicitCastExpr(ImplicitCastExpr *castE) {
52      if (DeclRefExpr *
53            ref = dyn_cast<DeclRefExpr>(castE->getSubExpr())) {
54        if (ref->getDecl() == Var) {
55          if (castE->getCastKind() == CK_LValueToRValue)
56            return true; // Using the value of the variable.
57          if (castE->getCastKind() == CK_NoOp && castE->isLValue() &&
58              Var->getASTContext().getLangOpts().CPlusPlus)
59            return true; // Binding to const C++ reference.
60        }
61      }
62
63      return base::TraverseImplicitCastExpr(castE);
64    }
65
66    bool VisitDeclRefExpr(DeclRefExpr *E) {
67      if (E->getDecl() == Var)
68        return false; // The reference of the variable, and not just its value,
69                      //  is needed.
70      return true;
71    }
72  };
73
74public:
75  RootBlockObjCVarRewriter(llvm::DenseSet<VarDecl *> &VarsToChange)
76    : VarsToChange(VarsToChange) { }
77
78  bool VisitBlockDecl(BlockDecl *block) {
79    SmallVector<VarDecl *, 4> BlockVars;
80
81    for (const auto &I : block->captures()) {
82      VarDecl *var = I.getVariable();
83      if (I.isByRef() &&
84          var->getType()->isObjCObjectPointerType() &&
85          isImplicitStrong(var->getType())) {
86        BlockVars.push_back(var);
87      }
88    }
89
90    for (unsigned i = 0, e = BlockVars.size(); i != e; ++i) {
91      VarDecl *var = BlockVars[i];
92
93      BlockVarChecker checker(var);
94      bool onlyValueOfVarIsNeeded = checker.TraverseStmt(block->getBody());
95      if (onlyValueOfVarIsNeeded)
96        VarsToChange.insert(var);
97      else
98        VarsToChange.erase(var);
99    }
100
101    return true;
102  }
103
104private:
105  bool isImplicitStrong(QualType ty) {
106    if (isa<AttributedType>(ty.getTypePtr()))
107      return false;
108    return ty.getLocalQualifiers().getObjCLifetime() == Qualifiers::OCL_Strong;
109  }
110};
111
112class BlockObjCVarRewriter : public RecursiveASTVisitor<BlockObjCVarRewriter> {
113  llvm::DenseSet<VarDecl *> &VarsToChange;
114
115public:
116  BlockObjCVarRewriter(llvm::DenseSet<VarDecl *> &VarsToChange)
117    : VarsToChange(VarsToChange) { }
118
119  bool TraverseBlockDecl(BlockDecl *block) {
120    RootBlockObjCVarRewriter(VarsToChange).TraverseDecl(block);
121    return true;
122  }
123};
124
125} // anonymous namespace
126
127void BlockObjCVariableTraverser::traverseBody(BodyContext &BodyCtx) {
128  MigrationPass &Pass = BodyCtx.getMigrationContext().Pass;
129  llvm::DenseSet<VarDecl *> VarsToChange;
130
131  BlockObjCVarRewriter trans(VarsToChange);
132  trans.TraverseStmt(BodyCtx.getTopStmt());
133
134  for (llvm::DenseSet<VarDecl *>::iterator
135         I = VarsToChange.begin(), E = VarsToChange.end(); I != E; ++I) {
136    VarDecl *var = *I;
137    BlocksAttr *attr = var->getAttr<BlocksAttr>();
138    if(!attr)
139      continue;
140    bool useWeak = canApplyWeak(Pass.Ctx, var->getType());
141    SourceManager &SM = Pass.Ctx.getSourceManager();
142    Transaction Trans(Pass.TA);
143    Pass.TA.replaceText(SM.getExpansionLoc(attr->getLocation()),
144                        "__block",
145                        useWeak ? "__weak" : "__unsafe_unretained");
146  }
147}
148