johannes updated this revision to Diff 121656.
johannes added a comment.

update


https://reviews.llvm.org/D37005

Files:
  include/clang/Tooling/ASTDiff/ASTDiff.h
  include/clang/Tooling/ASTDiff/ASTPatch.h
  lib/Tooling/ASTDiff/ASTDiff.cpp
  lib/Tooling/ASTDiff/ASTPatch.cpp
  lib/Tooling/ASTDiff/CMakeLists.txt
  test/Tooling/clang-diff-patch.test
  tools/clang-diff/CMakeLists.txt
  tools/clang-diff/ClangDiff.cpp
  unittests/Tooling/ASTPatchTest.cpp
  unittests/Tooling/CMakeLists.txt

Index: unittests/Tooling/CMakeLists.txt
===================================================================
--- unittests/Tooling/CMakeLists.txt
+++ unittests/Tooling/CMakeLists.txt
@@ -11,6 +11,7 @@
 endif()
 
 add_clang_unittest(ToolingTests
+  ASTPatchTest.cpp
   ASTSelectionTest.cpp
   CastExprTest.cpp
   CommentHandlerTest.cpp
@@ -45,4 +46,5 @@
   clangTooling
   clangToolingCore
   clangToolingRefactor
+  clangToolingASTDiff
   )
