ASTMatchFinder.cpp revision e0e6b9e79a0c4edae92abd3928263875c78e23aa
1//===--- ASTMatchFinder.cpp - Structural query framework ------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// Implements an algorithm to efficiently search for matches on AST nodes. 11// Uses memoization to support recursive matches like HasDescendant. 12// 13// The general idea is to visit all AST nodes with a RecursiveASTVisitor, 14// calling the Matches(...) method of each matcher we are running on each 15// AST node. The matcher can recurse via the ASTMatchFinder interface. 16// 17//===----------------------------------------------------------------------===// 18 19#include "clang/ASTMatchers/ASTMatchFinder.h" 20#include "clang/AST/ASTConsumer.h" 21#include "clang/AST/ASTContext.h" 22#include "clang/AST/RecursiveASTVisitor.h" 23#include <set> 24 25namespace clang { 26namespace ast_matchers { 27namespace internal { 28namespace { 29 30// Returns the value that 'AMap' maps 'Key' to, or NULL if 'Key' is 31// not in 'AMap'. 32template <typename Map> 33static const typename Map::mapped_type * 34find(const Map &AMap, const typename Map::key_type &Key) { 35 typename Map::const_iterator It = AMap.find(Key); 36 return It == AMap.end() ? NULL : &It->second; 37} 38 39// We use memoization to avoid running the same matcher on the same 40// AST node twice. This pair is the key for looking up match 41// result. It consists of an ID of the MatcherInterface (for 42// identifying the matcher) and a pointer to the AST node. 43typedef std::pair<uint64_t, const void*> UntypedMatchInput; 44 45// Used to store the result of a match and possibly bound nodes. 46struct MemoizedMatchResult { 47 bool ResultOfMatch; 48 BoundNodesTree Nodes; 49}; 50 51// A RecursiveASTVisitor that traverses all children or all descendants of 52// a node. 53class MatchChildASTVisitor 54 : public RecursiveASTVisitor<MatchChildASTVisitor> { 55public: 56 typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase; 57 58 // Creates an AST visitor that matches 'matcher' on all children or 59 // descendants of a traversed node. max_depth is the maximum depth 60 // to traverse: use 1 for matching the children and INT_MAX for 61 // matching the descendants. 62 MatchChildASTVisitor(const UntypedBaseMatcher *BaseMatcher, 63 ASTMatchFinder *Finder, 64 BoundNodesTreeBuilder *Builder, 65 int MaxDepth, 66 ASTMatchFinder::TraversalKind Traversal, 67 ASTMatchFinder::BindKind Bind) 68 : BaseMatcher(BaseMatcher), 69 Finder(Finder), 70 Builder(Builder), 71 CurrentDepth(-1), 72 MaxDepth(MaxDepth), 73 Traversal(Traversal), 74 Bind(Bind), 75 Matches(false) {} 76 77 // Returns true if a match is found in the subtree rooted at the 78 // given AST node. This is done via a set of mutually recursive 79 // functions. Here's how the recursion is done (the *wildcard can 80 // actually be Decl, Stmt, or Type): 81 // 82 // - Traverse(node) calls BaseTraverse(node) when it needs 83 // to visit the descendants of node. 84 // - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node)) 85 // Traverse*(c) for each child c of 'node'. 86 // - Traverse*(c) in turn calls Traverse(c), completing the 87 // recursion. 88 template <typename T> 89 bool findMatch(const T &Node) { 90 reset(); 91 traverse(Node); 92 return Matches; 93 } 94 95 // The following are overriding methods from the base visitor class. 96 // They are public only to allow CRTP to work. They are *not *part 97 // of the public API of this class. 98 bool TraverseDecl(Decl *DeclNode) { 99 return (DeclNode == NULL) || traverse(*DeclNode); 100 } 101 bool TraverseStmt(Stmt *StmtNode) { 102 const Stmt *StmtToTraverse = StmtNode; 103 if (Traversal == 104 ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) { 105 const Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode); 106 if (ExprNode != NULL) { 107 StmtToTraverse = ExprNode->IgnoreParenImpCasts(); 108 } 109 } 110 return (StmtToTraverse == NULL) || traverse(*StmtToTraverse); 111 } 112 bool TraverseType(QualType TypeNode) { 113 return traverse(TypeNode); 114 } 115 116 bool shouldVisitTemplateInstantiations() const { return true; } 117 bool shouldVisitImplicitCode() const { return true; } 118 119private: 120 // Used for updating the depth during traversal. 121 struct ScopedIncrement { 122 explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); } 123 ~ScopedIncrement() { --(*Depth); } 124 125 private: 126 int *Depth; 127 }; 128 129 // Resets the state of this object. 130 void reset() { 131 Matches = false; 132 CurrentDepth = -1; 133 } 134 135 // Forwards the call to the corresponding Traverse*() method in the 136 // base visitor class. 137 bool baseTraverse(const Decl &DeclNode) { 138 return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode)); 139 } 140 bool baseTraverse(const Stmt &StmtNode) { 141 return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode)); 142 } 143 bool baseTraverse(QualType TypeNode) { 144 return VisitorBase::TraverseType(TypeNode); 145 } 146 147 // Traverses the subtree rooted at 'node'; returns true if the 148 // traversal should continue after this function returns; also sets 149 // matched_ to true if a match is found during the traversal. 150 template <typename T> 151 bool traverse(const T &Node) { 152 TOOLING_COMPILE_ASSERT(IsBaseType<T>::value, 153 traverse_can_only_be_instantiated_with_base_type); 154 ScopedIncrement ScopedDepth(&CurrentDepth); 155 if (CurrentDepth == 0) { 156 // We don't want to match the root node, so just recurse. 157 return baseTraverse(Node); 158 } 159 if (Bind != ASTMatchFinder::BK_All) { 160 if (BaseMatcher->matches(Node, Finder, Builder)) { 161 Matches = true; 162 return false; // Abort as soon as a match is found. 163 } 164 if (CurrentDepth < MaxDepth) { 165 // The current node doesn't match, and we haven't reached the 166 // maximum depth yet, so recurse. 167 return baseTraverse(Node); 168 } 169 // The current node doesn't match, and we have reached the 170 // maximum depth, so don't recurse (but continue the traversal 171 // such that other nodes at the current level can be visited). 172 return true; 173 } else { 174 BoundNodesTreeBuilder RecursiveBuilder; 175 if (BaseMatcher->matches(Node, Finder, &RecursiveBuilder)) { 176 // After the first match the matcher succeeds. 177 Matches = true; 178 Builder->addMatch(RecursiveBuilder.build()); 179 } 180 if (CurrentDepth < MaxDepth) { 181 baseTraverse(Node); 182 } 183 // In kBindAll mode we always search for more matches. 184 return true; 185 } 186 } 187 188 const UntypedBaseMatcher *const BaseMatcher; 189 ASTMatchFinder *const Finder; 190 BoundNodesTreeBuilder *const Builder; 191 int CurrentDepth; 192 const int MaxDepth; 193 const ASTMatchFinder::TraversalKind Traversal; 194 const ASTMatchFinder::BindKind Bind; 195 bool Matches; 196}; 197 198// Controls the outermost traversal of the AST and allows to match multiple 199// matchers. 200class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>, 201 public ASTMatchFinder { 202public: 203 MatchASTVisitor(std::vector< std::pair<const UntypedBaseMatcher*, 204 MatchFinder::MatchCallback*> > *Triggers) 205 : Triggers(Triggers), 206 ActiveASTContext(NULL) { 207 } 208 209 void set_active_ast_context(ASTContext *NewActiveASTContext) { 210 ActiveASTContext = NewActiveASTContext; 211 } 212 213 // The following Visit*() and Traverse*() functions "override" 214 // methods in RecursiveASTVisitor. 215 216 bool VisitTypedefDecl(TypedefDecl *DeclNode) { 217 // When we see 'typedef A B', we add name 'B' to the set of names 218 // A's canonical type maps to. This is necessary for implementing 219 // IsDerivedFrom(x) properly, where x can be the name of the base 220 // class or any of its aliases. 221 // 222 // In general, the is-alias-of (as defined by typedefs) relation 223 // is tree-shaped, as you can typedef a type more than once. For 224 // example, 225 // 226 // typedef A B; 227 // typedef A C; 228 // typedef C D; 229 // typedef C E; 230 // 231 // gives you 232 // 233 // A 234 // |- B 235 // `- C 236 // |- D 237 // `- E 238 // 239 // It is wrong to assume that the relation is a chain. A correct 240 // implementation of IsDerivedFrom() needs to recognize that B and 241 // E are aliases, even though neither is a typedef of the other. 242 // Therefore, we cannot simply walk through one typedef chain to 243 // find out whether the type name matches. 244 const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr(); 245 const Type *CanonicalType = // root of the typedef tree 246 ActiveASTContext->getCanonicalType(TypeNode); 247 TypeToUnqualifiedAliases[CanonicalType].insert( 248 DeclNode->getName().str()); 249 return true; 250 } 251 252 bool TraverseDecl(Decl *DeclNode); 253 bool TraverseStmt(Stmt *StmtNode); 254 bool TraverseType(QualType TypeNode); 255 bool TraverseTypeLoc(TypeLoc TypeNode); 256 257 // Matches children or descendants of 'Node' with 'BaseMatcher'. 258 template <typename T> 259 bool memoizedMatchesRecursively(const T &Node, 260 const UntypedBaseMatcher &BaseMatcher, 261 BoundNodesTreeBuilder *Builder, int MaxDepth, 262 TraversalKind Traversal, BindKind Bind) { 263 TOOLING_COMPILE_ASSERT((llvm::is_same<T, Decl>::value) || 264 (llvm::is_same<T, Stmt>::value), 265 type_does_not_support_memoization); 266 const UntypedMatchInput input(BaseMatcher.getID(), &Node); 267 std::pair<MemoizationMap::iterator, bool> InsertResult 268 = ResultCache.insert(std::make_pair(input, MemoizedMatchResult())); 269 if (InsertResult.second) { 270 BoundNodesTreeBuilder DescendantBoundNodesBuilder; 271 InsertResult.first->second.ResultOfMatch = 272 matchesRecursively(Node, BaseMatcher, &DescendantBoundNodesBuilder, 273 MaxDepth, Traversal, Bind); 274 InsertResult.first->second.Nodes = 275 DescendantBoundNodesBuilder.build(); 276 } 277 InsertResult.first->second.Nodes.copyTo(Builder); 278 return InsertResult.first->second.ResultOfMatch; 279 } 280 281 // Matches children or descendants of 'Node' with 'BaseMatcher'. 282 template <typename T> 283 bool matchesRecursively(const T &Node, const UntypedBaseMatcher &BaseMatcher, 284 BoundNodesTreeBuilder *Builder, int MaxDepth, 285 TraversalKind Traversal, BindKind Bind) { 286 MatchChildASTVisitor Visitor( 287 &BaseMatcher, this, Builder, MaxDepth, Traversal, Bind); 288 return Visitor.findMatch(Node); 289 } 290 291 virtual bool classIsDerivedFrom(const CXXRecordDecl *Declaration, 292 StringRef BaseName) const; 293 294 // Implements ASTMatchFinder::MatchesChildOf. 295 virtual bool matchesChildOf(const Decl &DeclNode, 296 const UntypedBaseMatcher &BaseMatcher, 297 BoundNodesTreeBuilder *Builder, 298 TraversalKind Traversal, 299 BindKind Bind) { 300 return matchesRecursively(DeclNode, BaseMatcher, Builder, 1, Traversal, 301 Bind); 302 } 303 virtual bool matchesChildOf(const Stmt &StmtNode, 304 const UntypedBaseMatcher &BaseMatcher, 305 BoundNodesTreeBuilder *Builder, 306 TraversalKind Traversal, 307 BindKind Bind) { 308 return matchesRecursively(StmtNode, BaseMatcher, Builder, 1, Traversal, 309 Bind); 310 } 311 312 // Implements ASTMatchFinder::MatchesDescendantOf. 313 virtual bool matchesDescendantOf(const Decl &DeclNode, 314 const UntypedBaseMatcher &BaseMatcher, 315 BoundNodesTreeBuilder *Builder, 316 BindKind Bind) { 317 return memoizedMatchesRecursively(DeclNode, BaseMatcher, Builder, INT_MAX, 318 TK_AsIs, Bind); 319 } 320 virtual bool matchesDescendantOf(const Stmt &StmtNode, 321 const UntypedBaseMatcher &BaseMatcher, 322 BoundNodesTreeBuilder *Builder, 323 BindKind Bind) { 324 return memoizedMatchesRecursively(StmtNode, BaseMatcher, Builder, INT_MAX, 325 TK_AsIs, Bind); 326 } 327 328 bool shouldVisitTemplateInstantiations() const { return true; } 329 bool shouldVisitImplicitCode() const { return true; } 330 331private: 332 // Implements a BoundNodesTree::Visitor that calls a MatchCallback with 333 // the aggregated bound nodes for each match. 334 class MatchVisitor : public BoundNodesTree::Visitor { 335 public: 336 MatchVisitor(ASTContext* Context, 337 MatchFinder::MatchCallback* Callback) 338 : Context(Context), 339 Callback(Callback) {} 340 341 virtual void visitMatch(const BoundNodes& BoundNodesView) { 342 Callback->run(MatchFinder::MatchResult(BoundNodesView, Context)); 343 } 344 345 private: 346 ASTContext* Context; 347 MatchFinder::MatchCallback* Callback; 348 }; 349 350 // Returns true if 'TypeNode' is also known by the name 'Name'. In other 351 // words, there is a type (including typedef) with the name 'Name' 352 // that is equal to 'TypeNode'. 353 bool typeHasAlias(const Type *TypeNode, 354 StringRef Name) const { 355 const Type *const CanonicalType = 356 ActiveASTContext->getCanonicalType(TypeNode); 357 const std::set<std::string> *UnqualifiedAlias = 358 find(TypeToUnqualifiedAliases, CanonicalType); 359 return UnqualifiedAlias != NULL && UnqualifiedAlias->count(Name) > 0; 360 } 361 362 // Matches all registered matchers on the given node and calls the 363 // result callback for every node that matches. 364 template <typename T> 365 void match(const T &node) { 366 for (std::vector< std::pair<const UntypedBaseMatcher*, 367 MatchFinder::MatchCallback*> >::const_iterator 368 It = Triggers->begin(), End = Triggers->end(); 369 It != End; ++It) { 370 BoundNodesTreeBuilder Builder; 371 if (It->first->matches(node, this, &Builder)) { 372 BoundNodesTree BoundNodes = Builder.build(); 373 MatchVisitor Visitor(ActiveASTContext, It->second); 374 BoundNodes.visitMatches(&Visitor); 375 } 376 } 377 } 378 379 std::vector< std::pair<const UntypedBaseMatcher*, 380 MatchFinder::MatchCallback*> > *const Triggers; 381 ASTContext *ActiveASTContext; 382 383 // Maps a canonical type to the names of its typedefs. 384 llvm::DenseMap<const Type*, std::set<std::string> > 385 TypeToUnqualifiedAliases; 386 387 // Maps (matcher, node) -> the match result for memoization. 388 typedef llvm::DenseMap<UntypedMatchInput, MemoizedMatchResult> MemoizationMap; 389 MemoizationMap ResultCache; 390}; 391 392// Returns true if the given class is directly or indirectly derived 393// from a base type with the given name. A class is considered to be 394// also derived from itself. 395bool 396MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration, 397 StringRef BaseName) const { 398 if (Declaration->getName() == BaseName) { 399 return true; 400 } 401 if (!Declaration->hasDefinition()) { 402 return false; 403 } 404 typedef CXXRecordDecl::base_class_const_iterator BaseIterator; 405 for (BaseIterator It = Declaration->bases_begin(), 406 End = Declaration->bases_end(); It != End; ++It) { 407 const Type *TypeNode = It->getType().getTypePtr(); 408 409 if (typeHasAlias(TypeNode, BaseName)) 410 return true; 411 412 // Type::getAs<...>() drills through typedefs. 413 if (TypeNode->getAs<DependentNameType>() != NULL || 414 TypeNode->getAs<TemplateTypeParmType>() != NULL) { 415 // Dependent names and template TypeNode parameters will be matched when 416 // the template is instantiated. 417 continue; 418 } 419 CXXRecordDecl *ClassDecl = NULL; 420 TemplateSpecializationType const *TemplateType = 421 TypeNode->getAs<TemplateSpecializationType>(); 422 if (TemplateType != NULL) { 423 if (TemplateType->getTemplateName().isDependent()) { 424 // Dependent template specializations will be matched when the 425 // template is instantiated. 426 continue; 427 } 428 // For template specialization types which are specializing a template 429 // declaration which is an explicit or partial specialization of another 430 // template declaration, getAsCXXRecordDecl() returns the corresponding 431 // ClassTemplateSpecializationDecl. 432 // 433 // For template specialization types which are specializing a template 434 // declaration which is neither an explicit nor partial specialization of 435 // another template declaration, getAsCXXRecordDecl() returns NULL and 436 // we get the CXXRecordDecl of the templated declaration. 437 CXXRecordDecl *SpecializationDecl = 438 TemplateType->getAsCXXRecordDecl(); 439 if (SpecializationDecl != NULL) { 440 ClassDecl = SpecializationDecl; 441 } else { 442 ClassDecl = llvm::dyn_cast<CXXRecordDecl>( 443 TemplateType->getTemplateName() 444 .getAsTemplateDecl()->getTemplatedDecl()); 445 } 446 } else { 447 ClassDecl = TypeNode->getAsCXXRecordDecl(); 448 } 449 assert(ClassDecl != NULL); 450 assert(ClassDecl != Declaration); 451 if (classIsDerivedFrom(ClassDecl, BaseName)) { 452 return true; 453 } 454 } 455 return false; 456} 457 458bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) { 459 if (DeclNode == NULL) { 460 return true; 461 } 462 match(*DeclNode); 463 return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode); 464} 465 466bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode) { 467 if (StmtNode == NULL) { 468 return true; 469 } 470 match(*StmtNode); 471 return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode); 472} 473 474bool MatchASTVisitor::TraverseType(QualType TypeNode) { 475 match(TypeNode); 476 return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode); 477} 478 479bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLoc) { 480 return RecursiveASTVisitor<MatchASTVisitor>:: 481 TraverseType(TypeLoc.getType()); 482} 483 484class MatchASTConsumer : public ASTConsumer { 485public: 486 MatchASTConsumer(std::vector< std::pair<const UntypedBaseMatcher*, 487 MatchFinder::MatchCallback*> > *Triggers, 488 MatchFinder::ParsingDoneTestCallback *ParsingDone) 489 : Visitor(Triggers), 490 ParsingDone(ParsingDone) {} 491 492private: 493 virtual void HandleTranslationUnit(ASTContext &Context) { 494 if (ParsingDone != NULL) { 495 ParsingDone->run(); 496 } 497 Visitor.set_active_ast_context(&Context); 498 Visitor.TraverseDecl(Context.getTranslationUnitDecl()); 499 Visitor.set_active_ast_context(NULL); 500 } 501 502 MatchASTVisitor Visitor; 503 MatchFinder::ParsingDoneTestCallback *ParsingDone; 504}; 505 506} // end namespace 507} // end namespace internal 508 509MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes, 510 ASTContext *Context) 511 : Nodes(Nodes), Context(Context), 512 SourceManager(&Context->getSourceManager()) {} 513 514MatchFinder::MatchCallback::~MatchCallback() {} 515MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {} 516 517MatchFinder::MatchFinder() : ParsingDone(NULL) {} 518 519MatchFinder::~MatchFinder() { 520 for (std::vector< std::pair<const internal::UntypedBaseMatcher*, 521 MatchFinder::MatchCallback*> >::const_iterator 522 It = Triggers.begin(), End = Triggers.end(); 523 It != End; ++It) { 524 delete It->first; 525 } 526} 527 528void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch, 529 MatchCallback *Action) { 530 Triggers.push_back(std::make_pair( 531 new internal::TypedBaseMatcher<Decl>(NodeMatch), Action)); 532} 533 534void MatchFinder::addMatcher(const TypeMatcher &NodeMatch, 535 MatchCallback *Action) { 536 Triggers.push_back(std::make_pair( 537 new internal::TypedBaseMatcher<QualType>(NodeMatch), Action)); 538} 539 540void MatchFinder::addMatcher(const StatementMatcher &NodeMatch, 541 MatchCallback *Action) { 542 Triggers.push_back(std::make_pair( 543 new internal::TypedBaseMatcher<Stmt>(NodeMatch), Action)); 544} 545 546ASTConsumer *MatchFinder::newASTConsumer() { 547 return new internal::MatchASTConsumer(&Triggers, ParsingDone); 548} 549 550void MatchFinder::registerTestCallbackAfterParsing( 551 MatchFinder::ParsingDoneTestCallback *NewParsingDone) { 552 ParsingDone = NewParsingDone; 553} 554 555} // end namespace ast_matchers 556} // end namespace clang 557