Consumed.cpp revision 66540857c08de7f1be9bea48381548d3942cf9d1
1//===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// A intra-procedural analysis for checking consumed properties.  This is based,
11// in part, on research on linear types.
12//
13//===----------------------------------------------------------------------===//
14
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/Attr.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/ExprCXX.h"
19#include "clang/AST/RecursiveASTVisitor.h"
20#include "clang/AST/StmtVisitor.h"
21#include "clang/AST/StmtCXX.h"
22#include "clang/AST/Type.h"
23#include "clang/Analysis/Analyses/PostOrderCFGView.h"
24#include "clang/Analysis/AnalysisContext.h"
25#include "clang/Analysis/CFG.h"
26#include "clang/Analysis/Analyses/Consumed.h"
27#include "clang/Basic/OperatorKinds.h"
28#include "clang/Basic/SourceLocation.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Compiler.h"
32#include "llvm/Support/raw_ostream.h"
33
34// TODO: Add notes about the actual and expected state for
35// TODO: Correctly identify unreachable blocks when chaining boolean operators.
36// TODO: Adjust the parser and AttributesList class to support lists of
37//       identifiers.
38// TODO: Warn about unreachable code.
39// TODO: Switch to using a bitmap to track unreachable blocks.
40// TODO: Mark variables as Unknown going into while- or for-loops only if they
41//       are referenced inside that block. (Deferred)
42// TODO: Handle variable definitions, e.g. bool valid = x.isValid();
43//       if (valid) ...; (Deferred)
44// TODO: Add a method(s) to identify which method calls perform what state
45//       transitions. (Deferred)
46// TODO: Take notes on state transitions to provide better warning messages.
47//       (Deferred)
48// TODO: Test nested conditionals: A) Checking the same value multiple times,
49//       and 2) Checking different values. (Deferred)
50
51using namespace clang;
52using namespace consumed;
53
54// Key method definition
55ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
56
57static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
58  switch (State) {
59  case CS_Unconsumed:
60    return CS_Consumed;
61  case CS_Consumed:
62    return CS_Unconsumed;
63  case CS_None:
64    return CS_None;
65  case CS_Unknown:
66    return CS_Unknown;
67  }
68  llvm_unreachable("invalid enum");
69}
70
71static bool isCallableInState(const CallableWhenAttr *CWAttr,
72                              ConsumedState State) {
73
74  CallableWhenAttr::callableState_iterator I = CWAttr->callableState_begin(),
75                                           E = CWAttr->callableState_end();
76
77  for (; I != E; ++I) {
78
79    ConsumedState MappedAttrState = CS_None;
80
81    switch (*I) {
82    case CallableWhenAttr::Unknown:
83      MappedAttrState = CS_Unknown;
84      break;
85
86    case CallableWhenAttr::Unconsumed:
87      MappedAttrState = CS_Unconsumed;
88      break;
89
90    case CallableWhenAttr::Consumed:
91      MappedAttrState = CS_Consumed;
92      break;
93    }
94
95    if (MappedAttrState == State)
96      return true;
97  }
98
99  return false;
100}
101
102static bool isConsumableType(const QualType &QT) {
103  if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
104    return RD->hasAttr<ConsumableAttr>();
105  else
106    return false;
107}
108
109static bool isKnownState(ConsumedState State) {
110  switch (State) {
111  case CS_Unconsumed:
112  case CS_Consumed:
113    return true;
114  case CS_None:
115  case CS_Unknown:
116    return false;
117  }
118  llvm_unreachable("invalid enum");
119}
120
121static bool isTestingFunction(const FunctionDecl *FunDecl) {
122  return FunDecl->hasAttr<TestsUnconsumedAttr>();
123}
124
125static ConsumedState mapConsumableAttrState(const QualType QT) {
126  assert(isConsumableType(QT));
127
128  const ConsumableAttr *CAttr =
129      QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();
130
131  switch (CAttr->getDefaultState()) {
132  case ConsumableAttr::Unknown:
133    return CS_Unknown;
134  case ConsumableAttr::Unconsumed:
135    return CS_Unconsumed;
136  case ConsumableAttr::Consumed:
137    return CS_Consumed;
138  }
139  llvm_unreachable("invalid enum");
140}
141
142static ConsumedState
143mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {
144  switch (RTSAttr->getState()) {
145  case ReturnTypestateAttr::Unknown:
146    return CS_Unknown;
147  case ReturnTypestateAttr::Unconsumed:
148    return CS_Unconsumed;
149  case ReturnTypestateAttr::Consumed:
150    return CS_Consumed;
151  }
152  llvm_unreachable("invalid enum");
153}
154
155static StringRef stateToString(ConsumedState State) {
156  switch (State) {
157  case consumed::CS_None:
158    return "none";
159
160  case consumed::CS_Unknown:
161    return "unknown";
162
163  case consumed::CS_Unconsumed:
164    return "unconsumed";
165
166  case consumed::CS_Consumed:
167    return "consumed";
168  }
169  llvm_unreachable("invalid enum");
170}
171
172namespace {
173struct VarTestResult {
174  const VarDecl *Var;
175  ConsumedState TestsFor;
176};
177} // end anonymous::VarTestResult
178
179namespace clang {
180namespace consumed {
181
182enum EffectiveOp {
183  EO_And,
184  EO_Or
185};
186
187class PropagationInfo {
188  enum {
189    IT_None,
190    IT_State,
191    IT_Test,
192    IT_BinTest,
193    IT_Var
194  } InfoType;
195
196  struct BinTestTy {
197    const BinaryOperator *Source;
198    EffectiveOp EOp;
199    VarTestResult LTest;
200    VarTestResult RTest;
201  };
202
203  union {
204    ConsumedState State;
205    VarTestResult Test;
206    const VarDecl *Var;
207    BinTestTy BinTest;
208  };
209
210  QualType TempType;
211
212public:
213  PropagationInfo() : InfoType(IT_None) {}
214
215  PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
216  PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
217    : InfoType(IT_Test) {
218
219    Test.Var      = Var;
220    Test.TestsFor = TestsFor;
221  }
222
223  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
224                  const VarTestResult &LTest, const VarTestResult &RTest)
225    : InfoType(IT_BinTest) {
226
227    BinTest.Source  = Source;
228    BinTest.EOp     = EOp;
229    BinTest.LTest   = LTest;
230    BinTest.RTest   = RTest;
231  }
232
233  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
234                  const VarDecl *LVar, ConsumedState LTestsFor,
235                  const VarDecl *RVar, ConsumedState RTestsFor)
236    : InfoType(IT_BinTest) {
237
238    BinTest.Source         = Source;
239    BinTest.EOp            = EOp;
240    BinTest.LTest.Var      = LVar;
241    BinTest.LTest.TestsFor = LTestsFor;
242    BinTest.RTest.Var      = RVar;
243    BinTest.RTest.TestsFor = RTestsFor;
244  }
245
246  PropagationInfo(ConsumedState State, QualType TempType)
247    : InfoType(IT_State), State(State), TempType(TempType) {}
248
249  PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
250
251  const ConsumedState & getState() const {
252    assert(InfoType == IT_State);
253    return State;
254  }
255
256  const QualType & getTempType() const {
257    assert(InfoType == IT_State);
258    return TempType;
259  }
260
261  const VarTestResult & getTest() const {
262    assert(InfoType == IT_Test);
263    return Test;
264  }
265
266  const VarTestResult & getLTest() const {
267    assert(InfoType == IT_BinTest);
268    return BinTest.LTest;
269  }
270
271  const VarTestResult & getRTest() const {
272    assert(InfoType == IT_BinTest);
273    return BinTest.RTest;
274  }
275
276  const VarDecl * getVar() const {
277    assert(InfoType == IT_Var);
278    return Var;
279  }
280
281  EffectiveOp testEffectiveOp() const {
282    assert(InfoType == IT_BinTest);
283    return BinTest.EOp;
284  }
285
286  const BinaryOperator * testSourceNode() const {
287    assert(InfoType == IT_BinTest);
288    return BinTest.Source;
289  }
290
291  bool isValid()   const { return InfoType != IT_None;     }
292  bool isState()   const { return InfoType == IT_State;    }
293  bool isTest()    const { return InfoType == IT_Test;     }
294  bool isBinTest() const { return InfoType == IT_BinTest;  }
295  bool isVar()     const { return InfoType == IT_Var;      }
296
297  PropagationInfo invertTest() const {
298    assert(InfoType == IT_Test || InfoType == IT_BinTest);
299
300    if (InfoType == IT_Test) {
301      return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
302
303    } else if (InfoType == IT_BinTest) {
304      return PropagationInfo(BinTest.Source,
305        BinTest.EOp == EO_And ? EO_Or : EO_And,
306        BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
307        BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
308    } else {
309      return PropagationInfo();
310    }
311  }
312};
313
314class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
315
316  typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
317  typedef std::pair<const Stmt *, PropagationInfo> PairType;
318  typedef MapType::iterator InfoEntry;
319  typedef MapType::const_iterator ConstInfoEntry;
320
321  AnalysisDeclContext &AC;
322  ConsumedAnalyzer &Analyzer;
323  ConsumedStateMap *StateMap;
324  MapType PropagationMap;
325
326  void checkCallability(const PropagationInfo &PInfo,
327                        const FunctionDecl *FunDecl,
328                        const CallExpr *Call);
329  void forwardInfo(const Stmt *From, const Stmt *To);
330  void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
331  bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
332  void propagateReturnType(const Stmt *Call, const FunctionDecl *Fun,
333                           QualType ReturnType);
334
335public:
336
337  void Visit(const Stmt *StmtNode);
338
339  void VisitBinaryOperator(const BinaryOperator *BinOp);
340  void VisitCallExpr(const CallExpr *Call);
341  void VisitCastExpr(const CastExpr *Cast);
342  void VisitCXXConstructExpr(const CXXConstructExpr *Call);
343  void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
344  void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
345  void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
346  void VisitDeclStmt(const DeclStmt *DelcS);
347  void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
348  void VisitMemberExpr(const MemberExpr *MExpr);
349  void VisitParmVarDecl(const ParmVarDecl *Param);
350  void VisitReturnStmt(const ReturnStmt *Ret);
351  void VisitUnaryOperator(const UnaryOperator *UOp);
352  void VisitVarDecl(const VarDecl *Var);
353
354  ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer,
355                      ConsumedStateMap *StateMap)
356      : AC(AC), Analyzer(Analyzer), StateMap(StateMap) {}
357
358  PropagationInfo getInfo(const Stmt *StmtNode) const {
359    ConstInfoEntry Entry = PropagationMap.find(StmtNode);
360
361    if (Entry != PropagationMap.end())
362      return Entry->second;
363    else
364      return PropagationInfo();
365  }
366
367  void reset(ConsumedStateMap *NewStateMap) {
368    StateMap = NewStateMap;
369  }
370};
371
372void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
373                                           const FunctionDecl *FunDecl,
374                                           const CallExpr *Call) {
375
376  if (!FunDecl->hasAttr<CallableWhenAttr>())
377    return;
378
379  const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();
380
381  if (PInfo.isVar()) {
382    const VarDecl *Var = PInfo.getVar();
383    ConsumedState VarState = StateMap->getState(Var);
384
385    assert(VarState != CS_None && "Invalid state");
386
387    if (isCallableInState(CWAttr, VarState))
388      return;
389
390    Analyzer.WarningsHandler.warnUseInInvalidState(
391      FunDecl->getNameAsString(), Var->getNameAsString(),
392      stateToString(VarState), Call->getExprLoc());
393
394  } else if (PInfo.isState()) {
395
396    assert(PInfo.getState() != CS_None && "Invalid state");
397
398    if (isCallableInState(CWAttr, PInfo.getState()))
399      return;
400
401    Analyzer.WarningsHandler.warnUseOfTempInInvalidState(
402      FunDecl->getNameAsString(), stateToString(PInfo.getState()),
403      Call->getExprLoc());
404  }
405}
406
407void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
408  InfoEntry Entry = PropagationMap.find(From);
409
410  if (Entry != PropagationMap.end())
411    PropagationMap.insert(PairType(To, Entry->second));
412}
413
414void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
415                                                    const VarDecl  *Var) {
416
417  ConsumedState VarState = StateMap->getState(Var);
418
419  if (VarState != CS_Unknown) {
420    SourceLocation CallLoc = Call->getExprLoc();
421
422    if (!CallLoc.isMacroID())
423      Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
424        stateToString(VarState), CallLoc);
425  }
426
427  PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
428}
429
430bool ConsumedStmtVisitor::isLikeMoveAssignment(
431  const CXXMethodDecl *MethodDecl) {
432
433  return MethodDecl->isMoveAssignmentOperator() ||
434         (MethodDecl->getOverloadedOperator() == OO_Equal &&
435          MethodDecl->getNumParams() == 1 &&
436          MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
437}
438
439void ConsumedStmtVisitor::propagateReturnType(const Stmt *Call,
440                                              const FunctionDecl *Fun,
441                                              QualType ReturnType) {
442  if (isConsumableType(ReturnType)) {
443
444    ConsumedState ReturnState;
445
446    if (Fun->hasAttr<ReturnTypestateAttr>())
447      ReturnState = mapReturnTypestateAttrState(
448        Fun->getAttr<ReturnTypestateAttr>());
449    else
450      ReturnState = mapConsumableAttrState(ReturnType);
451
452    PropagationMap.insert(PairType(Call,
453      PropagationInfo(ReturnState, ReturnType)));
454  }
455}
456
457void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
458
459  ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
460
461  for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
462       CE = StmtNode->child_end(); CI != CE; ++CI) {
463
464    PropagationMap.erase(*CI);
465  }
466}
467
468void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
469  switch (BinOp->getOpcode()) {
470  case BO_LAnd:
471  case BO_LOr : {
472    InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
473              REntry = PropagationMap.find(BinOp->getRHS());
474
475    VarTestResult LTest, RTest;
476
477    if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
478      LTest = LEntry->second.getTest();
479
480    } else {
481      LTest.Var      = NULL;
482      LTest.TestsFor = CS_None;
483    }
484
485    if (REntry != PropagationMap.end() && REntry->second.isTest()) {
486      RTest = REntry->second.getTest();
487
488    } else {
489      RTest.Var      = NULL;
490      RTest.TestsFor = CS_None;
491    }
492
493    if (!(LTest.Var == NULL && RTest.Var == NULL))
494      PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
495        static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
496
497    break;
498  }
499
500  case BO_PtrMemD:
501  case BO_PtrMemI:
502    forwardInfo(BinOp->getLHS(), BinOp);
503    break;
504
505  default:
506    break;
507  }
508}
509
510void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
511  if (const FunctionDecl *FunDecl =
512    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
513
514    // Special case for the std::move function.
515    // TODO: Make this more specific. (Deferred)
516    if (FunDecl->getNameAsString() == "move") {
517      InfoEntry Entry = PropagationMap.find(Call->getArg(0));
518
519      if (Entry != PropagationMap.end()) {
520        PropagationMap.insert(PairType(Call, Entry->second));
521      }
522
523      return;
524    }
525
526    unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
527
528    for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
529      QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
530
531      InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
532
533      if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
534        continue;
535      }
536
537      PropagationInfo PInfo = Entry->second;
538
539      if (ParamType->isRValueReferenceType() ||
540          (ParamType->isLValueReferenceType() &&
541           !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
542
543        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
544
545      } else if (!(ParamType.isConstQualified() ||
546                   ((ParamType->isReferenceType() ||
547                     ParamType->isPointerType()) &&
548                    ParamType->getPointeeType().isConstQualified()))) {
549
550        StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
551      }
552    }
553
554    QualType RetType = FunDecl->getCallResultType();
555    if (RetType->isReferenceType())
556      RetType = RetType->getPointeeType();
557
558    propagateReturnType(Call, FunDecl, RetType);
559  }
560}
561
562void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
563  forwardInfo(Cast->getSubExpr(), Cast);
564}
565
566void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
567  CXXConstructorDecl *Constructor = Call->getConstructor();
568
569  ASTContext &CurrContext = AC.getASTContext();
570  QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
571
572  if (isConsumableType(ThisType)) {
573    if (Constructor->isDefaultConstructor()) {
574
575      PropagationMap.insert(PairType(Call,
576        PropagationInfo(consumed::CS_Consumed, ThisType)));
577
578    } else if (Constructor->isMoveConstructor()) {
579
580      PropagationInfo PInfo =
581        PropagationMap.find(Call->getArg(0))->second;
582
583      if (PInfo.isVar()) {
584        const VarDecl* Var = PInfo.getVar();
585
586        PropagationMap.insert(PairType(Call,
587          PropagationInfo(StateMap->getState(Var), ThisType)));
588
589        StateMap->setState(Var, consumed::CS_Consumed);
590
591      } else {
592        PropagationMap.insert(PairType(Call, PInfo));
593      }
594
595    } else if (Constructor->isCopyConstructor()) {
596      MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
597
598      if (Entry != PropagationMap.end())
599        PropagationMap.insert(PairType(Call, Entry->second));
600
601    } else {
602      propagateReturnType(Call, Constructor, ThisType);
603    }
604  }
605}
606
607void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
608  const CXXMemberCallExpr *Call) {
609
610  VisitCallExpr(Call);
611
612  InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
613
614  if (Entry != PropagationMap.end()) {
615    PropagationInfo PInfo = Entry->second;
616    const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
617
618    checkCallability(PInfo, MethodDecl, Call);
619
620    if (PInfo.isVar()) {
621      if (isTestingFunction(MethodDecl))
622        handleTestingFunctionCall(Call, PInfo.getVar());
623      else if (MethodDecl->hasAttr<ConsumesAttr>())
624        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
625    }
626  }
627}
628
629void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
630  const CXXOperatorCallExpr *Call) {
631
632  const FunctionDecl *FunDecl =
633    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
634
635  if (!FunDecl) return;
636
637  if (isa<CXXMethodDecl>(FunDecl) &&
638      isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
639
640    InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
641    InfoEntry REntry = PropagationMap.find(Call->getArg(1));
642
643    PropagationInfo LPInfo, RPInfo;
644
645    if (LEntry != PropagationMap.end() &&
646        REntry != PropagationMap.end()) {
647
648      LPInfo = LEntry->second;
649      RPInfo = REntry->second;
650
651      if (LPInfo.isVar() && RPInfo.isVar()) {
652        StateMap->setState(LPInfo.getVar(),
653          StateMap->getState(RPInfo.getVar()));
654
655        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
656
657        PropagationMap.insert(PairType(Call, LPInfo));
658
659      } else if (LPInfo.isVar() && !RPInfo.isVar()) {
660        StateMap->setState(LPInfo.getVar(), RPInfo.getState());
661
662        PropagationMap.insert(PairType(Call, LPInfo));
663
664      } else if (!LPInfo.isVar() && RPInfo.isVar()) {
665        PropagationMap.insert(PairType(Call,
666          PropagationInfo(StateMap->getState(RPInfo.getVar()),
667                          LPInfo.getTempType())));
668
669        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
670
671      } else {
672        PropagationMap.insert(PairType(Call, RPInfo));
673      }
674
675    } else if (LEntry != PropagationMap.end() &&
676               REntry == PropagationMap.end()) {
677
678      LPInfo = LEntry->second;
679
680      if (LPInfo.isVar()) {
681        StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
682
683        PropagationMap.insert(PairType(Call, LPInfo));
684
685      } else if (LPInfo.isState()) {
686        PropagationMap.insert(PairType(Call,
687          PropagationInfo(consumed::CS_Unknown, LPInfo.getTempType())));
688      }
689
690    } else if (LEntry == PropagationMap.end() &&
691               REntry != PropagationMap.end()) {
692
693      if (REntry->second.isVar())
694        StateMap->setState(REntry->second.getVar(), consumed::CS_Consumed);
695    }
696
697  } else {
698
699    VisitCallExpr(Call);
700
701    InfoEntry Entry = PropagationMap.find(Call->getArg(0));
702
703    if (Entry != PropagationMap.end()) {
704      PropagationInfo PInfo = Entry->second;
705
706      checkCallability(PInfo, FunDecl, Call);
707
708      if (PInfo.isVar()) {
709        if (isTestingFunction(FunDecl))
710          handleTestingFunctionCall(Call, PInfo.getVar());
711        else if (FunDecl->hasAttr<ConsumesAttr>())
712          StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
713      }
714    }
715  }
716}
717
718void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
719  if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
720    if (StateMap->getState(Var) != consumed::CS_None)
721      PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
722}
723
724void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
725  for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
726       DE = DeclS->decl_end(); DI != DE; ++DI) {
727
728    if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
729  }
730
731  if (DeclS->isSingleDecl())
732    if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
733      PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
734}
735
736void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
737  const MaterializeTemporaryExpr *Temp) {
738
739  InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
740
741  if (Entry != PropagationMap.end())
742    PropagationMap.insert(PairType(Temp, Entry->second));
743}
744
745void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
746  forwardInfo(MExpr->getBase(), MExpr);
747}
748
749
750void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {
751  QualType ParamType = Param->getType();
752  ConsumedState ParamState = consumed::CS_None;
753
754  if (!(ParamType->isPointerType() || ParamType->isReferenceType()) &&
755      isConsumableType(ParamType))
756    ParamState = mapConsumableAttrState(ParamType);
757  else if (ParamType->isReferenceType() &&
758           isConsumableType(ParamType->getPointeeType()))
759    ParamState = consumed::CS_Unknown;
760
761  if (ParamState)
762    StateMap->setState(Param, ParamState);
763}
764
765void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {
766  if (ConsumedState ExpectedState = Analyzer.getExpectedReturnState()) {
767    InfoEntry Entry = PropagationMap.find(Ret->getRetValue());
768
769    if (Entry != PropagationMap.end()) {
770      assert(Entry->second.isState() || Entry->second.isVar());
771
772      ConsumedState RetState = Entry->second.isState() ?
773        Entry->second.getState() : StateMap->getState(Entry->second.getVar());
774
775      if (RetState != ExpectedState)
776        Analyzer.WarningsHandler.warnReturnTypestateMismatch(
777          Ret->getReturnLoc(), stateToString(ExpectedState),
778          stateToString(RetState));
779    }
780  }
781}
782
783void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
784  InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
785  if (Entry == PropagationMap.end()) return;
786
787  switch (UOp->getOpcode()) {
788  case UO_AddrOf:
789    PropagationMap.insert(PairType(UOp, Entry->second));
790    break;
791
792  case UO_LNot:
793    if (Entry->second.isTest() || Entry->second.isBinTest())
794      PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
795    break;
796
797  default:
798    break;
799  }
800}
801
802// TODO: See if I need to check for reference types here.
803void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
804  if (isConsumableType(Var->getType())) {
805    if (Var->hasInit()) {
806      PropagationInfo PInfo =
807        PropagationMap.find(Var->getInit())->second;
808
809      StateMap->setState(Var, PInfo.isVar() ?
810        StateMap->getState(PInfo.getVar()) : PInfo.getState());
811
812    } else {
813      StateMap->setState(Var, consumed::CS_Unknown);
814    }
815  }
816}
817}} // end clang::consumed::ConsumedStmtVisitor
818
819namespace clang {
820namespace consumed {
821
822void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
823                        ConsumedStateMap *ThenStates,
824                        ConsumedStateMap *ElseStates) {
825
826  ConsumedState VarState = ThenStates->getState(Test.Var);
827
828  if (VarState == CS_Unknown) {
829    ThenStates->setState(Test.Var, Test.TestsFor);
830    ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
831
832  } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
833    ThenStates->markUnreachable();
834
835  } else if (VarState == Test.TestsFor) {
836    ElseStates->markUnreachable();
837  }
838}
839
840void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
841  ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
842
843  const VarTestResult &LTest = PInfo.getLTest(),
844                      &RTest = PInfo.getRTest();
845
846  ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
847                RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
848
849  if (LTest.Var) {
850    if (PInfo.testEffectiveOp() == EO_And) {
851      if (LState == CS_Unknown) {
852        ThenStates->setState(LTest.Var, LTest.TestsFor);
853
854      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
855        ThenStates->markUnreachable();
856
857      } else if (LState == LTest.TestsFor && isKnownState(RState)) {
858        if (RState == RTest.TestsFor)
859          ElseStates->markUnreachable();
860        else
861          ThenStates->markUnreachable();
862      }
863
864    } else {
865      if (LState == CS_Unknown) {
866        ElseStates->setState(LTest.Var,
867                             invertConsumedUnconsumed(LTest.TestsFor));
868
869      } else if (LState == LTest.TestsFor) {
870        ElseStates->markUnreachable();
871
872      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
873                 isKnownState(RState)) {
874
875        if (RState == RTest.TestsFor)
876          ElseStates->markUnreachable();
877        else
878          ThenStates->markUnreachable();
879      }
880    }
881  }
882
883  if (RTest.Var) {
884    if (PInfo.testEffectiveOp() == EO_And) {
885      if (RState == CS_Unknown)
886        ThenStates->setState(RTest.Var, RTest.TestsFor);
887      else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
888        ThenStates->markUnreachable();
889
890    } else {
891      if (RState == CS_Unknown)
892        ElseStates->setState(RTest.Var,
893                             invertConsumedUnconsumed(RTest.TestsFor));
894      else if (RState == RTest.TestsFor)
895        ElseStates->markUnreachable();
896    }
897  }
898}
899
900void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
901                                ConsumedStateMap *StateMap,
902                                bool &AlreadyOwned) {
903
904  if (VisitedBlocks.alreadySet(Block)) return;
905
906  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
907
908  if (Entry) {
909    Entry->intersect(StateMap);
910
911  } else if (AlreadyOwned) {
912    StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
913
914  } else {
915    StateMapsArray[Block->getBlockID()] = StateMap;
916    AlreadyOwned = true;
917  }
918}
919
920void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
921                                ConsumedStateMap *StateMap) {
922
923  if (VisitedBlocks.alreadySet(Block)) {
924    delete StateMap;
925    return;
926  }
927
928  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
929
930  if (Entry) {
931    Entry->intersect(StateMap);
932    delete StateMap;
933
934  } else {
935    StateMapsArray[Block->getBlockID()] = StateMap;
936  }
937}
938
939ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
940  return StateMapsArray[Block->getBlockID()];
941}
942
943void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
944  VisitedBlocks.insert(Block);
945}
946
947ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
948  MapType::const_iterator Entry = Map.find(Var);
949
950  if (Entry != Map.end()) {
951    return Entry->second;
952
953  } else {
954    return CS_None;
955  }
956}
957
958void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
959  ConsumedState LocalState;
960
961  if (this->From && this->From == Other->From && !Other->Reachable) {
962    this->markUnreachable();
963    return;
964  }
965
966  for (MapType::const_iterator DMI = Other->Map.begin(),
967       DME = Other->Map.end(); DMI != DME; ++DMI) {
968
969    LocalState = this->getState(DMI->first);
970
971    if (LocalState == CS_None)
972      continue;
973
974    if (LocalState != DMI->second)
975       Map[DMI->first] = CS_Unknown;
976  }
977}
978
979void ConsumedStateMap::markUnreachable() {
980  this->Reachable = false;
981  Map.clear();
982}
983
984void ConsumedStateMap::makeUnknown() {
985  for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
986       ++DMI) {
987
988    Map[DMI->first] = CS_Unknown;
989  }
990}
991
992void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
993  Map[Var] = State;
994}
995
996void ConsumedStateMap::remove(const VarDecl *Var) {
997  Map.erase(Var);
998}
999
1000void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,
1001                                                    const FunctionDecl *D) {
1002  QualType ReturnType;
1003  if (const CXXConstructorDecl *Constructor = dyn_cast<CXXConstructorDecl>(D)) {
1004    ASTContext &CurrContext = AC.getASTContext();
1005    ReturnType = Constructor->getThisType(CurrContext)->getPointeeType();
1006  } else
1007    ReturnType = D->getCallResultType();
1008
1009  if (D->hasAttr<ReturnTypestateAttr>()) {
1010    const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>();
1011
1012    const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();
1013    if (!RD || !RD->hasAttr<ConsumableAttr>()) {
1014      // FIXME: This should be removed when template instantiation propagates
1015      //        attributes at template specialization definition, not
1016      //        declaration. When it is removed the test needs to be enabled
1017      //        in SemaDeclAttr.cpp.
1018      WarningsHandler.warnReturnTypestateForUnconsumableType(
1019          RTSAttr->getLocation(), ReturnType.getAsString());
1020      ExpectedReturnState = CS_None;
1021    } else
1022      ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);
1023  } else if (isConsumableType(ReturnType))
1024    ExpectedReturnState = mapConsumableAttrState(ReturnType);
1025  else
1026    ExpectedReturnState = CS_None;
1027}
1028
1029bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
1030                                  const ConsumedStmtVisitor &Visitor) {
1031
1032  ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
1033  PropagationInfo PInfo;
1034
1035  if (const IfStmt *IfNode =
1036    dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
1037
1038    const Stmt *Cond = IfNode->getCond();
1039
1040    PInfo = Visitor.getInfo(Cond);
1041    if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
1042      PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
1043
1044    if (PInfo.isTest()) {
1045      CurrStates->setSource(Cond);
1046      FalseStates->setSource(Cond);
1047      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates, FalseStates);
1048
1049    } else if (PInfo.isBinTest()) {
1050      CurrStates->setSource(PInfo.testSourceNode());
1051      FalseStates->setSource(PInfo.testSourceNode());
1052      splitVarStateForIfBinOp(PInfo, CurrStates, FalseStates);
1053
1054    } else {
1055      delete FalseStates;
1056      return false;
1057    }
1058
1059  } else if (const BinaryOperator *BinOp =
1060    dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
1061
1062    PInfo = Visitor.getInfo(BinOp->getLHS());
1063    if (!PInfo.isTest()) {
1064      if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
1065        PInfo = Visitor.getInfo(BinOp->getRHS());
1066
1067        if (!PInfo.isTest()) {
1068          delete FalseStates;
1069          return false;
1070        }
1071
1072      } else {
1073        delete FalseStates;
1074        return false;
1075      }
1076    }
1077
1078    CurrStates->setSource(BinOp);
1079    FalseStates->setSource(BinOp);
1080
1081    const VarTestResult &Test = PInfo.getTest();
1082    ConsumedState VarState = CurrStates->getState(Test.Var);
1083
1084    if (BinOp->getOpcode() == BO_LAnd) {
1085      if (VarState == CS_Unknown)
1086        CurrStates->setState(Test.Var, Test.TestsFor);
1087      else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
1088        CurrStates->markUnreachable();
1089
1090    } else if (BinOp->getOpcode() == BO_LOr) {
1091      if (VarState == CS_Unknown)
1092        FalseStates->setState(Test.Var,
1093                              invertConsumedUnconsumed(Test.TestsFor));
1094      else if (VarState == Test.TestsFor)
1095        FalseStates->markUnreachable();
1096    }
1097
1098  } else {
1099    delete FalseStates;
1100    return false;
1101  }
1102
1103  CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1104
1105  if (*SI)
1106    BlockInfo.addInfo(*SI, CurrStates);
1107  else
1108    delete CurrStates;
1109
1110  if (*++SI)
1111    BlockInfo.addInfo(*SI, FalseStates);
1112  else
1113    delete FalseStates;
1114
1115  CurrStates = NULL;
1116  return true;
1117}
1118
1119void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1120  const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1121  if (!D)
1122    return;
1123
1124  CFG *CFGraph = AC.getCFG();
1125  if (!CFGraph)
1126    return;
1127
1128  determineExpectedReturnState(AC, D);
1129
1130  BlockInfo = ConsumedBlockInfo(CFGraph);
1131
1132  PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1133
1134  CurrStates = new ConsumedStateMap();
1135  ConsumedStmtVisitor Visitor(AC, *this, CurrStates);
1136
1137  // Add all trackable parameters to the state map.
1138  for (FunctionDecl::param_const_iterator PI = D->param_begin(),
1139       PE = D->param_end(); PI != PE; ++PI) {
1140    Visitor.VisitParmVarDecl(*PI);
1141  }
1142
1143  // Visit all of the function's basic blocks.
1144  for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1145       E = SortedGraph->end(); I != E; ++I) {
1146
1147    const CFGBlock *CurrBlock = *I;
1148    BlockInfo.markVisited(CurrBlock);
1149
1150    if (CurrStates == NULL)
1151      CurrStates = BlockInfo.getInfo(CurrBlock);
1152
1153    if (!CurrStates) {
1154      continue;
1155
1156    } else if (!CurrStates->isReachable()) {
1157      delete CurrStates;
1158      CurrStates = NULL;
1159      continue;
1160    }
1161
1162    Visitor.reset(CurrStates);
1163
1164    // Visit all of the basic block's statements.
1165    for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1166         BE = CurrBlock->end(); BI != BE; ++BI) {
1167
1168      switch (BI->getKind()) {
1169      case CFGElement::Statement:
1170        Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1171        break;
1172      case CFGElement::AutomaticObjectDtor:
1173        CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1174      default:
1175        break;
1176      }
1177    }
1178
1179    // TODO: Handle other forms of branching with precision, including while-
1180    //       and for-loops. (Deferred)
1181    if (!splitState(CurrBlock, Visitor)) {
1182      CurrStates->setSource(NULL);
1183
1184      if (CurrBlock->succ_size() > 1) {
1185        CurrStates->makeUnknown();
1186
1187        bool OwnershipTaken = false;
1188
1189        for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1190             SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1191
1192          if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1193        }
1194
1195        if (!OwnershipTaken)
1196          delete CurrStates;
1197
1198        CurrStates = NULL;
1199
1200      } else if (CurrBlock->succ_size() == 1 &&
1201                 (*CurrBlock->succ_begin())->pred_size() > 1) {
1202
1203        BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1204        CurrStates = NULL;
1205      }
1206    }
1207  } // End of block iterator.
1208
1209  // Delete the last existing state map.
1210  delete CurrStates;
1211
1212  WarningsHandler.emitDiagnostics();
1213}
1214}} // end namespace clang::consumed
1215