Index: unittests/Tooling/ASTPatchTest.cpp
===================================================================
--- /dev/null
+++ unittests/Tooling/ASTPatchTest.cpp
@@ -0,0 +1,265 @@
+//===- unittest/Tooling/ASTPatchTest.cpp ----------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/ASTDiff/ASTPatch.h"
+#include "clang/Tooling/ASTDiff/ASTDiff.h"
+#include "clang/Tooling/Tooling.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/Program.h"
+#include "gtest/gtest.h"
+#include <fstream>
+
+using namespace clang;
+using namespace tooling;
+
+std::string ReadShellCommand(const Twine &Command) {
+  char Buffer[128];
+  std::string Result;
+  std::shared_ptr<FILE> Pipe(popen(Command.str().data(), "r"), pclose);
+  if (!Pipe)
+    return Result;
+  while (!feof(Pipe.get())) {
+    if (fgets(Buffer, 128, Pipe.get()) != nullptr)
+      Result += Buffer;
+  }
+  return Result;
+}
+
+class ASTPatchTest : public ::testing::Test {
+  llvm::SmallString<256> TargetFile, ExpectedFile;
+  std::array<std::string, 1> TargetFileArray;
+
+public:
+  void SetUp() override {
+    std::string Suffix = "cpp";
+    ASSERT_FALSE(llvm::sys::fs::createTemporaryFile(
+        "clang-libtooling-patch-target", Suffix, TargetFile));
+    ASSERT_FALSE(llvm::sys::fs::createTemporaryFile(
+        "clang-libtooling-patch-expected", Suffix, ExpectedFile));
+    TargetFileArray[0] = TargetFile.str();
+  }
+  void TearDown() override {
+    llvm::sys::fs::remove(TargetFile);
+    llvm::sys::fs::remove(ExpectedFile);
+  }
+
+  void WriteFile(StringRef Filename, StringRef Contents) {
+    std::ofstream OS(Filename);
+    OS << Contents.str();
+    assert(OS.good());
+  }
+
+  std::string ReadFile(StringRef Filename) {
+    std::ifstream IS(Filename);
+    std::stringstream OS;
+    OS << IS.rdbuf();
+    assert(IS.good());
+    return OS.str();
+  }
+
+  std::string formatExpected(StringRef Code) {
+    WriteFile(ExpectedFile, Code);
+    return ReadShellCommand("clang-format " + ExpectedFile);
+  }
+
+  llvm::Expected<std::string> patchResult(const char *SrcCode,
+                                          const char *DstCode,
+                                          const char *TargetCode) {
+    std::unique_ptr<ASTUnit> SrcAST = buildASTFromCode(SrcCode),
+                             DstAST = buildASTFromCode(DstCode);
+    if (!SrcAST || !DstAST) {
+      if (!SrcAST)
+        llvm::errs() << "Failed to build AST from code:\n" << SrcCode << "\n";
+      if (!DstAST)
+        llvm::errs() << "Failed to build AST from code:\n" << DstCode << "\n";
+      return llvm::make_error<diff::PatchingError>(
+          diff::patching_error::failed_to_build_AST);
+    }
+
+    diff::SyntaxTree Src(*SrcAST);
+    diff::SyntaxTree Dst(*DstAST);
+
+    WriteFile(TargetFile, TargetCode);
+    FixedCompilationDatabase Compilations(".", std::vector<std::string>());
+    RefactoringTool TargetTool(Compilations, TargetFileArray);
+    diff::ComparisonOptions Options;
+
+    if (auto Err = diff::patch(TargetTool, Src, Dst, Options, /*Debug=*/false))
+      return std::move(Err);
+    return ReadShellCommand("clang-format " + TargetFile);
+  }
+
+#define APPEND_NEWLINE(x) x "\n"
+// use macros for this to make test failures have proper line numbers
+#define PATCH(Src, Dst, Target, ExpectedResult)                                \
+  {                                                                            \
+    llvm::Expected<std::string> Result = patchResult(                          \
+        APPEND_NEWLINE(Src), APPEND_NEWLINE(Dst), APPEND_NEWLINE(Target));     \
+    ASSERT_TRUE(bool(Result));                                                 \
+    EXPECT_EQ(Result.get(), formatExpected(APPEND_NEWLINE(ExpectedResult)));   \
+  }
+#define PATCH_ERROR(Src, Dst, Target, ErrorCode)                               \
+  {                                                                            \
+    llvm::Expected<std::string> Result = patchResult(Src, Dst, Target);        \
+    ASSERT_FALSE(bool(Result));                                                \
+    llvm::handleAllErrors(Result.takeError(),                                  \
+                          [&](const diff::PatchingError &PE) {                 \
+                            EXPECT_EQ(PE.get(), ErrorCode);                    \
+                          });                                                  \
+  }
+};
+
+TEST_F(ASTPatchTest, Delete) {
+  PATCH(R"(void f() { { int x = 1; } })",
+        R"(void f() { })",
+        R"(void f() { { int x = 2; } })",
+        R"(void f() {  })");
+  PATCH(R"(void foo(){})",
+        R"()",
+        R"(int x; void foo() {;;} int y;)",
+        R"(int x;  int y;)");
+}
+TEST_F(ASTPatchTest, DeleteCallArguments) {
+  PATCH(R"(void foo(...); void test1() { foo ( 1 + 1); })",
+        R"(void foo(...); void test1() { foo ( ); })",
+        R"(void foo(...); void test2() { foo ( 1 + 1 ); })",
+        R"(void foo(...); void test2() { foo (  ); })");
+}
+TEST_F(ASTPatchTest, DeleteParmVarDecl) {
+  PATCH(R"(void foo(int a);)",
+        R"(void foo();)",
+        R"(void bar(int x);)",
+        R"(void bar();)");
+}
+TEST_F(ASTPatchTest, Insert) {
+  PATCH(R"(class C {              C() {} };)",
+        R"(class C { int b;       C() {} };)",
+        R"(class C { int c;       C() {} };)",
+        R"(class C { int c;int b; C() {} };)");
+  PATCH(R"(class C {        C() {} };)",
+        R"(class C { int b; C() {} };)",
+        R"(class C {        C() {} };)",
+        R"(class C { int b; C() {} };)");
+  PATCH(R"(class C { int x;              };)",
+        R"(class C { int x;int b;        };)",
+        R"(class C { int x ;int c;        };)",
+        R"(class C { int x;int b;int c;  };)");
+  PATCH(R"(class C { int x;              };)",
+        R"(class C { int x;int b;        };)",
+        R"(class C { int x; int c;        };)",
+        R"(class C { int x;int b;int c;  };)");
+  PATCH(R"(class C { int x;              };)",
+        R"(class C { int x;int b;        };)",
+        R"(class C { int x;int c;        };)",
+        R"(class C { int x;int b;int c;  };)");
+  PATCH(R"(int a;)",
+        R"(int a; int x();)",
+        R"(int a;)",
+        R"(int a; int x();)");
+  PATCH(R"(int a; int b;)",
+        R"(int a; int x; int b;)",
+        R"(int a; int b;)",
+        R"(int a; int x; int b;)");
+  PATCH(R"(int b;)",
+        R"(int x; int b;)",
+        R"(int b;)",
+        R"(int x; int b;)");
+  PATCH(R"(void f() {   int x = 1 + 1;   })",
+        R"(void f() { { int x = 1 + 1; } })",
+        R"(void f() {   int x = 1 + 1;   })",
+        R"(void f() { { int x = 1 + 1; } })");
+}
+TEST_F(ASTPatchTest, InsertNoParent) {
+  PATCH(R"(void f() { })",
+        R"(void f() { int x; })",
+        R"()",
+        R"()");
+}
+TEST_F(ASTPatchTest, InsertTopLevel) {
+  PATCH(R"(namespace a {})",
+        R"(namespace a {} void x();)",
+        R"(namespace a {})",
+        R"(namespace a {} void x();)");
+}
+TEST_F(ASTPatchTest, Move) {
+  PATCH(R"(namespace a {  void f(){} })",
+        R"(namespace a {} void f(){}  )",
+        R"(namespace a {  void f(){} })",
+        R"(namespace a {} void f(){}  )");
+  PATCH(R"(namespace a {  void f(){} } int x;)",
+        R"(namespace a {} void f(){}   int x;)",
+        R"(namespace a {  void f(){} } int x;)",
+        R"(namespace a {} void f(){}   int x;)");
+  PATCH(R"(namespace a { namespace { } })",
+        R"(namespace a { })",
+        R"(namespace a { namespace { } })",
+        R"(namespace a { })");
+  PATCH(R"(namespace { int x = 1 + 1; })",
+        R"(namespace { int x = 1 + 1; int y;})",
+        R"(namespace { int x = 1 + 1; })",
+        R"(namespace { int x = 1 + 1; int y;})");
+  PATCH(R"(namespace { int y; int x = 1 + 1; })",
+        R"(namespace { int x = 1 + 1; int y; })",
+        R"(namespace { int y; int x = 1 + 1; })",
+        R"(namespace { int x = 1 + 1; int y; })");
+  PATCH(R"(void f() { ; int x = 1 + 1; })",
+        R"(void f() { int x = 1 + 1; ; })",
+        R"(void f() { ; int x = 1 + 1; })",
+        R"(void f() { int x = 1 + 1; ; })");
+  PATCH(R"(void f() { {{;;;}}        })",
+        R"(void f() { {{{;;;}}}      })",
+        R"(void f() { {{;;;}}        })",
+        R"(void f() { {{{;;;}}}      })");
+}
+TEST_F(ASTPatchTest, MoveNoSource) {
+  PATCH(R"(void f() { })",
+        R"(void f() { int x; })",
+        R"()",
+        R"()");
+}
+TEST_F(ASTPatchTest, MoveNoTarget) {
+  PATCH(R"(int x; void f() { })",
+        R"(void f() { int x; })",
+        R"(int x;)",
+        R"()");
+}
+TEST_F(ASTPatchTest, Newline) {
+  PATCH(R"(void f(){
+;
+})",
+        R"(void f(){
+;
+int x;
+})",
+        R"(void f(){
+;
+})",
+        R"(void f(){
+;
+int x;
+})");
+}
+TEST_F(ASTPatchTest, Nothing) {
+  PATCH(R"()",
+        R"()",
+        R"()",
+        R"()");
+}
+TEST_F(ASTPatchTest, Update) {
+  PATCH(R"(class A { int x; };)",
+        R"(class A { int x; };)",
+        R"(class A { int y; };)",
+        R"(class A { int y; };)");
+}
+TEST_F(ASTPatchTest, UpdateMove) {
+  PATCH(R"(void f() { { int x = 1; } })",
+        R"(void f() { })",
+        R"(void f() { { int x = 2; } })",
+        R"(void f() {  })");
+}
Index: tools/clang-diff/ClangDiff.cpp
===================================================================
--- tools/clang-diff/ClangDiff.cpp
+++ tools/clang-diff/ClangDiff.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Tooling/ASTDiff/ASTDiff.h"
+#include "clang/Tooling/ASTDiff/ASTPatch.h"
 #include "clang/Tooling/CommonOptionsParser.h"
 #include "clang/Tooling/Tooling.h"
 #include "llvm/Support/CommandLine.h"
