1//===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  Provides MatchVerifier, a base class to implement gtest matchers that
11//  verify things that can be matched on the AST.
12//
13//  Also implements matchers based on MatchVerifier:
14//  LocationVerifier and RangeVerifier to verify whether a matched node has
15//  the expected source location or source range.
16//
17//===----------------------------------------------------------------------===//
18
19#include "clang/AST/ASTContext.h"
20#include "clang/ASTMatchers/ASTMatchFinder.h"
21#include "clang/ASTMatchers/ASTMatchers.h"
22#include "clang/Tooling/Tooling.h"
23#include "gtest/gtest.h"
24
25namespace clang {
26namespace ast_matchers {
27
28enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_CXX11, Lang_OpenCL };
29
30/// \brief Base class for verifying some property of nodes found by a matcher.
31template <typename NodeType>
32class MatchVerifier : public MatchFinder::MatchCallback {
33public:
34  template <typename MatcherType>
35  testing::AssertionResult match(const std::string &Code,
36                                 const MatcherType &AMatcher) {
37    std::vector<std::string> Args;
38    return match(Code, AMatcher, Args, Lang_CXX);
39  }
40
41  template <typename MatcherType>
42  testing::AssertionResult match(const std::string &Code,
43                                 const MatcherType &AMatcher,
44                                 Language L) {
45    std::vector<std::string> Args;
46    return match(Code, AMatcher, Args, L);
47  }
48
49  template <typename MatcherType>
50  testing::AssertionResult match(const std::string &Code,
51                                 const MatcherType &AMatcher,
52                                 std::vector<std::string>& Args,
53                                 Language L);
54
55protected:
56  virtual void run(const MatchFinder::MatchResult &Result);
57  virtual void verify(const MatchFinder::MatchResult &Result,
58                      const NodeType &Node) {}
59
60  void setFailure(const Twine &Result) {
61    Verified = false;
62    VerifyResult = Result.str();
63  }
64
65  void setSuccess() {
66    Verified = true;
67  }
68
69private:
70  bool Verified;
71  std::string VerifyResult;
72};
73
74/// \brief Runs a matcher over some code, and returns the result of the
75/// verifier for the matched node.
76template <typename NodeType> template <typename MatcherType>
77testing::AssertionResult MatchVerifier<NodeType>::match(
78    const std::string &Code, const MatcherType &AMatcher,
79    std::vector<std::string>& Args, Language L) {
80  MatchFinder Finder;
81  Finder.addMatcher(AMatcher.bind(""), this);
82  std::unique_ptr<tooling::FrontendActionFactory> Factory(
83      tooling::newFrontendActionFactory(&Finder));
84
85  StringRef FileName;
86  switch (L) {
87  case Lang_C:
88    Args.push_back("-std=c99");
89    FileName = "input.c";
90    break;
91  case Lang_C89:
92    Args.push_back("-std=c89");
93    FileName = "input.c";
94    break;
95  case Lang_CXX:
96    Args.push_back("-std=c++98");
97    FileName = "input.cc";
98    break;
99  case Lang_CXX11:
100    Args.push_back("-std=c++11");
101    FileName = "input.cc";
102    break;
103  case Lang_OpenCL:
104    FileName = "input.cl";
105  }
106
107  // Default to failure in case callback is never called
108  setFailure("Could not find match");
109  if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
110    return testing::AssertionFailure() << "Parsing error";
111  if (!Verified)
112    return testing::AssertionFailure() << VerifyResult;
113  return testing::AssertionSuccess();
114}
115
116template <typename NodeType>
117void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
118  const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
119  if (!Node) {
120    setFailure("Matched node has wrong type");
121  } else {
122    // Callback has been called, default to success.
123    setSuccess();
124    verify(Result, *Node);
125  }
126}
127
128template <>
129inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
130    const MatchFinder::MatchResult &Result) {
131  BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
132  BoundNodes::IDToNodeMap::const_iterator I = M.find("");
133  if (I == M.end()) {
134    setFailure("Node was not bound");
135  } else {
136    // Callback has been called, default to success.
137    setSuccess();
138    verify(Result, I->second);
139  }
140}
141
142/// \brief Verify whether a node has the correct source location.
143///
144/// By default, Node.getSourceLocation() is checked. This can be changed
145/// by overriding getLocation().
146template <typename NodeType>
147class LocationVerifier : public MatchVerifier<NodeType> {
148public:
149  void expectLocation(unsigned Line, unsigned Column) {
150    ExpectLine = Line;
151    ExpectColumn = Column;
152  }
153
154protected:
155  void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
156    SourceLocation Loc = getLocation(Node);
157    unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
158    unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
159    if (Line != ExpectLine || Column != ExpectColumn) {
160      std::string MsgStr;
161      llvm::raw_string_ostream Msg(MsgStr);
162      Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
163          << ">, found <";
164      Loc.print(Msg, *Result.SourceManager);
165      Msg << '>';
166      this->setFailure(Msg.str());
167    }
168  }
169
170  virtual SourceLocation getLocation(const NodeType &Node) {
171    return Node.getLocation();
172  }
173
174private:
175  unsigned ExpectLine, ExpectColumn;
176};
177
178/// \brief Verify whether a node has the correct source range.
179///
180/// By default, Node.getSourceRange() is checked. This can be changed
181/// by overriding getRange().
182template <typename NodeType>
183class RangeVerifier : public MatchVerifier<NodeType> {
184public:
185  void expectRange(unsigned BeginLine, unsigned BeginColumn,
186                   unsigned EndLine, unsigned EndColumn) {
187    ExpectBeginLine = BeginLine;
188    ExpectBeginColumn = BeginColumn;
189    ExpectEndLine = EndLine;
190    ExpectEndColumn = EndColumn;
191  }
192
193protected:
194  void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
195    SourceRange R = getRange(Node);
196    SourceLocation Begin = R.getBegin();
197    SourceLocation End = R.getEnd();
198    unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
199    unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
200    unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
201    unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
202    if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
203        EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
204      std::string MsgStr;
205      llvm::raw_string_ostream Msg(MsgStr);
206      Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
207          << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
208      Begin.print(Msg, *Result.SourceManager);
209      Msg << '-';
210      End.print(Msg, *Result.SourceManager);
211      Msg << '>';
212      this->setFailure(Msg.str());
213    }
214  }
215
216  virtual SourceRange getRange(const NodeType &Node) {
217    return Node.getSourceRange();
218  }
219
220private:
221  unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
222};
223
224/// \brief Verify whether a node's dump contains a given substring.
225class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
226public:
227  void expectSubstring(const std::string &Str) {
228    ExpectSubstring = Str;
229  }
230
231protected:
232  void verify(const MatchFinder::MatchResult &Result,
233              const ast_type_traits::DynTypedNode &Node) {
234    std::string DumpStr;
235    llvm::raw_string_ostream Dump(DumpStr);
236    Node.dump(Dump, *Result.SourceManager);
237
238    if (Dump.str().find(ExpectSubstring) == std::string::npos) {
239      std::string MsgStr;
240      llvm::raw_string_ostream Msg(MsgStr);
241      Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
242          << Dump.str() << '>';
243      this->setFailure(Msg.str());
244    }
245  }
246
247private:
248  std::string ExpectSubstring;
249};
250
251/// \brief Verify whether a node's pretty print matches a given string.
252class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
253public:
254  void expectString(const std::string &Str) {
255    ExpectString = Str;
256  }
257
258protected:
259  void verify(const MatchFinder::MatchResult &Result,
260              const ast_type_traits::DynTypedNode &Node) {
261    std::string PrintStr;
262    llvm::raw_string_ostream Print(PrintStr);
263    Node.print(Print, Result.Context->getPrintingPolicy());
264
265    if (Print.str() != ExpectString) {
266      std::string MsgStr;
267      llvm::raw_string_ostream Msg(MsgStr);
268      Msg << "Expected pretty print <" << ExpectString << ">, found <"
269          << Print.str() << '>';
270      this->setFailure(Msg.str());
271    }
272  }
273
274private:
275  std::string ExpectString;
276};
277
278} // end namespace ast_matchers
279} // end namespace clang
280