1//===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  Provides MatchVerifier, a base class to implement gtest matchers that
11//  verify things that can be matched on the AST.
12//
13//  Also implements matchers based on MatchVerifier:
14//  LocationVerifier and RangeVerifier to verify whether a matched node has
15//  the expected source location or source range.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
20#define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
21
22#include "clang/AST/ASTContext.h"
23#include "clang/ASTMatchers/ASTMatchFinder.h"
24#include "clang/ASTMatchers/ASTMatchers.h"
25#include "clang/Tooling/Tooling.h"
26#include "gtest/gtest.h"
27
28namespace clang {
29namespace ast_matchers {
30
31enum Language {
32    Lang_C,
33    Lang_C89,
34    Lang_CXX,
35    Lang_CXX11,
36    Lang_OpenCL,
37    Lang_OBJCXX
38};
39
40/// \brief Base class for verifying some property of nodes found by a matcher.
41template <typename NodeType>
42class MatchVerifier : public MatchFinder::MatchCallback {
43public:
44  template <typename MatcherType>
45  testing::AssertionResult match(const std::string &Code,
46                                 const MatcherType &AMatcher) {
47    std::vector<std::string> Args;
48    return match(Code, AMatcher, Args, Lang_CXX);
49  }
50
51  template <typename MatcherType>
52  testing::AssertionResult match(const std::string &Code,
53                                 const MatcherType &AMatcher,
54                                 Language L) {
55    std::vector<std::string> Args;
56    return match(Code, AMatcher, Args, L);
57  }
58
59  template <typename MatcherType>
60  testing::AssertionResult match(const std::string &Code,
61                                 const MatcherType &AMatcher,
62                                 std::vector<std::string>& Args,
63                                 Language L);
64
65protected:
66  void run(const MatchFinder::MatchResult &Result) override;
67  virtual void verify(const MatchFinder::MatchResult &Result,
68                      const NodeType &Node) {}
69
70  void setFailure(const Twine &Result) {
71    Verified = false;
72    VerifyResult = Result.str();
73  }
74
75  void setSuccess() {
76    Verified = true;
77  }
78
79private:
80  bool Verified;
81  std::string VerifyResult;
82};
83
84/// \brief Runs a matcher over some code, and returns the result of the
85/// verifier for the matched node.
86template <typename NodeType> template <typename MatcherType>
87testing::AssertionResult MatchVerifier<NodeType>::match(
88    const std::string &Code, const MatcherType &AMatcher,
89    std::vector<std::string>& Args, Language L) {
90  MatchFinder Finder;
91  Finder.addMatcher(AMatcher.bind(""), this);
92  std::unique_ptr<tooling::FrontendActionFactory> Factory(
93      tooling::newFrontendActionFactory(&Finder));
94
95  StringRef FileName;
96  switch (L) {
97  case Lang_C:
98    Args.push_back("-std=c99");
99    FileName = "input.c";
100    break;
101  case Lang_C89:
102    Args.push_back("-std=c89");
103    FileName = "input.c";
104    break;
105  case Lang_CXX:
106    Args.push_back("-std=c++98");
107    FileName = "input.cc";
108    break;
109  case Lang_CXX11:
110    Args.push_back("-std=c++11");
111    FileName = "input.cc";
112    break;
113  case Lang_OpenCL:
114    FileName = "input.cl";
115    break;
116  case Lang_OBJCXX:
117    FileName = "input.mm";
118    break;
119  }
120
121  // Default to failure in case callback is never called
122  setFailure("Could not find match");
123  if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
124    return testing::AssertionFailure() << "Parsing error";
125  if (!Verified)
126    return testing::AssertionFailure() << VerifyResult;
127  return testing::AssertionSuccess();
128}
129
130template <typename NodeType>
131void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
132  const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
133  if (!Node) {
134    setFailure("Matched node has wrong type");
135  } else {
136    // Callback has been called, default to success.
137    setSuccess();
138    verify(Result, *Node);
139  }
140}
141
142template <>
143inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
144    const MatchFinder::MatchResult &Result) {
145  BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
146  BoundNodes::IDToNodeMap::const_iterator I = M.find("");
147  if (I == M.end()) {
148    setFailure("Node was not bound");
149  } else {
150    // Callback has been called, default to success.
151    setSuccess();
152    verify(Result, I->second);
153  }
154}
155
156/// \brief Verify whether a node has the correct source location.
157///
158/// By default, Node.getSourceLocation() is checked. This can be changed
159/// by overriding getLocation().
160template <typename NodeType>
161class LocationVerifier : public MatchVerifier<NodeType> {
162public:
163  void expectLocation(unsigned Line, unsigned Column) {
164    ExpectLine = Line;
165    ExpectColumn = Column;
166  }
167
168protected:
169  void verify(const MatchFinder::MatchResult &Result,
170              const NodeType &Node) override {
171    SourceLocation Loc = getLocation(Node);
172    unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
173    unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
174    if (Line != ExpectLine || Column != ExpectColumn) {
175      std::string MsgStr;
176      llvm::raw_string_ostream Msg(MsgStr);
177      Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
178          << ">, found <";
179      Loc.print(Msg, *Result.SourceManager);
180      Msg << '>';
181      this->setFailure(Msg.str());
182    }
183  }
184
185  virtual SourceLocation getLocation(const NodeType &Node) {
186    return Node.getLocation();
187  }
188
189private:
190  unsigned ExpectLine, ExpectColumn;
191};
192
193/// \brief Verify whether a node has the correct source range.
194///
195/// By default, Node.getSourceRange() is checked. This can be changed
196/// by overriding getRange().
197template <typename NodeType>
198class RangeVerifier : public MatchVerifier<NodeType> {
199public:
200  void expectRange(unsigned BeginLine, unsigned BeginColumn,
201                   unsigned EndLine, unsigned EndColumn) {
202    ExpectBeginLine = BeginLine;
203    ExpectBeginColumn = BeginColumn;
204    ExpectEndLine = EndLine;
205    ExpectEndColumn = EndColumn;
206  }
207
208protected:
209  void verify(const MatchFinder::MatchResult &Result,
210              const NodeType &Node) override {
211    SourceRange R = getRange(Node);
212    SourceLocation Begin = R.getBegin();
213    SourceLocation End = R.getEnd();
214    unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
215    unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
216    unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
217    unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
218    if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
219        EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
220      std::string MsgStr;
221      llvm::raw_string_ostream Msg(MsgStr);
222      Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
223          << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
224      Begin.print(Msg, *Result.SourceManager);
225      Msg << '-';
226      End.print(Msg, *Result.SourceManager);
227      Msg << '>';
228      this->setFailure(Msg.str());
229    }
230  }
231
232  virtual SourceRange getRange(const NodeType &Node) {
233    return Node.getSourceRange();
234  }
235
236private:
237  unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
238};
239
240/// \brief Verify whether a node's dump contains a given substring.
241class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
242public:
243  void expectSubstring(const std::string &Str) {
244    ExpectSubstring = Str;
245  }
246
247protected:
248  void verify(const MatchFinder::MatchResult &Result,
249              const ast_type_traits::DynTypedNode &Node) override {
250    std::string DumpStr;
251    llvm::raw_string_ostream Dump(DumpStr);
252    Node.dump(Dump, *Result.SourceManager);
253
254    if (Dump.str().find(ExpectSubstring) == std::string::npos) {
255      std::string MsgStr;
256      llvm::raw_string_ostream Msg(MsgStr);
257      Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
258          << Dump.str() << '>';
259      this->setFailure(Msg.str());
260    }
261  }
262
263private:
264  std::string ExpectSubstring;
265};
266
267/// \brief Verify whether a node's pretty print matches a given string.
268class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
269public:
270  void expectString(const std::string &Str) {
271    ExpectString = Str;
272  }
273
274protected:
275  void verify(const MatchFinder::MatchResult &Result,
276              const ast_type_traits::DynTypedNode &Node) override {
277    std::string PrintStr;
278    llvm::raw_string_ostream Print(PrintStr);
279    Node.print(Print, Result.Context->getPrintingPolicy());
280
281    if (Print.str() != ExpectString) {
282      std::string MsgStr;
283      llvm::raw_string_ostream Msg(MsgStr);
284      Msg << "Expected pretty print <" << ExpectString << ">, found <"
285          << Print.str() << '>';
286      this->setFailure(Msg.str());
287    }
288  }
289
290private:
291  std::string ExpectString;
292};
293
294} // end namespace ast_matchers
295} // end namespace clang
296
297#endif
298