@@ -41,6 +42,12 @@
                               cl::desc("Output a side-by-side diff in HTML."),
                               cl::init(false), cl::cat(ClangDiffCategory));
 
+static cl::opt<std::string>
+    FileToPatch("patch",
+                cl::desc("Try to apply the edit actions between the two input "
+                         "files to the specified target."),
+                cl::desc("<target>"), cl::cat(ClangDiffCategory));
+
 static cl::opt<std::string> SourcePath(cl::Positional, cl::desc("<source>"),
                                        cl::Required,
                                        cl::cat(ClangDiffCategory));
@@ -453,6 +460,24 @@
   }
   diff::SyntaxTree SrcTree(*Src);
   diff::SyntaxTree DstTree(*Dst);
+
+  if (!FileToPatch.empty()) {
+    std::array<std::string, 1> Files = {{FileToPatch}};
+    RefactoringTool TargetTool(CommonCompilations
+                                   ? *CommonCompilations
+                                   : *getCompilationDatabase(FileToPatch),
+                               Files);
+    if (auto Err = diff::patch(TargetTool, SrcTree, DstTree, Options)) {
+      llvm::handleAllErrors(
+          std::move(Err),
+          [](const diff::PatchingError &PE) { PE.log(llvm::errs()); },
+          [](const ReplacementError &RE) { RE.log(llvm::errs()); });
+      llvm::errs() << "*** errors occured, patching failed.\n";
+      return 1;
+    }
+    return 0;
+  }
+
   diff::ASTDiff Diff(SrcTree, DstTree, Options);
 
   if (HtmlDiff) {
Index: tools/clang-diff/CMakeLists.txt
===================================================================
--- tools/clang-diff/CMakeLists.txt
+++ tools/clang-diff/CMakeLists.txt
@@ -9,6 +9,8 @@
 target_link_libraries(clang-diff
   clangBasic
   clangFrontend
+  clangRewrite
   clangTooling
+  clangToolingCore
   clangToolingASTDiff
   )
Index: test/Tooling/clang-diff-patch.test
===================================================================
--- /dev/null
+++ test/Tooling/clang-diff-patch.test
@@ -0,0 +1,9 @@
+// compare the file with an empty file, patch it to remove all code
+RUN: rm -rf %t && mkdir -p %t
+RUN: cp %S/clang-diff-ast.cpp %t
+RUN: echo > %t/dst.cpp
+RUN: clang-diff %t/clang-diff-ast.cpp %t/dst.cpp \
+RUN:  -patch %t/clang-diff-ast.cpp -- -std=c++11
+// the resulting file should not contain anything other than comments and
+// whitespace
+RUN: cat %t/clang-diff-ast.cpp | grep -v '^#' | grep -v '^\s*//' | not grep -v '^\s*$'
Index: lib/Tooling/ASTDiff/CMakeLists.txt
===================================================================
--- lib/Tooling/ASTDiff/CMakeLists.txt
+++ lib/Tooling/ASTDiff/CMakeLists.txt
@@ -4,8 +4,13 @@
 
 add_clang_library(clangToolingASTDiff
   ASTDiff.cpp
+  ASTPatch.cpp
   LINK_LIBS
   clangBasic
   clangAST
   clangLex
+  clangRewrite
+  clangFrontend
+  clangTooling
+  clangToolingCore
   )
Index: lib/Tooling/ASTDiff/ASTPatch.cpp
===================================================================
--- /dev/null
+++ lib/Tooling/ASTDiff/ASTPatch.cpp
@@ -0,0 +1,582 @@
+//===- ASTPatch.cpp - Structural patching based on ASTDiff ----*- C++ -*- -===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/ASTDiff/ASTPatch.h"
+
+#include "clang/AST/DeclTemplate.h"
+#include "clang/AST/ExprCXX.h"
+#include "clang/Rewrite/Core/Rewriter.h"
+#include "clang/Tooling/Core/Replacement.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace tooling;
+
+namespace clang {
+namespace diff {
+
+static Error error(patching_error code) {
+  return llvm::make_error<PatchingError>(code);
+};
+
+static CharSourceRange makeEmptyCharRange(SourceLocation Point) {
+  return CharSourceRange::getCharRange(Point, Point);
+}
+
+// Returns a comparison function that considers invalid source locations
+// to be less than anything.
+static std::function<bool(SourceLocation, SourceLocation)>
+makeTolerantLess(SourceManager &SM) {
+  return [&SM](SourceLocation A, SourceLocation B) {
+    if (A.isInvalid())
+      return true;
+    if (B.isInvalid())
+      return false;
+    BeforeThanCompare<SourceLocation> Less(SM);
+    return Less(A, B);
+  };
+}
+
+namespace {
+// This wraps a node from Patcher::Target or Patcher::Dst.
+class PatchedTreeNode {
+  NodeRef BaseNode;
+
+public:
+  operator NodeRef() const { return BaseNode; }
+  NodeRef originalNode() const { return *this; }
+  CharSourceRange getSourceRange() const { return BaseNode.getSourceRange(); }
+  NodeId getId() const { return BaseNode.getId(); }
+  SyntaxTree &getTree() const { return BaseNode.getTree(); }
+  StringRef getTypeLabel() const { return BaseNode.getTypeLabel(); }
+  decltype(BaseNode.getOwnedSourceRanges()) getOwnedSourceRanges() {
+    return BaseNode.getOwnedSourceRanges();
+  }
+
+  // This flag indicates whether this node, or any of its descendants was
+  // changed with regards to the original tree.
+  bool Changed = false;
+  // The pointers to the children, including nodes that have been inserted or
+  // moved here.
+  SmallVector<PatchedTreeNode *, 4> Children;
+  // First location for each child.
+  SmallVector<SourceLocation, 4> ChildrenLocations;
+  // The offsets at which the children should be inserted into OwnText.
+  SmallVector<unsigned, 4> ChildrenOffsets;
+
+  // This contains the text of this node, but not the text of it's children.
+  Optional<std::string> OwnText;
+
+  PatchedTreeNode(NodeRef BaseNode) : BaseNode(BaseNode) {}
+  PatchedTreeNode(const PatchedTreeNode &Other) = delete;
+  PatchedTreeNode(PatchedTreeNode &&Other) = default;
+
+  void addInsertion(PatchedTreeNode &PatchedNode, SourceLocation InsertionLoc) {
+    addChildAt(PatchedNode, InsertionLoc);
+  }
+  void addChild(PatchedTreeNode &PatchedNode) {
+    SourceLocation InsertionLoc = PatchedNode.getSourceRange().getBegin();
+    addChildAt(PatchedNode, InsertionLoc);
+  }
+
+private:
+  void addChildAt(PatchedTreeNode &PatchedNode, SourceLocation InsertionLoc) {
+    auto Less = makeTolerantLess(getTree().getSourceManager());
+    auto It = std::lower_bound(ChildrenLocations.begin(),
+                               ChildrenLocations.end(), InsertionLoc, Less);
+    auto Offset = It - ChildrenLocations.begin();
+    Children.insert(Children.begin() + Offset, &PatchedNode);
+    ChildrenLocations.insert(It, InsertionLoc);
+  }
+};
+} // end anonymous namespace
+
+namespace {
+class Patcher {
+  SyntaxTree &Dst, &Target;
+  SourceManager &SM;
+  const LangOptions &LangOpts;
+  BeforeThanCompare<SourceLocation> Less;
+  ASTDiff Diff, TargetDiff;
+  RefactoringTool &TargetTool;
+  bool Debug;
+  std::vector<PatchedTreeNode> PatchedTreeNodes;
+  std::map<NodeId, PatchedTreeNode *> InsertedNodes;
+  // Maps NodeId in Dst to a flag that is true if this node is
+  // part of an inserted subtree.
+  std::vector<bool> AtomicInsertions;
+
+public:
+  Patcher(SyntaxTree &Src, SyntaxTree &Dst, SyntaxTree &Target,
+          const ComparisonOptions &Options, RefactoringTool &TargetTool,
+          bool Debug)
+      : Dst(Dst), Target(Target), SM(Target.getSourceManager()),
+        LangOpts(Target.getLangOpts()), Less(SM), Diff(Src, Dst, Options),
+        TargetDiff(Src, Target, Options), TargetTool(TargetTool), Debug(Debug) {
+  }
+
+  Error apply();
+
+private:
+  void buildPatchedTree();
+  void addInsertedAndMovedNodes();
+  SourceLocation findLocationForInsertion(NodeRef &InsertedNode,
+                                          PatchedTreeNode &InsertionTarget);
+  SourceLocation findLocationForMove(NodeRef DstNode, NodeRef TargetNode,
+                                     PatchedTreeNode &NewParent);
+  void markChangedNodes();
+  Error addReplacementsForChangedNodes();
+  Error addReplacementsForTopLevelChanges();
+
+  // Recursively builds the text that is represented by this subtree.
+  std::string buildSourceText(PatchedTreeNode &PatchedNode);
+  void setOwnedSourceText(PatchedTreeNode &PatchedNode);
+  std::pair<int, bool>
+  findPointOfInsertion(NodeRef N, PatchedTreeNode &TargetParent) const;
+  bool isInserted(const PatchedTreeNode &PatchedNode) const {
+    return isFromDst(PatchedNode);
+  }
+  ChangeKind getChange(NodeRef TargetNode) const {
+    if (!isFromTarget(TargetNode))
+      return NoChange;
+    const Node *SrcNode = TargetDiff.getMapped(TargetNode);
+    if (!SrcNode)
+      return NoChange;
+    return Diff.getNodeChange(*SrcNode);
+  }
+  bool isRemoved(NodeRef TargetNode) const {
+    return getChange(TargetNode) == Delete;
+  }
+  bool isMoved(NodeRef TargetNode) const {
+    return getChange(TargetNode) == Move || getChange(TargetNode) == UpdateMove;
+  }
+  bool isRemovedOrMoved(NodeRef TargetNode) const {
+    return isRemoved(TargetNode) || isMoved(TargetNode);
+  }
+  PatchedTreeNode &findParent(NodeRef N) {
+    if (isFromDst(N))
+      return findDstParent(N);
+    return findTargetParent(N);
+  }
+  PatchedTreeNode &findDstParent(NodeRef DstNode) {
+    const Node *SrcNode = Diff.getMapped(DstNode);
+    NodeRef DstParent = *DstNode.getParent();
+    if (SrcNode) {
+      assert(Diff.getNodeChange(*SrcNode) == Insert);
+      const Node *TargetParent = mapDstToTarget(DstParent);
+      assert(TargetParent);
+      return getTargetPatchedNode(*TargetParent);
+    }
+    return getPatchedNode(DstParent);
+  }
+  PatchedTreeNode &findTargetParent(NodeRef TargetNode) {
+    assert(isFromTarget(TargetNode));
+    const Node *SrcNode = TargetDiff.getMapped(TargetNode);
+    if (SrcNode) {
+      ChangeKind Change = Diff.getNodeChange(*SrcNode);
+      if (Change == Move || Change == UpdateMove) {
+        NodeRef DstNode = *Diff.getMapped(*SrcNode);
+        return getPatchedNode(*DstNode.getParent());
+      }
+    }
+    return getTargetPatchedNode(*TargetNode.getParent());
+  }
+  CharSourceRange getRangeForReplacing(NodeRef TargetNode) const {
+    if (isRemovedOrMoved(TargetNode))
+      return TargetNode.findRangeForDeletion();
+    return TargetNode.getSourceRange();
+  }
+  Error addReplacement(Replacement &&R) {
+    return TargetTool.getReplacements()[R.getFilePath()].add(R);
+  }
+  bool isFromTarget(NodeRef N) const { return &N.getTree() == &Target; }
+  bool isFromDst(NodeRef N) const { return &N.getTree() == &Dst; }
+  PatchedTreeNode &getTargetPatchedNode(NodeRef N) {
+    assert(isFromTarget(N));
+    return PatchedTreeNodes[N.getId()];
+  }
+  PatchedTreeNode &getPatchedNode(NodeRef N) {
+    if (isFromDst(N))
+      return *InsertedNodes.at(N.getId());
+    return PatchedTreeNodes[N.getId()];
+  }
+  const Node *mapDstToTarget(NodeRef DstNode) const {
+    const Node *SrcNode = Diff.getMapped(DstNode);
+    if (!SrcNode)
+      return nullptr;
+    return TargetDiff.getMapped(*SrcNode);
+  }
+  const Node *mapTargetToDst(NodeRef TargetNode) const {
+    const Node *SrcNode = TargetDiff.getMapped(TargetNode);
+    if (!SrcNode)
+      return nullptr;
+    return Diff.getMapped(*SrcNode);
+  }
+};
+} // end anonymous namespace
+
+static void markBiggestSubtrees(std::vector<bool> &Marked, SyntaxTree &Tree,
+                                llvm::function_ref<bool(NodeRef)> Predicate) {
+  Marked.resize(Tree.getSize());
+  for (NodeRef N : Tree.postorder()) {
+    bool AllChildrenMarked =
+        std::all_of(N.begin(), N.end(),
+                    [&Marked](NodeRef Child) { return Marked[Child.getId()]; });
+    Marked[N.getId()] = Predicate(N) && AllChildrenMarked;
+  }
+}
+
+Error Patcher::apply() {
+  if (Debug)
+    Diff.dumpChanges(llvm::errs(), /*DumpMatches=*/true);
+  markBiggestSubtrees(AtomicInsertions, Dst, [this](NodeRef DstNode) {
+    return Diff.getNodeChange(DstNode) == Insert;
+  });
+  buildPatchedTree();
+  addInsertedAndMovedNodes();
+  markChangedNodes();
+  if (auto Err = addReplacementsForChangedNodes())
+    return Err;
+  Rewriter Rewrite(SM, LangOpts);
+  if (!TargetTool.applyAllReplacements(Rewrite))
+    return error(patching_error::failed_to_apply_replacements);
+  if (Rewrite.overwriteChangedFiles())
+    // Some file has not been saved successfully.
+    return error(patching_error::failed_to_overwrite_files);
+  return Error::success();
+}
+
+static bool wantToInsertBefore(SourceLocation Insertion, SourceLocation Point,
+                               BeforeThanCompare<SourceLocation> &Less) {
+  assert(Insertion.isValid());
+  assert(Point.isValid());
+  return Less(Insertion, Point);
+}
+
+void Patcher::buildPatchedTree() {
+  // Firstly, add all nodes of the tree that will be patched to
+  // PatchedTreeNodes. This way, their offset (getId()) is the same as in the
+  // original tree.
+  PatchedTreeNodes.reserve(Target.getSize());
+  for (NodeRef TargetNode : Target)
+    PatchedTreeNodes.emplace_back(TargetNode);
+  // Then add all inserted nodes, from Dst.
+  for (NodeId DstId = Dst.getRootId(), E = Dst.getSize(); DstId < E; ++DstId) {
+    NodeRef DstNode = Dst.getNode(DstId);
+    ChangeKind Change = Diff.getNodeChange(DstNode);
+    if (Change == Insert) {
+      PatchedTreeNodes.emplace_back(DstNode);
+      InsertedNodes.emplace(DstNode.getId(), &PatchedTreeNodes.back());
+      // If the whole subtree is inserted, we can skip the children, as we
+      // will just copy the text of the entire subtree.
+      if (AtomicInsertions[DstId])
+        DstId = DstNode.RightMostDescendant;
+    }
+  }
+  // Add existing children.
+  for (auto &PatchedNode : PatchedTreeNodes) {
+    if (isFromTarget(PatchedNode))
+      for (auto &Child : PatchedNode.originalNode())
+        if (!isRemovedOrMoved(Child))
+          PatchedNode.addChild(getPatchedNode(Child));
+  }
+}
+
+void Patcher::addInsertedAndMovedNodes() {
+  ChangeKind Change = NoChange;
+  for (NodeId DstId = Dst.getRootId(), E = Dst.getSize(); DstId < E;
+       DstId = Change == Insert && AtomicInsertions[DstId]
+                   ? Dst.getNode(DstId).RightMostDescendant + 1
+                   : DstId + 1) {
+    NodeRef DstNode = Dst.getNode(DstId);
+    Change = Diff.getNodeChange(DstNode);
+    if (!(Change == Move || Change == UpdateMove || Change == Insert))
+      continue;
+    NodeRef DstParent = *DstNode.getParent();
+    PatchedTreeNode *InsertionTarget, *NodeToInsert;
+    SourceLocation InsertionLoc;
+    if (Diff.getNodeChange(DstParent) == Insert) {
+      InsertionTarget = &getPatchedNode(DstParent);
+    } else {
+      const Node *TargetParent = mapDstToTarget(DstParent);
+      if (!TargetParent)
+        continue;
+      InsertionTarget = &getTargetPatchedNode(*TargetParent);
+    }
+    if (Change == Insert) {
+      NodeToInsert = &getPatchedNode(DstNode);
+      InsertionLoc = findLocationForInsertion(DstNode, *InsertionTarget);
+    } else {
+      assert(Change == Move || Change == UpdateMove);
+      const Node *TargetNode = mapDstToTarget(DstNode);
+      assert(TargetNode && "Node to update not found.");
+      NodeToInsert = &getTargetPatchedNode(*TargetNode);
+      InsertionLoc =
+          findLocationForMove(DstNode, *TargetNode, *InsertionTarget);
+    }
+    assert(InsertionLoc.isValid());
+    InsertionTarget->addInsertion(*NodeToInsert, InsertionLoc);
+  }
+}
+
+SourceLocation
+Patcher::findLocationForInsertion(NodeRef DstNode,
+                                  PatchedTreeNode &InsertionTarget) {
+  assert(isFromDst(DstNode));
+  assert(isFromDst(InsertionTarget) || isFromTarget(InsertionTarget));
+  int ChildIndex;
+  bool RightOfChild;
+  unsigned NumChildren = InsertionTarget.Children.size();
+  std::tie(ChildIndex, RightOfChild) =
+      findPointOfInsertion(DstNode, InsertionTarget);
+  if (NumChildren && ChildIndex != -1) {
+    auto NeighborRange = InsertionTarget.Children[ChildIndex]->getSourceRange();
+    SourceLocation InsertionLocation =
+        RightOfChild ? NeighborRange.getEnd() : NeighborRange.getBegin();
+    if (InsertionLocation.isValid())
+      return InsertionLocation;
+  }
+  llvm_unreachable("Not implemented.");
+}
+
+SourceLocation Patcher::findLocationForMove(NodeRef DstNode, NodeRef TargetNode,
+                                            PatchedTreeNode &NewParent) {
+  assert(isFromDst(DstNode));
+  assert(isFromTarget(TargetNode));
+  return DstNode.getSourceRange().getEnd();
+}
+
+void Patcher::markChangedNodes() {
+  for (auto Pair : InsertedNodes) {
+    NodeRef DstNode = Dst.getNode(Pair.first);
+    getPatchedNode(DstNode).Changed = true;
+  }
+  // Mark nodes in original as changed.
+  for (NodeRef TargetNode : Target.postorder()) {
+    auto &PatchedNode = PatchedTreeNodes[TargetNode.getId()];
+    const Node *SrcNode = TargetDiff.getMapped(TargetNode);
+    if (!SrcNode)
+      continue;
+    ChangeKind Change = Diff.getNodeChange(*SrcNode);
+    auto &Children = PatchedNode.Children;
+    bool AnyChildChanged =
+        std::any_of(Children.begin(), Children.end(),
+                    [](PatchedTreeNode *Child) { return Child->Changed; });
+    bool AnyChildRemoved = std::any_of(
+        PatchedNode.originalNode().begin(), PatchedNode.originalNode().end(),
+        [this](NodeRef Child) { return isRemovedOrMoved(Child); });
+    assert(!PatchedNode.Changed);
+    PatchedNode.Changed =
+        AnyChildChanged || AnyChildRemoved || Change != NoChange;
+  }
+}
+
+Error Patcher::addReplacementsForChangedNodes() {
+  for (NodeId TargetId = Target.getRootId(), E = Target.getSize(); TargetId < E;
+       ++TargetId) {
+    NodeRef TargetNode = Target.getNode(TargetId);
+    auto &PatchedNode = getTargetPatchedNode(TargetNode);
+    if (!PatchedNode.Changed)
+      continue;
+    if (TargetId == Target.getRootId())
+      return addReplacementsForTopLevelChanges();
+    CharSourceRange Range = getRangeForReplacing(TargetNode);
+    std::string Text =
+        isRemovedOrMoved(PatchedNode) ? "" : buildSourceText(PatchedNode);
+    if (auto Err = addReplacement({SM, Range, Text, LangOpts}))
+      return Err;
+    TargetId = TargetNode.RightMostDescendant;
+  }
+  return Error::success();
+}
+
+Error Patcher::addReplacementsForTopLevelChanges() {
+  auto &Root = getTargetPatchedNode(Target.getRoot());
+  for (unsigned I = 0, E = Root.Children.size(); I < E; ++I) {
+    PatchedTreeNode *Child = Root.Children[I];
+    if (!Child->Changed)
+      continue;
+    std::string ChildText = buildSourceText(*Child);
+    CharSourceRange ChildRange;
+    if (isInserted(*Child) || isMoved(*Child)) {
+      SourceLocation InsertionLoc;
+      unsigned NumChildren = Root.Children.size();
+      int ChildIndex;
+      bool RightOfChild;
+      std::tie(ChildIndex, RightOfChild) = findPointOfInsertion(*Child, Root);
+      if (NumChildren && ChildIndex != -1) {
+        auto NeighborRange = Root.Children[ChildIndex]->getSourceRange();
+        InsertionLoc =
+            RightOfChild ? NeighborRange.getEnd() : NeighborRange.getBegin();
+      } else {
+        InsertionLoc = SM.getLocForEndOfFile(SM.getMainFileID())
+                           .getLocWithOffset(-int(strlen("\n")));
+      }
+      ChildRange = makeEmptyCharRange(InsertionLoc);
+    } else {
+      ChildRange = Child->getSourceRange();
+    }
+    if (auto Err = addReplacement({SM, ChildRange, ChildText, LangOpts})) {
+      return Err;
+    }
+  }
+  for (NodeRef Child : Root.originalNode()) {
+    if (isRemovedOrMoved(Child)) {
+      auto ChildRange = Child.findRangeForDeletion();
+      if (auto Err = addReplacement({SM, ChildRange, "", LangOpts}))
+        return Err;
+    }
+  }
+  return Error::success();
+}
+
+static StringRef trailingText(SourceLocation Loc, SyntaxTree &Tree) {
+  Token NextToken;
+  bool Failure = Lexer::getRawToken(Loc, NextToken, Tree.getSourceManager(),
+                                    Tree.getLangOpts(),
+                                    /*IgnoreWhiteSpace=*/true);
+  if (Failure)
+    return StringRef();
+  assert(!Failure);
+  return Lexer::getSourceText(
+      CharSourceRange::getCharRange({Loc, NextToken.getLocation()}),
+      Tree.getSourceManager(), Tree.getLangOpts());
+}
+
+std::string Patcher::buildSourceText(PatchedTreeNode &PatchedNode) {
+  auto &Children = PatchedNode.Children;
+  auto &ChildrenOffsets = PatchedNode.ChildrenOffsets;
+  auto &OwnText = PatchedNode.OwnText;
+  auto Range = PatchedNode.getSourceRange();
+  SyntaxTree &Tree = PatchedNode.getTree();
+  SourceManager &MySM = Tree.getSourceManager();
+  const LangOptions &MyLangOpts = Tree.getLangOpts();
+  assert(!isRemoved(PatchedNode));
+  if (!PatchedNode.Changed ||
+      (isFromDst(PatchedNode) && AtomicInsertions[PatchedNode.getId()])) {
+    std::string Text = Lexer::getSourceText(Range, MySM, MyLangOpts);
+    // TODO why
+    if (!isFromDst(PatchedNode))
+      Text += trailingText(Range.getEnd(), Tree);
+    return Text;
+  }
+  setOwnedSourceText(PatchedNode);
+  std::string Result;
+  unsigned Offset = 0;
+  assert(ChildrenOffsets.size() == Children.size());
+  for (unsigned I = 0, E = Children.size(); I < E; ++I) {
+    PatchedTreeNode *Child = Children[I];
+    unsigned Start = ChildrenOffsets[I];
+    Result += OwnText->substr(Offset, Start - Offset);
+    Result += buildSourceText(*Child);
+    Offset = Start;
+  }
+  assert(Offset <= OwnText->size());
+  Result += OwnText->substr(Offset, OwnText->size() - Offset);
+  return Result;
+}
+
+void Patcher::setOwnedSourceText(PatchedTreeNode &PatchedNode) {
+  assert(isFromTarget(PatchedNode) || isFromDst(PatchedNode));
+  SyntaxTree &Tree = PatchedNode.getTree();
+  const Node *SrcNode = nullptr;
+  bool IsUpdate = false;
+  auto &OwnText = PatchedNode.OwnText;
+  auto &Children = PatchedNode.Children;
+  auto &ChildrenLocations = PatchedNode.ChildrenLocations;
+  auto &ChildrenOffsets = PatchedNode.ChildrenOffsets;
+  OwnText = "";
+  unsigned NumChildren = Children.size();
+  if (isFromTarget(PatchedNode)) {
+    SrcNode = TargetDiff.getMapped(PatchedNode);
+    ChangeKind Change = SrcNode ? Diff.getNodeChange(*SrcNode) : NoChange;
+    IsUpdate = Change == Update || Change == UpdateMove;
+  }
+  unsigned ChildIndex = 0;
+  auto MySourceRanges = PatchedNode.getOwnedSourceRanges();
+  BeforeThanCompare<SourceLocation> MyLess(Tree.getSourceManager());
+  for (auto &MySubRange : MySourceRanges) {
+    SourceLocation ChildBegin;
+    SourceLocation InsertionBegin;
+    while (ChildIndex < NumChildren &&
+           ((ChildBegin = ChildrenLocations[ChildIndex]).isInvalid() ||
+            wantToInsertBefore(ChildBegin, MySubRange.getEnd(), MyLess))) {
+      ChildrenOffsets.push_back(OwnText->size());
+      ++ChildIndex;
+    }
+    if (IsUpdate) {
+      llvm_unreachable("Not implemented.");
+    } else {
+      *OwnText += Lexer::getSourceText(MySubRange, Tree.getSourceManager(),
+                                       Tree.getLangOpts());
+    }
+  }
+  while (ChildIndex++ < NumChildren)
+    ChildrenOffsets.push_back(OwnText->size());
+}
+
+std::pair<int, bool>
+Patcher::findPointOfInsertion(NodeRef N, PatchedTreeNode &TargetParent) const {
+  assert(isFromDst(N) || isFromTarget(N));
+  assert(isFromTarget(TargetParent));
+  auto MapFunction = [this, &N](PatchedTreeNode &Sibling) {
+    if (isFromDst(N) == isFromDst(Sibling))
+      return &NodeRef(Sibling);
+    if (isFromDst(N))
+      return mapTargetToDst(Sibling);
+    else
+      return mapDstToTarget(Sibling);
+  };
+  unsigned NumChildren = TargetParent.Children.size();
+  BeforeThanCompare<SourceLocation> Less(N.getTree().getSourceManager());
+  auto NodeIndex = N.findPositionInParent();
+  SourceLocation MyLoc = N.getSourceRange().getBegin();
+  assert(MyLoc.isValid());
+  for (unsigned I = 0; I < NumChildren; ++I) {
+    const Node *Sibling = MapFunction(*TargetParent.Children[I]);
+    if (!Sibling)
+      continue;
+    SourceLocation SiblingLoc = Sibling->getSourceRange().getBegin();
+    if (SiblingLoc.isInvalid())
+      continue;
+    if (NodeIndex && Sibling == &N.getParent()->getChild(NodeIndex - 1)) {
+      return {I, /*RightOfSibling=*/true};
+    }
+    if (Less(MyLoc, SiblingLoc)) {
+      return {I, /*RightOfSibling=*/false};
+    }
+  }
+  return {-1, true};
+}
+
+Error patch(RefactoringTool &TargetTool, SyntaxTree &Src, SyntaxTree &Dst,
+            const ComparisonOptions &Options, bool Debug) {
+  std::vector<std::unique_ptr<ASTUnit>> TargetASTs;
+  TargetTool.buildASTs(TargetASTs);
+  if (TargetASTs.size() != 1)
+    return error(patching_error::failed_to_build_AST);
+  SyntaxTree Target(*TargetASTs[0]);
+  return Patcher(Src, Dst, Target, Options, TargetTool, Debug).apply();
+}
+
+std::string PatchingError::message() const {
+  switch (Err) {
+  case patching_error::failed_to_build_AST:
+    return "Failed to build AST.\n";
+  case patching_error::failed_to_apply_replacements:
+    return "Failed to apply replacements.\n";
+  case patching_error::failed_to_overwrite_files:
+    return "Failed to overwrite some file(s).\n";
+  };
+}
+
+char PatchingError::ID = 1;
+
+} // end namespace diff
+} // end namespace clang
Index: lib/Tooling/ASTDiff/ASTDiff.cpp
===================================================================
--- lib/Tooling/ASTDiff/ASTDiff.cpp
+++ lib/Tooling/ASTDiff/ASTDiff.cpp
@@ -841,6 +841,37 @@
   return SourceRanges;
 }
 
