Consumed.cpp revision 1f55157a5b3f60389e32b337bdf50eeea50f8f77
1//===- Consumed.cpp --------------------------------------------*- C++ --*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// A intra-procedural analysis for checking consumed properties.  This is based,
11// in part, on research on linear types.
12//
13//===----------------------------------------------------------------------===//
14
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/Attr.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/ExprCXX.h"
19#include "clang/AST/RecursiveASTVisitor.h"
20#include "clang/AST/StmtVisitor.h"
21#include "clang/AST/StmtCXX.h"
22#include "clang/AST/Type.h"
23#include "clang/Analysis/Analyses/PostOrderCFGView.h"
24#include "clang/Analysis/AnalysisContext.h"
25#include "clang/Analysis/CFG.h"
26#include "clang/Analysis/Analyses/Consumed.h"
27#include "clang/Basic/OperatorKinds.h"
28#include "clang/Basic/SourceLocation.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Compiler.h"
32#include "llvm/Support/raw_ostream.h"
33
34// TODO: Correctly identify unreachable blocks when chaining boolean operators.
35// TODO: Warn about unreachable code.
36// TODO: Switch to using a bitmap to track unreachable blocks.
37// TODO: Mark variables as Unknown going into while- or for-loops only if they
38//       are referenced inside that block. (Deferred)
39// TODO: Handle variable definitions, e.g. bool valid = x.isValid();
40//       if (valid) ...; (Deferred)
41// TODO: Add a method(s) to identify which method calls perform what state
42//       transitions. (Deferred)
43// TODO: Take notes on state transitions to provide better warning messages.
44//       (Deferred)
45// TODO: Test nested conditionals: A) Checking the same value multiple times,
46//       and 2) Checking different values. (Deferred)
47// TODO: Test IsFalseVisitor with values in the unknown state. (Deferred)
48// TODO: Look into combining IsFalseVisitor and TestedVarsVisitor. (Deferred)
49
50using namespace clang;
51using namespace consumed;
52
53// Key method definition
54ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() {}
55
56static ConsumedState invertConsumedUnconsumed(ConsumedState State) {
57  switch (State) {
58  case CS_Unconsumed:
59    return CS_Consumed;
60  case CS_Consumed:
61    return CS_Unconsumed;
62  case CS_None:
63    return CS_None;
64  case CS_Unknown:
65    return CS_Unknown;
66  }
67  llvm_unreachable("invalid enum");
68}
69
70static bool isKnownState(ConsumedState State) {
71  switch (State) {
72  case CS_Unconsumed:
73  case CS_Consumed:
74    return true;
75  case CS_None:
76  case CS_Unknown:
77  default:
78    return false;
79  }
80}
81
82static bool isTestingFunction(const FunctionDecl *FunDecl) {
83  return FunDecl->hasAttr<TestsUnconsumedAttr>();
84}
85
86static StringRef stateToString(ConsumedState State) {
87  switch (State) {
88  case consumed::CS_None:
89    return "none";
90
91  case consumed::CS_Unknown:
92    return "unknown";
93
94  case consumed::CS_Unconsumed:
95    return "unconsumed";
96
97  case consumed::CS_Consumed:
98    return "consumed";
99  }
100  llvm_unreachable("invalid enum");
101}
102
103namespace {
104struct VarTestResult {
105  const VarDecl *Var;
106  ConsumedState TestsFor;
107};
108} // end anonymous::VarTestResult
109
110namespace clang {
111namespace consumed {
112
113enum EffectiveOp {
114  EO_And,
115  EO_Or
116};
117
118class PropagationInfo {
119  enum {
120    IT_None,
121    IT_State,
122    IT_Test,
123    IT_BinTest,
124    IT_Var
125  } InfoType;
126
127  struct BinTestTy {
128    const BinaryOperator *Source;
129    EffectiveOp EOp;
130    VarTestResult LTest;
131    VarTestResult RTest;
132  };
133
134  union {
135    ConsumedState State;
136    VarTestResult Test;
137    const VarDecl *Var;
138    BinTestTy BinTest;
139  };
140
141public:
142  PropagationInfo() : InfoType(IT_None) {}
143
144  PropagationInfo(const VarTestResult &Test) : InfoType(IT_Test), Test(Test) {}
145  PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)
146    : InfoType(IT_Test) {
147
148    Test.Var      = Var;
149    Test.TestsFor = TestsFor;
150  }
151
152  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
153                  const VarTestResult &LTest, const VarTestResult &RTest)
154    : InfoType(IT_BinTest) {
155
156    BinTest.Source  = Source;
157    BinTest.EOp     = EOp;
158    BinTest.LTest   = LTest;
159    BinTest.RTest   = RTest;
160  }
161
162  PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,
163                  const VarDecl *LVar, ConsumedState LTestsFor,
164                  const VarDecl *RVar, ConsumedState RTestsFor)
165    : InfoType(IT_BinTest) {
166
167    BinTest.Source         = Source;
168    BinTest.EOp            = EOp;
169    BinTest.LTest.Var      = LVar;
170    BinTest.LTest.TestsFor = LTestsFor;
171    BinTest.RTest.Var      = RVar;
172    BinTest.RTest.TestsFor = RTestsFor;
173  }
174
175  PropagationInfo(ConsumedState State) : InfoType(IT_State), State(State) {}
176  PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}
177
178  const ConsumedState & getState() const {
179    assert(InfoType == IT_State);
180    return State;
181  }
182
183  const VarTestResult & getTest() const {
184    assert(InfoType == IT_Test);
185    return Test;
186  }
187
188  const VarTestResult & getLTest() const {
189    assert(InfoType == IT_BinTest);
190    return BinTest.LTest;
191  }
192
193  const VarTestResult & getRTest() const {
194    assert(InfoType == IT_BinTest);
195    return BinTest.RTest;
196  }
197
198  const VarDecl * getVar() const {
199    assert(InfoType == IT_Var);
200    return Var;
201  }
202
203  EffectiveOp testEffectiveOp() const {
204    assert(InfoType == IT_BinTest);
205    return BinTest.EOp;
206  }
207
208  const BinaryOperator * testSourceNode() const {
209    assert(InfoType == IT_BinTest);
210    return BinTest.Source;
211  }
212
213  bool isValid()   const { return InfoType != IT_None;     }
214  bool isState()   const { return InfoType == IT_State;    }
215  bool isTest()    const { return InfoType == IT_Test;     }
216  bool isBinTest() const { return InfoType == IT_BinTest;  }
217  bool isVar()     const { return InfoType == IT_Var;      }
218
219  PropagationInfo invertTest() const {
220    assert(InfoType == IT_Test || InfoType == IT_BinTest);
221
222    if (InfoType == IT_Test) {
223      return PropagationInfo(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
224
225    } else if (InfoType == IT_BinTest) {
226      return PropagationInfo(BinTest.Source,
227        BinTest.EOp == EO_And ? EO_Or : EO_And,
228        BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),
229        BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));
230    } else {
231      return PropagationInfo();
232    }
233  }
234};
235
236class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {
237
238  typedef llvm::DenseMap<const Stmt *, PropagationInfo> MapType;
239  typedef std::pair<const Stmt *, PropagationInfo> PairType;
240  typedef MapType::iterator InfoEntry;
241  typedef MapType::const_iterator ConstInfoEntry;
242
243  AnalysisDeclContext &AC;
244  ConsumedAnalyzer &Analyzer;
245  ConsumedStateMap *StateMap;
246  MapType PropagationMap;
247
248  void checkCallability(const PropagationInfo &PInfo,
249                        const FunctionDecl *FunDecl,
250                        const CallExpr *Call);
251  void forwardInfo(const Stmt *From, const Stmt *To);
252  void handleTestingFunctionCall(const CallExpr *Call, const VarDecl *Var);
253  bool isLikeMoveAssignment(const CXXMethodDecl *MethodDecl);
254
255public:
256
257  void Visit(const Stmt *StmtNode);
258
259  void VisitBinaryOperator(const BinaryOperator *BinOp);
260  void VisitCallExpr(const CallExpr *Call);
261  void VisitCastExpr(const CastExpr *Cast);
262  void VisitCXXConstructExpr(const CXXConstructExpr *Call);
263  void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);
264  void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);
265  void VisitDeclRefExpr(const DeclRefExpr *DeclRef);
266  void VisitDeclStmt(const DeclStmt *DelcS);
267  void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);
268  void VisitMemberExpr(const MemberExpr *MExpr);
269  void VisitUnaryOperator(const UnaryOperator *UOp);
270  void VisitVarDecl(const VarDecl *Var);
271
272  ConsumedStmtVisitor(AnalysisDeclContext &AC, ConsumedAnalyzer &Analyzer)
273      : AC(AC), Analyzer(Analyzer), StateMap(NULL) {}
274
275  PropagationInfo getInfo(const Stmt *StmtNode) const {
276    ConstInfoEntry Entry = PropagationMap.find(StmtNode);
277
278    if (Entry != PropagationMap.end())
279      return Entry->second;
280    else
281      return PropagationInfo();
282  }
283
284  void reset(ConsumedStateMap *NewStateMap) {
285    StateMap = NewStateMap;
286  }
287};
288
289// TODO: When we support CallableWhenConsumed this will have to check for
290//       the different attributes and change the behavior bellow. (Deferred)
291void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,
292                                           const FunctionDecl *FunDecl,
293                                           const CallExpr *Call) {
294
295  if (!FunDecl->hasAttr<CallableWhenUnconsumedAttr>()) return;
296
297  if (PInfo.isVar()) {
298    const VarDecl *Var = PInfo.getVar();
299
300    switch (StateMap->getState(Var)) {
301    case CS_Consumed:
302      Analyzer.WarningsHandler.warnUseWhileConsumed(
303        FunDecl->getNameAsString(), Var->getNameAsString(),
304        Call->getExprLoc());
305      break;
306
307    case CS_Unknown:
308      Analyzer.WarningsHandler.warnUseInUnknownState(
309        FunDecl->getNameAsString(), Var->getNameAsString(),
310        Call->getExprLoc());
311      break;
312
313    case CS_None:
314    case CS_Unconsumed:
315      break;
316    }
317
318  } else {
319    switch (PInfo.getState()) {
320    case CS_Consumed:
321      Analyzer.WarningsHandler.warnUseOfTempWhileConsumed(
322        FunDecl->getNameAsString(), Call->getExprLoc());
323      break;
324
325    case CS_Unknown:
326      Analyzer.WarningsHandler.warnUseOfTempInUnknownState(
327        FunDecl->getNameAsString(), Call->getExprLoc());
328      break;
329
330    case CS_None:
331    case CS_Unconsumed:
332      break;
333    }
334  }
335}
336
337void ConsumedStmtVisitor::forwardInfo(const Stmt *From, const Stmt *To) {
338  InfoEntry Entry = PropagationMap.find(From);
339
340  if (Entry != PropagationMap.end())
341    PropagationMap.insert(PairType(To, Entry->second));
342}
343
344void ConsumedStmtVisitor::handleTestingFunctionCall(const CallExpr *Call,
345                                                    const VarDecl  *Var) {
346
347  ConsumedState VarState = StateMap->getState(Var);
348
349  if (VarState != CS_Unknown) {
350    SourceLocation CallLoc = Call->getExprLoc();
351
352    if (!CallLoc.isMacroID())
353      Analyzer.WarningsHandler.warnUnnecessaryTest(Var->getNameAsString(),
354        stateToString(VarState), CallLoc);
355  }
356
357  PropagationMap.insert(PairType(Call, PropagationInfo(Var, CS_Unconsumed)));
358}
359
360bool ConsumedStmtVisitor::isLikeMoveAssignment(
361  const CXXMethodDecl *MethodDecl) {
362
363  return MethodDecl->isMoveAssignmentOperator() ||
364         (MethodDecl->getOverloadedOperator() == OO_Equal &&
365          MethodDecl->getNumParams() == 1 &&
366          MethodDecl->getParamDecl(0)->getType()->isRValueReferenceType());
367}
368
369void ConsumedStmtVisitor::Visit(const Stmt *StmtNode) {
370
371  ConstStmtVisitor<ConsumedStmtVisitor>::Visit(StmtNode);
372
373  for (Stmt::const_child_iterator CI = StmtNode->child_begin(),
374       CE = StmtNode->child_end(); CI != CE; ++CI) {
375
376    PropagationMap.erase(*CI);
377  }
378}
379
380void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {
381  switch (BinOp->getOpcode()) {
382  case BO_LAnd:
383  case BO_LOr : {
384    InfoEntry LEntry = PropagationMap.find(BinOp->getLHS()),
385              REntry = PropagationMap.find(BinOp->getRHS());
386
387    VarTestResult LTest, RTest;
388
389    if (LEntry != PropagationMap.end() && LEntry->second.isTest()) {
390      LTest = LEntry->second.getTest();
391
392    } else {
393      LTest.Var      = NULL;
394      LTest.TestsFor = CS_None;
395    }
396
397    if (REntry != PropagationMap.end() && REntry->second.isTest()) {
398      RTest = REntry->second.getTest();
399
400    } else {
401      RTest.Var      = NULL;
402      RTest.TestsFor = CS_None;
403    }
404
405    if (!(LTest.Var == NULL && RTest.Var == NULL))
406      PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,
407        static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));
408
409    break;
410  }
411
412  case BO_PtrMemD:
413  case BO_PtrMemI:
414    forwardInfo(BinOp->getLHS(), BinOp);
415    break;
416
417  default:
418    break;
419  }
420}
421
422void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {
423  if (const FunctionDecl *FunDecl =
424    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee())) {
425
426    // Special case for the std::move function.
427    // TODO: Make this more specific. (Deferred)
428    if (FunDecl->getNameAsString() == "move") {
429      InfoEntry Entry = PropagationMap.find(Call->getArg(0));
430
431      if (Entry != PropagationMap.end()) {
432        PropagationMap.insert(PairType(Call, Entry->second));
433      }
434
435      return;
436    }
437
438    unsigned Offset = Call->getNumArgs() - FunDecl->getNumParams();
439
440    for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {
441      QualType ParamType = FunDecl->getParamDecl(Index - Offset)->getType();
442
443      InfoEntry Entry = PropagationMap.find(Call->getArg(Index));
444
445      if (Entry == PropagationMap.end() || !Entry->second.isVar()) {
446        continue;
447      }
448
449      PropagationInfo PInfo = Entry->second;
450
451      if (ParamType->isRValueReferenceType() ||
452          (ParamType->isLValueReferenceType() &&
453           !cast<LValueReferenceType>(*ParamType).isSpelledAsLValue())) {
454
455        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
456
457      } else if (!(ParamType.isConstQualified() ||
458                   ((ParamType->isReferenceType() ||
459                     ParamType->isPointerType()) &&
460                    ParamType->getPointeeType().isConstQualified()))) {
461
462        StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
463      }
464    }
465  }
466}
467
468void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {
469  forwardInfo(Cast->getSubExpr(), Cast);
470}
471
472void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {
473  CXXConstructorDecl *Constructor = Call->getConstructor();
474
475  ASTContext &CurrContext = AC.getASTContext();
476  QualType ThisType = Constructor->getThisType(CurrContext)->getPointeeType();
477
478  if (Analyzer.isConsumableType(ThisType)) {
479    if (Constructor->hasAttr<ConsumesAttr>() ||
480        Constructor->isDefaultConstructor()) {
481
482      PropagationMap.insert(PairType(Call,
483        PropagationInfo(consumed::CS_Consumed)));
484
485    } else if (Constructor->isMoveConstructor()) {
486
487      PropagationInfo PInfo =
488        PropagationMap.find(Call->getArg(0))->second;
489
490      if (PInfo.isVar()) {
491        const VarDecl* Var = PInfo.getVar();
492
493        PropagationMap.insert(PairType(Call,
494          PropagationInfo(StateMap->getState(Var))));
495
496        StateMap->setState(Var, consumed::CS_Consumed);
497
498      } else {
499        PropagationMap.insert(PairType(Call, PInfo));
500      }
501
502    } else if (Constructor->isCopyConstructor()) {
503      MapType::iterator Entry = PropagationMap.find(Call->getArg(0));
504
505      if (Entry != PropagationMap.end())
506        PropagationMap.insert(PairType(Call, Entry->second));
507
508    } else {
509      PropagationMap.insert(PairType(Call,
510        PropagationInfo(consumed::CS_Unconsumed)));
511    }
512  }
513}
514
515void ConsumedStmtVisitor::VisitCXXMemberCallExpr(
516  const CXXMemberCallExpr *Call) {
517
518  VisitCallExpr(Call);
519
520  InfoEntry Entry = PropagationMap.find(Call->getCallee()->IgnoreParens());
521
522  if (Entry != PropagationMap.end()) {
523    PropagationInfo PInfo = Entry->second;
524    const CXXMethodDecl *MethodDecl = Call->getMethodDecl();
525
526    checkCallability(PInfo, MethodDecl, Call);
527
528    if (PInfo.isVar()) {
529      if (isTestingFunction(MethodDecl))
530        handleTestingFunctionCall(Call, PInfo.getVar());
531      else if (MethodDecl->hasAttr<ConsumesAttr>())
532        StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
533      else if (!MethodDecl->isConst())
534        StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
535    }
536  }
537}
538
539void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(
540  const CXXOperatorCallExpr *Call) {
541
542  const FunctionDecl *FunDecl =
543    dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());
544
545  if (!FunDecl) return;
546
547  if (isa<CXXMethodDecl>(FunDecl) &&
548      isLikeMoveAssignment(cast<CXXMethodDecl>(FunDecl))) {
549
550    InfoEntry LEntry = PropagationMap.find(Call->getArg(0));
551    InfoEntry REntry = PropagationMap.find(Call->getArg(1));
552
553    PropagationInfo LPInfo, RPInfo;
554
555    if (LEntry != PropagationMap.end() &&
556        REntry != PropagationMap.end()) {
557
558      LPInfo = LEntry->second;
559      RPInfo = REntry->second;
560
561      if (LPInfo.isVar() && RPInfo.isVar()) {
562        StateMap->setState(LPInfo.getVar(),
563          StateMap->getState(RPInfo.getVar()));
564
565        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
566
567        PropagationMap.insert(PairType(Call, LPInfo));
568
569      } else if (LPInfo.isVar() && !RPInfo.isVar()) {
570        StateMap->setState(LPInfo.getVar(), RPInfo.getState());
571
572        PropagationMap.insert(PairType(Call, LPInfo));
573
574      } else if (!LPInfo.isVar() && RPInfo.isVar()) {
575        PropagationMap.insert(PairType(Call,
576          PropagationInfo(StateMap->getState(RPInfo.getVar()))));
577
578        StateMap->setState(RPInfo.getVar(), consumed::CS_Consumed);
579
580      } else {
581        PropagationMap.insert(PairType(Call, RPInfo));
582      }
583
584    } else if (LEntry != PropagationMap.end() &&
585               REntry == PropagationMap.end()) {
586
587      LPInfo = LEntry->second;
588
589      if (LPInfo.isVar()) {
590        StateMap->setState(LPInfo.getVar(), consumed::CS_Unknown);
591
592        PropagationMap.insert(PairType(Call, LPInfo));
593
594      } else {
595        PropagationMap.insert(PairType(Call,
596          PropagationInfo(consumed::CS_Unknown)));
597      }
598
599    } else if (LEntry == PropagationMap.end() &&
600               REntry != PropagationMap.end()) {
601
602      RPInfo = REntry->second;
603
604      if (RPInfo.isVar()) {
605        const VarDecl *Var = RPInfo.getVar();
606
607        PropagationMap.insert(PairType(Call,
608          PropagationInfo(StateMap->getState(Var))));
609
610        StateMap->setState(Var, consumed::CS_Consumed);
611
612      } else {
613        PropagationMap.insert(PairType(Call, RPInfo));
614      }
615    }
616
617  } else {
618
619    VisitCallExpr(Call);
620
621    InfoEntry Entry = PropagationMap.find(Call->getArg(0));
622
623    if (Entry != PropagationMap.end()) {
624      PropagationInfo PInfo = Entry->second;
625
626      checkCallability(PInfo, FunDecl, Call);
627
628      if (PInfo.isVar()) {
629        if (isTestingFunction(FunDecl)) {
630          handleTestingFunctionCall(Call, PInfo.getVar());
631
632        } else if (FunDecl->hasAttr<ConsumesAttr>()) {
633          StateMap->setState(PInfo.getVar(), consumed::CS_Consumed);
634
635        } else if (const CXXMethodDecl *MethodDecl =
636                     dyn_cast_or_null<CXXMethodDecl>(FunDecl)) {
637
638          if (!MethodDecl->isConst())
639            StateMap->setState(PInfo.getVar(), consumed::CS_Unknown);
640        }
641      }
642    }
643  }
644}
645
646void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {
647  if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))
648    if (StateMap->getState(Var) != consumed::CS_None)
649      PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));
650}
651
652void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {
653  for (DeclStmt::const_decl_iterator DI = DeclS->decl_begin(),
654       DE = DeclS->decl_end(); DI != DE; ++DI) {
655
656    if (isa<VarDecl>(*DI)) VisitVarDecl(cast<VarDecl>(*DI));
657  }
658
659  if (DeclS->isSingleDecl())
660    if (const VarDecl *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))
661      PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));
662}
663
664void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(
665  const MaterializeTemporaryExpr *Temp) {
666
667  InfoEntry Entry = PropagationMap.find(Temp->GetTemporaryExpr());
668
669  if (Entry != PropagationMap.end())
670    PropagationMap.insert(PairType(Temp, Entry->second));
671}
672
673void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {
674  forwardInfo(MExpr->getBase(), MExpr);
675}
676
677void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {
678  InfoEntry Entry = PropagationMap.find(UOp->getSubExpr()->IgnoreParens());
679  if (Entry == PropagationMap.end()) return;
680
681  switch (UOp->getOpcode()) {
682  case UO_AddrOf:
683    PropagationMap.insert(PairType(UOp, Entry->second));
684    break;
685
686  case UO_LNot:
687    if (Entry->second.isTest() || Entry->second.isBinTest())
688      PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));
689    break;
690
691  default:
692    break;
693  }
694}
695
696void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {
697  if (Analyzer.isConsumableType(Var->getType())) {
698    PropagationInfo PInfo =
699      PropagationMap.find(Var->getInit())->second;
700
701    StateMap->setState(Var, PInfo.isVar() ?
702      StateMap->getState(PInfo.getVar()) : PInfo.getState());
703  }
704}
705}} // end clang::consumed::ConsumedStmtVisitor
706
707namespace clang {
708namespace consumed {
709
710void splitVarStateForIf(const IfStmt * IfNode, const VarTestResult &Test,
711                        ConsumedStateMap *ThenStates,
712                        ConsumedStateMap *ElseStates) {
713
714  ConsumedState VarState = ThenStates->getState(Test.Var);
715
716  if (VarState == CS_Unknown) {
717    ThenStates->setState(Test.Var, Test.TestsFor);
718    if (ElseStates)
719      ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));
720
721  } else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {
722    ThenStates->markUnreachable();
723
724  } else if (VarState == Test.TestsFor && ElseStates) {
725    ElseStates->markUnreachable();
726  }
727}
728
729void splitVarStateForIfBinOp(const PropagationInfo &PInfo,
730  ConsumedStateMap *ThenStates, ConsumedStateMap *ElseStates) {
731
732  const VarTestResult &LTest = PInfo.getLTest(),
733                      &RTest = PInfo.getRTest();
734
735  ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,
736                RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;
737
738  if (LTest.Var) {
739    if (PInfo.testEffectiveOp() == EO_And) {
740      if (LState == CS_Unknown) {
741        ThenStates->setState(LTest.Var, LTest.TestsFor);
742
743      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {
744        ThenStates->markUnreachable();
745
746      } else if (LState == LTest.TestsFor && isKnownState(RState)) {
747        if (RState == RTest.TestsFor) {
748          if (ElseStates)
749            ElseStates->markUnreachable();
750        } else {
751          ThenStates->markUnreachable();
752        }
753      }
754
755    } else {
756      if (LState == CS_Unknown && ElseStates) {
757        ElseStates->setState(LTest.Var,
758                             invertConsumedUnconsumed(LTest.TestsFor));
759
760      } else if (LState == LTest.TestsFor && ElseStates) {
761        ElseStates->markUnreachable();
762
763      } else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&
764                 isKnownState(RState)) {
765
766        if (RState == RTest.TestsFor) {
767          if (ElseStates)
768            ElseStates->markUnreachable();
769        } else {
770          ThenStates->markUnreachable();
771        }
772      }
773    }
774  }
775
776  if (RTest.Var) {
777    if (PInfo.testEffectiveOp() == EO_And) {
778      if (RState == CS_Unknown)
779        ThenStates->setState(RTest.Var, RTest.TestsFor);
780      else if (RState == invertConsumedUnconsumed(RTest.TestsFor))
781        ThenStates->markUnreachable();
782
783    } else if (ElseStates) {
784      if (RState == CS_Unknown)
785        ElseStates->setState(RTest.Var,
786                             invertConsumedUnconsumed(RTest.TestsFor));
787      else if (RState == RTest.TestsFor)
788        ElseStates->markUnreachable();
789    }
790  }
791}
792
793void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
794                                ConsumedStateMap *StateMap,
795                                bool &AlreadyOwned) {
796
797  if (VisitedBlocks.alreadySet(Block)) return;
798
799  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
800
801  if (Entry) {
802    Entry->intersect(StateMap);
803
804  } else if (AlreadyOwned) {
805    StateMapsArray[Block->getBlockID()] = new ConsumedStateMap(*StateMap);
806
807  } else {
808    StateMapsArray[Block->getBlockID()] = StateMap;
809    AlreadyOwned = true;
810  }
811}
812
813void ConsumedBlockInfo::addInfo(const CFGBlock *Block,
814                                ConsumedStateMap *StateMap) {
815
816  if (VisitedBlocks.alreadySet(Block)) {
817    delete StateMap;
818    return;
819  }
820
821  ConsumedStateMap *Entry = StateMapsArray[Block->getBlockID()];
822
823  if (Entry) {
824    Entry->intersect(StateMap);
825    delete StateMap;
826
827  } else {
828    StateMapsArray[Block->getBlockID()] = StateMap;
829  }
830}
831
832ConsumedStateMap* ConsumedBlockInfo::getInfo(const CFGBlock *Block) {
833  return StateMapsArray[Block->getBlockID()];
834}
835
836void ConsumedBlockInfo::markVisited(const CFGBlock *Block) {
837  VisitedBlocks.insert(Block);
838}
839
840ConsumedState ConsumedStateMap::getState(const VarDecl *Var) {
841  MapType::const_iterator Entry = Map.find(Var);
842
843  if (Entry != Map.end()) {
844    return Entry->second;
845
846  } else {
847    return CS_None;
848  }
849}
850
851void ConsumedStateMap::intersect(const ConsumedStateMap *Other) {
852  ConsumedState LocalState;
853
854  if (this->From && this->From == Other->From && !Other->Reachable) {
855    this->markUnreachable();
856    return;
857  }
858
859  for (MapType::const_iterator DMI = Other->Map.begin(),
860       DME = Other->Map.end(); DMI != DME; ++DMI) {
861
862    LocalState = this->getState(DMI->first);
863
864    if (LocalState == CS_None)
865      continue;
866
867    if (LocalState != DMI->second)
868       Map[DMI->first] = CS_Unknown;
869  }
870}
871
872void ConsumedStateMap::markUnreachable() {
873  this->Reachable = false;
874  Map.clear();
875}
876
877void ConsumedStateMap::makeUnknown() {
878  for (MapType::const_iterator DMI = Map.begin(), DME = Map.end(); DMI != DME;
879       ++DMI) {
880
881    Map[DMI->first] = CS_Unknown;
882  }
883}
884
885void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {
886  Map[Var] = State;
887}
888
889void ConsumedStateMap::remove(const VarDecl *Var) {
890  Map.erase(Var);
891}
892
893bool ConsumedAnalyzer::isConsumableType(QualType Type) {
894  const CXXRecordDecl *RD =
895    dyn_cast_or_null<CXXRecordDecl>(Type->getAsCXXRecordDecl());
896
897  if (!RD) return false;
898
899  std::pair<CacheMapType::iterator, bool> Entry =
900    ConsumableTypeCache.insert(std::make_pair(RD, false));
901
902  if (Entry.second)
903    Entry.first->second = hasConsumableAttributes(RD);
904
905  return Entry.first->second;
906}
907
908// TODO: Walk the base classes to see if any of them are unique types.
909//       (Deferred)
910bool ConsumedAnalyzer::hasConsumableAttributes(const CXXRecordDecl *RD) {
911  for (CXXRecordDecl::method_iterator MI = RD->method_begin(),
912       ME = RD->method_end(); MI != ME; ++MI) {
913
914    for (Decl::attr_iterator AI = (*MI)->attr_begin(), AE = (*MI)->attr_end();
915         AI != AE; ++AI) {
916
917      switch ((*AI)->getKind()) {
918      case attr::CallableWhenUnconsumed:
919      case attr::TestsUnconsumed:
920        return true;
921
922      default:
923        break;
924      }
925    }
926  }
927
928  return false;
929}
930
931bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,
932                                  const ConsumedStmtVisitor &Visitor) {
933
934  ConsumedStateMap *FalseStates = new ConsumedStateMap(*CurrStates);
935  PropagationInfo PInfo;
936
937  if (const IfStmt *IfNode =
938    dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {
939
940    bool HasElse = IfNode->getElse() != NULL;
941    const Stmt *Cond = IfNode->getCond();
942
943    PInfo = Visitor.getInfo(Cond);
944    if (!PInfo.isValid() && isa<BinaryOperator>(Cond))
945      PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());
946
947    if (PInfo.isTest()) {
948      CurrStates->setSource(Cond);
949      FalseStates->setSource(Cond);
950
951      splitVarStateForIf(IfNode, PInfo.getTest(), CurrStates,
952                         HasElse ? FalseStates : NULL);
953
954    } else if (PInfo.isBinTest()) {
955      CurrStates->setSource(PInfo.testSourceNode());
956      FalseStates->setSource(PInfo.testSourceNode());
957
958      splitVarStateForIfBinOp(PInfo, CurrStates, HasElse ? FalseStates : NULL);
959
960    } else {
961      delete FalseStates;
962      return false;
963    }
964
965  } else if (const BinaryOperator *BinOp =
966    dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {
967
968    PInfo = Visitor.getInfo(BinOp->getLHS());
969    if (!PInfo.isTest()) {
970      if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {
971        PInfo = Visitor.getInfo(BinOp->getRHS());
972
973        if (!PInfo.isTest()) {
974          delete FalseStates;
975          return false;
976        }
977
978      } else {
979        delete FalseStates;
980        return false;
981      }
982    }
983
984    CurrStates->setSource(BinOp);
985    FalseStates->setSource(BinOp);
986
987    const VarTestResult &Test = PInfo.getTest();
988    ConsumedState VarState = CurrStates->getState(Test.Var);
989
990    if (BinOp->getOpcode() == BO_LAnd) {
991      if (VarState == CS_Unknown)
992        CurrStates->setState(Test.Var, Test.TestsFor);
993      else if (VarState == invertConsumedUnconsumed(Test.TestsFor))
994        CurrStates->markUnreachable();
995
996    } else if (BinOp->getOpcode() == BO_LOr) {
997      if (VarState == CS_Unknown)
998        FalseStates->setState(Test.Var,
999                              invertConsumedUnconsumed(Test.TestsFor));
1000      else if (VarState == Test.TestsFor)
1001        FalseStates->markUnreachable();
1002    }
1003
1004  } else {
1005    delete FalseStates;
1006    return false;
1007  }
1008
1009  CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();
1010
1011  if (*SI)
1012    BlockInfo.addInfo(*SI, CurrStates);
1013  else
1014    delete CurrStates;
1015
1016  if (*++SI)
1017    BlockInfo.addInfo(*SI, FalseStates);
1018  else
1019    delete FalseStates;
1020
1021  CurrStates = NULL;
1022  return true;
1023}
1024
1025void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {
1026  const FunctionDecl *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());
1027
1028  if (!D) return;
1029
1030  BlockInfo = ConsumedBlockInfo(AC.getCFG());
1031
1032  PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();
1033
1034  CurrStates = new ConsumedStateMap();
1035  ConsumedStmtVisitor Visitor(AC, *this);
1036
1037  // Visit all of the function's basic blocks.
1038  for (PostOrderCFGView::iterator I = SortedGraph->begin(),
1039       E = SortedGraph->end(); I != E; ++I) {
1040
1041    const CFGBlock *CurrBlock = *I;
1042    BlockInfo.markVisited(CurrBlock);
1043
1044    if (CurrStates == NULL)
1045      CurrStates = BlockInfo.getInfo(CurrBlock);
1046
1047    if (!CurrStates) {
1048      continue;
1049
1050    } else if (!CurrStates->isReachable()) {
1051      delete CurrStates;
1052      CurrStates = NULL;
1053      continue;
1054    }
1055
1056    Visitor.reset(CurrStates);
1057
1058    // Visit all of the basic block's statements.
1059    for (CFGBlock::const_iterator BI = CurrBlock->begin(),
1060         BE = CurrBlock->end(); BI != BE; ++BI) {
1061
1062      switch (BI->getKind()) {
1063      case CFGElement::Statement:
1064        Visitor.Visit(BI->castAs<CFGStmt>().getStmt());
1065        break;
1066      case CFGElement::AutomaticObjectDtor:
1067        CurrStates->remove(BI->castAs<CFGAutomaticObjDtor>().getVarDecl());
1068      default:
1069        break;
1070      }
1071    }
1072
1073    // TODO: Handle other forms of branching with precision, including while-
1074    //       and for-loops. (Deferred)
1075    if (!splitState(CurrBlock, Visitor)) {
1076      CurrStates->setSource(NULL);
1077
1078      if (CurrBlock->succ_size() > 1) {
1079        CurrStates->makeUnknown();
1080
1081        bool OwnershipTaken = false;
1082
1083        for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),
1084             SE = CurrBlock->succ_end(); SI != SE; ++SI) {
1085
1086          if (*SI) BlockInfo.addInfo(*SI, CurrStates, OwnershipTaken);
1087        }
1088
1089        if (!OwnershipTaken)
1090          delete CurrStates;
1091
1092        CurrStates = NULL;
1093
1094      } else if (CurrBlock->succ_size() == 1 &&
1095                 (*CurrBlock->succ_begin())->pred_size() > 1) {
1096
1097        BlockInfo.addInfo(*CurrBlock->succ_begin(), CurrStates);
1098        CurrStates = NULL;
1099      }
1100    }
1101  } // End of block iterator.
1102
1103  // Delete the last existing state map.
1104  delete CurrStates;
1105
1106  WarningsHandler.emitDiagnostics();
1107}
1108}} // end namespace clang::consumed
1109