Consumed.cpp revision 57b781dde2676cd1bd838a1cdd56d3aeea091d11
1//===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// A intra-procedural analysis for checking consumed properties.  This is based,
11// in part, on research on linear types.
12//
13//===----------------------------------------------------------------------===//
14
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/Attr.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/ExprCXX.h"
19#include "clang/AST/RecursiveASTVisitor.h"
20#include "clang/AST/StmtVisitor.h"
21#include "clang/AST/StmtCXX.h"
22#include "clang/AST/Type.h"
23#include "clang/Analysis/Analyses/PostOrderCFGView.h"
24#include "clang/Analysis/AnalysisContext.h"
25#include "clang/Analysis/CFG.h"
26#include "clang/Analysis/Analyses/Consumed.h"
27#include "clang/Basic/OperatorKinds.h"
28#include "clang/Basic/SourceLocation.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Compiler.h"
32#include "llvm/Support/raw_ostream.h"
33
34// TODO: Add notes about the actual and expected state for
35// TODO: Correctly identify unreachable blocks when chaining boolean operators.
36// TODO: Warn about unreachable code.
37// TODO: Switch to using a bitmap to track unreachable blocks.
38// TODO: Mark variables as Unknown going into while- or for-loops only if they
39//       are referenced inside that block. (Deferred)
40// TODO: Handle variable definitions, e.g. bool valid = x.isValid();
41//       if (valid) ...; (Deferred)
42// TODO: Add a method(s) to identify which method calls perform what state
43//       transitions. (Deferred)
44// TODO: Take notes on state transitions to provide better warning messages.
45//       (Deferred)
46// TODO: Test nested conditionals: A) Checking the same value multiple times,
47//       and 2) Checking different values. (Deferred)
48
49using namespace clang;
50using namespace consumed;
51
52// Key method definition
53ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
54
55static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
56  switch (State) {
57  case CS_Unconsumed:
58    return CS_Consumed;
59  case CS_Consumed:
60    return CS_Unconsumed;
61  case CS_None:
62    return CS_None;
63  case CS_Unknown:
64    return CS_Unknown;
65  }
66  llvm_unreachable("invalid enum");
67}
68
69static bool isConsumableType(const QualType &QT) {
70  if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
71    return RD->hasAttr<ConsumableAttr>();
72  else
73    return false;
74}
75
76static bool isKnownState(ConsumedState State) {
77  switch (State) {
78  case CS_Unconsumed:
79  case CS_Consumed:
80    return true;
81  case CS_None:
82  case CS_Unknown:
83    return false;
84  }
85  llvm_unreachable("invalid enum");
86}
87
88static bool isTestingFunction(const FunctionDecl *FunDecl) {
89  return FunDecl->hasAttr<TestsUnconsumedAttr>();
90}
91
92static ConsumedState mapConsumableAttrState(const QualType QT) {
93  assert(isConsumableType(QT));
94
95  const ConsumableAttr *CAttr =
96      QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
97
98  switch (CAttr->getDefaultState()) {
99  case ConsumableAttr::Unknown:
100    return CS_Unknown;
101  case ConsumableAttr::Unconsumed:
102    return CS_Unconsumed;
103  case ConsumableAttr::Consumed:
104    return CS_Consumed;
105  }
106  llvm_unreachable("invalid enum");
107}
108
109static ConsumedState
110mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
111  switch (RTSAttr->getState()) {
112  case ReturnTypestateAttr::Unknown:
113    return CS_Unknown;
114  case ReturnTypestateAttr::Unconsumed:
115    return CS_Unconsumed;
116  case ReturnTypestateAttr::Consumed:
117    return CS_Consumed;
118  }
119  llvm_unreachable("invalid enum");
120}
121
122static StringRef stateToString(ConsumedState State) {
123  switch (State) {
124  case consumed::CS_None:
125    return "none";
126
127  case consumed::CS_Unknown:
128    return "unknown";
129
130  case consumed::CS_Unconsumed:
131    return "unconsumed";
132
133  case consumed::CS_Consumed:
134    return "consumed";
135  }
136  llvm_unreachable("invalid enum");
137}
138
139namespace {
140struct VarTestResult {
141  const VarDecl *Var;
142  ConsumedState TestsFor;
143};
144} // end anonymous::VarTestResult
145
146namespace clang {
147namespace consumed {
148
149enum EffectiveOp {
150  EO_And,
151  EO_Or
152};
153
154class PropagationInfo {
155  enum {
156    IT_None,
157    IT_State,
158    IT_Test,
159    IT_BinTest,
160    IT_Var
161  } InfoType;
162
163  struct BinTestTy {
164    const BinaryOperator *Source;
165    EffectiveOp EOp;
166    VarTestResult LTest;
167    VarTestResult RTest;
168  };
169
170  union {
171    ConsumedState State;
172    VarTestResult Test;
173    const VarDecl *Var;
174    BinTestTy BinTest;
175  };
176
177public:
178  PropagationInfo() : InfoType(IT_None) {}
179
180  PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
181  PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
182    : InfoType(IT_Test) {
183
184    Test.Var      = Var;
185    Test.TestsFor = TestsFor;
186  }
187
188  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
189                  const VarTestResult &LTest, const VarTestResult &RTest)
190    : InfoType(IT_BinTest) {
191
192    BinTest.Source  = Source;
193    BinTest.EOp     = EOp;
194    BinTest.LTest   = LTest;
195    BinTest.RTest   = RTest;
196  }
197
198  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
199                  const VarDecl *LVar, ConsumedState LTestsFor,
200                  const VarDecl *RVar, ConsumedState RTestsFor)
201    : InfoType(IT_BinTest) {
202
203    BinTest.Source         = Source;
204    BinTest.EOp            = EOp;
205    BinTest.LTest.Var      = LVar;
206    BinTest.LTest.TestsFor = LTestsFor;
207    BinTest.RTest.Var      = RVar;
208    BinTest.RTest.TestsFor = RTestsFor;
209  }
210
211  PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
212  PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
213
214  const ConsumedState & getState() const {
215    assert(InfoType == IT_State);
216    return State;
217  }
218
219  const VarTestResult & getTest() const {
220    assert(InfoType == IT_Test);
221    return Test;
222  }
223
224  const VarTestResult & getLTest() const {
225    assert(InfoType == IT_BinTest);
226    return BinTest.LTest;
227  }
228
229  const VarTestResult & getRTest() const {
230    assert(InfoType == IT_BinTest);
231    return BinTest.RTest;
232  }
233
234  const VarDecl * getVar() const {
235    assert(InfoType == IT_Var);
236    return Var;
237  }
238
239  EffectiveOp testEffectiveOp() const {
240    assert(InfoType == IT_BinTest);
241    return BinTest.EOp;
242  }
243
244  const BinaryOperator * testSourceNode() const {
245    assert(InfoType == IT_BinTest);
246    return BinTest.Source;
247  }
248
249  bool isValid()   const { return InfoType != IT_None;     }
250  bool isState()   const { return InfoType == IT_State;    }
251  bool isTest()    const { return InfoType == IT_Test;     }
252  bool isBinTest() const { return InfoType == IT_BinTest;  }
253  bool isVar()     const { return InfoType == IT_Var;      }
254
255  PropagationInfo invertTest() const {
256    assert(InfoType == IT_Test || InfoType == IT_BinTest);
257
258    if (InfoType == IT_Test) {
259      return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
260
261    } else if (InfoType == IT_BinTest) {
262      return PropagationInfo(BinTest.Source,
263        BinTest.EOp == EO_And ? EO_Or : EO_And,
264        BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
265        BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
266    } else {
267      return PropagationInfo();
268    }
269  }
270};
271
272class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
273
274  typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
275  typedef std::pair<const Stmt *, PropagationInfo> PairType;
276  typedef MapType::iterator InfoEntry;
277  typedef MapType::const_iterator ConstInfoEntry;
278
279  AnalysisDeclContext &AC;
280  ConsumedAnalyzer &Analyzer;
281  ConsumedStateMap *StateMap;
282  MapType PropagationMap;
283
284  void checkCallability(const PropagationInfo &PInfo,
285                        const FunctionDecl *FunDecl,
286                        const CallExpr *Call);
287  void forwardInfo(const Stmt *From, const Stmt *To);
288  void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
289  bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
290  void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
291                           QualType ReturnType);
292
293public:
294
295  void Visit(const Stmt *StmtNode);
296
297  void VisitBinaryOperator(const BinaryOperator *BinOp);
298  void VisitCallExpr(const CallExpr *Call);
299  void VisitCastExpr(const CastExpr *Cast);
300  void VisitCXXConstructExpr(const CXXConstructExpr *Call);
301  void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
302  void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
303  void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
304  void VisitDeclStmt(const DeclStmt *DelcS);
305  void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
306  void VisitMemberExpr(const MemberExpr *MExpr);
307  void VisitParmVarDecl(const ParmVarDecl *Param);
308  void VisitReturnStmt(const ReturnStmt *Ret);
309  void VisitUnaryOperator(const UnaryOperator *UOp);
310  void VisitVarDecl(const VarDecl *Var);
311
312  ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
313                      ConsumedStateMap *StateMap)
314      : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
315
316  PropagationInfo getInfo(const Stmt *StmtNode) const {
317    ConstInfoEntry Entry = PropagationMap.find(StmtNode);
318
319    if (Entry != PropagationMap.end())
320      return Entry->second;
321    else
322      return PropagationInfo();
323  }
324
325  void reset(ConsumedStateMap *NewStateMap) {
326    StateMap = NewStateMap;
327  }
328};
329
330// TODO: When we support CallableWhenConsumed this will have to check for
331//       the different attributes and change the behavior bellow. (Deferred)
332void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
333                                           const FunctionDecl *FunDecl,
334                                           const CallExpr *Call) {
335
336  if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
337
338  if (PInfo.isVar()) {
339    const VarDecl *Var = PInfo.getVar();
340
341    switch (StateMap->getState(Var)) {
342    case CS_Consumed:
343      Analyzer.WarningsHandler.warnUseWhileConsumed(
344        FunDecl->getNameAsString(), Var->getNameAsString(),
345        Call->getExprLoc());
346      break;
347
348    case CS_Unknown:
349      Analyzer.WarningsHandler.warnUseInUnknownState(
350        FunDecl->getNameAsString(), Var->getNameAsString(),
351        Call->getExprLoc());
352      break;
353
354    case CS_None:
355    case CS_Unconsumed:
356      break;
357    }
358
359  } else {
360    switch (PInfo.getState()) {
361    case CS_Consumed:
362      Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
363        FunDecl->getNameAsString(), Call->getExprLoc());
364      break;
365
366    case CS_Unknown:
367      Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
368        FunDecl->getNameAsString(), Call->getExprLoc());
369      break;
370
371    case CS_None:
372    case CS_Unconsumed:
373      break;
374    }
375  }
376}
377
378void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
379  InfoEntry Entry = PropagationMap.find(From);
380
381  if (Entry != PropagationMap.end())
382    PropagationMap.insert(PairType(To, Entry->second));
383}
384
385void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
386                                                    const VarDecl  *Var) {
387
388  ConsumedState VarState = StateMap->getState(Var);
389
390  if (VarState != CS_Unknown) {
391    SourceLocation CallLoc = Call->getExprLoc();
392
393    if (!CallLoc.isMacroID())
394      Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
395        stateToString(VarState), CallLoc);
396  }
397
398  PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
399}
400
401bool ConsumedStmtVisitor::isLikeMoveAssignment(
402  const CXXMethodDecl *MethodDecl) {
403
404  return MethodDecl->isMoveAssignmentOperator() ||
405         (MethodDecl->getOverloadedOperator() == OO_Equal &&
406          MethodDecl->getNumParams() == 1 &&
407          MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
408}
409
410void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
411                                              const FunctionDecl *Fun,
412                                              QualType ReturnType) {
413  if (isConsumableType(ReturnType)) {
414
415    ConsumedState ReturnState;
416
417    if (Fun->hasAttr<ReturnTypestateAttr>())
418      ReturnState = mapReturnTypestateAttrState(
419        Fun->getAttr<ReturnTypestateAttr>());
420    else
421      ReturnState = mapConsumableAttrState(ReturnType);
422
423    PropagationMap.insert(PairType(Call,
424      PropagationInfo(ReturnState)));
425  }
426}
427
428void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
429
430  ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
431
432  for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
433       CE = StmtNode->child_end(); CI != CE; ++CI) {
434
435    PropagationMap.erase(*CI);
436  }
437}
438
439void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
440  switch (BinOp->getOpcode()) {
441  case BO_LAnd:
442  case BO_LOr : {
443    InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
444              REntry = PropagationMap.find(BinOp->getRHS());
445
446    VarTestResult LTest, RTest;
447
448    if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
449      LTest = LEntry->second.getTest();
450
451    } else {
452      LTest.Var      = NULL;
453      LTest.TestsFor = CS_None;
454    }
455
456    if (REntry != PropagationMap.end() && REntry->second.isTest()) {
457      RTest = REntry->second.getTest();
458
459    } else {
460      RTest.Var      = NULL;
461      RTest.TestsFor = CS_None;
462    }
463
464    if (!(LTest.Var == NULL && RTest.Var == NULL))
465      PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
466        static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
467
468    break;
469  }
470
471  case BO_PtrMemD:
472  case BO_PtrMemI:
473    forwardInfo(BinOp->getLHS(), BinOp);
474    break;
475
476  default:
477    break;
478  }
479}
480
481void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
482  if (const FunctionDecl *FunDecl =
483    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
484
485    // Special case for the std::move function.
486    // TODO: Make this more specific. (Deferred)
487    if (FunDecl->getNameAsString() == "move") {
488      InfoEntry Entry = PropagationMap.find(Call->getArg(0));
489
490      if (Entry != PropagationMap.end()) {
491        PropagationMap.insert(PairType(Call, Entry->second));
492      }
493
494      return;
495    }
496
497    unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
498
499    for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
500      QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
501
502      InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
503
504      if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
505        continue;
506      }
507
508      PropagationInfo PInfo = Entry->second;
509
510      if (ParamType->isRValueReferenceType() ||
511          (ParamType->isLValueReferenceType() &&
512           !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
513
514        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
515
516      } else if (!(ParamType.isConstQualified() ||
517                   ((ParamType->isReferenceType() ||
518                     ParamType->isPointerType()) &&
519                    ParamType->getPointeeType().isConstQualified()))) {
520
521        StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
522      }
523    }
524
525    propagateReturnType(Call, FunDecl, FunDecl->getCallResultType());
526  }
527}
528
529void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
530  forwardInfo(Cast->getSubExpr(), Cast);
531}
532
533void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
534  CXXConstructorDecl *Constructor = Call->getConstructor();
535
536  ASTContext &CurrContext = AC.getASTContext();
537  QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
538
539  if (isConsumableType(ThisType)) {
540    if (Constructor->isDefaultConstructor()) {
541
542      PropagationMap.insert(PairType(Call,
543        PropagationInfo(consumed::CS_Consumed)));
544
545    } else if (Constructor->isMoveConstructor()) {
546
547      PropagationInfo PInfo =
548        PropagationMap.find(Call->getArg(0))->second;
549
550      if (PInfo.isVar()) {
551        const VarDecl* Var = PInfo.getVar();
552
553        PropagationMap.insert(PairType(Call,
554          PropagationInfo(StateMap->getState(Var))));
555
556        StateMap->setState(Var, consumed::CS_Consumed);
557
558      } else {
559        PropagationMap.insert(PairType(Call, PInfo));
560      }
561
562    } else if (Constructor->isCopyConstructor()) {
563      MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
564
565      if (Entry != PropagationMap.end())
566        PropagationMap.insert(PairType(Call, Entry->second));
567
568    } else {
569      propagateReturnType(Call, Constructor, ThisType);
570    }
571  }
572}
573
574void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
575  const CXXMemberCallExpr *Call) {
576
577  VisitCallExpr(Call);
578
579  InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
580
581  if (Entry != PropagationMap.end()) {
582    PropagationInfo PInfo = Entry->second;
583    const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
584
585    checkCallability(PInfo, MethodDecl, Call);
586
587    if (PInfo.isVar()) {
588      if (isTestingFunction(MethodDecl))
589        handleTestingFunctionCall(Call, PInfo.getVar());
590      else if (MethodDecl->hasAttr<ConsumesAttr>())
591        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
592    }
593  }
594}
595
596void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
597  const CXXOperatorCallExpr *Call) {
598
599  const FunctionDecl *FunDecl =
600    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
601
602  if (!FunDecl) return;
603
604  if (isa<CXXMethodDecl>(FunDecl) &&
605      isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
606
607    InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
608    InfoEntry REntry = PropagationMap.find(Call->getArg(1));
609
610    PropagationInfo LPInfo, RPInfo;
611
612    if (LEntry != PropagationMap.end() &&
613        REntry != PropagationMap.end()) {
614
615      LPInfo = LEntry->second;
616      RPInfo = REntry->second;
617
618      if (LPInfo.isVar() && RPInfo.isVar()) {
619        StateMap->setState(LPInfo.getVar(),
620          StateMap->getState(RPInfo.getVar()));
621
622        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
623
624        PropagationMap.insert(PairType(Call, LPInfo));
625
626      } else if (LPInfo.isVar() && !RPInfo.isVar()) {
627        StateMap->setState(LPInfo.getVar(), RPInfo.getState());
628
629        PropagationMap.insert(PairType(Call, LPInfo));
630
631      } else if (!LPInfo.isVar() && RPInfo.isVar()) {
632        PropagationMap.insert(PairType(Call,
633          PropagationInfo(StateMap->getState(RPInfo.getVar()))));
634
635        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
636
637      } else {
638        PropagationMap.insert(PairType(Call, RPInfo));
639      }
640
641    } else if (LEntry != PropagationMap.end() &&
642               REntry == PropagationMap.end()) {
643
644      LPInfo = LEntry->second;
645
646      if (LPInfo.isVar()) {
647        StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
648
649        PropagationMap.insert(PairType(Call, LPInfo));
650
651      } else {
652        PropagationMap.insert(PairType(Call,
653          PropagationInfo(consumed::CS_Unknown)));
654      }
655
656    } else if (LEntry == PropagationMap.end() &&
657               REntry != PropagationMap.end()) {
658
659      RPInfo = REntry->second;
660
661      if (RPInfo.isVar()) {
662        const VarDecl *Var = RPInfo.getVar();
663
664        PropagationMap.insert(PairType(Call,
665          PropagationInfo(StateMap->getState(Var))));
666
667        StateMap->setState(Var, consumed::CS_Consumed);
668
669      } else {
670        PropagationMap.insert(PairType(Call, RPInfo));
671      }
672    }
673
674  } else {
675
676    VisitCallExpr(Call);
677
678    InfoEntry Entry = PropagationMap.find(Call->getArg(0));
679
680    if (Entry != PropagationMap.end()) {
681      PropagationInfo PInfo = Entry->second;
682
683      checkCallability(PInfo, FunDecl, Call);
684
685      if (PInfo.isVar()) {
686        if (isTestingFunction(FunDecl))
687          handleTestingFunctionCall(Call, PInfo.getVar());
688        else if (FunDecl->hasAttr<ConsumesAttr>())
689          StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
690      }
691    }
692  }
693}
694
695void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
696  if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
697    if (StateMap->getState(Var) != consumed::CS_None)
698      PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
699}
700
701void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
702  for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
703       DE = DeclS->decl_end(); DI != DE; ++DI) {
704
705    if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
706  }
707
708  if (DeclS->isSingleDecl())
709    if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
710      PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
711}
712
713void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
714  const MaterializeTemporaryExpr *Temp) {
715
716  InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
717
718  if (Entry != PropagationMap.end())
719    PropagationMap.insert(PairType(Temp, Entry->second));
720}
721
722void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
723  forwardInfo(MExpr->getBase(), MExpr);
724}
725
726
727void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
728  QualType ParamType = Param->getType();
729  ConsumedState ParamState = consumed::CS_None;
730
731  if (!(ParamType->isPointerType() || ParamType->isReferenceType()) &&
732      isConsumableType(ParamType))
733    ParamState = mapConsumableAttrState(ParamType);
734  else if (ParamType->isReferenceType() &&
735           isConsumableType(ParamType->getPointeeType()))
736    ParamState = consumed::CS_Unknown;
737
738  if (ParamState)
739    StateMap->setState(Param, ParamState);
740}
741
742void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
743  if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
744    InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
745
746    if (Entry != PropagationMap.end()) {
747      assert(Entry->second.isState() || Entry->second.isVar());
748
749      ConsumedState RetState = Entry->second.isState() ?
750        Entry->second.getState() : StateMap->getState(Entry->second.getVar());
751
752      if (RetState != ExpectedState)
753        Analyzer.WarningsHandler.warnReturnTypestateMismatch(
754          Ret->getReturnLoc(), stateToString(ExpectedState),
755          stateToString(RetState));
756    }
757  }
758}
759
760void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
761  InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
762  if (Entry == PropagationMap.end()) return;
763
764  switch (UOp->getOpcode()) {
765  case UO_AddrOf:
766    PropagationMap.insert(PairType(UOp, Entry->second));
767    break;
768
769  case UO_LNot:
770    if (Entry->second.isTest() || Entry->second.isBinTest())
771      PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
772    break;
773
774  default:
775    break;
776  }
777}
778
779void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
780  if (isConsumableType(Var->getType())) {
781    if (Var->hasInit()) {
782      PropagationInfo PInfo =
783        PropagationMap.find(Var->getInit())->second;
784
785      StateMap->setState(Var, PInfo.isVar() ?
786        StateMap->getState(PInfo.getVar()) : PInfo.getState());
787
788    } else {
789      StateMap->setState(Var, consumed::CS_Unknown);
790    }
791  }
792}
793}} // end clang::consumed::ConsumedStmtVisitor
794
795namespace clang {
796namespace consumed {
797
798void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
799                        ConsumedStateMap *ThenStates,
800                        ConsumedStateMap *ElseStates) {
801
802  ConsumedState VarState = ThenStates->getState(Test.Var);
803
804  if (VarState == CS_Unknown) {
805    ThenStates->setState(Test.Var, Test.TestsFor);
806    if (ElseStates)
807      ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
808
809  } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
810    ThenStates->markUnreachable();
811
812  } else if (VarState == Test.TestsFor && ElseStates) {
813    ElseStates->markUnreachable();
814  }
815}
816
817void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
818  ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
819
820  const VarTestResult &LTest = PInfo.getLTest(),
821                      &RTest = PInfo.getRTest();
822
823  ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
824                RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
825
826  if (LTest.Var) {
827    if (PInfo.testEffectiveOp() == EO_And) {
828      if (LState == CS_Unknown) {
829        ThenStates->setState(LTest.Var, LTest.TestsFor);
830
831      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
832        ThenStates->markUnreachable();
833
834      } else if (LState == LTest.TestsFor && isKnownState(RState)) {
835        if (RState == RTest.TestsFor) {
836          if (ElseStates)
837            ElseStates->markUnreachable();
838        } else {
839          ThenStates->markUnreachable();
840        }
841      }
842
843    } else {
844      if (LState == CS_Unknown && ElseStates) {
845        ElseStates->setState(LTest.Var,
846                             invertConsumedUnconsumed(LTest.TestsFor));
847
848      } else if (LState == LTest.TestsFor && ElseStates) {
849        ElseStates->markUnreachable();
850
851      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
852                 isKnownState(RState)) {
853
854        if (RState == RTest.TestsFor) {
855          if (ElseStates)
856            ElseStates->markUnreachable();
857        } else {
858          ThenStates->markUnreachable();
859        }
860      }
861    }
862  }
863
864  if (RTest.Var) {
865    if (PInfo.testEffectiveOp() == EO_And) {
866      if (RState == CS_Unknown)
867        ThenStates->setState(RTest.Var, RTest.TestsFor);
868      else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
869        ThenStates->markUnreachable();
870
871    } else if (ElseStates) {
872      if (RState == CS_Unknown)
873        ElseStates->setState(RTest.Var,
874                             invertConsumedUnconsumed(RTest.TestsFor));
875      else if (RState == RTest.TestsFor)
876        ElseStates->markUnreachable();
877    }
878  }
879}
880
881void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
882                                ConsumedStateMap *StateMap,
883                                bool &AlreadyOwned) {
884
885  if (VisitedBlocks.alreadySet(Block)) return;
886
887  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
888
889  if (Entry) {
890    Entry->intersect(StateMap);
891
892  } else if (AlreadyOwned) {
893    StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
894
895  } else {
896    StateMapsArray[Block->getBlockID()] = StateMap;
897    AlreadyOwned = true;
898  }
899}
900
901void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
902                                ConsumedStateMap *StateMap) {
903
904  if (VisitedBlocks.alreadySet(Block)) {
905    delete StateMap;
906    return;
907  }
908
909  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
910
911  if (Entry) {
912    Entry->intersect(StateMap);
913    delete StateMap;
914
915  } else {
916    StateMapsArray[Block->getBlockID()] = StateMap;
917  }
918}
919
920ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
921  return StateMapsArray[Block->getBlockID()];
922}
923
924void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
925  VisitedBlocks.insert(Block);
926}
927
928ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
929  MapType::const_iterator Entry = Map.find(Var);
930
931  if (Entry != Map.end()) {
932    return Entry->second;
933
934  } else {
935    return CS_None;
936  }
937}
938
939void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
940  ConsumedState LocalState;
941
942  if (this->From && this->From == Other->From && !Other->Reachable) {
943    this->markUnreachable();
944    return;
945  }
946
947  for (MapType::const_iterator DMI = Other->Map.begin(),
948       DME = Other->Map.end(); DMI != DME; ++DMI) {
949
950    LocalState = this->getState(DMI->first);
951
952    if (LocalState == CS_None)
953      continue;
954
955    if (LocalState != DMI->second)
956       Map[DMI->first] = CS_Unknown;
957  }
958}
959
960void ConsumedStateMap::markUnreachable() {
961  this->Reachable = false;
962  Map.clear();
963}
964
965void ConsumedStateMap::makeUnknown() {
966  for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
967       ++DMI) {
968
969    Map[DMI->first] = CS_Unknown;
970  }
971}
972
973void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
974  Map[Var] = State;
975}
976
977void ConsumedStateMap::remove(const VarDecl *Var) {
978  Map.erase(Var);
979}
980
981void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
982                                                    const FunctionDecl *D) {
983  QualType ReturnType;
984  if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
985    ASTContext &CurrContext = AC.getASTContext();
986    ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
987  } else
988    ReturnType = D->getCallResultType();
989
990  if (D->hasAttr<ReturnTypestateAttr>()) {
991    const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
992
993    const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
994    if (!RD || !RD->hasAttr<ConsumableAttr>()) {
995      // FIXME: This should be removed when template instantiation propagates
996      //        attributes at template specialization definition, not
997      //        declaration. When it is removed the test needs to be enabled
998      //        in SemaDeclAttr.cpp.
999      WarningsHandler.warnReturnTypestateForUnconsumableType(
1000          RTSAttr->getLocation(), ReturnType.getAsString());
1001      ExpectedReturnState = CS_None;
1002    } else
1003      ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1004  } else if (isConsumableType(ReturnType))
1005    ExpectedReturnState = mapConsumableAttrState(ReturnType);
1006  else
1007    ExpectedReturnState = CS_None;
1008}
1009
1010bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1011                                  const ConsumedStmtVisitor &Visitor) {
1012
1013  ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
1014  PropagationInfo PInfo;
1015
1016  if (const IfStmt *IfNode =
1017    dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1018
1019    bool HasElse = IfNode->getElse() != NULL;
1020    const Stmt *Cond = IfNode->getCond();
1021
1022    PInfo = Visitor.getInfo(Cond);
1023    if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1024      PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1025
1026    if (PInfo.isTest()) {
1027      CurrStates->setSource(Cond);
1028      FalseStates->setSource(Cond);
1029
1030      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
1031                         HasElse ? FalseStates : NULL);
1032
1033    } else if (PInfo.isBinTest()) {
1034      CurrStates->setSource(PInfo.testSourceNode());
1035      FalseStates->setSource(PInfo.testSourceNode());
1036
1037      splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
1038
1039    } else {
1040      delete FalseStates;
1041      return false;
1042    }
1043
1044  } else if (const BinaryOperator *BinOp =
1045    dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1046
1047    PInfo = Visitor.getInfo(BinOp->getLHS());
1048    if (!PInfo.isTest()) {
1049      if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1050        PInfo = Visitor.getInfo(BinOp->getRHS());
1051
1052        if (!PInfo.isTest()) {
1053          delete FalseStates;
1054          return false;
1055        }
1056
1057      } else {
1058        delete FalseStates;
1059        return false;
1060      }
1061    }
1062
1063    CurrStates->setSource(BinOp);
1064    FalseStates->setSource(BinOp);
1065
1066    const VarTestResult &Test = PInfo.getTest();
1067    ConsumedState VarState = CurrStates->getState(Test.Var);
1068
1069    if (BinOp->getOpcode() == BO_LAnd) {
1070      if (VarState == CS_Unknown)
1071        CurrStates->setState(Test.Var, Test.TestsFor);
1072      else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1073        CurrStates->markUnreachable();
1074
1075    } else if (BinOp->getOpcode() == BO_LOr) {
1076      if (VarState == CS_Unknown)
1077        FalseStates->setState(Test.Var,
1078                              invertConsumedUnconsumed(Test.TestsFor));
1079      else if (VarState == Test.TestsFor)
1080        FalseStates->markUnreachable();
1081    }
1082
1083  } else {
1084    delete FalseStates;
1085    return false;
1086  }
1087
1088  CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1089
1090  if (*SI)
1091    BlockInfo.addInfo(*SI, CurrStates);
1092  else
1093    delete CurrStates;
1094
1095  if (*++SI)
1096    BlockInfo.addInfo(*SI, FalseStates);
1097  else
1098    delete FalseStates;
1099
1100  CurrStates = NULL;
1101  return true;
1102}
1103
1104void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1105  const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1106  if (!D)
1107    return;
1108
1109  CFG *CFGraph = AC.getCFG();
1110  if (!CFGraph)
1111    return;
1112
1113  determineExpectedReturnState(AC, D);
1114
1115  BlockInfo = ConsumedBlockInfo(CFGraph);
1116
1117  PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1118
1119  CurrStates = new ConsumedStateMap();
1120  ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1121
1122  // Add all trackable parameters to the state map.
1123  for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1124       PE = D->param_end(); PI != PE; ++PI) {
1125    Visitor.VisitParmVarDecl(*PI);
1126  }
1127
1128  // Visit all of the function's basic blocks.
1129  for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1130       E = SortedGraph->end(); I != E; ++I) {
1131
1132    const CFGBlock *CurrBlock = *I;
1133    BlockInfo.markVisited(CurrBlock);
1134
1135    if (CurrStates == NULL)
1136      CurrStates = BlockInfo.getInfo(CurrBlock);
1137
1138    if (!CurrStates) {
1139      continue;
1140
1141    } else if (!CurrStates->isReachable()) {
1142      delete CurrStates;
1143      CurrStates = NULL;
1144      continue;
1145    }
1146
1147    Visitor.reset(CurrStates);
1148
1149    // Visit all of the basic block's statements.
1150    for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1151         BE = CurrBlock->end(); BI != BE; ++BI) {
1152
1153      switch (BI->getKind()) {
1154      case CFGElement::Statement:
1155        Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1156        break;
1157      case CFGElement::AutomaticObjectDtor:
1158        CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1159      default:
1160        break;
1161      }
1162    }
1163
1164    // TODO: Handle other forms of branching with precision, including while-
1165    //       and for-loops. (Deferred)
1166    if (!splitState(CurrBlock, Visitor)) {
1167      CurrStates->setSource(NULL);
1168
1169      if (CurrBlock->succ_size() > 1) {
1170        CurrStates->makeUnknown();
1171
1172        bool OwnershipTaken = false;
1173
1174        for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1175             SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1176
1177          if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1178        }
1179
1180        if (!OwnershipTaken)
1181          delete CurrStates;
1182
1183        CurrStates = NULL;
1184
1185      } else if (CurrBlock->succ_size() == 1 &&
1186                 (*CurrBlock->succ_begin())->pred_size() > 1) {
1187
1188        BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1189        CurrStates = NULL;
1190      }
1191    }
1192  } // End of block iterator.
1193
1194  // Delete the last existing state map.
1195  delete CurrStates;
1196
1197  WarningsHandler.emitDiagnostics();
1198}
1199}} // end namespace clang::consumed
1200