RewriteScopedRefptr.cpp revision 03b57e008b61dfcb1fbad3aea950ae0e001748b0
1// Copyright (c) 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4//
5// This implements a Clang tool to rewrite all instances of
6// scoped_refptr<T>'s implicit cast to T (operator T*) to an explicit call to
7// the .get() method.
8
9#include <algorithm>
10#include <memory>
11#include <string>
12
13#include "clang/AST/ASTContext.h"
14#include "clang/ASTMatchers/ASTMatchers.h"
15#include "clang/ASTMatchers/ASTMatchersMacros.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/Basic/SourceManager.h"
18#include "clang/Frontend/FrontendActions.h"
19#include "clang/Lex/Lexer.h"
20#include "clang/Tooling/CommonOptionsParser.h"
21#include "clang/Tooling/Refactoring.h"
22#include "clang/Tooling/Tooling.h"
23#include "llvm/Support/CommandLine.h"
24
25using namespace clang::ast_matchers;
26using clang::tooling::CommonOptionsParser;
27using clang::tooling::Replacement;
28using clang::tooling::Replacements;
29using llvm::StringRef;
30
31namespace clang {
32namespace ast_matchers {
33
34const internal::VariadicDynCastAllOfMatcher<Decl, CXXConversionDecl>
35    conversionDecl;
36
37AST_MATCHER(QualType, isBoolean) {
38  return Node->isBooleanType();
39}
40
41}  // namespace ast_matchers
42}  // namespace clang
43
44namespace {
45
46// Returns true if expr needs to be put in parens (eg: when it is an operator
47// syntactically).
48bool NeedsParens(const clang::Expr* expr) {
49  if (llvm::dyn_cast<clang::UnaryOperator>(expr) ||
50      llvm::dyn_cast<clang::BinaryOperator>(expr) ||
51      llvm::dyn_cast<clang::ConditionalOperator>(expr)) {
52    return true;
53  }
54  // Calls to an overloaded operator also need parens, except for foo(...) and
55  // foo[...] expressions.
56  if (const clang::CXXOperatorCallExpr* op =
57          llvm::dyn_cast<clang::CXXOperatorCallExpr>(expr)) {
58    return op->getOperator() != clang::OO_Call &&
59           op->getOperator() != clang::OO_Subscript;
60  }
61  return false;
62}
63
64Replacement RewriteRawPtrToScopedRefptr(const MatchFinder::MatchResult& result,
65                                        clang::SourceLocation begin,
66                                        clang::SourceLocation end) {
67  clang::CharSourceRange range = clang::CharSourceRange::getTokenRange(
68      result.SourceManager->getSpellingLoc(begin),
69      result.SourceManager->getSpellingLoc(end));
70
71  std::string text = clang::Lexer::getSourceText(
72      range, *result.SourceManager, result.Context->getLangOpts());
73  text.erase(text.rfind('*'));
74
75  std::string replacement_text("scoped_refptr<");
76  replacement_text += text;
77  replacement_text += ">";
78
79  return Replacement(*result.SourceManager, range, replacement_text);
80}
81
82class GetRewriterCallback : public MatchFinder::MatchCallback {
83 public:
84  explicit GetRewriterCallback(Replacements* replacements)
85      : replacements_(replacements) {}
86  virtual void run(const MatchFinder::MatchResult& result) override;
87
88 private:
89  Replacements* const replacements_;
90};
91
92void GetRewriterCallback::run(const MatchFinder::MatchResult& result) {
93  const clang::CXXMemberCallExpr* const implicit_call =
94      result.Nodes.getNodeAs<clang::CXXMemberCallExpr>("call");
95  const clang::Expr* arg = result.Nodes.getNodeAs<clang::Expr>("arg");
96
97  if (!implicit_call || !arg)
98    return;
99
100  clang::CharSourceRange range = clang::CharSourceRange::getTokenRange(
101      result.SourceManager->getSpellingLoc(arg->getLocStart()),
102      result.SourceManager->getSpellingLoc(arg->getLocEnd()));
103  if (!range.isValid())
104    return;  // TODO(rsleevi): Log an error?
105
106  // Handle cases where an implicit cast is being done by dereferencing a
107  // pointer to a scoped_refptr<> (sadly, it happens...)
108  //
109  // This rewrites both "*foo" and "*(foo)" as "foo->get()".
110  if (const clang::UnaryOperator* op =
111          llvm::dyn_cast<clang::UnaryOperator>(arg)) {
112    if (op->getOpcode() == clang::UO_Deref) {
113      const clang::Expr* const sub_expr =
114          op->getSubExpr()->IgnoreParenImpCasts();
115      clang::CharSourceRange sub_expr_range =
116          clang::CharSourceRange::getTokenRange(
117              result.SourceManager->getSpellingLoc(sub_expr->getLocStart()),
118              result.SourceManager->getSpellingLoc(sub_expr->getLocEnd()));
119      if (!sub_expr_range.isValid())
120        return;  // TODO(rsleevi): Log an error?
121      std::string inner_text = clang::Lexer::getSourceText(
122          sub_expr_range, *result.SourceManager, result.Context->getLangOpts());
123      if (inner_text.empty())
124        return;  // TODO(rsleevi): Log an error?
125
126      if (NeedsParens(sub_expr)) {
127        inner_text.insert(0, "(");
128        inner_text.append(")");
129      }
130      inner_text.append("->get()");
131      replacements_->insert(
132          Replacement(*result.SourceManager, range, inner_text));
133      return;
134    }
135  }
136
137  std::string text = clang::Lexer::getSourceText(
138      range, *result.SourceManager, result.Context->getLangOpts());
139  if (text.empty())
140    return;  // TODO(rsleevi): Log an error?
141
142  // Unwrap any temporaries - for example, custom iterators that return
143  // scoped_refptr<T> as part of operator*. Any such iterators should also
144  // be declaring a scoped_refptr<T>* operator->, per C++03 24.4.1.1 (Table 72)
145  if (const clang::CXXBindTemporaryExpr* op =
146          llvm::dyn_cast<clang::CXXBindTemporaryExpr>(arg)) {
147    arg = op->getSubExpr();
148  }
149
150  // Handle iterators (which are operator* calls, followed by implicit
151  // conversions) by rewriting *it as it->get()
152  if (const clang::CXXOperatorCallExpr* op =
153          llvm::dyn_cast<clang::CXXOperatorCallExpr>(arg)) {
154    if (op->getOperator() == clang::OO_Star) {
155      // Note that this doesn't rewrite **it correctly, since it should be
156      // rewritten using parens, e.g. (*it)->get(). However, this shouldn't
157      // happen frequently, if at all, since it would likely indicate code is
158      // storing pointers to a scoped_refptr in a container.
159      text.erase(0, 1);
160      text.append("->get()");
161      replacements_->insert(Replacement(*result.SourceManager, range, text));
162      return;
163    }
164  }
165
166  // The only remaining calls should be non-dereferencing calls (eg: member
167  // calls), so a simple ".get()" appending should suffice.
168  if (NeedsParens(arg)) {
169    text.insert(0, "(");
170    text.append(")");
171  }
172  text.append(".get()");
173  replacements_->insert(Replacement(*result.SourceManager, range, text));
174}
175
176class VarRewriterCallback : public MatchFinder::MatchCallback {
177 public:
178  explicit VarRewriterCallback(Replacements* replacements)
179      : replacements_(replacements) {}
180  virtual void run(const MatchFinder::MatchResult& result) override;
181
182 private:
183  Replacements* const replacements_;
184};
185
186void VarRewriterCallback::run(const MatchFinder::MatchResult& result) {
187  const clang::DeclaratorDecl* const var_decl =
188      result.Nodes.getNodeAs<clang::DeclaratorDecl>("var");
189
190  if (!var_decl)
191    return;
192
193  const clang::TypeSourceInfo* tsi = var_decl->getTypeSourceInfo();
194
195  // TODO(dcheng): This mishandles a case where a variable has multiple
196  // declarations, e.g.:
197  //
198  // in .h:
199  // Foo* my_global_magical_foo;
200  //
201  // in .cc:
202  // Foo* my_global_magical_foo = CreateFoo();
203  //
204  // In this case, it will only rewrite the .cc definition. Oh well. This should
205  // be rare enough that these cases can be manually handled, since the style
206  // guide prohibits globals of non-POD type.
207  replacements_->insert(RewriteRawPtrToScopedRefptr(
208      result, tsi->getTypeLoc().getBeginLoc(), tsi->getTypeLoc().getEndLoc()));
209}
210
211class FunctionRewriterCallback : public MatchFinder::MatchCallback {
212 public:
213  explicit FunctionRewriterCallback(Replacements* replacements)
214      : replacements_(replacements) {}
215  virtual void run(const MatchFinder::MatchResult& result) override;
216
217 private:
218  Replacements* const replacements_;
219};
220
221void FunctionRewriterCallback::run(const MatchFinder::MatchResult& result) {
222  const clang::FunctionDecl* const function_decl =
223      result.Nodes.getNodeAs<clang::FunctionDecl>("fn");
224
225  if (!function_decl)
226    return;
227
228  // If matched against an implicit conversion to a DeclRefExpr, make sure the
229  // referenced declaration is of class type, e.g. the tool skips trying to
230  // chase pointers/references to determine if the pointee is a scoped_refptr<T>
231  // with local storage. Instead, let a human manually handle those cases.
232  const clang::VarDecl* const var_decl =
233      result.Nodes.getNodeAs<clang::VarDecl>("var");
234  if (var_decl && !var_decl->getTypeSourceInfo()->getType()->isClassType()) {
235    return;
236  }
237
238  for (clang::FunctionDecl* f : function_decl->redecls()) {
239    clang::SourceRange range = f->getReturnTypeSourceRange();
240    replacements_->insert(
241        RewriteRawPtrToScopedRefptr(result, range.getBegin(), range.getEnd()));
242  }
243}
244
245}  // namespace
246
247static llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage);
248
249int main(int argc, const char* argv[]) {
250  llvm::cl::OptionCategory category("Remove scoped_refptr conversions");
251  CommonOptionsParser options(argc, argv, category);
252  clang::tooling::ClangTool tool(options.getCompilations(),
253                                 options.getSourcePathList());
254
255  MatchFinder match_finder;
256  Replacements replacements;
257
258  // Finds all calls to conversion operator member function. This catches calls
259  // to "operator T*", "operator Testable", and "operator bool" equally.
260  auto base_matcher =
261      id("call",
262         memberCallExpr(
263             thisPointerType(recordDecl(isSameOrDerivedFrom("::scoped_refptr"),
264                                        isTemplateInstantiation())),
265             callee(conversionDecl()),
266             on(id("arg", expr()))));
267
268  // The heuristic for whether or not converting a temporary is 'unsafe'. An
269  // unsafe conversion is one where a temporary scoped_refptr<T> is converted to
270  // another type. The matcher provides an exception for a temporary
271  // scoped_refptr that is the result of an operator call. In this case, assume
272  // that it's the result of an iterator dereference, and the container itself
273  // retains the necessary reference, since this is a common idiom to see in
274  // loop bodies.
275  auto is_unsafe_temporary_conversion =
276      on(bindTemporaryExpr(unless(has(operatorCallExpr()))));
277
278  // Returning a scoped_refptr<T> as a T* is considered unsafe if either are
279  // true:
280  // - The scoped_refptr<T> is a temporary.
281  // - The scoped_refptr<T> has local lifetime.
282  auto returned_as_raw_ptr = hasParent(
283      returnStmt(hasAncestor(id("fn", functionDecl(returns(pointerType()))))));
284  // This matcher intentionally matches more than it should. For example, this
285  // will match:
286  //   scoped_refptr<Foo>& foo = some_other_foo;
287  //   return foo;
288  // The matcher callback filters out VarDecls that aren't a scoped_refptr<T>,
289  // so those cases can be manually handled.
290  auto is_local_variable =
291      on(declRefExpr(to(id("var", varDecl(hasLocalStorage())))));
292  auto is_unsafe_return =
293      anyOf(allOf(hasParent(implicitCastExpr(returned_as_raw_ptr)),
294                  is_local_variable),
295            allOf(hasParent(implicitCastExpr(
296                      hasParent(exprWithCleanups(returned_as_raw_ptr)))),
297                  is_unsafe_temporary_conversion));
298
299  // This catches both user-defined conversions (eg: "operator bool") and
300  // standard conversion sequence (C++03 13.3.3.1.1), such as converting a
301  // pointer to a bool.
302  auto implicit_to_bool =
303      implicitCastExpr(hasImplicitDestinationType(isBoolean()));
304
305  // Avoid converting calls to of "operator Testable" -> "bool" and calls of
306  // "operator T*" -> "bool".
307  auto bool_conversion_matcher = hasParent(
308      expr(anyOf(implicit_to_bool, expr(hasParent(implicit_to_bool)))));
309
310  // Find all calls to an operator overload that are 'safe'.
311  //
312  // All bool conversions will be handled with the Testable trick, but that
313  // can only be used once "operator T*" is removed, since otherwise it leaves
314  // the call ambiguous.
315  GetRewriterCallback get_callback(&replacements);
316  match_finder.addMatcher(
317      memberCallExpr(
318          base_matcher,
319          unless(anyOf(is_unsafe_temporary_conversion, is_unsafe_return))),
320      &get_callback);
321
322  // Find temporary scoped_refptr<T>'s being unsafely assigned to a T*.
323  VarRewriterCallback var_callback(&replacements);
324  auto initialized_with_temporary = ignoringImpCasts(exprWithCleanups(
325      has(memberCallExpr(base_matcher, is_unsafe_temporary_conversion))));
326  match_finder.addMatcher(id("var",
327                             varDecl(hasInitializer(initialized_with_temporary),
328                                     hasType(pointerType()))),
329                          &var_callback);
330  match_finder.addMatcher(
331      constructorDecl(forEachConstructorInitializer(
332          allOf(withInitializer(initialized_with_temporary),
333                forField(id("var", fieldDecl(hasType(pointerType()))))))),
334      &var_callback);
335
336  // Rewrite functions that unsafely turn a scoped_refptr<T> into a T* when
337  // returning a value.
338  FunctionRewriterCallback fn_callback(&replacements);
339  match_finder.addMatcher(memberCallExpr(base_matcher, is_unsafe_return),
340                          &fn_callback);
341
342  std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
343      clang::tooling::newFrontendActionFactory(&match_finder);
344  int result = tool.run(factory.get());
345  if (result != 0)
346    return result;
347
348  // Serialization format is documented in tools/clang/scripts/run_tool.py
349  llvm::outs() << "==== BEGIN EDITS ====\n";
350  for (const auto& r : replacements) {
351    std::string replacement_text = r.getReplacementText().str();
352    std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
353    llvm::outs() << "r:" << r.getFilePath() << ":" << r.getOffset() << ":"
354                 << r.getLength() << ":" << replacement_text << "\n";
355  }
356  llvm::outs() << "==== END EDITS ====\n";
357
358  return 0;
359}
360