Consumed.cpp revision e988dc45254405aff0950337d82aa8623fb1b88e
1//===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// A intra-procedural analysis for checking consumed properties.  This is based,
11// in part, on research on linear types.
12//
13//===----------------------------------------------------------------------===//
14
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/Attr.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/ExprCXX.h"
19#include "clang/AST/RecursiveASTVisitor.h"
20#include "clang/AST/StmtVisitor.h"
21#include "clang/AST/StmtCXX.h"
22#include "clang/AST/Type.h"
23#include "clang/Analysis/Analyses/PostOrderCFGView.h"
24#include "clang/Analysis/AnalysisContext.h"
25#include "clang/Analysis/CFG.h"
26#include "clang/Analysis/Analyses/Consumed.h"
27#include "clang/Basic/OperatorKinds.h"
28#include "clang/Basic/SourceLocation.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Compiler.h"
32#include "llvm/Support/raw_ostream.h"
33
34// TODO: Add notes about the actual and expected state for
35// TODO: Correctly identify unreachable blocks when chaining boolean operators.
36// TODO: Warn about unreachable code.
37// TODO: Switch to using a bitmap to track unreachable blocks.
38// TODO: Mark variables as Unknown going into while- or for-loops only if they
39//       are referenced inside that block. (Deferred)
40// TODO: Handle variable definitions, e.g. bool valid = x.isValid();
41//       if (valid) ...; (Deferred)
42// TODO: Add a method(s) to identify which method calls perform what state
43//       transitions. (Deferred)
44// TODO: Take notes on state transitions to provide better warning messages.
45//       (Deferred)
46// TODO: Test nested conditionals: A) Checking the same value multiple times,
47//       and 2) Checking different values. (Deferred)
48
49using namespace clang;
50using namespace consumed;
51
52// Key method definition
53ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54
55static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
56  switch (State) {
57  case CS_Unconsumed:
58    return CS_Consumed;
59  case CS_Consumed:
60    return CS_Unconsumed;
61  case CS_None:
62    return CS_None;
63  case CS_Unknown:
64    return CS_Unknown;
65  }
66  llvm_unreachable("invalid enum");
67}
68
69static bool isConsumableType(const QualType &QT) {
70  if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
71    return RD->hasAttr<ConsumableAttr>();
72  else
73    return false;
74}
75
76static bool isKnownState(ConsumedState State) {
77  switch (State) {
78  case CS_Unconsumed:
79  case CS_Consumed:
80    return true;
81  case CS_None:
82  case CS_Unknown:
83    return false;
84  }
85  llvm_unreachable("invalid enum");
86}
87
88static bool isTestingFunction(const FunctionDecl *FunDecl) {
89  return FunDecl->hasAttr<TestsUnconsumedAttr>();
90}
91
92static ConsumedState
93mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
94
95  switch (RTSAttr->getState()) {
96  case ReturnTypestateAttr::Unknown:
97    return CS_Unknown;
98  case ReturnTypestateAttr::Unconsumed:
99    return CS_Unconsumed;
100  case ReturnTypestateAttr::Consumed:
101    return CS_Consumed;
102  }
103  llvm_unreachable("invalid enum");
104}
105
106static StringRef stateToString(ConsumedState State) {
107  switch (State) {
108  case consumed::CS_None:
109    return "none";
110
111  case consumed::CS_Unknown:
112    return "unknown";
113
114  case consumed::CS_Unconsumed:
115    return "unconsumed";
116
117  case consumed::CS_Consumed:
118    return "consumed";
119  }
120  llvm_unreachable("invalid enum");
121}
122
123namespace {
124struct VarTestResult {
125  const VarDecl *Var;
126  ConsumedState TestsFor;
127};
128} // end anonymous::VarTestResult
129
130namespace clang {
131namespace consumed {
132
133enum EffectiveOp {
134  EO_And,
135  EO_Or
136};
137
138class PropagationInfo {
139  enum {
140    IT_None,
141    IT_State,
142    IT_Test,
143    IT_BinTest,
144    IT_Var
145  } InfoType;
146
147  struct BinTestTy {
148    const BinaryOperator *Source;
149    EffectiveOp EOp;
150    VarTestResult LTest;
151    VarTestResult RTest;
152  };
153
154  union {
155    ConsumedState State;
156    VarTestResult Test;
157    const VarDecl *Var;
158    BinTestTy BinTest;
159  };
160
161public:
162  PropagationInfo() : InfoType(IT_None) {}
163
164  PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
165  PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
166    : InfoType(IT_Test) {
167
168    Test.Var      = Var;
169    Test.TestsFor = TestsFor;
170  }
171
172  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
173                  const VarTestResult &LTest, const VarTestResult &RTest)
174    : InfoType(IT_BinTest) {
175
176    BinTest.Source  = Source;
177    BinTest.EOp     = EOp;
178    BinTest.LTest   = LTest;
179    BinTest.RTest   = RTest;
180  }
181
182  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
183                  const VarDecl *LVar, ConsumedState LTestsFor,
184                  const VarDecl *RVar, ConsumedState RTestsFor)
185    : InfoType(IT_BinTest) {
186
187    BinTest.Source         = Source;
188    BinTest.EOp            = EOp;
189    BinTest.LTest.Var      = LVar;
190    BinTest.LTest.TestsFor = LTestsFor;
191    BinTest.RTest.Var      = RVar;
192    BinTest.RTest.TestsFor = RTestsFor;
193  }
194
195  PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
196  PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
197
198  const ConsumedState & getState() const {
199    assert(InfoType == IT_State);
200    return State;
201  }
202
203  const VarTestResult & getTest() const {
204    assert(InfoType == IT_Test);
205    return Test;
206  }
207
208  const VarTestResult & getLTest() const {
209    assert(InfoType == IT_BinTest);
210    return BinTest.LTest;
211  }
212
213  const VarTestResult & getRTest() const {
214    assert(InfoType == IT_BinTest);
215    return BinTest.RTest;
216  }
217
218  const VarDecl * getVar() const {
219    assert(InfoType == IT_Var);
220    return Var;
221  }
222
223  EffectiveOp testEffectiveOp() const {
224    assert(InfoType == IT_BinTest);
225    return BinTest.EOp;
226  }
227
228  const BinaryOperator * testSourceNode() const {
229    assert(InfoType == IT_BinTest);
230    return BinTest.Source;
231  }
232
233  bool isValid()   const { return InfoType != IT_None;     }
234  bool isState()   const { return InfoType == IT_State;    }
235  bool isTest()    const { return InfoType == IT_Test;     }
236  bool isBinTest() const { return InfoType == IT_BinTest;  }
237  bool isVar()     const { return InfoType == IT_Var;      }
238
239  PropagationInfo invertTest() const {
240    assert(InfoType == IT_Test || InfoType == IT_BinTest);
241
242    if (InfoType == IT_Test) {
243      return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
244
245    } else if (InfoType == IT_BinTest) {
246      return PropagationInfo(BinTest.Source,
247        BinTest.EOp == EO_And ? EO_Or : EO_And,
248        BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
249        BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
250    } else {
251      return PropagationInfo();
252    }
253  }
254};
255
256class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
257
258  typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
259  typedef std::pair<const Stmt *, PropagationInfo> PairType;
260  typedef MapType::iterator InfoEntry;
261  typedef MapType::const_iterator ConstInfoEntry;
262
263  AnalysisDeclContext &AC;
264  ConsumedAnalyzer &Analyzer;
265  ConsumedStateMap *StateMap;
266  MapType PropagationMap;
267
268  void checkCallability(const PropagationInfo &PInfo,
269                        const FunctionDecl *FunDecl,
270                        const CallExpr *Call);
271  void forwardInfo(const Stmt *From, const Stmt *To);
272  void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
273  bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
274  void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
275                           QualType ReturnType);
276
277public:
278
279  void Visit(const Stmt *StmtNode);
280
281  void VisitBinaryOperator(const BinaryOperator *BinOp);
282  void VisitCallExpr(const CallExpr *Call);
283  void VisitCastExpr(const CastExpr *Cast);
284  void VisitCXXConstructExpr(const CXXConstructExpr *Call);
285  void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
286  void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
287  void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
288  void VisitDeclStmt(const DeclStmt *DelcS);
289  void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
290  void VisitMemberExpr(const MemberExpr *MExpr);
291  void VisitParmVarDecl(const ParmVarDecl *Param);
292  void VisitReturnStmt(const ReturnStmt *Ret);
293  void VisitUnaryOperator(const UnaryOperator *UOp);
294  void VisitVarDecl(const VarDecl *Var);
295
296  ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
297                      ConsumedStateMap *StateMap)
298      : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
299
300  PropagationInfo getInfo(const Stmt *StmtNode) const {
301    ConstInfoEntry Entry = PropagationMap.find(StmtNode);
302
303    if (Entry != PropagationMap.end())
304      return Entry->second;
305    else
306      return PropagationInfo();
307  }
308
309  void reset(ConsumedStateMap *NewStateMap) {
310    StateMap = NewStateMap;
311  }
312};
313
314// TODO: When we support CallableWhenConsumed this will have to check for
315//       the different attributes and change the behavior bellow. (Deferred)
316void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
317                                           const FunctionDecl *FunDecl,
318                                           const CallExpr *Call) {
319
320  if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
321
322  if (PInfo.isVar()) {
323    const VarDecl *Var = PInfo.getVar();
324
325    switch (StateMap->getState(Var)) {
326    case CS_Consumed:
327      Analyzer.WarningsHandler.warnUseWhileConsumed(
328        FunDecl->getNameAsString(), Var->getNameAsString(),
329        Call->getExprLoc());
330      break;
331
332    case CS_Unknown:
333      Analyzer.WarningsHandler.warnUseInUnknownState(
334        FunDecl->getNameAsString(), Var->getNameAsString(),
335        Call->getExprLoc());
336      break;
337
338    case CS_None:
339    case CS_Unconsumed:
340      break;
341    }
342
343  } else {
344    switch (PInfo.getState()) {
345    case CS_Consumed:
346      Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
347        FunDecl->getNameAsString(), Call->getExprLoc());
348      break;
349
350    case CS_Unknown:
351      Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
352        FunDecl->getNameAsString(), Call->getExprLoc());
353      break;
354
355    case CS_None:
356    case CS_Unconsumed:
357      break;
358    }
359  }
360}
361
362void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
363  InfoEntry Entry = PropagationMap.find(From);
364
365  if (Entry != PropagationMap.end())
366    PropagationMap.insert(PairType(To, Entry->second));
367}
368
369void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
370                                                    const VarDecl  *Var) {
371
372  ConsumedState VarState = StateMap->getState(Var);
373
374  if (VarState != CS_Unknown) {
375    SourceLocation CallLoc = Call->getExprLoc();
376
377    if (!CallLoc.isMacroID())
378      Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
379        stateToString(VarState), CallLoc);
380  }
381
382  PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
383}
384
385bool ConsumedStmtVisitor::isLikeMoveAssignment(
386  const CXXMethodDecl *MethodDecl) {
387
388  return MethodDecl->isMoveAssignmentOperator() ||
389         (MethodDecl->getOverloadedOperator() == OO_Equal &&
390          MethodDecl->getNumParams() == 1 &&
391          MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
392}
393
394void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
395                                              const FunctionDecl *Fun,
396                                              QualType ReturnType) {
397  if (isConsumableType(ReturnType)) {
398
399    ConsumedState ReturnState;
400
401    if (Fun->hasAttr<ReturnTypestateAttr>())
402      ReturnState = mapReturnTypestateAttrState(
403        Fun->getAttr<ReturnTypestateAttr>());
404    else
405      ReturnState = CS_Unknown;
406
407    PropagationMap.insert(PairType(Call,
408      PropagationInfo(ReturnState)));
409  }
410}
411
412void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
413
414  ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
415
416  for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
417       CE = StmtNode->child_end(); CI != CE; ++CI) {
418
419    PropagationMap.erase(*CI);
420  }
421}
422
423void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
424  switch (BinOp->getOpcode()) {
425  case BO_LAnd:
426  case BO_LOr : {
427    InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
428              REntry = PropagationMap.find(BinOp->getRHS());
429
430    VarTestResult LTest, RTest;
431
432    if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
433      LTest = LEntry->second.getTest();
434
435    } else {
436      LTest.Var      = NULL;
437      LTest.TestsFor = CS_None;
438    }
439
440    if (REntry != PropagationMap.end() && REntry->second.isTest()) {
441      RTest = REntry->second.getTest();
442
443    } else {
444      RTest.Var      = NULL;
445      RTest.TestsFor = CS_None;
446    }
447
448    if (!(LTest.Var == NULL && RTest.Var == NULL))
449      PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
450        static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
451
452    break;
453  }
454
455  case BO_PtrMemD:
456  case BO_PtrMemI:
457    forwardInfo(BinOp->getLHS(), BinOp);
458    break;
459
460  default:
461    break;
462  }
463}
464
465void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
466  if (const FunctionDecl *FunDecl =
467    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
468
469    // Special case for the std::move function.
470    // TODO: Make this more specific. (Deferred)
471    if (FunDecl->getNameAsString() == "move") {
472      InfoEntry Entry = PropagationMap.find(Call->getArg(0));
473
474      if (Entry != PropagationMap.end()) {
475        PropagationMap.insert(PairType(Call, Entry->second));
476      }
477
478      return;
479    }
480
481    unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
482
483    for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
484      QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
485
486      InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
487
488      if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
489        continue;
490      }
491
492      PropagationInfo PInfo = Entry->second;
493
494      if (ParamType->isRValueReferenceType() ||
495          (ParamType->isLValueReferenceType() &&
496           !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
497
498        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
499
500      } else if (!(ParamType.isConstQualified() ||
501                   ((ParamType->isReferenceType() ||
502                     ParamType->isPointerType()) &&
503                    ParamType->getPointeeType().isConstQualified()))) {
504
505        StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
506      }
507    }
508
509    propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
510  }
511}
512
513void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
514  forwardInfo(Cast->getSubExpr(), Cast);
515}
516
517void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
518  CXXConstructorDecl *Constructor = Call->getConstructor();
519
520  ASTContext &CurrContext = AC.getASTContext();
521  QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
522
523  if (isConsumableType(ThisType)) {
524    if (Constructor->isDefaultConstructor()) {
525
526      PropagationMap.insert(PairType(Call,
527        PropagationInfo(consumed::CS_Consumed)));
528
529    } else if (Constructor->isMoveConstructor()) {
530
531      PropagationInfo PInfo =
532        PropagationMap.find(Call->getArg(0))->second;
533
534      if (PInfo.isVar()) {
535        const VarDecl* Var = PInfo.getVar();
536
537        PropagationMap.insert(PairType(Call,
538          PropagationInfo(StateMap->getState(Var))));
539
540        StateMap->setState(Var, consumed::CS_Consumed);
541
542      } else {
543        PropagationMap.insert(PairType(Call, PInfo));
544      }
545
546    } else if (Constructor->isCopyConstructor()) {
547      MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
548
549      if (Entry != PropagationMap.end())
550        PropagationMap.insert(PairType(Call, Entry->second));
551
552    } else {
553      propagateReturnType(Call, Constructor, ThisType);
554    }
555  }
556}
557
558void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
559  const CXXMemberCallExpr *Call) {
560
561  VisitCallExpr(Call);
562
563  InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
564
565  if (Entry != PropagationMap.end()) {
566    PropagationInfo PInfo = Entry->second;
567    const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
568
569    checkCallability(PInfo, MethodDecl, Call);
570
571    if (PInfo.isVar()) {
572      if (isTestingFunction(MethodDecl))
573        handleTestingFunctionCall(Call, PInfo.getVar());
574      else if (MethodDecl->hasAttr<ConsumesAttr>())
575        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
576    }
577  }
578}
579
580void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
581  const CXXOperatorCallExpr *Call) {
582
583  const FunctionDecl *FunDecl =
584    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
585
586  if (!FunDecl) return;
587
588  if (isa<CXXMethodDecl>(FunDecl) &&
589      isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
590
591    InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
592    InfoEntry REntry = PropagationMap.find(Call->getArg(1));
593
594    PropagationInfo LPInfo, RPInfo;
595
596    if (LEntry != PropagationMap.end() &&
597        REntry != PropagationMap.end()) {
598
599      LPInfo = LEntry->second;
600      RPInfo = REntry->second;
601
602      if (LPInfo.isVar() && RPInfo.isVar()) {
603        StateMap->setState(LPInfo.getVar(),
604          StateMap->getState(RPInfo.getVar()));
605
606        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
607
608        PropagationMap.insert(PairType(Call, LPInfo));
609
610      } else if (LPInfo.isVar() && !RPInfo.isVar()) {
611        StateMap->setState(LPInfo.getVar(), RPInfo.getState());
612
613        PropagationMap.insert(PairType(Call, LPInfo));
614
615      } else if (!LPInfo.isVar() && RPInfo.isVar()) {
616        PropagationMap.insert(PairType(Call,
617          PropagationInfo(StateMap->getState(RPInfo.getVar()))));
618
619        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
620
621      } else {
622        PropagationMap.insert(PairType(Call, RPInfo));
623      }
624
625    } else if (LEntry != PropagationMap.end() &&
626               REntry == PropagationMap.end()) {
627
628      LPInfo = LEntry->second;
629
630      if (LPInfo.isVar()) {
631        StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
632
633        PropagationMap.insert(PairType(Call, LPInfo));
634
635      } else {
636        PropagationMap.insert(PairType(Call,
637          PropagationInfo(consumed::CS_Unknown)));
638      }
639
640    } else if (LEntry == PropagationMap.end() &&
641               REntry != PropagationMap.end()) {
642
643      RPInfo = REntry->second;
644
645      if (RPInfo.isVar()) {
646        const VarDecl *Var = RPInfo.getVar();
647
648        PropagationMap.insert(PairType(Call,
649          PropagationInfo(StateMap->getState(Var))));
650
651        StateMap->setState(Var, consumed::CS_Consumed);
652
653      } else {
654        PropagationMap.insert(PairType(Call, RPInfo));
655      }
656    }
657
658  } else {
659
660    VisitCallExpr(Call);
661
662    InfoEntry Entry = PropagationMap.find(Call->getArg(0));
663
664    if (Entry != PropagationMap.end()) {
665      PropagationInfo PInfo = Entry->second;
666
667      checkCallability(PInfo, FunDecl, Call);
668
669      if (PInfo.isVar()) {
670        if (isTestingFunction(FunDecl))
671          handleTestingFunctionCall(Call, PInfo.getVar());
672        else if (FunDecl->hasAttr<ConsumesAttr>())
673          StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
674      }
675    }
676  }
677}
678
679void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
680  if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
681    if (StateMap->getState(Var) != consumed::CS_None)
682      PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
683}
684
685void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
686  for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
687       DE = DeclS->decl_end(); DI != DE; ++DI) {
688
689    if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
690  }
691
692  if (DeclS->isSingleDecl())
693    if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
694      PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
695}
696
697void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
698  const MaterializeTemporaryExpr *Temp) {
699
700  InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
701
702  if (Entry != PropagationMap.end())
703    PropagationMap.insert(PairType(Temp, Entry->second));
704}
705
706void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
707  forwardInfo(MExpr->getBase(), MExpr);
708}
709
710
711void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
712  if (isConsumableType(Param->getType()))
713    StateMap->setState(Param, consumed::CS_Unknown);
714}
715
716void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
717  if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
718    InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
719
720    if (Entry != PropagationMap.end()) {
721      assert(Entry->second.isState() || Entry->second.isVar());
722
723      ConsumedState RetState = Entry->second.isState() ?
724        Entry->second.getState() : StateMap->getState(Entry->second.getVar());
725
726      if (RetState != ExpectedState)
727        Analyzer.WarningsHandler.warnReturnTypestateMismatch(
728          Ret->getReturnLoc(), stateToString(ExpectedState),
729          stateToString(RetState));
730    }
731  }
732}
733
734void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
735  InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
736  if (Entry == PropagationMap.end()) return;
737
738  switch (UOp->getOpcode()) {
739  case UO_AddrOf:
740    PropagationMap.insert(PairType(UOp, Entry->second));
741    break;
742
743  case UO_LNot:
744    if (Entry->second.isTest() || Entry->second.isBinTest())
745      PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
746    break;
747
748  default:
749    break;
750  }
751}
752
753void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
754  if (isConsumableType(Var->getType())) {
755    if (Var->hasInit()) {
756      PropagationInfo PInfo =
757        PropagationMap.find(Var->getInit())->second;
758
759      StateMap->setState(Var, PInfo.isVar() ?
760        StateMap->getState(PInfo.getVar()) : PInfo.getState());
761
762    } else {
763      StateMap->setState(Var, consumed::CS_Unknown);
764    }
765  }
766}
767}} // end clang::consumed::ConsumedStmtVisitor
768
769namespace clang {
770namespace consumed {
771
772void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
773                        ConsumedStateMap *ThenStates,
774                        ConsumedStateMap *ElseStates) {
775
776  ConsumedState VarState = ThenStates->getState(Test.Var);
777
778  if (VarState == CS_Unknown) {
779    ThenStates->setState(Test.Var, Test.TestsFor);
780    if (ElseStates)
781      ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
782
783  } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
784    ThenStates->markUnreachable();
785
786  } else if (VarState == Test.TestsFor && ElseStates) {
787    ElseStates->markUnreachable();
788  }
789}
790
791void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
792  ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
793
794  const VarTestResult &LTest = PInfo.getLTest(),
795                      &RTest = PInfo.getRTest();
796
797  ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
798                RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
799
800  if (LTest.Var) {
801    if (PInfo.testEffectiveOp() == EO_And) {
802      if (LState == CS_Unknown) {
803        ThenStates->setState(LTest.Var, LTest.TestsFor);
804
805      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
806        ThenStates->markUnreachable();
807
808      } else if (LState == LTest.TestsFor && isKnownState(RState)) {
809        if (RState == RTest.TestsFor) {
810          if (ElseStates)
811            ElseStates->markUnreachable();
812        } else {
813          ThenStates->markUnreachable();
814        }
815      }
816
817    } else {
818      if (LState == CS_Unknown && ElseStates) {
819        ElseStates->setState(LTest.Var,
820                             invertConsumedUnconsumed(LTest.TestsFor));
821
822      } else if (LState == LTest.TestsFor && ElseStates) {
823        ElseStates->markUnreachable();
824
825      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
826                 isKnownState(RState)) {
827
828        if (RState == RTest.TestsFor) {
829          if (ElseStates)
830            ElseStates->markUnreachable();
831        } else {
832          ThenStates->markUnreachable();
833        }
834      }
835    }
836  }
837
838  if (RTest.Var) {
839    if (PInfo.testEffectiveOp() == EO_And) {
840      if (RState == CS_Unknown)
841        ThenStates->setState(RTest.Var, RTest.TestsFor);
842      else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
843        ThenStates->markUnreachable();
844
845    } else if (ElseStates) {
846      if (RState == CS_Unknown)
847        ElseStates->setState(RTest.Var,
848                             invertConsumedUnconsumed(RTest.TestsFor));
849      else if (RState == RTest.TestsFor)
850        ElseStates->markUnreachable();
851    }
852  }
853}
854
855void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
856                                ConsumedStateMap *StateMap,
857                                bool &AlreadyOwned) {
858
859  if (VisitedBlocks.alreadySet(Block)) return;
860
861  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
862
863  if (Entry) {
864    Entry->intersect(StateMap);
865
866  } else if (AlreadyOwned) {
867    StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
868
869  } else {
870    StateMapsArray[Block->getBlockID()] = StateMap;
871    AlreadyOwned = true;
872  }
873}
874
875void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
876                                ConsumedStateMap *StateMap) {
877
878  if (VisitedBlocks.alreadySet(Block)) {
879    delete StateMap;
880    return;
881  }
882
883  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
884
885  if (Entry) {
886    Entry->intersect(StateMap);
887    delete StateMap;
888
889  } else {
890    StateMapsArray[Block->getBlockID()] = StateMap;
891  }
892}
893
894ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
895  return StateMapsArray[Block->getBlockID()];
896}
897
898void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
899  VisitedBlocks.insert(Block);
900}
901
902ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
903  MapType::const_iterator Entry = Map.find(Var);
904
905  if (Entry != Map.end()) {
906    return Entry->second;
907
908  } else {
909    return CS_None;
910  }
911}
912
913void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
914  ConsumedState LocalState;
915
916  if (this->From && this->From == Other->From && !Other->Reachable) {
917    this->markUnreachable();
918    return;
919  }
920
921  for (MapType::const_iterator DMI = Other->Map.begin(),
922       DME = Other->Map.end(); DMI != DME; ++DMI) {
923
924    LocalState = this->getState(DMI->first);
925
926    if (LocalState == CS_None)
927      continue;
928
929    if (LocalState != DMI->second)
930       Map[DMI->first] = CS_Unknown;
931  }
932}
933
934void ConsumedStateMap::markUnreachable() {
935  this->Reachable = false;
936  Map.clear();
937}
938
939void ConsumedStateMap::makeUnknown() {
940  for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
941       ++DMI) {
942
943    Map[DMI->first] = CS_Unknown;
944  }
945}
946
947void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
948  Map[Var] = State;
949}
950
951void ConsumedStateMap::remove(const VarDecl *Var) {
952  Map.erase(Var);
953}
954
955bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
956                                  const ConsumedStmtVisitor &Visitor) {
957
958  ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
959  PropagationInfo PInfo;
960
961  if (const IfStmt *IfNode =
962    dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
963
964    bool HasElse = IfNode->getElse() != NULL;
965    const Stmt *Cond = IfNode->getCond();
966
967    PInfo = Visitor.getInfo(Cond);
968    if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
969      PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
970
971    if (PInfo.isTest()) {
972      CurrStates->setSource(Cond);
973      FalseStates->setSource(Cond);
974
975      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
976                         HasElse ? FalseStates : NULL);
977
978    } else if (PInfo.isBinTest()) {
979      CurrStates->setSource(PInfo.testSourceNode());
980      FalseStates->setSource(PInfo.testSourceNode());
981
982      splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
983
984    } else {
985      delete FalseStates;
986      return false;
987    }
988
989  } else if (const BinaryOperator *BinOp =
990    dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
991
992    PInfo = Visitor.getInfo(BinOp->getLHS());
993    if (!PInfo.isTest()) {
994      if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
995        PInfo = Visitor.getInfo(BinOp->getRHS());
996
997        if (!PInfo.isTest()) {
998          delete FalseStates;
999          return false;
1000        }
1001
1002      } else {
1003        delete FalseStates;
1004        return false;
1005      }
1006    }
1007
1008    CurrStates->setSource(BinOp);
1009    FalseStates->setSource(BinOp);
1010
1011    const VarTestResult &Test = PInfo.getTest();
1012    ConsumedState VarState = CurrStates->getState(Test.Var);
1013
1014    if (BinOp->getOpcode() == BO_LAnd) {
1015      if (VarState == CS_Unknown)
1016        CurrStates->setState(Test.Var, Test.TestsFor);
1017      else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1018        CurrStates->markUnreachable();
1019
1020    } else if (BinOp->getOpcode() == BO_LOr) {
1021      if (VarState == CS_Unknown)
1022        FalseStates->setState(Test.Var,
1023                              invertConsumedUnconsumed(Test.TestsFor));
1024      else if (VarState == Test.TestsFor)
1025        FalseStates->markUnreachable();
1026    }
1027
1028  } else {
1029    delete FalseStates;
1030    return false;
1031  }
1032
1033  CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1034
1035  if (*SI)
1036    BlockInfo.addInfo(*SI, CurrStates);
1037  else
1038    delete CurrStates;
1039
1040  if (*++SI)
1041    BlockInfo.addInfo(*SI, FalseStates);
1042  else
1043    delete FalseStates;
1044
1045  CurrStates = NULL;
1046  return true;
1047}
1048
1049void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1050  const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1051
1052  if (!D) return;
1053
1054  // FIXME: This should be removed when template instantiation propagates
1055  //        attributes at template specialization definition, not declaration.
1056  //        When it is removed the test needs to be enabled in SemaDeclAttr.cpp.
1057  QualType ReturnType;
1058  if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1059    ASTContext &CurrContext = AC.getASTContext();
1060    ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1061
1062  } else {
1063    ReturnType = D->getCallResultType();
1064  }
1065
1066  // Determine the expected return value.
1067  if (D->hasAttr<ReturnTypestateAttr>()) {
1068
1069    ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
1070
1071    const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1072    if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1073        // FIXME: This branch can be removed with the code above.
1074        WarningsHandler.warnReturnTypestateForUnconsumableType(
1075          RTSAttr->getLocation(), ReturnType.getAsString());
1076        ExpectedReturnState = CS_None;
1077
1078    } else {
1079      switch (RTSAttr->getState()) {
1080      case ReturnTypestateAttr::Unknown:
1081        ExpectedReturnState = CS_Unknown;
1082        break;
1083
1084      case ReturnTypestateAttr::Unconsumed:
1085        ExpectedReturnState = CS_Unconsumed;
1086        break;
1087
1088      case ReturnTypestateAttr::Consumed:
1089        ExpectedReturnState = CS_Consumed;
1090        break;
1091      }
1092    }
1093
1094  } else if (isConsumableType(ReturnType)) {
1095    ExpectedReturnState = CS_Unknown;
1096
1097  } else {
1098    ExpectedReturnState = CS_None;
1099  }
1100
1101  BlockInfo = ConsumedBlockInfo(AC.getCFG());
1102
1103  PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1104
1105  CurrStates = new ConsumedStateMap();
1106  ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1107
1108  // Add all trackable parameters to the state map.
1109  for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1110       PE = D->param_end(); PI != PE; ++PI) {
1111    Visitor.VisitParmVarDecl(*PI);
1112  }
1113
1114  // Visit all of the function's basic blocks.
1115  for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1116       E = SortedGraph->end(); I != E; ++I) {
1117
1118    const CFGBlock *CurrBlock = *I;
1119    BlockInfo.markVisited(CurrBlock);
1120
1121    if (CurrStates == NULL)
1122      CurrStates = BlockInfo.getInfo(CurrBlock);
1123
1124    if (!CurrStates) {
1125      continue;
1126
1127    } else if (!CurrStates->isReachable()) {
1128      delete CurrStates;
1129      CurrStates = NULL;
1130      continue;
1131    }
1132
1133    Visitor.reset(CurrStates);
1134
1135    // Visit all of the basic block's statements.
1136    for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1137         BE = CurrBlock->end(); BI != BE; ++BI) {
1138
1139      switch (BI->getKind()) {
1140      case CFGElement::Statement:
1141        Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1142        break;
1143      case CFGElement::AutomaticObjectDtor:
1144        CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1145      default:
1146        break;
1147      }
1148    }
1149
1150    // TODO: Handle other forms of branching with precision, including while-
1151    //       and for-loops. (Deferred)
1152    if (!splitState(CurrBlock, Visitor)) {
1153      CurrStates->setSource(NULL);
1154
1155      if (CurrBlock->succ_size() > 1) {
1156        CurrStates->makeUnknown();
1157
1158        bool OwnershipTaken = false;
1159
1160        for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1161             SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1162
1163          if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1164        }
1165
1166        if (!OwnershipTaken)
1167          delete CurrStates;
1168
1169        CurrStates = NULL;
1170
1171      } else if (CurrBlock->succ_size() == 1 &&
1172                 (*CurrBlock->succ_begin())->pred_size() > 1) {
1173
1174        BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1175        CurrStates = NULL;
1176      }
1177    }
1178  } // End of block iterator.
1179
1180  // Delete the last existing state map.
1181  delete CurrStates;
1182
1183  WarningsHandler.emitDiagnostics();
1184}
1185}} // end namespace clang::consumed
1186