MatchVerifier.h revision 4d0e9f58076037d84a7da0b407c3de8f76a9d552
1//===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  Provides MatchVerifier, a base class to implement gtest matchers that
11//  verify things that can be matched on the AST.
12//
13//  Also implements matchers based on MatchVerifier:
14//  LocationVerifier and RangeVerifier to verify whether a matched node has
15//  the expected source location or source range.
16//
17//===----------------------------------------------------------------------===//
18
19#include "clang/AST/ASTContext.h"
20#include "clang/ASTMatchers/ASTMatchFinder.h"
21#include "clang/ASTMatchers/ASTMatchers.h"
22#include "clang/Tooling/Tooling.h"
23#include "gtest/gtest.h"
24
25namespace clang {
26namespace ast_matchers {
27
28enum Language { Lang_C, Lang_C89, Lang_CXX };
29
30/// \brief Base class for verifying some property of nodes found by a matcher.
31template <typename NodeType>
32class MatchVerifier : public MatchFinder::MatchCallback {
33public:
34  template <typename MatcherType>
35  testing::AssertionResult match(const std::string &Code,
36                                 const MatcherType &AMatcher) {
37    return match(Code, AMatcher, Lang_CXX);
38  }
39
40  template <typename MatcherType>
41  testing::AssertionResult match(const std::string &Code,
42                                 const MatcherType &AMatcher, Language L);
43
44protected:
45  virtual void run(const MatchFinder::MatchResult &Result);
46  virtual void verify(const MatchFinder::MatchResult &Result,
47                      const NodeType &Node) = 0;
48
49  void setFailure(const Twine &Result) {
50    Verified = false;
51    VerifyResult = Result.str();
52  }
53
54  void setSuccess() {
55    Verified = true;
56  }
57
58private:
59  bool Verified;
60  std::string VerifyResult;
61};
62
63/// \brief Runs a matcher over some code, and returns the result of the
64/// verifier for the matched node.
65template <typename NodeType> template <typename MatcherType>
66testing::AssertionResult MatchVerifier<NodeType>::match(
67    const std::string &Code, const MatcherType &AMatcher, Language L) {
68  MatchFinder Finder;
69  Finder.addMatcher(AMatcher.bind(""), this);
70  OwningPtr<tooling::FrontendActionFactory> Factory(
71      tooling::newFrontendActionFactory(&Finder));
72
73  std::vector<std::string> Args;
74  StringRef FileName;
75  switch (L) {
76  case Lang_C:
77    Args.push_back("-std=c99");
78    FileName = "input.c";
79    break;
80  case Lang_C89:
81    Args.push_back("-std=c89");
82    FileName = "input.c";
83    break;
84  case Lang_CXX:
85    Args.push_back("-std=c++98");
86    FileName = "input.cc";
87    break;
88  }
89
90  // Default to failure in case callback is never called
91  setFailure("Could not find match");
92  if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
93    return testing::AssertionFailure() << "Parsing error";
94  if (!Verified)
95    return testing::AssertionFailure() << VerifyResult;
96  return testing::AssertionSuccess();
97}
98
99template <typename NodeType>
100void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
101  const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
102  if (!Node) {
103    setFailure("Matched node has wrong type");
104  } else {
105    // Callback has been called, default to success.
106    setSuccess();
107    verify(Result, *Node);
108  }
109}
110
111/// \brief Verify whether a node has the correct source location.
112///
113/// By default, Node.getSourceLocation() is checked. This can be changed
114/// by overriding getLocation().
115template <typename NodeType>
116class LocationVerifier : public MatchVerifier<NodeType> {
117public:
118  void expectLocation(unsigned Line, unsigned Column) {
119    ExpectLine = Line;
120    ExpectColumn = Column;
121  }
122
123protected:
124  void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
125    SourceLocation Loc = getLocation(Node);
126    unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
127    unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
128    if (Line != ExpectLine || Column != ExpectColumn) {
129      std::string MsgStr;
130      llvm::raw_string_ostream Msg(MsgStr);
131      Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
132          << ">, found <";
133      Loc.print(Msg, *Result.SourceManager);
134      Msg << '>';
135      this->setFailure(Msg.str());
136    }
137  }
138
139  virtual SourceLocation getLocation(const NodeType &Node) {
140    return Node.getLocation();
141  }
142
143private:
144  unsigned ExpectLine, ExpectColumn;
145};
146
147/// \brief Verify whether a node has the correct source range.
148///
149/// By default, Node.getSourceRange() is checked. This can be changed
150/// by overriding getRange().
151template <typename NodeType>
152class RangeVerifier : public MatchVerifier<NodeType> {
153public:
154  void expectRange(unsigned BeginLine, unsigned BeginColumn,
155                   unsigned EndLine, unsigned EndColumn) {
156    ExpectBeginLine = BeginLine;
157    ExpectBeginColumn = BeginColumn;
158    ExpectEndLine = EndLine;
159    ExpectEndColumn = EndColumn;
160  }
161
162protected:
163  void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
164    SourceRange R = getRange(Node);
165    SourceLocation Begin = R.getBegin();
166    SourceLocation End = R.getEnd();
167    unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
168    unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
169    unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
170    unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
171    if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
172        EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
173      std::string MsgStr;
174      llvm::raw_string_ostream Msg(MsgStr);
175      Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
176          << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
177      Begin.print(Msg, *Result.SourceManager);
178      Msg << '-';
179      End.print(Msg, *Result.SourceManager);
180      Msg << '>';
181      this->setFailure(Msg.str());
182    }
183  }
184
185  virtual SourceRange getRange(const NodeType &Node) {
186    return Node.getSourceRange();
187  }
188
189private:
190  unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
191};
192
193} // end namespace ast_matchers
194} // end namespace clang
195