1//===--- TransAutoreleasePool.cpp - Transformations to ARC mode -----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// rewriteAutoreleasePool:
11//
12// Calls to NSAutoreleasePools will be rewritten as an @autorelease scope.
13//
14//  NSAutoreleasePool *pool = [[NSAutoreleasePool alloc] init];
15//  ...
16//  [pool release];
17// ---->
18//  @autorelease {
19//  ...
20//  }
21//
22// An NSAutoreleasePool will not be touched if:
23// - There is not a corresponding -release/-drain in the same scope
24// - Not all references of the NSAutoreleasePool variable can be removed
25// - There is a variable that is declared inside the intended @autorelease scope
26//   which is also used outside it.
27//
28//===----------------------------------------------------------------------===//
29
30#include "Transforms.h"
31#include "Internals.h"
32#include "clang/AST/ASTContext.h"
33#include "clang/Basic/SourceManager.h"
34#include "clang/Sema/SemaDiagnostic.h"
35#include <map>
36
37using namespace clang;
38using namespace arcmt;
39using namespace trans;
40
41namespace {
42
43class ReleaseCollector : public RecursiveASTVisitor<ReleaseCollector> {
44  Decl *Dcl;
45  SmallVectorImpl<ObjCMessageExpr *> &Releases;
46
47public:
48  ReleaseCollector(Decl *D, SmallVectorImpl<ObjCMessageExpr *> &releases)
49    : Dcl(D), Releases(releases) { }
50
51  bool VisitObjCMessageExpr(ObjCMessageExpr *E) {
52    if (!E->isInstanceMessage())
53      return true;
54    if (E->getMethodFamily() != OMF_release)
55      return true;
56    Expr *instance = E->getInstanceReceiver()->IgnoreParenCasts();
57    if (DeclRefExpr *DE = dyn_cast<DeclRefExpr>(instance)) {
58      if (DE->getDecl() == Dcl)
59        Releases.push_back(E);
60    }
61    return true;
62  }
63};
64
65}
66
67namespace {
68
69class AutoreleasePoolRewriter
70                         : public RecursiveASTVisitor<AutoreleasePoolRewriter> {
71public:
72  AutoreleasePoolRewriter(MigrationPass &pass)
73    : Body(nullptr), Pass(pass) {
74    PoolII = &pass.Ctx.Idents.get("NSAutoreleasePool");
75    DrainSel = pass.Ctx.Selectors.getNullarySelector(
76                                                 &pass.Ctx.Idents.get("drain"));
77  }
78
79  void transformBody(Stmt *body, Decl *ParentD) {
80    Body = body;
81    TraverseStmt(body);
82  }
83
84  ~AutoreleasePoolRewriter() {
85    SmallVector<VarDecl *, 8> VarsToHandle;
86
87    for (std::map<VarDecl *, PoolVarInfo>::iterator
88           I = PoolVars.begin(), E = PoolVars.end(); I != E; ++I) {
89      VarDecl *var = I->first;
90      PoolVarInfo &info = I->second;
91
92      // Check that we can handle/rewrite all references of the pool.
93
94      clearRefsIn(info.Dcl, info.Refs);
95      for (SmallVectorImpl<PoolScope>::iterator
96             scpI = info.Scopes.begin(),
97             scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
98        PoolScope &scope = *scpI;
99        clearRefsIn(*scope.Begin, info.Refs);
100        clearRefsIn(*scope.End, info.Refs);
101        clearRefsIn(scope.Releases.begin(), scope.Releases.end(), info.Refs);
102      }
103
104      // Even if one reference is not handled we will not do anything about that
105      // pool variable.
106      if (info.Refs.empty())
107        VarsToHandle.push_back(var);
108    }
109
110    for (unsigned i = 0, e = VarsToHandle.size(); i != e; ++i) {
111      PoolVarInfo &info = PoolVars[VarsToHandle[i]];
112
113      Transaction Trans(Pass.TA);
114
115      clearUnavailableDiags(info.Dcl);
116      Pass.TA.removeStmt(info.Dcl);
117
118      // Add "@autoreleasepool { }"
119      for (SmallVectorImpl<PoolScope>::iterator
120             scpI = info.Scopes.begin(),
121             scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
122        PoolScope &scope = *scpI;
123        clearUnavailableDiags(*scope.Begin);
124        clearUnavailableDiags(*scope.End);
125        if (scope.IsFollowedBySimpleReturnStmt) {
126          // Include the return in the scope.
127          Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
128          Pass.TA.removeStmt(*scope.End);
129          Stmt::child_iterator retI = scope.End;
130          ++retI;
131          SourceLocation afterSemi = findLocationAfterSemi((*retI)->getLocEnd(),
132                                                           Pass.Ctx);
133          assert(afterSemi.isValid() &&
134                 "Didn't we check before setting IsFollowedBySimpleReturnStmt "
135                 "to true?");
136          Pass.TA.insertAfterToken(afterSemi, "\n}");
137          Pass.TA.increaseIndentation(
138                                SourceRange(scope.getIndentedRange().getBegin(),
139                                            (*retI)->getLocEnd()),
140                                      scope.CompoundParent->getLocStart());
141        } else {
142          Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
143          Pass.TA.replaceStmt(*scope.End, "}");
144          Pass.TA.increaseIndentation(scope.getIndentedRange(),
145                                      scope.CompoundParent->getLocStart());
146        }
147      }
148
149      // Remove rest of pool var references.
150      for (SmallVectorImpl<PoolScope>::iterator
151             scpI = info.Scopes.begin(),
152             scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
153        PoolScope &scope = *scpI;
154        for (SmallVectorImpl<ObjCMessageExpr *>::iterator
155               relI = scope.Releases.begin(),
156               relE = scope.Releases.end(); relI != relE; ++relI) {
157          clearUnavailableDiags(*relI);
158          Pass.TA.removeStmt(*relI);
159        }
160      }
161    }
162  }
163
164  bool VisitCompoundStmt(CompoundStmt *S) {
165    SmallVector<PoolScope, 4> Scopes;
166
167    for (Stmt::child_iterator
168           I = S->body_begin(), E = S->body_end(); I != E; ++I) {
169      Stmt *child = getEssential(*I);
170      if (DeclStmt *DclS = dyn_cast<DeclStmt>(child)) {
171        if (DclS->isSingleDecl()) {
172          if (VarDecl *VD = dyn_cast<VarDecl>(DclS->getSingleDecl())) {
173            if (isNSAutoreleasePool(VD->getType())) {
174              PoolVarInfo &info = PoolVars[VD];
175              info.Dcl = DclS;
176              collectRefs(VD, S, info.Refs);
177              // Does this statement follow the pattern:
178              // NSAutoreleasePool * pool = [NSAutoreleasePool  new];
179              if (isPoolCreation(VD->getInit())) {
180                Scopes.push_back(PoolScope());
181                Scopes.back().PoolVar = VD;
182                Scopes.back().CompoundParent = S;
183                Scopes.back().Begin = I;
184              }
185            }
186          }
187        }
188      } else if (BinaryOperator *bop = dyn_cast<BinaryOperator>(child)) {
189        if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(bop->getLHS())) {
190          if (VarDecl *VD = dyn_cast<VarDecl>(dref->getDecl())) {
191            // Does this statement follow the pattern:
192            // pool = [NSAutoreleasePool  new];
193            if (isNSAutoreleasePool(VD->getType()) &&
194                isPoolCreation(bop->getRHS())) {
195              Scopes.push_back(PoolScope());
196              Scopes.back().PoolVar = VD;
197              Scopes.back().CompoundParent = S;
198              Scopes.back().Begin = I;
199            }
200          }
201        }
202      }
203
204      if (Scopes.empty())
205        continue;
206
207      if (isPoolDrain(Scopes.back().PoolVar, child)) {
208        PoolScope &scope = Scopes.back();
209        scope.End = I;
210        handlePoolScope(scope, S);
211        Scopes.pop_back();
212      }
213    }
214    return true;
215  }
216
217private:
218  void clearUnavailableDiags(Stmt *S) {
219    if (S)
220      Pass.TA.clearDiagnostic(diag::err_unavailable,
221                              diag::err_unavailable_message,
222                              S->getSourceRange());
223  }
224
225  struct PoolScope {
226    VarDecl *PoolVar;
227    CompoundStmt *CompoundParent;
228    Stmt::child_iterator Begin;
229    Stmt::child_iterator End;
230    bool IsFollowedBySimpleReturnStmt;
231    SmallVector<ObjCMessageExpr *, 4> Releases;
232
233    PoolScope() : PoolVar(nullptr), CompoundParent(nullptr), Begin(), End(),
234                  IsFollowedBySimpleReturnStmt(false) { }
235
236    SourceRange getIndentedRange() const {
237      Stmt::child_iterator rangeS = Begin;
238      ++rangeS;
239      if (rangeS == End)
240        return SourceRange();
241      Stmt::child_iterator rangeE = Begin;
242      for (Stmt::child_iterator I = rangeS; I != End; ++I)
243        ++rangeE;
244      return SourceRange((*rangeS)->getLocStart(), (*rangeE)->getLocEnd());
245    }
246  };
247
248  class NameReferenceChecker : public RecursiveASTVisitor<NameReferenceChecker>{
249    ASTContext &Ctx;
250    SourceRange ScopeRange;
251    SourceLocation &referenceLoc, &declarationLoc;
252
253  public:
254    NameReferenceChecker(ASTContext &ctx, PoolScope &scope,
255                         SourceLocation &referenceLoc,
256                         SourceLocation &declarationLoc)
257      : Ctx(ctx), referenceLoc(referenceLoc),
258        declarationLoc(declarationLoc) {
259      ScopeRange = SourceRange((*scope.Begin)->getLocStart(),
260                               (*scope.End)->getLocStart());
261    }
262
263    bool VisitDeclRefExpr(DeclRefExpr *E) {
264      return checkRef(E->getLocation(), E->getDecl()->getLocation());
265    }
266
267    bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
268      return checkRef(TL.getBeginLoc(), TL.getTypedefNameDecl()->getLocation());
269    }
270
271    bool VisitTagTypeLoc(TagTypeLoc TL) {
272      return checkRef(TL.getBeginLoc(), TL.getDecl()->getLocation());
273    }
274
275  private:
276    bool checkRef(SourceLocation refLoc, SourceLocation declLoc) {
277      if (isInScope(declLoc)) {
278        referenceLoc = refLoc;
279        declarationLoc = declLoc;
280        return false;
281      }
282      return true;
283    }
284
285    bool isInScope(SourceLocation loc) {
286      if (loc.isInvalid())
287        return false;
288
289      SourceManager &SM = Ctx.getSourceManager();
290      if (SM.isBeforeInTranslationUnit(loc, ScopeRange.getBegin()))
291        return false;
292      return SM.isBeforeInTranslationUnit(loc, ScopeRange.getEnd());
293    }
294  };
295
296  void handlePoolScope(PoolScope &scope, CompoundStmt *compoundS) {
297    // Check that all names declared inside the scope are not used
298    // outside the scope.
299    {
300      bool nameUsedOutsideScope = false;
301      SourceLocation referenceLoc, declarationLoc;
302      Stmt::child_iterator SI = scope.End, SE = compoundS->body_end();
303      ++SI;
304      // Check if the autoreleasepool scope is followed by a simple return
305      // statement, in which case we will include the return in the scope.
306      if (SI != SE)
307        if (ReturnStmt *retS = dyn_cast<ReturnStmt>(*SI))
308          if ((retS->getRetValue() == nullptr ||
309               isa<DeclRefExpr>(retS->getRetValue()->IgnoreParenCasts())) &&
310              findLocationAfterSemi(retS->getLocEnd(), Pass.Ctx).isValid()) {
311            scope.IsFollowedBySimpleReturnStmt = true;
312            ++SI; // the return will be included in scope, don't check it.
313          }
314
315      for (; SI != SE; ++SI) {
316        nameUsedOutsideScope = !NameReferenceChecker(Pass.Ctx, scope,
317                                                     referenceLoc,
318                                              declarationLoc).TraverseStmt(*SI);
319        if (nameUsedOutsideScope)
320          break;
321      }
322
323      // If not all references were cleared it means some variables/typenames/etc
324      // declared inside the pool scope are used outside of it.
325      // We won't try to rewrite the pool.
326      if (nameUsedOutsideScope) {
327        Pass.TA.reportError("a name is referenced outside the "
328            "NSAutoreleasePool scope that it was declared in", referenceLoc);
329        Pass.TA.reportNote("name declared here", declarationLoc);
330        Pass.TA.reportNote("intended @autoreleasepool scope begins here",
331                           (*scope.Begin)->getLocStart());
332        Pass.TA.reportNote("intended @autoreleasepool scope ends here",
333                           (*scope.End)->getLocStart());
334        return;
335      }
336    }
337
338    // Collect all releases of the pool; they will be removed.
339    {
340      ReleaseCollector releaseColl(scope.PoolVar, scope.Releases);
341      Stmt::child_iterator I = scope.Begin;
342      ++I;
343      for (; I != scope.End; ++I)
344        releaseColl.TraverseStmt(*I);
345    }
346
347    PoolVars[scope.PoolVar].Scopes.push_back(scope);
348  }
349
350  bool isPoolCreation(Expr *E) {
351    if (!E) return false;
352    E = getEssential(E);
353    ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
354    if (!ME) return false;
355    if (ME->getMethodFamily() == OMF_new &&
356        ME->getReceiverKind() == ObjCMessageExpr::Class &&
357        isNSAutoreleasePool(ME->getReceiverInterface()))
358      return true;
359    if (ME->getReceiverKind() == ObjCMessageExpr::Instance &&
360        ME->getMethodFamily() == OMF_init) {
361      Expr *rec = getEssential(ME->getInstanceReceiver());
362      if (ObjCMessageExpr *recME = dyn_cast_or_null<ObjCMessageExpr>(rec)) {
363        if (recME->getMethodFamily() == OMF_alloc &&
364            recME->getReceiverKind() == ObjCMessageExpr::Class &&
365            isNSAutoreleasePool(recME->getReceiverInterface()))
366          return true;
367      }
368    }
369
370    return false;
371  }
372
373  bool isPoolDrain(VarDecl *poolVar, Stmt *S) {
374    if (!S) return false;
375    S = getEssential(S);
376    ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(S);
377    if (!ME) return false;
378    if (ME->getReceiverKind() == ObjCMessageExpr::Instance) {
379      Expr *rec = getEssential(ME->getInstanceReceiver());
380      if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(rec))
381        if (dref->getDecl() == poolVar)
382          return ME->getMethodFamily() == OMF_release ||
383                 ME->getSelector() == DrainSel;
384    }
385
386    return false;
387  }
388
389  bool isNSAutoreleasePool(ObjCInterfaceDecl *IDecl) {
390    return IDecl && IDecl->getIdentifier() == PoolII;
391  }
392
393  bool isNSAutoreleasePool(QualType Ty) {
394    QualType pointee = Ty->getPointeeType();
395    if (pointee.isNull())
396      return false;
397    if (const ObjCInterfaceType *interT = pointee->getAs<ObjCInterfaceType>())
398      return isNSAutoreleasePool(interT->getDecl());
399    return false;
400  }
401
402  static Expr *getEssential(Expr *E) {
403    return cast<Expr>(getEssential((Stmt*)E));
404  }
405  static Stmt *getEssential(Stmt *S) {
406    if (ExprWithCleanups *EWC = dyn_cast<ExprWithCleanups>(S))
407      S = EWC->getSubExpr();
408    if (Expr *E = dyn_cast<Expr>(S))
409      S = E->IgnoreParenCasts();
410    return S;
411  }
412
413  Stmt *Body;
414  MigrationPass &Pass;
415
416  IdentifierInfo *PoolII;
417  Selector DrainSel;
418
419  struct PoolVarInfo {
420    DeclStmt *Dcl;
421    ExprSet Refs;
422    SmallVector<PoolScope, 2> Scopes;
423
424    PoolVarInfo() : Dcl(nullptr) { }
425  };
426
427  std::map<VarDecl *, PoolVarInfo> PoolVars;
428};
429
430} // anonymous namespace
431
432void trans::rewriteAutoreleasePool(MigrationPass &pass) {
433  BodyTransform<AutoreleasePoolRewriter> trans(pass);
434  trans.TraverseDecl(pass.Ctx.getTranslationUnitDecl());
435}
436