Consumed.cpp revision 0e8534efc3c536795ede0128aed86a6b8ad53ab7
1//===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// A intra-procedural analysis for checking consumed properties.  This is based,
11// in part, on research on linear types.
12//
13//===----------------------------------------------------------------------===//
14
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/Attr.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/ExprCXX.h"
19#include "clang/AST/RecursiveASTVisitor.h"
20#include "clang/AST/StmtVisitor.h"
21#include "clang/AST/StmtCXX.h"
22#include "clang/AST/Type.h"
23#include "clang/Analysis/Analyses/PostOrderCFGView.h"
24#include "clang/Analysis/AnalysisContext.h"
25#include "clang/Analysis/CFG.h"
26#include "clang/Analysis/Analyses/Consumed.h"
27#include "clang/Basic/OperatorKinds.h"
28#include "clang/Basic/SourceLocation.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Compiler.h"
32#include "llvm/Support/raw_ostream.h"
33
34// TODO: Add notes about the actual and expected state for
35// TODO: Correctly identify unreachable blocks when chaining boolean operators.
36// TODO: Warn about unreachable code.
37// TODO: Switch to using a bitmap to track unreachable blocks.
38// TODO: Mark variables as Unknown going into while- or for-loops only if they
39//       are referenced inside that block. (Deferred)
40// TODO: Handle variable definitions, e.g. bool valid = x.isValid();
41//       if (valid) ...; (Deferred)
42// TODO: Add a method(s) to identify which method calls perform what state
43//       transitions. (Deferred)
44// TODO: Take notes on state transitions to provide better warning messages.
45//       (Deferred)
46// TODO: Test nested conditionals: A) Checking the same value multiple times,
47//       and 2) Checking different values. (Deferred)
48
49using namespace clang;
50using namespace consumed;
51
52// Key method definition
53ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54
55static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
56  switch (State) {
57  case CS_Unconsumed:
58    return CS_Consumed;
59  case CS_Consumed:
60    return CS_Unconsumed;
61  case CS_None:
62    return CS_None;
63  case CS_Unknown:
64    return CS_Unknown;
65  }
66  llvm_unreachable("invalid enum");
67}
68
69static bool isConsumableType(const QualType &QT) {
70  if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
71    return RD->hasAttr<ConsumableAttr>();
72  else
73    return false;
74}
75
76static bool isKnownState(ConsumedState State) {
77  switch (State) {
78  case CS_Unconsumed:
79  case CS_Consumed:
80    return true;
81  case CS_None:
82  case CS_Unknown:
83    return false;
84  }
85  llvm_unreachable("invalid enum");
86}
87
88static bool isTestingFunction(const FunctionDecl *FunDecl) {
89  return FunDecl->hasAttr<TestsUnconsumedAttr>();
90}
91
92static ConsumedState mapReturnTypestateAttrState(
93  const ReturnTypestateAttr *RTSAttr) {
94
95  switch (RTSAttr->getState()) {
96  case ReturnTypestateAttr::Unknown:
97    return CS_Unknown;
98  case ReturnTypestateAttr::Unconsumed:
99    return CS_Unconsumed;
100  case ReturnTypestateAttr::Consumed:
101    return CS_Consumed;
102  }
103}
104
105static StringRef stateToString(ConsumedState State) {
106  switch (State) {
107  case consumed::CS_None:
108    return "none";
109
110  case consumed::CS_Unknown:
111    return "unknown";
112
113  case consumed::CS_Unconsumed:
114    return "unconsumed";
115
116  case consumed::CS_Consumed:
117    return "consumed";
118  }
119  llvm_unreachable("invalid enum");
120}
121
122namespace {
123struct VarTestResult {
124  const VarDecl *Var;
125  ConsumedState TestsFor;
126};
127} // end anonymous::VarTestResult
128
129namespace clang {
130namespace consumed {
131
132enum EffectiveOp {
133  EO_And,
134  EO_Or
135};
136
137class PropagationInfo {
138  enum {
139    IT_None,
140    IT_State,
141    IT_Test,
142    IT_BinTest,
143    IT_Var
144  } InfoType;
145
146  struct BinTestTy {
147    const BinaryOperator *Source;
148    EffectiveOp EOp;
149    VarTestResult LTest;
150    VarTestResult RTest;
151  };
152
153  union {
154    ConsumedState State;
155    VarTestResult Test;
156    const VarDecl *Var;
157    BinTestTy BinTest;
158  };
159
160public:
161  PropagationInfo() : InfoType(IT_None) {}
162
163  PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
164  PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
165    : InfoType(IT_Test) {
166
167    Test.Var      = Var;
168    Test.TestsFor = TestsFor;
169  }
170
171  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
172                  const VarTestResult &LTest, const VarTestResult &RTest)
173    : InfoType(IT_BinTest) {
174
175    BinTest.Source  = Source;
176    BinTest.EOp     = EOp;
177    BinTest.LTest   = LTest;
178    BinTest.RTest   = RTest;
179  }
180
181  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
182                  const VarDecl *LVar, ConsumedState LTestsFor,
183                  const VarDecl *RVar, ConsumedState RTestsFor)
184    : InfoType(IT_BinTest) {
185
186    BinTest.Source         = Source;
187    BinTest.EOp            = EOp;
188    BinTest.LTest.Var      = LVar;
189    BinTest.LTest.TestsFor = LTestsFor;
190    BinTest.RTest.Var      = RVar;
191    BinTest.RTest.TestsFor = RTestsFor;
192  }
193
194  PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
195  PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
196
197  const ConsumedState & getState() const {
198    assert(InfoType == IT_State);
199    return State;
200  }
201
202  const VarTestResult & getTest() const {
203    assert(InfoType == IT_Test);
204    return Test;
205  }
206
207  const VarTestResult & getLTest() const {
208    assert(InfoType == IT_BinTest);
209    return BinTest.LTest;
210  }
211
212  const VarTestResult & getRTest() const {
213    assert(InfoType == IT_BinTest);
214    return BinTest.RTest;
215  }
216
217  const VarDecl * getVar() const {
218    assert(InfoType == IT_Var);
219    return Var;
220  }
221
222  EffectiveOp testEffectiveOp() const {
223    assert(InfoType == IT_BinTest);
224    return BinTest.EOp;
225  }
226
227  const BinaryOperator * testSourceNode() const {
228    assert(InfoType == IT_BinTest);
229    return BinTest.Source;
230  }
231
232  bool isValid()   const { return InfoType != IT_None;     }
233  bool isState()   const { return InfoType == IT_State;    }
234  bool isTest()    const { return InfoType == IT_Test;     }
235  bool isBinTest() const { return InfoType == IT_BinTest;  }
236  bool isVar()     const { return InfoType == IT_Var;      }
237
238  PropagationInfo invertTest() const {
239    assert(InfoType == IT_Test || InfoType == IT_BinTest);
240
241    if (InfoType == IT_Test) {
242      return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
243
244    } else if (InfoType == IT_BinTest) {
245      return PropagationInfo(BinTest.Source,
246        BinTest.EOp == EO_And ? EO_Or : EO_And,
247        BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
248        BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
249    } else {
250      return PropagationInfo();
251    }
252  }
253};
254
255class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
256
257  typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
258  typedef std::pair<const Stmt *, PropagationInfo> PairType;
259  typedef MapType::iterator InfoEntry;
260  typedef MapType::const_iterator ConstInfoEntry;
261
262  AnalysisDeclContext &AC;
263  ConsumedAnalyzer &Analyzer;
264  ConsumedStateMap *StateMap;
265  MapType PropagationMap;
266
267  void checkCallability(const PropagationInfo &PInfo,
268                        const FunctionDecl *FunDecl,
269                        const CallExpr *Call);
270  void forwardInfo(const Stmt *From, const Stmt *To);
271  void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
272  bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
273  void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
274                           QualType ReturnType);
275
276public:
277
278  void Visit(const Stmt *StmtNode);
279
280  void VisitBinaryOperator(const BinaryOperator *BinOp);
281  void VisitCallExpr(const CallExpr *Call);
282  void VisitCastExpr(const CastExpr *Cast);
283  void VisitCXXConstructExpr(const CXXConstructExpr *Call);
284  void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
285  void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
286  void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
287  void VisitDeclStmt(const DeclStmt *DelcS);
288  void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
289  void VisitMemberExpr(const MemberExpr *MExpr);
290  void VisitParmVarDecl(const ParmVarDecl *Param);
291  void VisitReturnStmt(const ReturnStmt *Ret);
292  void VisitUnaryOperator(const UnaryOperator *UOp);
293  void VisitVarDecl(const VarDecl *Var);
294
295  ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
296                      ConsumedStateMap *StateMap)
297      : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
298
299  PropagationInfo getInfo(const Stmt *StmtNode) const {
300    ConstInfoEntry Entry = PropagationMap.find(StmtNode);
301
302    if (Entry != PropagationMap.end())
303      return Entry->second;
304    else
305      return PropagationInfo();
306  }
307
308  void reset(ConsumedStateMap *NewStateMap) {
309    StateMap = NewStateMap;
310  }
311};
312
313// TODO: When we support CallableWhenConsumed this will have to check for
314//       the different attributes and change the behavior bellow. (Deferred)
315void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
316                                           const FunctionDecl *FunDecl,
317                                           const CallExpr *Call) {
318
319  if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
320
321  if (PInfo.isVar()) {
322    const VarDecl *Var = PInfo.getVar();
323
324    switch (StateMap->getState(Var)) {
325    case CS_Consumed:
326      Analyzer.WarningsHandler.warnUseWhileConsumed(
327        FunDecl->getNameAsString(), Var->getNameAsString(),
328        Call->getExprLoc());
329      break;
330
331    case CS_Unknown:
332      Analyzer.WarningsHandler.warnUseInUnknownState(
333        FunDecl->getNameAsString(), Var->getNameAsString(),
334        Call->getExprLoc());
335      break;
336
337    case CS_None:
338    case CS_Unconsumed:
339      break;
340    }
341
342  } else {
343    switch (PInfo.getState()) {
344    case CS_Consumed:
345      Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
346        FunDecl->getNameAsString(), Call->getExprLoc());
347      break;
348
349    case CS_Unknown:
350      Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
351        FunDecl->getNameAsString(), Call->getExprLoc());
352      break;
353
354    case CS_None:
355    case CS_Unconsumed:
356      break;
357    }
358  }
359}
360
361void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
362  InfoEntry Entry = PropagationMap.find(From);
363
364  if (Entry != PropagationMap.end())
365    PropagationMap.insert(PairType(To, Entry->second));
366}
367
368void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
369                                                    const VarDecl  *Var) {
370
371  ConsumedState VarState = StateMap->getState(Var);
372
373  if (VarState != CS_Unknown) {
374    SourceLocation CallLoc = Call->getExprLoc();
375
376    if (!CallLoc.isMacroID())
377      Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
378        stateToString(VarState), CallLoc);
379  }
380
381  PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
382}
383
384bool ConsumedStmtVisitor::isLikeMoveAssignment(
385  const CXXMethodDecl *MethodDecl) {
386
387  return MethodDecl->isMoveAssignmentOperator() ||
388         (MethodDecl->getOverloadedOperator() == OO_Equal &&
389          MethodDecl->getNumParams() == 1 &&
390          MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
391}
392
393void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
394                                              const FunctionDecl *Fun,
395                                              QualType ReturnType) {
396  if (isConsumableType(ReturnType)) {
397
398    ConsumedState ReturnState;
399
400    if (Fun->hasAttr<ReturnTypestateAttr>())
401      ReturnState = mapReturnTypestateAttrState(
402        Fun->getAttr<ReturnTypestateAttr>());
403    else
404      ReturnState = CS_Unknown;
405
406    PropagationMap.insert(PairType(Call,
407      PropagationInfo(ReturnState)));
408  }
409}
410
411void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
412
413  ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
414
415  for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
416       CE = StmtNode->child_end(); CI != CE; ++CI) {
417
418    PropagationMap.erase(*CI);
419  }
420}
421
422void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
423  switch (BinOp->getOpcode()) {
424  case BO_LAnd:
425  case BO_LOr : {
426    InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
427              REntry = PropagationMap.find(BinOp->getRHS());
428
429    VarTestResult LTest, RTest;
430
431    if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
432      LTest = LEntry->second.getTest();
433
434    } else {
435      LTest.Var      = NULL;
436      LTest.TestsFor = CS_None;
437    }
438
439    if (REntry != PropagationMap.end() && REntry->second.isTest()) {
440      RTest = REntry->second.getTest();
441
442    } else {
443      RTest.Var      = NULL;
444      RTest.TestsFor = CS_None;
445    }
446
447    if (!(LTest.Var == NULL && RTest.Var == NULL))
448      PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
449        static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
450
451    break;
452  }
453
454  case BO_PtrMemD:
455  case BO_PtrMemI:
456    forwardInfo(BinOp->getLHS(), BinOp);
457    break;
458
459  default:
460    break;
461  }
462}
463
464void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
465  if (const FunctionDecl *FunDecl =
466    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
467
468    // Special case for the std::move function.
469    // TODO: Make this more specific. (Deferred)
470    if (FunDecl->getNameAsString() == "move") {
471      InfoEntry Entry = PropagationMap.find(Call->getArg(0));
472
473      if (Entry != PropagationMap.end()) {
474        PropagationMap.insert(PairType(Call, Entry->second));
475      }
476
477      return;
478    }
479
480    unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
481
482    for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
483      QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
484
485      InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
486
487      if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
488        continue;
489      }
490
491      PropagationInfo PInfo = Entry->second;
492
493      if (ParamType->isRValueReferenceType() ||
494          (ParamType->isLValueReferenceType() &&
495           !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
496
497        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
498
499      } else if (!(ParamType.isConstQualified() ||
500                   ((ParamType->isReferenceType() ||
501                     ParamType->isPointerType()) &&
502                    ParamType->getPointeeType().isConstQualified()))) {
503
504        StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
505      }
506    }
507
508    propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
509  }
510}
511
512void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
513  forwardInfo(Cast->getSubExpr(), Cast);
514}
515
516void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
517  CXXConstructorDecl *Constructor = Call->getConstructor();
518
519  ASTContext &CurrContext = AC.getASTContext();
520  QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
521
522  if (isConsumableType(ThisType)) {
523    if (Constructor->isDefaultConstructor()) {
524
525      PropagationMap.insert(PairType(Call,
526        PropagationInfo(consumed::CS_Consumed)));
527
528    } else if (Constructor->isMoveConstructor()) {
529
530      PropagationInfo PInfo =
531        PropagationMap.find(Call->getArg(0))->second;
532
533      if (PInfo.isVar()) {
534        const VarDecl* Var = PInfo.getVar();
535
536        PropagationMap.insert(PairType(Call,
537          PropagationInfo(StateMap->getState(Var))));
538
539        StateMap->setState(Var, consumed::CS_Consumed);
540
541      } else {
542        PropagationMap.insert(PairType(Call, PInfo));
543      }
544
545    } else if (Constructor->isCopyConstructor()) {
546      MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
547
548      if (Entry != PropagationMap.end())
549        PropagationMap.insert(PairType(Call, Entry->second));
550
551    } else {
552      propagateReturnType(Call, Constructor, ThisType);
553    }
554  }
555}
556
557void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
558  const CXXMemberCallExpr *Call) {
559
560  VisitCallExpr(Call);
561
562  InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
563
564  if (Entry != PropagationMap.end()) {
565    PropagationInfo PInfo = Entry->second;
566    const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
567
568    checkCallability(PInfo, MethodDecl, Call);
569
570    if (PInfo.isVar()) {
571      if (isTestingFunction(MethodDecl))
572        handleTestingFunctionCall(Call, PInfo.getVar());
573      else if (MethodDecl->hasAttr<ConsumesAttr>())
574        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
575    }
576  }
577}
578
579void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
580  const CXXOperatorCallExpr *Call) {
581
582  const FunctionDecl *FunDecl =
583    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
584
585  if (!FunDecl) return;
586
587  if (isa<CXXMethodDecl>(FunDecl) &&
588      isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
589
590    InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
591    InfoEntry REntry = PropagationMap.find(Call->getArg(1));
592
593    PropagationInfo LPInfo, RPInfo;
594
595    if (LEntry != PropagationMap.end() &&
596        REntry != PropagationMap.end()) {
597
598      LPInfo = LEntry->second;
599      RPInfo = REntry->second;
600
601      if (LPInfo.isVar() && RPInfo.isVar()) {
602        StateMap->setState(LPInfo.getVar(),
603          StateMap->getState(RPInfo.getVar()));
604
605        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
606
607        PropagationMap.insert(PairType(Call, LPInfo));
608
609      } else if (LPInfo.isVar() && !RPInfo.isVar()) {
610        StateMap->setState(LPInfo.getVar(), RPInfo.getState());
611
612        PropagationMap.insert(PairType(Call, LPInfo));
613
614      } else if (!LPInfo.isVar() && RPInfo.isVar()) {
615        PropagationMap.insert(PairType(Call,
616          PropagationInfo(StateMap->getState(RPInfo.getVar()))));
617
618        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
619
620      } else {
621        PropagationMap.insert(PairType(Call, RPInfo));
622      }
623
624    } else if (LEntry != PropagationMap.end() &&
625               REntry == PropagationMap.end()) {
626
627      LPInfo = LEntry->second;
628
629      if (LPInfo.isVar()) {
630        StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
631
632        PropagationMap.insert(PairType(Call, LPInfo));
633
634      } else {
635        PropagationMap.insert(PairType(Call,
636          PropagationInfo(consumed::CS_Unknown)));
637      }
638
639    } else if (LEntry == PropagationMap.end() &&
640               REntry != PropagationMap.end()) {
641
642      RPInfo = REntry->second;
643
644      if (RPInfo.isVar()) {
645        const VarDecl *Var = RPInfo.getVar();
646
647        PropagationMap.insert(PairType(Call,
648          PropagationInfo(StateMap->getState(Var))));
649
650        StateMap->setState(Var, consumed::CS_Consumed);
651
652      } else {
653        PropagationMap.insert(PairType(Call, RPInfo));
654      }
655    }
656
657  } else {
658
659    VisitCallExpr(Call);
660
661    InfoEntry Entry = PropagationMap.find(Call->getArg(0));
662
663    if (Entry != PropagationMap.end()) {
664      PropagationInfo PInfo = Entry->second;
665
666      checkCallability(PInfo, FunDecl, Call);
667
668      if (PInfo.isVar()) {
669        if (isTestingFunction(FunDecl))
670          handleTestingFunctionCall(Call, PInfo.getVar());
671        else if (FunDecl->hasAttr<ConsumesAttr>())
672          StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
673      }
674    }
675  }
676}
677
678void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
679  if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
680    if (StateMap->getState(Var) != consumed::CS_None)
681      PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
682}
683
684void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
685  for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
686       DE = DeclS->decl_end(); DI != DE; ++DI) {
687
688    if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
689  }
690
691  if (DeclS->isSingleDecl())
692    if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
693      PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
694}
695
696void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
697  const MaterializeTemporaryExpr *Temp) {
698
699  InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
700
701  if (Entry != PropagationMap.end())
702    PropagationMap.insert(PairType(Temp, Entry->second));
703}
704
705void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
706  forwardInfo(MExpr->getBase(), MExpr);
707}
708
709
710void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
711  if (isConsumableType(Param->getType()))
712    StateMap->setState(Param, consumed::CS_Unknown);
713}
714
715void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
716  if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
717    InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
718
719    if (Entry != PropagationMap.end()) {
720      assert(Entry->second.isState() || Entry->second.isVar());
721
722      ConsumedState RetState = Entry->second.isState() ?
723        Entry->second.getState() : StateMap->getState(Entry->second.getVar());
724
725      if (RetState != ExpectedState)
726        Analyzer.WarningsHandler.warnReturnTypestateMismatch(
727          Ret->getReturnLoc(), stateToString(ExpectedState),
728          stateToString(RetState));
729    }
730  }
731}
732
733void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
734  InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
735  if (Entry == PropagationMap.end()) return;
736
737  switch (UOp->getOpcode()) {
738  case UO_AddrOf:
739    PropagationMap.insert(PairType(UOp, Entry->second));
740    break;
741
742  case UO_LNot:
743    if (Entry->second.isTest() || Entry->second.isBinTest())
744      PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
745    break;
746
747  default:
748    break;
749  }
750}
751
752void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
753  if (isConsumableType(Var->getType())) {
754    if (Var->hasInit()) {
755      PropagationInfo PInfo =
756        PropagationMap.find(Var->getInit())->second;
757
758      StateMap->setState(Var, PInfo.isVar() ?
759        StateMap->getState(PInfo.getVar()) : PInfo.getState());
760
761    } else {
762      StateMap->setState(Var, consumed::CS_Unknown);
763    }
764  }
765}
766}} // end clang::consumed::ConsumedStmtVisitor
767
768namespace clang {
769namespace consumed {
770
771void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
772                        ConsumedStateMap *ThenStates,
773                        ConsumedStateMap *ElseStates) {
774
775  ConsumedState VarState = ThenStates->getState(Test.Var);
776
777  if (VarState == CS_Unknown) {
778    ThenStates->setState(Test.Var, Test.TestsFor);
779    if (ElseStates)
780      ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
781
782  } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
783    ThenStates->markUnreachable();
784
785  } else if (VarState == Test.TestsFor && ElseStates) {
786    ElseStates->markUnreachable();
787  }
788}
789
790void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
791  ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
792
793  const VarTestResult &LTest = PInfo.getLTest(),
794                      &RTest = PInfo.getRTest();
795
796  ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
797                RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
798
799  if (LTest.Var) {
800    if (PInfo.testEffectiveOp() == EO_And) {
801      if (LState == CS_Unknown) {
802        ThenStates->setState(LTest.Var, LTest.TestsFor);
803
804      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
805        ThenStates->markUnreachable();
806
807      } else if (LState == LTest.TestsFor && isKnownState(RState)) {
808        if (RState == RTest.TestsFor) {
809          if (ElseStates)
810            ElseStates->markUnreachable();
811        } else {
812          ThenStates->markUnreachable();
813        }
814      }
815
816    } else {
817      if (LState == CS_Unknown && ElseStates) {
818        ElseStates->setState(LTest.Var,
819                             invertConsumedUnconsumed(LTest.TestsFor));
820
821      } else if (LState == LTest.TestsFor && ElseStates) {
822        ElseStates->markUnreachable();
823
824      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
825                 isKnownState(RState)) {
826
827        if (RState == RTest.TestsFor) {
828          if (ElseStates)
829            ElseStates->markUnreachable();
830        } else {
831          ThenStates->markUnreachable();
832        }
833      }
834    }
835  }
836
837  if (RTest.Var) {
838    if (PInfo.testEffectiveOp() == EO_And) {
839      if (RState == CS_Unknown)
840        ThenStates->setState(RTest.Var, RTest.TestsFor);
841      else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
842        ThenStates->markUnreachable();
843
844    } else if (ElseStates) {
845      if (RState == CS_Unknown)
846        ElseStates->setState(RTest.Var,
847                             invertConsumedUnconsumed(RTest.TestsFor));
848      else if (RState == RTest.TestsFor)
849        ElseStates->markUnreachable();
850    }
851  }
852}
853
854void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
855                                ConsumedStateMap *StateMap,
856                                bool &AlreadyOwned) {
857
858  if (VisitedBlocks.alreadySet(Block)) return;
859
860  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
861
862  if (Entry) {
863    Entry->intersect(StateMap);
864
865  } else if (AlreadyOwned) {
866    StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
867
868  } else {
869    StateMapsArray[Block->getBlockID()] = StateMap;
870    AlreadyOwned = true;
871  }
872}
873
874void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
875                                ConsumedStateMap *StateMap) {
876
877  if (VisitedBlocks.alreadySet(Block)) {
878    delete StateMap;
879    return;
880  }
881
882  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
883
884  if (Entry) {
885    Entry->intersect(StateMap);
886    delete StateMap;
887
888  } else {
889    StateMapsArray[Block->getBlockID()] = StateMap;
890  }
891}
892
893ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
894  return StateMapsArray[Block->getBlockID()];
895}
896
897void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
898  VisitedBlocks.insert(Block);
899}
900
901ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
902  MapType::const_iterator Entry = Map.find(Var);
903
904  if (Entry != Map.end()) {
905    return Entry->second;
906
907  } else {
908    return CS_None;
909  }
910}
911
912void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
913  ConsumedState LocalState;
914
915  if (this->From && this->From == Other->From && !Other->Reachable) {
916    this->markUnreachable();
917    return;
918  }
919
920  for (MapType::const_iterator DMI = Other->Map.begin(),
921       DME = Other->Map.end(); DMI != DME; ++DMI) {
922
923    LocalState = this->getState(DMI->first);
924
925    if (LocalState == CS_None)
926      continue;
927
928    if (LocalState != DMI->second)
929       Map[DMI->first] = CS_Unknown;
930  }
931}
932
933void ConsumedStateMap::markUnreachable() {
934  this->Reachable = false;
935  Map.clear();
936}
937
938void ConsumedStateMap::makeUnknown() {
939  for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
940       ++DMI) {
941
942    Map[DMI->first] = CS_Unknown;
943  }
944}
945
946void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
947  Map[Var] = State;
948}
949
950void ConsumedStateMap::remove(const VarDecl *Var) {
951  Map.erase(Var);
952}
953
954bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
955                                  const ConsumedStmtVisitor &Visitor) {
956
957  ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
958  PropagationInfo PInfo;
959
960  if (const IfStmt *IfNode =
961    dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
962
963    bool HasElse = IfNode->getElse() != NULL;
964    const Stmt *Cond = IfNode->getCond();
965
966    PInfo = Visitor.getInfo(Cond);
967    if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
968      PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
969
970    if (PInfo.isTest()) {
971      CurrStates->setSource(Cond);
972      FalseStates->setSource(Cond);
973
974      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
975                         HasElse ? FalseStates : NULL);
976
977    } else if (PInfo.isBinTest()) {
978      CurrStates->setSource(PInfo.testSourceNode());
979      FalseStates->setSource(PInfo.testSourceNode());
980
981      splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
982
983    } else {
984      delete FalseStates;
985      return false;
986    }
987
988  } else if (const BinaryOperator *BinOp =
989    dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
990
991    PInfo = Visitor.getInfo(BinOp->getLHS());
992    if (!PInfo.isTest()) {
993      if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
994        PInfo = Visitor.getInfo(BinOp->getRHS());
995
996        if (!PInfo.isTest()) {
997          delete FalseStates;
998          return false;
999        }
1000
1001      } else {
1002        delete FalseStates;
1003        return false;
1004      }
1005    }
1006
1007    CurrStates->setSource(BinOp);
1008    FalseStates->setSource(BinOp);
1009
1010    const VarTestResult &Test = PInfo.getTest();
1011    ConsumedState VarState = CurrStates->getState(Test.Var);
1012
1013    if (BinOp->getOpcode() == BO_LAnd) {
1014      if (VarState == CS_Unknown)
1015        CurrStates->setState(Test.Var, Test.TestsFor);
1016      else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1017        CurrStates->markUnreachable();
1018
1019    } else if (BinOp->getOpcode() == BO_LOr) {
1020      if (VarState == CS_Unknown)
1021        FalseStates->setState(Test.Var,
1022                              invertConsumedUnconsumed(Test.TestsFor));
1023      else if (VarState == Test.TestsFor)
1024        FalseStates->markUnreachable();
1025    }
1026
1027  } else {
1028    delete FalseStates;
1029    return false;
1030  }
1031
1032  CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1033
1034  if (*SI)
1035    BlockInfo.addInfo(*SI, CurrStates);
1036  else
1037    delete CurrStates;
1038
1039  if (*++SI)
1040    BlockInfo.addInfo(*SI, FalseStates);
1041  else
1042    delete FalseStates;
1043
1044  CurrStates = NULL;
1045  return true;
1046}
1047
1048void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1049  const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1050
1051  if (!D) return;
1052
1053  // FIXME: This should be removed when template instantiation propagates
1054  //        attributes at template specialization definition, not declaration.
1055  //        When it is removed the test needs to be enabled in SemaDeclAttr.cpp.
1056  QualType ReturnType;
1057  if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1058    ASTContext &CurrContext = AC.getASTContext();
1059    ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1060
1061  } else {
1062    ReturnType = D->getCallResultType();
1063  }
1064
1065  // Determine the expected return value.
1066  if (D->hasAttr<ReturnTypestateAttr>()) {
1067
1068    ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
1069
1070    const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1071    if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1072        // FIXME: This branch can be removed with the code above.
1073        WarningsHandler.warnReturnTypestateForUnconsumableType(
1074          RTSAttr->getLocation(), ReturnType.getAsString());
1075        ExpectedReturnState = CS_None;
1076
1077    } else {
1078      switch (RTSAttr->getState()) {
1079      case ReturnTypestateAttr::Unknown:
1080        ExpectedReturnState = CS_Unknown;
1081        break;
1082
1083      case ReturnTypestateAttr::Unconsumed:
1084        ExpectedReturnState = CS_Unconsumed;
1085        break;
1086
1087      case ReturnTypestateAttr::Consumed:
1088        ExpectedReturnState = CS_Consumed;
1089        break;
1090      }
1091    }
1092
1093  } else if (isConsumableType(ReturnType)) {
1094    ExpectedReturnState = CS_Unknown;
1095
1096  } else {
1097    ExpectedReturnState = CS_None;
1098  }
1099
1100  BlockInfo = ConsumedBlockInfo(AC.getCFG());
1101
1102  PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1103
1104  CurrStates = new ConsumedStateMap();
1105  ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1106
1107  // Add all trackable parameters to the state map.
1108  for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1109       PE = D->param_end(); PI != PE; ++PI) {
1110    Visitor.VisitParmVarDecl(*PI);
1111  }
1112
1113  // Visit all of the function's basic blocks.
1114  for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1115       E = SortedGraph->end(); I != E; ++I) {
1116
1117    const CFGBlock *CurrBlock = *I;
1118    BlockInfo.markVisited(CurrBlock);
1119
1120    if (CurrStates == NULL)
1121      CurrStates = BlockInfo.getInfo(CurrBlock);
1122
1123    if (!CurrStates) {
1124      continue;
1125
1126    } else if (!CurrStates->isReachable()) {
1127      delete CurrStates;
1128      CurrStates = NULL;
1129      continue;
1130    }
1131
1132    Visitor.reset(CurrStates);
1133
1134    // Visit all of the basic block's statements.
1135    for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1136         BE = CurrBlock->end(); BI != BE; ++BI) {
1137
1138      switch (BI->getKind()) {
1139      case CFGElement::Statement:
1140        Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1141        break;
1142      case CFGElement::AutomaticObjectDtor:
1143        CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1144      default:
1145        break;
1146      }
1147    }
1148
1149    // TODO: Handle other forms of branching with precision, including while-
1150    //       and for-loops. (Deferred)
1151    if (!splitState(CurrBlock, Visitor)) {
1152      CurrStates->setSource(NULL);
1153
1154      if (CurrBlock->succ_size() > 1) {
1155        CurrStates->makeUnknown();
1156
1157        bool OwnershipTaken = false;
1158
1159        for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1160             SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1161
1162          if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1163        }
1164
1165        if (!OwnershipTaken)
1166          delete CurrStates;
1167
1168        CurrStates = NULL;
1169
1170      } else if (CurrBlock->succ_size() == 1 &&
1171                 (*CurrBlock->succ_begin())->pred_size() > 1) {
1172
1173        BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1174        CurrStates = NULL;
1175      }
1176    }
1177  } // End of block iterator.
1178
1179  // Delete the last existing state map.
1180  delete CurrStates;
1181
1182  WarningsHandler.emitDiagnostics();
1183}
1184}} // end namespace clang::consumed
1185