ASTMatchFinder.cpp revision a78d0d6203a990b88c9c3e4c4f2a277001e8bd46
1//===--- ASTMatchFinder.cpp - Structural query framework ------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// Implements an algorithm to efficiently search for matches on AST nodes. 11// Uses memoization to support recursive matches like HasDescendant. 12// 13// The general idea is to visit all AST nodes with a RecursiveASTVisitor, 14// calling the Matches(...) method of each matcher we are running on each 15// AST node. The matcher can recurse via the ASTMatchFinder interface. 16// 17//===----------------------------------------------------------------------===// 18 19#include "clang/ASTMatchers/ASTMatchFinder.h" 20#include "clang/AST/ASTConsumer.h" 21#include "clang/AST/ASTContext.h" 22#include "clang/AST/RecursiveASTVisitor.h" 23#include <set> 24 25namespace clang { 26namespace ast_matchers { 27namespace internal { 28namespace { 29 30typedef MatchFinder::MatchCallback MatchCallback; 31 32// We use memoization to avoid running the same matcher on the same 33// AST node twice. This pair is the key for looking up match 34// result. It consists of an ID of the MatcherInterface (for 35// identifying the matcher) and a pointer to the AST node. 36// 37// We currently only memoize on nodes whose pointers identify the 38// nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc). 39// For \c QualType and \c TypeLoc it is possible to implement 40// generation of keys for each type. 41// FIXME: Benchmark whether memoization of non-pointer typed nodes 42// provides enough benefit for the additional amount of code. 43typedef std::pair<uint64_t, const void*> UntypedMatchInput; 44 45// Used to store the result of a match and possibly bound nodes. 46struct MemoizedMatchResult { 47 bool ResultOfMatch; 48 BoundNodesTree Nodes; 49}; 50 51// A RecursiveASTVisitor that traverses all children or all descendants of 52// a node. 53class MatchChildASTVisitor 54 : public RecursiveASTVisitor<MatchChildASTVisitor> { 55public: 56 typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase; 57 58 // Creates an AST visitor that matches 'matcher' on all children or 59 // descendants of a traversed node. max_depth is the maximum depth 60 // to traverse: use 1 for matching the children and INT_MAX for 61 // matching the descendants. 62 MatchChildASTVisitor(const DynTypedMatcher *Matcher, 63 ASTMatchFinder *Finder, 64 BoundNodesTreeBuilder *Builder, 65 int MaxDepth, 66 ASTMatchFinder::TraversalKind Traversal, 67 ASTMatchFinder::BindKind Bind) 68 : Matcher(Matcher), 69 Finder(Finder), 70 Builder(Builder), 71 CurrentDepth(-1), 72 MaxDepth(MaxDepth), 73 Traversal(Traversal), 74 Bind(Bind), 75 Matches(false) {} 76 77 // Returns true if a match is found in the subtree rooted at the 78 // given AST node. This is done via a set of mutually recursive 79 // functions. Here's how the recursion is done (the *wildcard can 80 // actually be Decl, Stmt, or Type): 81 // 82 // - Traverse(node) calls BaseTraverse(node) when it needs 83 // to visit the descendants of node. 84 // - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node)) 85 // Traverse*(c) for each child c of 'node'. 86 // - Traverse*(c) in turn calls Traverse(c), completing the 87 // recursion. 88 bool findMatch(const ast_type_traits::DynTypedNode &DynNode) { 89 reset(); 90 if (const Decl *D = DynNode.get<Decl>()) 91 traverse(*D); 92 else if (const Stmt *S = DynNode.get<Stmt>()) 93 traverse(*S); 94 // FIXME: Add other base types after adding tests. 95 return Matches; 96 } 97 98 // The following are overriding methods from the base visitor class. 99 // They are public only to allow CRTP to work. They are *not *part 100 // of the public API of this class. 101 bool TraverseDecl(Decl *DeclNode) { 102 return (DeclNode == NULL) || traverse(*DeclNode); 103 } 104 bool TraverseStmt(Stmt *StmtNode) { 105 const Stmt *StmtToTraverse = StmtNode; 106 if (Traversal == 107 ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) { 108 const Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode); 109 if (ExprNode != NULL) { 110 StmtToTraverse = ExprNode->IgnoreParenImpCasts(); 111 } 112 } 113 return (StmtToTraverse == NULL) || traverse(*StmtToTraverse); 114 } 115 bool TraverseType(QualType TypeNode) { 116 return traverse(TypeNode); 117 } 118 119 bool shouldVisitTemplateInstantiations() const { return true; } 120 bool shouldVisitImplicitCode() const { return true; } 121 122private: 123 // Used for updating the depth during traversal. 124 struct ScopedIncrement { 125 explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); } 126 ~ScopedIncrement() { --(*Depth); } 127 128 private: 129 int *Depth; 130 }; 131 132 // Resets the state of this object. 133 void reset() { 134 Matches = false; 135 CurrentDepth = -1; 136 } 137 138 // Forwards the call to the corresponding Traverse*() method in the 139 // base visitor class. 140 bool baseTraverse(const Decl &DeclNode) { 141 return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode)); 142 } 143 bool baseTraverse(const Stmt &StmtNode) { 144 return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode)); 145 } 146 bool baseTraverse(QualType TypeNode) { 147 return VisitorBase::TraverseType(TypeNode); 148 } 149 150 // Traverses the subtree rooted at 'node'; returns true if the 151 // traversal should continue after this function returns; also sets 152 // matched_ to true if a match is found during the traversal. 153 template <typename T> 154 bool traverse(const T &Node) { 155 TOOLING_COMPILE_ASSERT(IsBaseType<T>::value, 156 traverse_can_only_be_instantiated_with_base_type); 157 ScopedIncrement ScopedDepth(&CurrentDepth); 158 if (CurrentDepth == 0) { 159 // We don't want to match the root node, so just recurse. 160 return baseTraverse(Node); 161 } 162 if (Bind != ASTMatchFinder::BK_All) { 163 if (Matcher->matches(ast_type_traits::DynTypedNode::create(Node), 164 Finder, Builder)) { 165 Matches = true; 166 return false; // Abort as soon as a match is found. 167 } 168 if (CurrentDepth < MaxDepth) { 169 // The current node doesn't match, and we haven't reached the 170 // maximum depth yet, so recurse. 171 return baseTraverse(Node); 172 } 173 // The current node doesn't match, and we have reached the 174 // maximum depth, so don't recurse (but continue the traversal 175 // such that other nodes at the current level can be visited). 176 return true; 177 } else { 178 BoundNodesTreeBuilder RecursiveBuilder; 179 if (Matcher->matches(ast_type_traits::DynTypedNode::create(Node), 180 Finder, &RecursiveBuilder)) { 181 // After the first match the matcher succeeds. 182 Matches = true; 183 Builder->addMatch(RecursiveBuilder.build()); 184 } 185 if (CurrentDepth < MaxDepth) { 186 baseTraverse(Node); 187 } 188 // In kBindAll mode we always search for more matches. 189 return true; 190 } 191 } 192 193 const DynTypedMatcher *const Matcher; 194 ASTMatchFinder *const Finder; 195 BoundNodesTreeBuilder *const Builder; 196 int CurrentDepth; 197 const int MaxDepth; 198 const ASTMatchFinder::TraversalKind Traversal; 199 const ASTMatchFinder::BindKind Bind; 200 bool Matches; 201}; 202 203// Controls the outermost traversal of the AST and allows to match multiple 204// matchers. 205class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>, 206 public ASTMatchFinder { 207public: 208 MatchASTVisitor(std::vector<std::pair<const internal::DynTypedMatcher*, 209 MatchCallback*> > *MatcherCallbackPairs) 210 : MatcherCallbackPairs(MatcherCallbackPairs), 211 ActiveASTContext(NULL) { 212 } 213 214 void set_active_ast_context(ASTContext *NewActiveASTContext) { 215 ActiveASTContext = NewActiveASTContext; 216 } 217 218 // The following Visit*() and Traverse*() functions "override" 219 // methods in RecursiveASTVisitor. 220 221 bool VisitTypedefDecl(TypedefDecl *DeclNode) { 222 // When we see 'typedef A B', we add name 'B' to the set of names 223 // A's canonical type maps to. This is necessary for implementing 224 // IsDerivedFrom(x) properly, where x can be the name of the base 225 // class or any of its aliases. 226 // 227 // In general, the is-alias-of (as defined by typedefs) relation 228 // is tree-shaped, as you can typedef a type more than once. For 229 // example, 230 // 231 // typedef A B; 232 // typedef A C; 233 // typedef C D; 234 // typedef C E; 235 // 236 // gives you 237 // 238 // A 239 // |- B 240 // `- C 241 // |- D 242 // `- E 243 // 244 // It is wrong to assume that the relation is a chain. A correct 245 // implementation of IsDerivedFrom() needs to recognize that B and 246 // E are aliases, even though neither is a typedef of the other. 247 // Therefore, we cannot simply walk through one typedef chain to 248 // find out whether the type name matches. 249 const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr(); 250 const Type *CanonicalType = // root of the typedef tree 251 ActiveASTContext->getCanonicalType(TypeNode); 252 TypeAliases[CanonicalType].insert(DeclNode); 253 return true; 254 } 255 256 bool TraverseDecl(Decl *DeclNode); 257 bool TraverseStmt(Stmt *StmtNode); 258 bool TraverseType(QualType TypeNode); 259 bool TraverseTypeLoc(TypeLoc TypeNode); 260 261 // Matches children or descendants of 'Node' with 'BaseMatcher'. 262 bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node, 263 const DynTypedMatcher &Matcher, 264 BoundNodesTreeBuilder *Builder, int MaxDepth, 265 TraversalKind Traversal, BindKind Bind) { 266 const UntypedMatchInput input(Matcher.getID(), Node.getMemoizationData()); 267 assert(input.second && 268 "Fix getMemoizationData once more types allow recursive matching."); 269 std::pair<MemoizationMap::iterator, bool> InsertResult 270 = ResultCache.insert(std::make_pair(input, MemoizedMatchResult())); 271 if (InsertResult.second) { 272 BoundNodesTreeBuilder DescendantBoundNodesBuilder; 273 InsertResult.first->second.ResultOfMatch = 274 matchesRecursively(Node, Matcher, &DescendantBoundNodesBuilder, 275 MaxDepth, Traversal, Bind); 276 InsertResult.first->second.Nodes = 277 DescendantBoundNodesBuilder.build(); 278 } 279 InsertResult.first->second.Nodes.copyTo(Builder); 280 return InsertResult.first->second.ResultOfMatch; 281 } 282 283 // Matches children or descendants of 'Node' with 'BaseMatcher'. 284 bool matchesRecursively(const ast_type_traits::DynTypedNode &Node, 285 const DynTypedMatcher &Matcher, 286 BoundNodesTreeBuilder *Builder, int MaxDepth, 287 TraversalKind Traversal, BindKind Bind) { 288 MatchChildASTVisitor Visitor( 289 &Matcher, this, Builder, MaxDepth, Traversal, Bind); 290 return Visitor.findMatch(Node); 291 } 292 293 virtual bool classIsDerivedFrom(const CXXRecordDecl *Declaration, 294 const Matcher<NamedDecl> &Base, 295 BoundNodesTreeBuilder *Builder); 296 297 // Implements ASTMatchFinder::MatchesChildOf. 298 virtual bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, 299 const DynTypedMatcher &Matcher, 300 BoundNodesTreeBuilder *Builder, 301 TraversalKind Traversal, 302 BindKind Bind) { 303 return matchesRecursively(Node, Matcher, Builder, 1, Traversal, 304 Bind); 305 } 306 // Implements ASTMatchFinder::MatchesDescendantOf. 307 virtual bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, 308 const DynTypedMatcher &Matcher, 309 BoundNodesTreeBuilder *Builder, 310 BindKind Bind) { 311 return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX, 312 TK_AsIs, Bind); 313 } 314 315 bool shouldVisitTemplateInstantiations() const { return true; } 316 bool shouldVisitImplicitCode() const { return true; } 317 318private: 319 // Implements a BoundNodesTree::Visitor that calls a MatchCallback with 320 // the aggregated bound nodes for each match. 321 class MatchVisitor : public BoundNodesTree::Visitor { 322 public: 323 MatchVisitor(ASTContext* Context, 324 MatchFinder::MatchCallback* Callback) 325 : Context(Context), 326 Callback(Callback) {} 327 328 virtual void visitMatch(const BoundNodes& BoundNodesView) { 329 Callback->run(MatchFinder::MatchResult(BoundNodesView, Context)); 330 } 331 332 private: 333 ASTContext* Context; 334 MatchFinder::MatchCallback* Callback; 335 }; 336 337 // Returns true if 'TypeNode' has an alias that matches the given matcher. 338 bool typeHasMatchingAlias(const Type *TypeNode, 339 const Matcher<NamedDecl> Matcher, 340 BoundNodesTreeBuilder *Builder) { 341 const Type *const CanonicalType = 342 ActiveASTContext->getCanonicalType(TypeNode); 343 const std::set<const TypedefDecl*> &Aliases = TypeAliases[CanonicalType]; 344 for (std::set<const TypedefDecl*>::const_iterator 345 It = Aliases.begin(), End = Aliases.end(); 346 It != End; ++It) { 347 if (Matcher.matches(**It, this, Builder)) 348 return true; 349 } 350 return false; 351 } 352 353 // Matches all registered matchers on the given node and calls the 354 // result callback for every node that matches. 355 template <typename T> 356 void match(const T &node) { 357 for (std::vector<std::pair<const internal::DynTypedMatcher*, 358 MatchCallback*> >::const_iterator 359 I = MatcherCallbackPairs->begin(), E = MatcherCallbackPairs->end(); 360 I != E; ++I) { 361 BoundNodesTreeBuilder Builder; 362 if (I->first->matches(ast_type_traits::DynTypedNode::create(node), 363 this, &Builder)) { 364 BoundNodesTree BoundNodes = Builder.build(); 365 MatchVisitor Visitor(ActiveASTContext, I->second); 366 BoundNodes.visitMatches(&Visitor); 367 } 368 } 369 } 370 371 std::vector<std::pair<const internal::DynTypedMatcher*, 372 MatchCallback*> > *const MatcherCallbackPairs; 373 ASTContext *ActiveASTContext; 374 375 // Maps a canonical type to its TypedefDecls. 376 llvm::DenseMap<const Type*, std::set<const TypedefDecl*> > TypeAliases; 377 378 // Maps (matcher, node) -> the match result for memoization. 379 typedef llvm::DenseMap<UntypedMatchInput, MemoizedMatchResult> MemoizationMap; 380 MemoizationMap ResultCache; 381}; 382 383// Returns true if the given class is directly or indirectly derived 384// from a base type with the given name. A class is considered to be 385// also derived from itself. 386bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration, 387 const Matcher<NamedDecl> &Base, 388 BoundNodesTreeBuilder *Builder) { 389 if (Base.matches(*Declaration, this, Builder)) 390 return true; 391 if (!Declaration->hasDefinition()) 392 return false; 393 typedef CXXRecordDecl::base_class_const_iterator BaseIterator; 394 for (BaseIterator It = Declaration->bases_begin(), 395 End = Declaration->bases_end(); It != End; ++It) { 396 const Type *TypeNode = It->getType().getTypePtr(); 397 398 if (typeHasMatchingAlias(TypeNode, Base, Builder)) 399 return true; 400 401 // Type::getAs<...>() drills through typedefs. 402 if (TypeNode->getAs<DependentNameType>() != NULL || 403 TypeNode->getAs<TemplateTypeParmType>() != NULL) 404 // Dependent names and template TypeNode parameters will be matched when 405 // the template is instantiated. 406 continue; 407 CXXRecordDecl *ClassDecl = NULL; 408 TemplateSpecializationType const *TemplateType = 409 TypeNode->getAs<TemplateSpecializationType>(); 410 if (TemplateType != NULL) { 411 if (TemplateType->getTemplateName().isDependent()) 412 // Dependent template specializations will be matched when the 413 // template is instantiated. 414 continue; 415 416 // For template specialization types which are specializing a template 417 // declaration which is an explicit or partial specialization of another 418 // template declaration, getAsCXXRecordDecl() returns the corresponding 419 // ClassTemplateSpecializationDecl. 420 // 421 // For template specialization types which are specializing a template 422 // declaration which is neither an explicit nor partial specialization of 423 // another template declaration, getAsCXXRecordDecl() returns NULL and 424 // we get the CXXRecordDecl of the templated declaration. 425 CXXRecordDecl *SpecializationDecl = 426 TemplateType->getAsCXXRecordDecl(); 427 if (SpecializationDecl != NULL) { 428 ClassDecl = SpecializationDecl; 429 } else { 430 ClassDecl = llvm::dyn_cast<CXXRecordDecl>( 431 TemplateType->getTemplateName() 432 .getAsTemplateDecl()->getTemplatedDecl()); 433 } 434 } else { 435 ClassDecl = TypeNode->getAsCXXRecordDecl(); 436 } 437 assert(ClassDecl != NULL); 438 assert(ClassDecl != Declaration); 439 if (classIsDerivedFrom(ClassDecl, Base, Builder)) 440 return true; 441 } 442 return false; 443} 444 445bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) { 446 if (DeclNode == NULL) { 447 return true; 448 } 449 match(*DeclNode); 450 return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode); 451} 452 453bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode) { 454 if (StmtNode == NULL) { 455 return true; 456 } 457 match(*StmtNode); 458 return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode); 459} 460 461bool MatchASTVisitor::TraverseType(QualType TypeNode) { 462 match(TypeNode); 463 return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode); 464} 465 466bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLoc) { 467 match(TypeLoc.getType()); 468 return RecursiveASTVisitor<MatchASTVisitor>:: 469 TraverseTypeLoc(TypeLoc); 470} 471 472class MatchASTConsumer : public ASTConsumer { 473public: 474 MatchASTConsumer( 475 std::vector<std::pair<const internal::DynTypedMatcher*, 476 MatchCallback*> > *MatcherCallbackPairs, 477 MatchFinder::ParsingDoneTestCallback *ParsingDone) 478 : Visitor(MatcherCallbackPairs), 479 ParsingDone(ParsingDone) {} 480 481private: 482 virtual void HandleTranslationUnit(ASTContext &Context) { 483 if (ParsingDone != NULL) { 484 ParsingDone->run(); 485 } 486 Visitor.set_active_ast_context(&Context); 487 Visitor.TraverseDecl(Context.getTranslationUnitDecl()); 488 Visitor.set_active_ast_context(NULL); 489 } 490 491 MatchASTVisitor Visitor; 492 MatchFinder::ParsingDoneTestCallback *ParsingDone; 493}; 494 495} // end namespace 496} // end namespace internal 497 498MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes, 499 ASTContext *Context) 500 : Nodes(Nodes), Context(Context), 501 SourceManager(&Context->getSourceManager()) {} 502 503MatchFinder::MatchCallback::~MatchCallback() {} 504MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {} 505 506MatchFinder::MatchFinder() : ParsingDone(NULL) {} 507 508MatchFinder::~MatchFinder() { 509 for (std::vector<std::pair<const internal::DynTypedMatcher*, 510 MatchCallback*> >::const_iterator 511 It = MatcherCallbackPairs.begin(), End = MatcherCallbackPairs.end(); 512 It != End; ++It) { 513 delete It->first; 514 } 515} 516 517void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch, 518 MatchCallback *Action) { 519 MatcherCallbackPairs.push_back(std::make_pair( 520 new internal::Matcher<Decl>(NodeMatch), Action)); 521} 522 523void MatchFinder::addMatcher(const TypeMatcher &NodeMatch, 524 MatchCallback *Action) { 525 MatcherCallbackPairs.push_back(std::make_pair( 526 new internal::Matcher<QualType>(NodeMatch), Action)); 527} 528 529void MatchFinder::addMatcher(const StatementMatcher &NodeMatch, 530 MatchCallback *Action) { 531 MatcherCallbackPairs.push_back(std::make_pair( 532 new internal::Matcher<Stmt>(NodeMatch), Action)); 533} 534 535ASTConsumer *MatchFinder::newASTConsumer() { 536 return new internal::MatchASTConsumer(&MatcherCallbackPairs, ParsingDone); 537} 538 539void MatchFinder::registerTestCallbackAfterParsing( 540 MatchFinder::ParsingDoneTestCallback *NewParsingDone) { 541 ParsingDone = NewParsingDone; 542} 543 544} // end namespace ast_matchers 545} // end namespace clang 546