jbangert updated this revision to Diff 87575.
jbangert added a comment.

- change to push_back


https://reviews.llvm.org/D29621

Files:
  include/clang/Tooling/RefactoringCallbacks.h
  lib/Tooling/RefactoringCallbacks.cpp
  unittests/Tooling/RefactoringCallbacksTest.cpp

Index: unittests/Tooling/RefactoringCallbacksTest.cpp
===================================================================
--- unittests/Tooling/RefactoringCallbacksTest.cpp
+++ unittests/Tooling/RefactoringCallbacksTest.cpp
@@ -7,31 +7,30 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Tooling/RefactoringCallbacks.h"
 #include "RewriterTestContext.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Tooling/RefactoringCallbacks.h"
 #include "gtest/gtest.h"
 
 namespace clang {
 namespace tooling {
 
 using namespace ast_matchers;
 
 template <typename T>
-void expectRewritten(const std::string &Code,
-                     const std::string &Expected,
-                     const T &AMatcher,
-                     RefactoringCallback &Callback) {
-  MatchFinder Finder;
+void expectRewritten(const std::string &Code, const std::string &Expected,
+                     const T &AMatcher, RefactoringCallback &Callback) {
+  std::map<std::string, Replacements> FileToReplace;
+  ASTMatchRefactorer Finder(FileToReplace);
   Finder.addMatcher(AMatcher, &Callback);
   std::unique_ptr<tooling::FrontendActionFactory> Factory(
       tooling::newFrontendActionFactory(&Finder));
   ASSERT_TRUE(tooling::runToolOnCode(Factory->create(), Code))
       << "Parsing error in \"" << Code << "\"";
   RewriterTestContext Context;
   FileID ID = Context.createInMemoryFile("input.cc", Code);
-  EXPECT_TRUE(tooling::applyAllReplacements(Callback.getReplacements(),
+  EXPECT_TRUE(tooling::applyAllReplacements(FileToReplace["input.cc"],
                                             Context.Rewrite));
   EXPECT_EQ(Expected, Context.getRewrittenText(ID));
 }
@@ -61,39 +60,64 @@
   std::string Code = "void f() { int i = 1; }";
   std::string Expected = "void f() { int i = 2; }";
   ReplaceStmtWithText Callback("id", "2");
-  expectRewritten(Code, Expected, id("id", expr(integerLiteral())),
-                  Callback);
+  expectRewritten(Code, Expected, id("id", expr(integerLiteral())), Callback);
 }
 
 TEST(RefactoringCallbacksTest, ReplacesStmtWithStmt) {
   std::string Code = "void f() { int i = false ? 1 : i * 2; }";
   std::string Expected = "void f() { int i = i * 2; }";
   ReplaceStmtWithStmt Callback("always-false", "should-be");
-  expectRewritten(Code, Expected,
-      id("always-false", conditionalOperator(
-          hasCondition(cxxBoolLiteral(equals(false))),
-          hasFalseExpression(id("should-be", expr())))),
+  expectRewritten(
+      Code, Expected,
+      id("always-false",
+         conditionalOperator(hasCondition(cxxBoolLiteral(equals(false))),
+                             hasFalseExpression(id("should-be", expr())))),
       Callback);
 }
 
 TEST(RefactoringCallbacksTest, ReplacesIfStmt) {
   std::string Code = "bool a; void f() { if (a) f(); else a = true; }";
   std::string Expected = "bool a; void f() { f(); }";
   ReplaceIfStmtWithItsBody Callback("id", true);
-  expectRewritten(Code, Expected,
-      id("id", ifStmt(
-          hasCondition(implicitCastExpr(hasSourceExpression(
-              declRefExpr(to(varDecl(hasName("a"))))))))),
+  expectRewritten(
+      Code, Expected,
+      id("id", ifStmt(hasCondition(implicitCastExpr(hasSourceExpression(
+                   declRefExpr(to(varDecl(hasName("a"))))))))),
       Callback);
 }
 
 TEST(RefactoringCallbacksTest, RemovesEntireIfOnEmptyElse) {
   std::string Code = "void f() { if (false) int i = 0; }";
   std::string Expected = "void f() {  }";
   ReplaceIfStmtWithItsBody Callback("id", false);
   expectRewritten(Code, Expected,
-      id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))),
-      Callback);
+                  id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))),
+                  Callback);
+}
+
+TEST(RefactoringCallbacksTest, TemplateJustText) {
+  std::string Code = "void f() { int i = 1; }";
+  std::string Expected = "void f() { FOO }";
+  ReplaceNodeWithTemplate Callback("id", "FOO");
+  expectRewritten(Code, Expected, id("id", declStmt()), Callback);
+}
+
+TEST(RefactoringCallbacksTest, TemplateSimpleSubst) {
+  std::string Code = "void f() { int i = 1; }";
+  std::string Expected = "void f() { long x = 1; }";
+  ReplaceNodeWithTemplate Callback("decl", "long x = ${init}");
+  expectRewritten(Code, Expected,
+                  id("decl", varDecl(hasInitializer(id("init", expr())))),
+                  Callback);
+}
+
+TEST(RefactoringCallbacksTest, TemplateLiteral) {
+  std::string Code = "void f() { int i = 1; }";
+  std::string Expected = "void f() { string x = \"$-1\"; }";
+  ReplaceNodeWithTemplate Callback("decl", "string x = \"$$-${init}\"");
+  expectRewritten(Code, Expected,
+                  id("decl", varDecl(hasInitializer(id("init", expr())))),
+                  Callback);
 }
 
 } // end namespace ast_matchers
