ASTMatchFinder.cpp revision e0e6b9e79a0c4edae92abd3928263875c78e23aa
1//===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  Implements an algorithm to efficiently search for matches on AST nodes.
11//  Uses memoization to support recursive matches like HasDescendant.
12//
13//  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
14//  calling the Matches(...) method of each matcher we are running on each
15//  AST node. The matcher can recurse via the ASTMatchFinder interface.
16//
17//===----------------------------------------------------------------------===//
18
19#include "clang/ASTMatchers/ASTMatchFinder.h"
20#include "clang/AST/ASTConsumer.h"
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/RecursiveASTVisitor.h"
23#include <set>
24
25namespace clang {
26namespace ast_matchers {
27namespace internal {
28namespace {
29
30// Returns the value that 'AMap' maps 'Key' to, or NULL if 'Key' is
31// not in 'AMap'.
32template <typename Map>
33static const typename Map::mapped_type *
34find(const Map &AMap, const typename Map::key_type &Key) {
35  typename Map::const_iterator It = AMap.find(Key);
36  return It == AMap.end() ? NULL : &It->second;
37}
38
39// We use memoization to avoid running the same matcher on the same
40// AST node twice.  This pair is the key for looking up match
41// result.  It consists of an ID of the MatcherInterface (for
42// identifying the matcher) and a pointer to the AST node.
43typedef std::pair<uint64_t, const void*> UntypedMatchInput;
44
45// Used to store the result of a match and possibly bound nodes.
46struct MemoizedMatchResult {
47  bool ResultOfMatch;
48  BoundNodesTree Nodes;
49};
50
51// A RecursiveASTVisitor that traverses all children or all descendants of
52// a node.
53class MatchChildASTVisitor
54    : public RecursiveASTVisitor<MatchChildASTVisitor> {
55public:
56  typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
57
58  // Creates an AST visitor that matches 'matcher' on all children or
59  // descendants of a traversed node. max_depth is the maximum depth
60  // to traverse: use 1 for matching the children and INT_MAX for
61  // matching the descendants.
62  MatchChildASTVisitor(const UntypedBaseMatcher *BaseMatcher,
63                       ASTMatchFinder *Finder,
64                       BoundNodesTreeBuilder *Builder,
65                       int MaxDepth,
66                       ASTMatchFinder::TraversalKind Traversal,
67                       ASTMatchFinder::BindKind Bind)
68      : BaseMatcher(BaseMatcher),
69        Finder(Finder),
70        Builder(Builder),
71        CurrentDepth(-1),
72        MaxDepth(MaxDepth),
73        Traversal(Traversal),
74        Bind(Bind),
75        Matches(false) {}
76
77  // Returns true if a match is found in the subtree rooted at the
78  // given AST node. This is done via a set of mutually recursive
79  // functions. Here's how the recursion is done (the  *wildcard can
80  // actually be Decl, Stmt, or Type):
81  //
82  //   - Traverse(node) calls BaseTraverse(node) when it needs
83  //     to visit the descendants of node.
84  //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
85  //     Traverse*(c) for each child c of 'node'.
86  //   - Traverse*(c) in turn calls Traverse(c), completing the
87  //     recursion.
88  template <typename T>
89  bool findMatch(const T &Node) {
90    reset();
91    traverse(Node);
92    return Matches;
93  }
94
95  // The following are overriding methods from the base visitor class.
96  // They are public only to allow CRTP to work. They are *not *part
97  // of the public API of this class.
98  bool TraverseDecl(Decl *DeclNode) {
99    return (DeclNode == NULL) || traverse(*DeclNode);
100  }
101  bool TraverseStmt(Stmt *StmtNode) {
102    const Stmt *StmtToTraverse = StmtNode;
103    if (Traversal ==
104        ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) {
105      const Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode);
106      if (ExprNode != NULL) {
107        StmtToTraverse = ExprNode->IgnoreParenImpCasts();
108      }
109    }
110    return (StmtToTraverse == NULL) || traverse(*StmtToTraverse);
111  }
112  bool TraverseType(QualType TypeNode) {
113    return traverse(TypeNode);
114  }
115
116  bool shouldVisitTemplateInstantiations() const { return true; }
117  bool shouldVisitImplicitCode() const { return true; }
118
119private:
120  // Used for updating the depth during traversal.
121  struct ScopedIncrement {
122    explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
123    ~ScopedIncrement() { --(*Depth); }
124
125   private:
126    int *Depth;
127  };
128
129  // Resets the state of this object.
130  void reset() {
131    Matches = false;
132    CurrentDepth = -1;
133  }
134
135  // Forwards the call to the corresponding Traverse*() method in the
136  // base visitor class.
137  bool baseTraverse(const Decl &DeclNode) {
138    return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
139  }
140  bool baseTraverse(const Stmt &StmtNode) {
141    return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
142  }
143  bool baseTraverse(QualType TypeNode) {
144    return VisitorBase::TraverseType(TypeNode);
145  }
146
147  // Traverses the subtree rooted at 'node'; returns true if the
148  // traversal should continue after this function returns; also sets
149  // matched_ to true if a match is found during the traversal.
150  template <typename T>
151  bool traverse(const T &Node) {
152    TOOLING_COMPILE_ASSERT(IsBaseType<T>::value,
153                           traverse_can_only_be_instantiated_with_base_type);
154    ScopedIncrement ScopedDepth(&CurrentDepth);
155    if (CurrentDepth == 0) {
156      // We don't want to match the root node, so just recurse.
157      return baseTraverse(Node);
158    }
159    if (Bind != ASTMatchFinder::BK_All) {
160      if (BaseMatcher->matches(Node, Finder, Builder)) {
161        Matches = true;
162        return false;  // Abort as soon as a match is found.
163      }
164      if (CurrentDepth < MaxDepth) {
165        // The current node doesn't match, and we haven't reached the
166        // maximum depth yet, so recurse.
167        return baseTraverse(Node);
168      }
169      // The current node doesn't match, and we have reached the
170      // maximum depth, so don't recurse (but continue the traversal
171      // such that other nodes at the current level can be visited).
172      return true;
173    } else {
174      BoundNodesTreeBuilder RecursiveBuilder;
175      if (BaseMatcher->matches(Node, Finder, &RecursiveBuilder)) {
176        // After the first match the matcher succeeds.
177        Matches = true;
178        Builder->addMatch(RecursiveBuilder.build());
179      }
180      if (CurrentDepth < MaxDepth) {
181        baseTraverse(Node);
182      }
183      // In kBindAll mode we always search for more matches.
184      return true;
185    }
186  }
187
188  const UntypedBaseMatcher *const BaseMatcher;
189  ASTMatchFinder *const Finder;
190  BoundNodesTreeBuilder *const Builder;
191  int CurrentDepth;
192  const int MaxDepth;
193  const ASTMatchFinder::TraversalKind Traversal;
194  const ASTMatchFinder::BindKind Bind;
195  bool Matches;
196};
197
198// Controls the outermost traversal of the AST and allows to match multiple
199// matchers.
200class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
201                        public ASTMatchFinder {
202public:
203  MatchASTVisitor(std::vector< std::pair<const UntypedBaseMatcher*,
204                               MatchFinder::MatchCallback*> > *Triggers)
205     : Triggers(Triggers),
206       ActiveASTContext(NULL) {
207  }
208
209  void set_active_ast_context(ASTContext *NewActiveASTContext) {
210    ActiveASTContext = NewActiveASTContext;
211  }
212
213  // The following Visit*() and Traverse*() functions "override"
214  // methods in RecursiveASTVisitor.
215
216  bool VisitTypedefDecl(TypedefDecl *DeclNode) {
217    // When we see 'typedef A B', we add name 'B' to the set of names
218    // A's canonical type maps to.  This is necessary for implementing
219    // IsDerivedFrom(x) properly, where x can be the name of the base
220    // class or any of its aliases.
221    //
222    // In general, the is-alias-of (as defined by typedefs) relation
223    // is tree-shaped, as you can typedef a type more than once.  For
224    // example,
225    //
226    //   typedef A B;
227    //   typedef A C;
228    //   typedef C D;
229    //   typedef C E;
230    //
231    // gives you
232    //
233    //   A
234    //   |- B
235    //   `- C
236    //      |- D
237    //      `- E
238    //
239    // It is wrong to assume that the relation is a chain.  A correct
240    // implementation of IsDerivedFrom() needs to recognize that B and
241    // E are aliases, even though neither is a typedef of the other.
242    // Therefore, we cannot simply walk through one typedef chain to
243    // find out whether the type name matches.
244    const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
245    const Type *CanonicalType =  // root of the typedef tree
246        ActiveASTContext->getCanonicalType(TypeNode);
247    TypeToUnqualifiedAliases[CanonicalType].insert(
248        DeclNode->getName().str());
249    return true;
250  }
251
252  bool TraverseDecl(Decl *DeclNode);
253  bool TraverseStmt(Stmt *StmtNode);
254  bool TraverseType(QualType TypeNode);
255  bool TraverseTypeLoc(TypeLoc TypeNode);
256
257  // Matches children or descendants of 'Node' with 'BaseMatcher'.
258  template <typename T>
259  bool memoizedMatchesRecursively(const T &Node,
260                                  const UntypedBaseMatcher &BaseMatcher,
261                                  BoundNodesTreeBuilder *Builder, int MaxDepth,
262                                  TraversalKind Traversal, BindKind Bind) {
263    TOOLING_COMPILE_ASSERT((llvm::is_same<T, Decl>::value) ||
264                           (llvm::is_same<T, Stmt>::value),
265                           type_does_not_support_memoization);
266    const UntypedMatchInput input(BaseMatcher.getID(), &Node);
267    std::pair<MemoizationMap::iterator, bool> InsertResult
268      = ResultCache.insert(std::make_pair(input, MemoizedMatchResult()));
269    if (InsertResult.second) {
270      BoundNodesTreeBuilder DescendantBoundNodesBuilder;
271      InsertResult.first->second.ResultOfMatch =
272        matchesRecursively(Node, BaseMatcher, &DescendantBoundNodesBuilder,
273                           MaxDepth, Traversal, Bind);
274      InsertResult.first->second.Nodes =
275        DescendantBoundNodesBuilder.build();
276    }
277    InsertResult.first->second.Nodes.copyTo(Builder);
278    return InsertResult.first->second.ResultOfMatch;
279  }
280
281  // Matches children or descendants of 'Node' with 'BaseMatcher'.
282  template <typename T>
283  bool matchesRecursively(const T &Node, const UntypedBaseMatcher &BaseMatcher,
284                          BoundNodesTreeBuilder *Builder, int MaxDepth,
285                          TraversalKind Traversal, BindKind Bind) {
286    MatchChildASTVisitor Visitor(
287      &BaseMatcher, this, Builder, MaxDepth, Traversal, Bind);
288    return Visitor.findMatch(Node);
289  }
290
291  virtual bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
292                                  StringRef BaseName) const;
293
294  // Implements ASTMatchFinder::MatchesChildOf.
295  virtual bool matchesChildOf(const Decl &DeclNode,
296                              const UntypedBaseMatcher &BaseMatcher,
297                              BoundNodesTreeBuilder *Builder,
298                              TraversalKind Traversal,
299                              BindKind Bind) {
300    return matchesRecursively(DeclNode, BaseMatcher, Builder, 1, Traversal,
301                              Bind);
302  }
303  virtual bool matchesChildOf(const Stmt &StmtNode,
304                              const UntypedBaseMatcher &BaseMatcher,
305                              BoundNodesTreeBuilder *Builder,
306                              TraversalKind Traversal,
307                              BindKind Bind) {
308    return matchesRecursively(StmtNode, BaseMatcher, Builder, 1, Traversal,
309                              Bind);
310  }
311
312  // Implements ASTMatchFinder::MatchesDescendantOf.
313  virtual bool matchesDescendantOf(const Decl &DeclNode,
314                                   const UntypedBaseMatcher &BaseMatcher,
315                                   BoundNodesTreeBuilder *Builder,
316                                   BindKind Bind) {
317    return memoizedMatchesRecursively(DeclNode, BaseMatcher, Builder, INT_MAX,
318                                      TK_AsIs, Bind);
319  }
320  virtual bool matchesDescendantOf(const Stmt &StmtNode,
321                                   const UntypedBaseMatcher &BaseMatcher,
322                                   BoundNodesTreeBuilder *Builder,
323                                   BindKind Bind) {
324    return memoizedMatchesRecursively(StmtNode, BaseMatcher, Builder, INT_MAX,
325                                      TK_AsIs, Bind);
326  }
327
328  bool shouldVisitTemplateInstantiations() const { return true; }
329  bool shouldVisitImplicitCode() const { return true; }
330
331private:
332  // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
333  // the aggregated bound nodes for each match.
334  class MatchVisitor : public BoundNodesTree::Visitor {
335  public:
336    MatchVisitor(ASTContext* Context,
337                 MatchFinder::MatchCallback* Callback)
338      : Context(Context),
339        Callback(Callback) {}
340
341    virtual void visitMatch(const BoundNodes& BoundNodesView) {
342      Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
343    }
344
345  private:
346    ASTContext* Context;
347    MatchFinder::MatchCallback* Callback;
348  };
349
350  // Returns true if 'TypeNode' is also known by the name 'Name'.  In other
351  // words, there is a type (including typedef) with the name 'Name'
352  // that is equal to 'TypeNode'.
353  bool typeHasAlias(const Type *TypeNode,
354                    StringRef Name) const {
355    const Type *const CanonicalType =
356      ActiveASTContext->getCanonicalType(TypeNode);
357    const std::set<std::string> *UnqualifiedAlias =
358      find(TypeToUnqualifiedAliases, CanonicalType);
359    return UnqualifiedAlias != NULL && UnqualifiedAlias->count(Name) > 0;
360  }
361
362  // Matches all registered matchers on the given node and calls the
363  // result callback for every node that matches.
364  template <typename T>
365  void match(const T &node) {
366    for (std::vector< std::pair<const UntypedBaseMatcher*,
367                      MatchFinder::MatchCallback*> >::const_iterator
368             It = Triggers->begin(), End = Triggers->end();
369         It != End; ++It) {
370      BoundNodesTreeBuilder Builder;
371      if (It->first->matches(node, this, &Builder)) {
372        BoundNodesTree BoundNodes = Builder.build();
373        MatchVisitor Visitor(ActiveASTContext, It->second);
374        BoundNodes.visitMatches(&Visitor);
375      }
376    }
377  }
378
379  std::vector< std::pair<const UntypedBaseMatcher*,
380               MatchFinder::MatchCallback*> > *const Triggers;
381  ASTContext *ActiveASTContext;
382
383  // Maps a canonical type to the names of its typedefs.
384  llvm::DenseMap<const Type*, std::set<std::string> >
385    TypeToUnqualifiedAliases;
386
387  // Maps (matcher, node) -> the match result for memoization.
388  typedef llvm::DenseMap<UntypedMatchInput, MemoizedMatchResult> MemoizationMap;
389  MemoizationMap ResultCache;
390};
391
392// Returns true if the given class is directly or indirectly derived
393// from a base type with the given name.  A class is considered to be
394// also derived from itself.
395bool
396MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
397                                    StringRef BaseName) const {
398  if (Declaration->getName() == BaseName) {
399    return true;
400  }
401  if (!Declaration->hasDefinition()) {
402    return false;
403  }
404  typedef CXXRecordDecl::base_class_const_iterator BaseIterator;
405  for (BaseIterator It = Declaration->bases_begin(),
406                    End = Declaration->bases_end(); It != End; ++It) {
407    const Type *TypeNode = It->getType().getTypePtr();
408
409    if (typeHasAlias(TypeNode, BaseName))
410      return true;
411
412    // Type::getAs<...>() drills through typedefs.
413    if (TypeNode->getAs<DependentNameType>() != NULL ||
414        TypeNode->getAs<TemplateTypeParmType>() != NULL) {
415      // Dependent names and template TypeNode parameters will be matched when
416      // the template is instantiated.
417      continue;
418    }
419    CXXRecordDecl *ClassDecl = NULL;
420    TemplateSpecializationType const *TemplateType =
421      TypeNode->getAs<TemplateSpecializationType>();
422    if (TemplateType != NULL) {
423      if (TemplateType->getTemplateName().isDependent()) {
424        // Dependent template specializations will be matched when the
425        // template is instantiated.
426        continue;
427      }
428      // For template specialization types which are specializing a template
429      // declaration which is an explicit or partial specialization of another
430      // template declaration, getAsCXXRecordDecl() returns the corresponding
431      // ClassTemplateSpecializationDecl.
432      //
433      // For template specialization types which are specializing a template
434      // declaration which is neither an explicit nor partial specialization of
435      // another template declaration, getAsCXXRecordDecl() returns NULL and
436      // we get the CXXRecordDecl of the templated declaration.
437      CXXRecordDecl *SpecializationDecl =
438        TemplateType->getAsCXXRecordDecl();
439      if (SpecializationDecl != NULL) {
440        ClassDecl = SpecializationDecl;
441      } else {
442        ClassDecl = llvm::dyn_cast<CXXRecordDecl>(
443            TemplateType->getTemplateName()
444                .getAsTemplateDecl()->getTemplatedDecl());
445      }
446    } else {
447      ClassDecl = TypeNode->getAsCXXRecordDecl();
448    }
449    assert(ClassDecl != NULL);
450    assert(ClassDecl != Declaration);
451    if (classIsDerivedFrom(ClassDecl, BaseName)) {
452      return true;
453    }
454  }
455  return false;
456}
457
458bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
459  if (DeclNode == NULL) {
460    return true;
461  }
462  match(*DeclNode);
463  return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
464}
465
466bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode) {
467  if (StmtNode == NULL) {
468    return true;
469  }
470  match(*StmtNode);
471  return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode);
472}
473
474bool MatchASTVisitor::TraverseType(QualType TypeNode) {
475  match(TypeNode);
476  return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
477}
478
479bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLoc) {
480  return RecursiveASTVisitor<MatchASTVisitor>::
481      TraverseType(TypeLoc.getType());
482}
483
484class MatchASTConsumer : public ASTConsumer {
485public:
486  MatchASTConsumer(std::vector< std::pair<const UntypedBaseMatcher*,
487                                MatchFinder::MatchCallback*> > *Triggers,
488                   MatchFinder::ParsingDoneTestCallback *ParsingDone)
489      : Visitor(Triggers),
490        ParsingDone(ParsingDone) {}
491
492private:
493  virtual void HandleTranslationUnit(ASTContext &Context) {
494    if (ParsingDone != NULL) {
495      ParsingDone->run();
496    }
497    Visitor.set_active_ast_context(&Context);
498    Visitor.TraverseDecl(Context.getTranslationUnitDecl());
499    Visitor.set_active_ast_context(NULL);
500  }
501
502  MatchASTVisitor Visitor;
503  MatchFinder::ParsingDoneTestCallback *ParsingDone;
504};
505
506} // end namespace
507} // end namespace internal
508
509MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
510                                      ASTContext *Context)
511  : Nodes(Nodes), Context(Context),
512    SourceManager(&Context->getSourceManager()) {}
513
514MatchFinder::MatchCallback::~MatchCallback() {}
515MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
516
517MatchFinder::MatchFinder() : ParsingDone(NULL) {}
518
519MatchFinder::~MatchFinder() {
520  for (std::vector< std::pair<const internal::UntypedBaseMatcher*,
521                    MatchFinder::MatchCallback*> >::const_iterator
522           It = Triggers.begin(), End = Triggers.end();
523       It != End; ++It) {
524    delete It->first;
525  }
526}
527
528void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
529                             MatchCallback *Action) {
530  Triggers.push_back(std::make_pair(
531    new internal::TypedBaseMatcher<Decl>(NodeMatch), Action));
532}
533
534void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
535                             MatchCallback *Action) {
536  Triggers.push_back(std::make_pair(
537    new internal::TypedBaseMatcher<QualType>(NodeMatch), Action));
538}
539
540void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
541                             MatchCallback *Action) {
542  Triggers.push_back(std::make_pair(
543    new internal::TypedBaseMatcher<Stmt>(NodeMatch), Action));
544}
545
546ASTConsumer *MatchFinder::newASTConsumer() {
547  return new internal::MatchASTConsumer(&Triggers, ParsingDone);
548}
549
550void MatchFinder::registerTestCallbackAfterParsing(
551    MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
552  ParsingDone = NewParsingDone;
553}
554
555} // end namespace ast_matchers
556} // end namespace clang
557