+CharSourceRange Node::findRangeForDeletion() const {
+  CharSourceRange Range = getSourceRange();
+  if (!getParent())
+    return Range;
+  NodeRef Parent = *getParent();
+  SyntaxTree &Tree = getTree();
+  SourceManager &SM = Tree.getSourceManager();
+  const LangOptions &LangOpts = Tree.getLangOpts();
+  auto &DTN = ASTNode;
+  auto &ParentDTN = Parent.ASTNode;
+  size_t SiblingIndex = findPositionInParent();
+  const auto &Siblings = Parent.Children;
+  // Remove the comma if the location is within a comma-separated list of
+  // at least size 2 (minus the callee for CallExpr).
+  if ((ParentDTN.get<CallExpr>() && Siblings.size() > 2) ||
+      (DTN.get<ParmVarDecl>() && Siblings.size() > 2)) {
+    bool LastSibling = SiblingIndex == Siblings.size() - 1;
+    SourceLocation CommaLoc;
+    if (LastSibling) {
+      CommaLoc = Parent.getChild(SiblingIndex - 1).getSourceRange().getEnd();
+      Range.setBegin(CommaLoc);
+    } else {
+      Optional<Token> Comma =
+          Lexer::findNextToken(Range.getEnd(), SM, LangOpts);
+      if (Comma && Comma->is(tok::comma))
+        Range.setEnd(Comma->getEndLoc());
+    }
+  }
+  return Range;
+}
+
 void forEachTokenInRange(CharSourceRange Range, SyntaxTree &Tree,
                          std::function<void(Token &)> Body) {
   SourceLocation Begin = Range.getBegin(), End = Range.getEnd();
Index: include/clang/Tooling/ASTDiff/ASTPatch.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/ASTDiff/ASTPatch.h
@@ -0,0 +1,49 @@
+//===- ASTPatch.h - Structural patching based on ASTDiff ------*- C++ -*- -===//
+//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H
+#define LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H
+
+#include "clang/Tooling/ASTDiff/ASTDiff.h"
+#include "clang/Tooling/Refactoring.h"
+#include "llvm/Support/Error.h"
+
+namespace clang {
+namespace diff {
+
+enum class patching_error {
+  failed_to_build_AST,
+  failed_to_apply_replacements,
+  failed_to_overwrite_files,
+};
+
+class PatchingError : public llvm::ErrorInfo<PatchingError> {
+public:
+  PatchingError(patching_error Err) : Err(Err) {}
+  std::string message() const override;
+  void log(raw_ostream &OS) const override { OS << message() << "\n"; }
+  patching_error get() const { return Err; }
+  static char ID;
+
+private:
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+  patching_error Err;
+};
+
+llvm::Error patch(tooling::RefactoringTool &TargetTool, SyntaxTree &Src,
+                  SyntaxTree &Dst, const ComparisonOptions &Options,
+                  bool Debug = false);
+
+} // end namespace diff
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_ASTDIFF_ASTPATCH_H
Index: include/clang/Tooling/ASTDiff/ASTDiff.h
===================================================================
--- include/clang/Tooling/ASTDiff/ASTDiff.h
+++ include/clang/Tooling/ASTDiff/ASTDiff.h
@@ -163,6 +163,11 @@
   /// this node, that is, none of its descendants includes them.
   SmallVector<CharSourceRange, 4> getOwnedSourceRanges() const;
 
+  /// This differs from getSourceRange() in the sense that the range is extended
+  /// to include the trailing comma if the node is within a comma-separated
+  /// list.
+  CharSourceRange findRangeForDeletion() const;
+
   /// Returns the offsets for the range returned by getSourceRange().
   std::pair<unsigned, unsigned> getSourceRangeOffsets() const;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to