Index: lib/Tooling/RefactoringCallbacks.cpp
===================================================================
--- lib/Tooling/RefactoringCallbacks.cpp
+++ lib/Tooling/RefactoringCallbacks.cpp
@@ -9,8 +9,13 @@
 //
 //
 //===----------------------------------------------------------------------===//
-#include "clang/Lex/Lexer.h"
 #include "clang/Tooling/RefactoringCallbacks.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Lex/Lexer.h"
+
+using llvm::StringError;
+using llvm::make_error;
 
 namespace clang {
 namespace tooling {
@@ -20,18 +25,62 @@
   return Replace;
 }
 
-static Replacement replaceStmtWithText(SourceManager &Sources,
-                                       const Stmt &From,
+ASTMatchRefactorer::ASTMatchRefactorer(
+    std::map<std::string, Replacements> &FileToReplaces)
+    : FileToReplaces(FileToReplaces) {}
+
+void ASTMatchRefactorer::addDynamicMatcher(
+    const ast_matchers::internal::DynTypedMatcher &Matcher,
+    RefactoringCallback *Callback) {
+  MatchFinder.addDynamicMatcher(Matcher, Callback);
+  Callbacks.push_back(Callback);
+}
+
+class RefactoringASTConsumer : public ASTConsumer {
+public:
+  RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
+      : Refactoring(Refactoring) {}
+
+  void HandleTranslationUnit(ASTContext &Context) override {
+    // The ASTMatchRefactorer is re-used between translation units.
+    // Clear the matchers so that each Replacement is only emitted once.
+    for (const auto &Callback : Refactoring.Callbacks) {
+      Callback->getReplacements().clear();
+    }
+    Refactoring.MatchFinder.matchAST(Context);
+    for (const auto &Callback : Refactoring.Callbacks) {
+      for (const auto &Replacement : Callback->getReplacements()) {
+        llvm::Error Err =
+            Refactoring.FileToReplaces[Replacement.getFilePath()].add(
+                Replacement);
+        if (Err) {
+          llvm::errs() << "Skipping replacement " << Replacement.toString()
+                       << " due to this error:\n"
+                       << toString(std::move(Err)) << "\n";
+        }
+      }
+    }
+  }
+
+private:
+  ASTMatchRefactorer &Refactoring;
+};
+
+std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
+  return llvm::make_unique<RefactoringASTConsumer>(*this);
+}
+
+static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
                                        StringRef Text) {
-  return tooling::Replacement(Sources, CharSourceRange::getTokenRange(
-      From.getSourceRange()), Text);
+  return tooling::Replacement(
+      Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
 }
-static Replacement replaceStmtWithStmt(SourceManager &Sources,
-                                       const Stmt &From,
+static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
                                        const Stmt &To) {
-  return replaceStmtWithText(Sources, From, Lexer::getSourceText(
-      CharSourceRange::getTokenRange(To.getSourceRange()),
-      Sources, LangOptions()));
+  return replaceStmtWithText(
+      Sources, From,
+      Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
+                           Sources, LangOptions()));
 }
 
 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
@@ -103,5 +152,91 @@
   }
 }
 
+ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
+    llvm::StringRef FromId, std::vector<TemplateElement> &&Template)
+    : FromId(FromId), Template(Template) {}
+
+llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
+ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
+  std::vector<TemplateElement> ParsedTemplate;
+  for (size_t Index = 0; Index < ToTemplate.size();) {
+    if (ToTemplate[Index] == '$') {
+      if (ToTemplate.substr(Index, 2) == "$$") {
+        Index += 2;
+        ParsedTemplate.push_back(
+            TemplateElement{TemplateElement::Literal, "$"});
+      } else if (ToTemplate.substr(Index, 2) == "${") {
+        size_t EndOfIdentifier = ToTemplate.find("}", Index);
+        if (EndOfIdentifier == std::string::npos) {
+          return make_error<StringError>(
+              "Unterminated ${...} in replacement template near " +
+                  ToTemplate.substr(Index),
+              std::make_error_code(std::errc::bad_message));
+        }
+        std::string SourceNodeName =
+            ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2);
+        ParsedTemplate.push_back(
+            TemplateElement{TemplateElement::Identifier, SourceNodeName});
+        Index = EndOfIdentifier + 1;
+      } else {
+        return make_error<StringError>(
+            "Invalid $ in replacement template near " +
+                ToTemplate.substr(Index),
+            std::make_error_code(std::errc::bad_message));
+      }
+    } else {
+      size_t NextIndex = ToTemplate.find('$', Index + 1);
+      ParsedTemplate.push_back(
+          TemplateElement{TemplateElement::Literal,
+                          ToTemplate.substr(Index, NextIndex - Index)});
+      Index = NextIndex;
+    }
+  }
+  return std::unique_ptr<ReplaceNodeWithTemplate>(
+      new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
+}
+
+void ReplaceNodeWithTemplate::run(
+    const ast_matchers::MatchFinder::MatchResult &Result) {
+  const auto &NodeMap = Result.Nodes.getMap();
+
+  std::string ToText;
+  for (const auto &Element : Template) {
+    switch (Element.Type) {
+    case TemplateElement::Literal:
+      ToText += Element.Value;
+      break;
+    case TemplateElement::Identifier: {
+      if (NodeMap.count(Element.Value) == 0) {
+        llvm::errs() << "Node " << Element.Value
+                     << " used in replacement template not bound in Matcher \n";
+        llvm_unreachable("Unbound node in replacement template.");
+      }
+      CharSourceRange Source = CharSourceRange::getTokenRange(
+          NodeMap.at(Element.Value).getSourceRange());
+      ToText += Lexer::getSourceText(Source, *Result.SourceManager,
+                                     Result.Context->getLangOpts());
+      break;
+    }
+    default:
+      llvm_unreachable("Element.Type not recognized");
+    }
+  }
+  if (NodeMap.count(FromId) == 0) {
+    llvm::errs() << "Node to be replaced " << FromId
+                 << " not bound in query.\n";
+    llvm_unreachable("FromId node not bound in MatchResult");
+  }
+  auto Replacement =
+      tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
+                           Result.Context->getLangOpts());
+  llvm::Error Err = Replace.add(Replacement);
+  if (Err) {
+    llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
+                 << "! " << llvm::toString(std::move(Err)) << "\n";
+    llvm_unreachable("Replacement failed");
+  }
+}
+
 } // end namespace tooling
 } // end namespace clang
Index: include/clang/Tooling/RefactoringCallbacks.h
===================================================================
--- include/clang/Tooling/RefactoringCallbacks.h
+++ include/clang/Tooling/RefactoringCallbacks.h
@@ -47,6 +47,32 @@
   Replacements Replace;
 };
 
+/// \brief Adaptor between \c ast_matchers::MatchFinder and \c
+/// tooling::RefactoringTool.
+///
+/// Runs AST matchers and stores the \c tooling::Replacements in a map.
+class ASTMatchRefactorer {
+public:
+  ASTMatchRefactorer(std::map<std::string, Replacements> &FileToReplaces);
+
+  template <typename T>
+  void addMatcher(const T &Matcher, RefactoringCallback *Callback) {
+    MatchFinder.addMatcher(Matcher, Callback);
+    Callbacks.push_back(Callback);
+  }
+
+  void addDynamicMatcher(const ast_matchers::internal::DynTypedMatcher &Matcher,
+                         RefactoringCallback *Callback);
+
+  std::unique_ptr<ASTConsumer> newASTConsumer();
+
+private:
+  friend class RefactoringASTConsumer;
+  std::vector<RefactoringCallback *> Callbacks;
+  ast_matchers::MatchFinder MatchFinder;
+  std::map<std::string, Replacements> &FileToReplaces;
+};
+
 /// \brief Replace the text of the statement bound to \c FromId with the text in
 /// \c ToText.
 class ReplaceStmtWithText : public RefactoringCallback {
@@ -59,6 +85,29 @@
   std::string ToText;
 };
 
+/// \brief Replace the text of an AST node bound to \c FromId with the result of
+/// evaluating the template in \c ToTemplate.
+///
+/// Expressions of the form ${NodeName} in \c ToTemplate will be
+/// replaced by the text of the node bound to ${NodeName}. The string
+/// "$$" will be replaced by "$".
+class ReplaceNodeWithTemplate : public RefactoringCallback {
+public:
+  static llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
+  create(StringRef FromId, StringRef ToTemplate);
+  void run(const ast_matchers::MatchFinder::MatchResult &Result) override;
+
+private:
+  struct TemplateElement {
+    enum { Literal, Identifier } Type;
+    std::string Value;
+  };
+  ReplaceNodeWithTemplate(llvm::StringRef FromId,
+                          std::vector<TemplateElement> &&Template);
+  std::string FromId;
+  std::vector<TemplateElement> Template;
+};
+
 /// \brief Replace the text of the statement bound to \c FromId with the text of
 /// the statement bound to \c ToId.
 class ReplaceStmtWithStmt : public RefactoringCallback {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to