1//===--- Tranforms.cpp - Tranformations to ARC mode -----------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#include "Transforms.h"
11#include "Internals.h"
12#include "clang/Analysis/DomainSpecific/CocoaConventions.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/RecursiveASTVisitor.h"
15#include "clang/AST/StmtVisitor.h"
16#include "clang/Basic/SourceManager.h"
17#include "clang/Lex/Lexer.h"
18#include "clang/Sema/Sema.h"
19#include "clang/Sema/SemaDiagnostic.h"
20#include "llvm/ADT/StringSwitch.h"
21#include "llvm/ADT/DenseSet.h"
22#include <map>
23
24using namespace clang;
25using namespace arcmt;
26using namespace trans;
27
28ASTTraverser::~ASTTraverser() { }
29
30bool MigrationPass::CFBridgingFunctionsDefined() {
31  if (!EnableCFBridgeFns.hasValue())
32    EnableCFBridgeFns = SemaRef.isKnownName("CFBridgingRetain") &&
33                        SemaRef.isKnownName("CFBridgingRelease");
34  return *EnableCFBridgeFns;
35}
36
37//===----------------------------------------------------------------------===//
38// Helpers.
39//===----------------------------------------------------------------------===//
40
41bool trans::canApplyWeak(ASTContext &Ctx, QualType type,
42                         bool AllowOnUnknownClass) {
43  if (!Ctx.getLangOpts().ObjCARCWeak)
44    return false;
45
46  QualType T = type;
47  if (T.isNull())
48    return false;
49
50  // iOS is always safe to use 'weak'.
51  if (Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::IOS)
52    AllowOnUnknownClass = true;
53
54  while (const PointerType *ptr = T->getAs<PointerType>())
55    T = ptr->getPointeeType();
56  if (const ObjCObjectPointerType *ObjT = T->getAs<ObjCObjectPointerType>()) {
57    ObjCInterfaceDecl *Class = ObjT->getInterfaceDecl();
58    if (!AllowOnUnknownClass && (!Class || Class->getName() == "NSObject"))
59      return false; // id/NSObject is not safe for weak.
60    if (!AllowOnUnknownClass && !Class->hasDefinition())
61      return false; // forward classes are not verifiable, therefore not safe.
62    if (Class && Class->isArcWeakrefUnavailable())
63      return false;
64  }
65
66  return true;
67}
68
69bool trans::isPlusOneAssign(const BinaryOperator *E) {
70  if (E->getOpcode() != BO_Assign)
71    return false;
72
73  if (const ObjCMessageExpr *
74        ME = dyn_cast<ObjCMessageExpr>(E->getRHS()->IgnoreParenCasts()))
75    if (ME->getMethodFamily() == OMF_retain)
76      return true;
77
78  if (const CallExpr *
79        callE = dyn_cast<CallExpr>(E->getRHS()->IgnoreParenCasts())) {
80    if (const FunctionDecl *FD = callE->getDirectCallee()) {
81      if (FD->getAttr<CFReturnsRetainedAttr>())
82        return true;
83
84      if (FD->isGlobal() &&
85          FD->getIdentifier() &&
86          FD->getParent()->isTranslationUnit() &&
87          FD->getLinkage() == ExternalLinkage &&
88          ento::cocoa::isRefType(callE->getType(), "CF",
89                                 FD->getIdentifier()->getName())) {
90        StringRef fname = FD->getIdentifier()->getName();
91        if (fname.endswith("Retain") ||
92            fname.find("Create") != StringRef::npos ||
93            fname.find("Copy") != StringRef::npos) {
94          return true;
95        }
96      }
97    }
98  }
99
100  const ImplicitCastExpr *implCE = dyn_cast<ImplicitCastExpr>(E->getRHS());
101  while (implCE && implCE->getCastKind() ==  CK_BitCast)
102    implCE = dyn_cast<ImplicitCastExpr>(implCE->getSubExpr());
103
104  if (implCE && implCE->getCastKind() == CK_ARCConsumeObject)
105    return true;
106
107  return false;
108}
109
110/// \brief 'Loc' is the end of a statement range. This returns the location
111/// immediately after the semicolon following the statement.
112/// If no semicolon is found or the location is inside a macro, the returned
113/// source location will be invalid.
114SourceLocation trans::findLocationAfterSemi(SourceLocation loc,
115                                            ASTContext &Ctx) {
116  SourceLocation SemiLoc = findSemiAfterLocation(loc, Ctx);
117  if (SemiLoc.isInvalid())
118    return SourceLocation();
119  return SemiLoc.getLocWithOffset(1);
120}
121
122/// \brief \arg Loc is the end of a statement range. This returns the location
123/// of the semicolon following the statement.
124/// If no semicolon is found or the location is inside a macro, the returned
125/// source location will be invalid.
126SourceLocation trans::findSemiAfterLocation(SourceLocation loc,
127                                            ASTContext &Ctx) {
128  SourceManager &SM = Ctx.getSourceManager();
129  if (loc.isMacroID()) {
130    if (!Lexer::isAtEndOfMacroExpansion(loc, SM, Ctx.getLangOpts(), &loc))
131      return SourceLocation();
132  }
133  loc = Lexer::getLocForEndOfToken(loc, /*Offset=*/0, SM, Ctx.getLangOpts());
134
135  // Break down the source location.
136  std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(loc);
137
138  // Try to load the file buffer.
139  bool invalidTemp = false;
140  StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
141  if (invalidTemp)
142    return SourceLocation();
143
144  const char *tokenBegin = file.data() + locInfo.second;
145
146  // Lex from the start of the given location.
147  Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
148              Ctx.getLangOpts(),
149              file.begin(), tokenBegin, file.end());
150  Token tok;
151  lexer.LexFromRawLexer(tok);
152  if (tok.isNot(tok::semi))
153    return SourceLocation();
154
155  return tok.getLocation();
156}
157
158bool trans::hasSideEffects(Expr *E, ASTContext &Ctx) {
159  if (!E || !E->HasSideEffects(Ctx))
160    return false;
161
162  E = E->IgnoreParenCasts();
163  ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
164  if (!ME)
165    return true;
166  switch (ME->getMethodFamily()) {
167  case OMF_autorelease:
168  case OMF_dealloc:
169  case OMF_release:
170  case OMF_retain:
171    switch (ME->getReceiverKind()) {
172    case ObjCMessageExpr::SuperInstance:
173      return false;
174    case ObjCMessageExpr::Instance:
175      return hasSideEffects(ME->getInstanceReceiver(), Ctx);
176    default:
177      break;
178    }
179    break;
180  default:
181    break;
182  }
183
184  return true;
185}
186
187bool trans::isGlobalVar(Expr *E) {
188  E = E->IgnoreParenCasts();
189  if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E))
190    return DRE->getDecl()->getDeclContext()->isFileContext() &&
191           DRE->getDecl()->getLinkage() == ExternalLinkage;
192  if (ConditionalOperator *condOp = dyn_cast<ConditionalOperator>(E))
193    return isGlobalVar(condOp->getTrueExpr()) &&
194           isGlobalVar(condOp->getFalseExpr());
195
196  return false;
197}
198
199StringRef trans::getNilString(ASTContext &Ctx) {
200  if (Ctx.Idents.get("nil").hasMacroDefinition())
201    return "nil";
202  else
203    return "0";
204}
205
206namespace {
207
208class ReferenceClear : public RecursiveASTVisitor<ReferenceClear> {
209  ExprSet &Refs;
210public:
211  ReferenceClear(ExprSet &refs) : Refs(refs) { }
212  bool VisitDeclRefExpr(DeclRefExpr *E) { Refs.erase(E); return true; }
213};
214
215class ReferenceCollector : public RecursiveASTVisitor<ReferenceCollector> {
216  ValueDecl *Dcl;
217  ExprSet &Refs;
218
219public:
220  ReferenceCollector(ValueDecl *D, ExprSet &refs)
221    : Dcl(D), Refs(refs) { }
222
223  bool VisitDeclRefExpr(DeclRefExpr *E) {
224    if (E->getDecl() == Dcl)
225      Refs.insert(E);
226    return true;
227  }
228};
229
230class RemovablesCollector : public RecursiveASTVisitor<RemovablesCollector> {
231  ExprSet &Removables;
232
233public:
234  RemovablesCollector(ExprSet &removables)
235  : Removables(removables) { }
236
237  bool shouldWalkTypesOfTypeLocs() const { return false; }
238
239  bool TraverseStmtExpr(StmtExpr *E) {
240    CompoundStmt *S = E->getSubStmt();
241    for (CompoundStmt::body_iterator
242        I = S->body_begin(), E = S->body_end(); I != E; ++I) {
243      if (I != E - 1)
244        mark(*I);
245      TraverseStmt(*I);
246    }
247    return true;
248  }
249
250  bool VisitCompoundStmt(CompoundStmt *S) {
251    for (CompoundStmt::body_iterator
252        I = S->body_begin(), E = S->body_end(); I != E; ++I)
253      mark(*I);
254    return true;
255  }
256
257  bool VisitIfStmt(IfStmt *S) {
258    mark(S->getThen());
259    mark(S->getElse());
260    return true;
261  }
262
263  bool VisitWhileStmt(WhileStmt *S) {
264    mark(S->getBody());
265    return true;
266  }
267
268  bool VisitDoStmt(DoStmt *S) {
269    mark(S->getBody());
270    return true;
271  }
272
273  bool VisitForStmt(ForStmt *S) {
274    mark(S->getInit());
275    mark(S->getInc());
276    mark(S->getBody());
277    return true;
278  }
279
280private:
281  void mark(Stmt *S) {
282    if (!S) return;
283
284    while (LabelStmt *Label = dyn_cast<LabelStmt>(S))
285      S = Label->getSubStmt();
286    S = S->IgnoreImplicit();
287    if (Expr *E = dyn_cast<Expr>(S))
288      Removables.insert(E);
289  }
290};
291
292} // end anonymous namespace
293
294void trans::clearRefsIn(Stmt *S, ExprSet &refs) {
295  ReferenceClear(refs).TraverseStmt(S);
296}
297
298void trans::collectRefs(ValueDecl *D, Stmt *S, ExprSet &refs) {
299  ReferenceCollector(D, refs).TraverseStmt(S);
300}
301
302void trans::collectRemovables(Stmt *S, ExprSet &exprs) {
303  RemovablesCollector(exprs).TraverseStmt(S);
304}
305
306//===----------------------------------------------------------------------===//
307// MigrationContext
308//===----------------------------------------------------------------------===//
309
310namespace {
311
312class ASTTransform : public RecursiveASTVisitor<ASTTransform> {
313  MigrationContext &MigrateCtx;
314  typedef RecursiveASTVisitor<ASTTransform> base;
315
316public:
317  ASTTransform(MigrationContext &MigrateCtx) : MigrateCtx(MigrateCtx) { }
318
319  bool shouldWalkTypesOfTypeLocs() const { return false; }
320
321  bool TraverseObjCImplementationDecl(ObjCImplementationDecl *D) {
322    ObjCImplementationContext ImplCtx(MigrateCtx, D);
323    for (MigrationContext::traverser_iterator
324           I = MigrateCtx.traversers_begin(),
325           E = MigrateCtx.traversers_end(); I != E; ++I)
326      (*I)->traverseObjCImplementation(ImplCtx);
327
328    return base::TraverseObjCImplementationDecl(D);
329  }
330
331  bool TraverseStmt(Stmt *rootS) {
332    if (!rootS)
333      return true;
334
335    BodyContext BodyCtx(MigrateCtx, rootS);
336    for (MigrationContext::traverser_iterator
337           I = MigrateCtx.traversers_begin(),
338           E = MigrateCtx.traversers_end(); I != E; ++I)
339      (*I)->traverseBody(BodyCtx);
340
341    return true;
342  }
343};
344
345}
346
347MigrationContext::~MigrationContext() {
348  for (traverser_iterator
349         I = traversers_begin(), E = traversers_end(); I != E; ++I)
350    delete *I;
351}
352
353bool MigrationContext::isGCOwnedNonObjC(QualType T) {
354  while (!T.isNull()) {
355    if (const AttributedType *AttrT = T->getAs<AttributedType>()) {
356      if (AttrT->getAttrKind() == AttributedType::attr_objc_ownership)
357        return !AttrT->getModifiedType()->isObjCRetainableType();
358    }
359
360    if (T->isArrayType())
361      T = Pass.Ctx.getBaseElementType(T);
362    else if (const PointerType *PT = T->getAs<PointerType>())
363      T = PT->getPointeeType();
364    else if (const ReferenceType *RT = T->getAs<ReferenceType>())
365      T = RT->getPointeeType();
366    else
367      break;
368  }
369
370  return false;
371}
372
373bool MigrationContext::rewritePropertyAttribute(StringRef fromAttr,
374                                                StringRef toAttr,
375                                                SourceLocation atLoc) {
376  if (atLoc.isMacroID())
377    return false;
378
379  SourceManager &SM = Pass.Ctx.getSourceManager();
380
381  // Break down the source location.
382  std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(atLoc);
383
384  // Try to load the file buffer.
385  bool invalidTemp = false;
386  StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
387  if (invalidTemp)
388    return false;
389
390  const char *tokenBegin = file.data() + locInfo.second;
391
392  // Lex from the start of the given location.
393  Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
394              Pass.Ctx.getLangOpts(),
395              file.begin(), tokenBegin, file.end());
396  Token tok;
397  lexer.LexFromRawLexer(tok);
398  if (tok.isNot(tok::at)) return false;
399  lexer.LexFromRawLexer(tok);
400  if (tok.isNot(tok::raw_identifier)) return false;
401  if (StringRef(tok.getRawIdentifierData(), tok.getLength())
402        != "property")
403    return false;
404  lexer.LexFromRawLexer(tok);
405  if (tok.isNot(tok::l_paren)) return false;
406
407  Token BeforeTok = tok;
408  Token AfterTok;
409  AfterTok.startToken();
410  SourceLocation AttrLoc;
411
412  lexer.LexFromRawLexer(tok);
413  if (tok.is(tok::r_paren))
414    return false;
415
416  while (1) {
417    if (tok.isNot(tok::raw_identifier)) return false;
418    StringRef ident(tok.getRawIdentifierData(), tok.getLength());
419    if (ident == fromAttr) {
420      if (!toAttr.empty()) {
421        Pass.TA.replaceText(tok.getLocation(), fromAttr, toAttr);
422        return true;
423      }
424      // We want to remove the attribute.
425      AttrLoc = tok.getLocation();
426    }
427
428    do {
429      lexer.LexFromRawLexer(tok);
430      if (AttrLoc.isValid() && AfterTok.is(tok::unknown))
431        AfterTok = tok;
432    } while (tok.isNot(tok::comma) && tok.isNot(tok::r_paren));
433    if (tok.is(tok::r_paren))
434      break;
435    if (AttrLoc.isInvalid())
436      BeforeTok = tok;
437    lexer.LexFromRawLexer(tok);
438  }
439
440  if (toAttr.empty() && AttrLoc.isValid() && AfterTok.isNot(tok::unknown)) {
441    // We want to remove the attribute.
442    if (BeforeTok.is(tok::l_paren) && AfterTok.is(tok::r_paren)) {
443      Pass.TA.remove(SourceRange(BeforeTok.getLocation(),
444                                 AfterTok.getLocation()));
445    } else if (BeforeTok.is(tok::l_paren) && AfterTok.is(tok::comma)) {
446      Pass.TA.remove(SourceRange(AttrLoc, AfterTok.getLocation()));
447    } else {
448      Pass.TA.remove(SourceRange(BeforeTok.getLocation(), AttrLoc));
449    }
450
451    return true;
452  }
453
454  return false;
455}
456
457bool MigrationContext::addPropertyAttribute(StringRef attr,
458                                            SourceLocation atLoc) {
459  if (atLoc.isMacroID())
460    return false;
461
462  SourceManager &SM = Pass.Ctx.getSourceManager();
463
464  // Break down the source location.
465  std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(atLoc);
466
467  // Try to load the file buffer.
468  bool invalidTemp = false;
469  StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
470  if (invalidTemp)
471    return false;
472
473  const char *tokenBegin = file.data() + locInfo.second;
474
475  // Lex from the start of the given location.
476  Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
477              Pass.Ctx.getLangOpts(),
478              file.begin(), tokenBegin, file.end());
479  Token tok;
480  lexer.LexFromRawLexer(tok);
481  if (tok.isNot(tok::at)) return false;
482  lexer.LexFromRawLexer(tok);
483  if (tok.isNot(tok::raw_identifier)) return false;
484  if (StringRef(tok.getRawIdentifierData(), tok.getLength())
485        != "property")
486    return false;
487  lexer.LexFromRawLexer(tok);
488
489  if (tok.isNot(tok::l_paren)) {
490    Pass.TA.insert(tok.getLocation(), std::string("(") + attr.str() + ") ");
491    return true;
492  }
493
494  lexer.LexFromRawLexer(tok);
495  if (tok.is(tok::r_paren)) {
496    Pass.TA.insert(tok.getLocation(), attr);
497    return true;
498  }
499
500  if (tok.isNot(tok::raw_identifier)) return false;
501
502  Pass.TA.insert(tok.getLocation(), std::string(attr) + ", ");
503  return true;
504}
505
506void MigrationContext::traverse(TranslationUnitDecl *TU) {
507  for (traverser_iterator
508         I = traversers_begin(), E = traversers_end(); I != E; ++I)
509    (*I)->traverseTU(*this);
510
511  ASTTransform(*this).TraverseDecl(TU);
512}
513
514static void GCRewriteFinalize(MigrationPass &pass) {
515  ASTContext &Ctx = pass.Ctx;
516  TransformActions &TA = pass.TA;
517  DeclContext *DC = Ctx.getTranslationUnitDecl();
518  Selector FinalizeSel =
519   Ctx.Selectors.getNullarySelector(&pass.Ctx.Idents.get("finalize"));
520
521  typedef DeclContext::specific_decl_iterator<ObjCImplementationDecl>
522  impl_iterator;
523  for (impl_iterator I = impl_iterator(DC->decls_begin()),
524       E = impl_iterator(DC->decls_end()); I != E; ++I) {
525    for (ObjCImplementationDecl::instmeth_iterator
526         MI = I->instmeth_begin(),
527         ME = I->instmeth_end(); MI != ME; ++MI) {
528      ObjCMethodDecl *MD = *MI;
529      if (!MD->hasBody())
530        continue;
531
532      if (MD->isInstanceMethod() && MD->getSelector() == FinalizeSel) {
533        ObjCMethodDecl *FinalizeM = MD;
534        Transaction Trans(TA);
535        TA.insert(FinalizeM->getSourceRange().getBegin(),
536                  "#if !__has_feature(objc_arc)\n");
537        CharSourceRange::getTokenRange(FinalizeM->getSourceRange());
538        const SourceManager &SM = pass.Ctx.getSourceManager();
539        const LangOptions &LangOpts = pass.Ctx.getLangOpts();
540        bool Invalid;
541        std::string str = "\n#endif\n";
542        str += Lexer::getSourceText(
543                  CharSourceRange::getTokenRange(FinalizeM->getSourceRange()),
544                                    SM, LangOpts, &Invalid);
545        TA.insertAfterToken(FinalizeM->getSourceRange().getEnd(), str);
546
547        break;
548      }
549    }
550  }
551}
552
553//===----------------------------------------------------------------------===//
554// getAllTransformations.
555//===----------------------------------------------------------------------===//
556
557static void traverseAST(MigrationPass &pass) {
558  MigrationContext MigrateCtx(pass);
559
560  if (pass.isGCMigration()) {
561    MigrateCtx.addTraverser(new GCCollectableCallsTraverser);
562    MigrateCtx.addTraverser(new GCAttrsTraverser());
563  }
564  MigrateCtx.addTraverser(new PropertyRewriteTraverser());
565  MigrateCtx.addTraverser(new BlockObjCVariableTraverser());
566
567  MigrateCtx.traverse(pass.Ctx.getTranslationUnitDecl());
568}
569
570static void independentTransforms(MigrationPass &pass) {
571  rewriteAutoreleasePool(pass);
572  removeRetainReleaseDeallocFinalize(pass);
573  rewriteUnusedInitDelegate(pass);
574  removeZeroOutPropsInDeallocFinalize(pass);
575  makeAssignARCSafe(pass);
576  rewriteUnbridgedCasts(pass);
577  checkAPIUses(pass);
578  traverseAST(pass);
579}
580
581std::vector<TransformFn> arcmt::getAllTransformations(
582                                               LangOptions::GCMode OrigGCMode,
583                                               bool NoFinalizeRemoval) {
584  std::vector<TransformFn> transforms;
585
586  if (OrigGCMode ==  LangOptions::GCOnly && NoFinalizeRemoval)
587    transforms.push_back(GCRewriteFinalize);
588  transforms.push_back(independentTransforms);
589  // This depends on previous transformations removing various expressions.
590  transforms.push_back(removeEmptyStatementsAndDeallocFinalize);
591
592  return transforms;
593}
594