1//===- unittests/AST/StmtPrinterTest.cpp --- Statement printer tests ------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains tests for Stmt::printPretty() and related methods.
11//
12// Search this file for WRONG to see test cases that are producing something
13// completely wrong, invalid C++ or just misleading.
14//
15// These tests have a coding convention:
16// * statements to be printed should be contained within a function named 'A'
17//   unless it should have some special name (e.g., 'operator+');
18// * additional helper declarations are 'Z', 'Y', 'X' and so on.
19//
20//===----------------------------------------------------------------------===//
21
22#include "clang/AST/ASTContext.h"
23#include "clang/ASTMatchers/ASTMatchFinder.h"
24#include "clang/Tooling/Tooling.h"
25#include "llvm/ADT/SmallString.h"
26#include "gtest/gtest.h"
27
28using namespace clang;
29using namespace ast_matchers;
30using namespace tooling;
31
32namespace {
33
34void PrintStmt(raw_ostream &Out, const ASTContext *Context, const Stmt *S) {
35  assert(S != nullptr && "Expected non-null Stmt");
36  PrintingPolicy Policy = Context->getPrintingPolicy();
37  S->printPretty(Out, /*Helper*/ nullptr, Policy);
38}
39
40class PrintMatch : public MatchFinder::MatchCallback {
41  SmallString<1024> Printed;
42  unsigned NumFoundStmts;
43
44public:
45  PrintMatch() : NumFoundStmts(0) {}
46
47  void run(const MatchFinder::MatchResult &Result) override {
48    const Stmt *S = Result.Nodes.getStmtAs<Stmt>("id");
49    if (!S)
50      return;
51    NumFoundStmts++;
52    if (NumFoundStmts > 1)
53      return;
54
55    llvm::raw_svector_ostream Out(Printed);
56    PrintStmt(Out, Result.Context, S);
57  }
58
59  StringRef getPrinted() const {
60    return Printed;
61  }
62
63  unsigned getNumFoundStmts() const {
64    return NumFoundStmts;
65  }
66};
67
68template <typename T>
69::testing::AssertionResult
70PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
71                   const T &NodeMatch, StringRef ExpectedPrinted) {
72
73  PrintMatch Printer;
74  MatchFinder Finder;
75  Finder.addMatcher(NodeMatch, &Printer);
76  std::unique_ptr<FrontendActionFactory> Factory(
77      newFrontendActionFactory(&Finder));
78
79  if (!runToolOnCodeWithArgs(Factory->create(), Code, Args))
80    return testing::AssertionFailure()
81      << "Parsing error in \"" << Code.str() << "\"";
82
83  if (Printer.getNumFoundStmts() == 0)
84    return testing::AssertionFailure()
85        << "Matcher didn't find any statements";
86
87  if (Printer.getNumFoundStmts() > 1)
88    return testing::AssertionFailure()
89        << "Matcher should match only one statement "
90           "(found " << Printer.getNumFoundStmts() << ")";
91
92  if (Printer.getPrinted() != ExpectedPrinted)
93    return ::testing::AssertionFailure()
94      << "Expected \"" << ExpectedPrinted.str() << "\", "
95         "got \"" << Printer.getPrinted().str() << "\"";
96
97  return ::testing::AssertionSuccess();
98}
99
100::testing::AssertionResult
101PrintedStmtCXX98Matches(StringRef Code, const StatementMatcher &NodeMatch,
102                        StringRef ExpectedPrinted) {
103  std::vector<std::string> Args;
104  Args.push_back("-std=c++98");
105  Args.push_back("-Wno-unused-value");
106  return PrintedStmtMatches(Code, Args, NodeMatch, ExpectedPrinted);
107}
108
109::testing::AssertionResult PrintedStmtCXX98Matches(
110                                              StringRef Code,
111                                              StringRef ContainingFunction,
112                                              StringRef ExpectedPrinted) {
113  std::vector<std::string> Args;
114  Args.push_back("-std=c++98");
115  Args.push_back("-Wno-unused-value");
116  return PrintedStmtMatches(Code,
117                            Args,
118                            functionDecl(hasName(ContainingFunction),
119                                         has(compoundStmt(has(stmt().bind("id"))))),
120                            ExpectedPrinted);
121}
122
123::testing::AssertionResult
124PrintedStmtCXX11Matches(StringRef Code, const StatementMatcher &NodeMatch,
125                        StringRef ExpectedPrinted) {
126  std::vector<std::string> Args;
127  Args.push_back("-std=c++11");
128  Args.push_back("-Wno-unused-value");
129  return PrintedStmtMatches(Code, Args, NodeMatch, ExpectedPrinted);
130}
131
132::testing::AssertionResult PrintedStmtMSMatches(
133                                              StringRef Code,
134                                              StringRef ContainingFunction,
135                                              StringRef ExpectedPrinted) {
136  std::vector<std::string> Args;
137  Args.push_back("-target");
138  Args.push_back("i686-pc-win32");
139  Args.push_back("-std=c++98");
140  Args.push_back("-fms-extensions");
141  Args.push_back("-Wno-unused-value");
142  return PrintedStmtMatches(Code,
143                            Args,
144                            functionDecl(hasName(ContainingFunction),
145                                         has(compoundStmt(has(stmt().bind("id"))))),
146                            ExpectedPrinted);
147}
148
149} // unnamed namespace
150
151TEST(StmtPrinter, TestIntegerLiteral) {
152  ASSERT_TRUE(PrintedStmtCXX98Matches(
153    "void A() {"
154    "  1, -1, 1U, 1u,"
155    "  1L, 1l, -1L, 1UL, 1ul,"
156    "  1LL, -1LL, 1ULL;"
157    "}",
158    "A",
159    "1 , -1 , 1U , 1U , "
160    "1L , 1L , -1L , 1UL , 1UL , "
161    "1LL , -1LL , 1ULL"));
162    // Should be: with semicolon
163}
164
165TEST(StmtPrinter, TestMSIntegerLiteral) {
166  ASSERT_TRUE(PrintedStmtMSMatches(
167    "void A() {"
168    "  1i8, -1i8, 1ui8, "
169    "  1i16, -1i16, 1ui16, "
170    "  1i32, -1i32, 1ui32, "
171    "  1i64, -1i64, 1ui64;"
172    "}",
173    "A",
174    "1i8 , -1i8 , 1Ui8 , "
175    "1i16 , -1i16 , 1Ui16 , "
176    "1 , -1 , 1U , "
177    "1LL , -1LL , 1ULL"));
178    // Should be: with semicolon
179}
180
181TEST(StmtPrinter, TestFloatingPointLiteral) {
182  ASSERT_TRUE(PrintedStmtCXX98Matches(
183    "void A() { 1.0f, -1.0f, 1.0, -1.0, 1.0l, -1.0l; }",
184    "A",
185    "1.F , -1.F , 1. , -1. , 1.L , -1.L"));
186    // Should be: with semicolon
187}
188
189TEST(StmtPrinter, TestCXXConversionDeclImplicit) {
190  ASSERT_TRUE(PrintedStmtCXX98Matches(
191    "struct A {"
192      "operator void *();"
193      "A operator&(A);"
194    "};"
195    "void bar(void *);"
196    "void foo(A a, A b) {"
197    "  bar(a & b);"
198    "}",
199    cxxMemberCallExpr(anything()).bind("id"),
200    "a & b"));
201}
202
203TEST(StmtPrinter, TestCXXConversionDeclExplicit) {
204  ASSERT_TRUE(PrintedStmtCXX11Matches(
205    "struct A {"
206      "operator void *();"
207      "A operator&(A);"
208    "};"
209    "void bar(void *);"
210    "void foo(A a, A b) {"
211    "  auto x = (a & b).operator void *();"
212    "}",
213    cxxMemberCallExpr(anything()).bind("id"),
214    "(a & b)"));
215    // WRONG; Should be: (a & b).operator void *()
216}
217