arphaman updated this revision to Diff 107121.
arphaman added a comment.

Factor out the lexical ordering code into a new visitor and simplify the 
implementation of the ast selection visitor


Repository:
  rL LLVM

https://reviews.llvm.org/D35012

Files:
  include/clang/AST/LexicallyOrderedRecursiveASTVisitor.h
  include/clang/Basic/SourceLocation.h
  include/clang/Basic/SourceManager.h
  include/clang/Tooling/Refactoring/ASTSelection.h
  lib/Tooling/Refactoring/ASTSelection.cpp
  lib/Tooling/Refactoring/CMakeLists.txt
  unittests/Tooling/ASTSelectionTest.cpp
  unittests/Tooling/CMakeLists.txt
  unittests/Tooling/LexicallyOrderedRecursiveASTVisitorTest.cpp

Index: unittests/Tooling/LexicallyOrderedRecursiveASTVisitorTest.cpp
===================================================================
--- /dev/null
+++ unittests/Tooling/LexicallyOrderedRecursiveASTVisitorTest.cpp
@@ -0,0 +1,141 @@
+//===- unittest/Tooling/LexicallyOrderedRecursiveASTVisitorTest.cpp -------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestVisitor.h"
+#include "clang/AST/LexicallyOrderedRecursiveASTVisitor.h"
+#include <stack>
+
+using namespace clang;
+
+namespace {
+
+class DummyMatchVisitor;
+
+class LexicallyOrderedDeclVisitor
+    : public LexicallyOrderedRecursiveASTVisitor<LexicallyOrderedDeclVisitor> {
+public:
+  LexicallyOrderedDeclVisitor(DummyMatchVisitor &Matcher,
+                              const SourceManager &SM)
+      : LexicallyOrderedRecursiveASTVisitor(SM), Matcher(Matcher) {}
+
+  bool TraverseDecl(Decl *D) {
+    TraversalStack.push_back(D);
+    LexicallyOrderedRecursiveASTVisitor::TraverseDecl(D);
+    TraversalStack.pop_back();
+    return true;
+  }
+
+  bool VisitNamedDecl(const NamedDecl *D);
+
+private:
+  DummyMatchVisitor &Matcher;
+  llvm::SmallVector<Decl *, 8> TraversalStack;
+};
+
+class DummyMatchVisitor : public ExpectedLocationVisitor<DummyMatchVisitor> {
+public:
+  bool VisitTranslationUnitDecl(TranslationUnitDecl *TU) {
+    const ASTContext &Context = TU->getASTContext();
+    const SourceManager &SM = Context.getSourceManager();
+    LexicallyOrderedDeclVisitor SubVisitor(*this, SM);
+    SubVisitor.TraverseDecl(TU);
+    return false;
+  }
+
+  void match(StringRef Path, const Decl *D) { Match(Path, D->getLocStart()); }
+};
+
+bool LexicallyOrderedDeclVisitor::VisitNamedDecl(const NamedDecl *D) {
+  std::string Path;
+  llvm::raw_string_ostream OS(Path);
+  assert(TraversalStack.back() == D);
+  for (const Decl *D : TraversalStack) {
+    if (isa<TranslationUnitDecl>(D)) {
+      OS << "/";
+      continue;
+    }
+    if (const auto *ND = dyn_cast<NamedDecl>(D))
+      OS << ND->getNameAsString();
+    else
+      OS << "???";
+    if (isa<DeclContext>(D))
+      OS << "/";
+  }
+  Matcher.match(OS.str(), D);
+  return true;
+}
+
+TEST(LexicallyOrderedRecursiveASTVisitor, VisitDeclsInImplementation) {
+  StringRef Source = R"(
+@interface I
+@end
+@implementation I
+
+int nestedFunction() { }
+
+- (void) method{ }
+
+int anotherNestedFunction(int x) {
+  return x;
+}
+
+int innerVariable = 0;
+
+@end
+
+int outerVariable = 0;
+
+@implementation I(Cat)
+
+void catF() { }
+
+@end
+
+void outerFunction() { }
+)";
+  DummyMatchVisitor Visitor;
+  Visitor.DisallowMatch("/nestedFunction/", 6, 1);
+  Visitor.ExpectMatch("/I/nestedFunction/", 6, 1);
+  Visitor.ExpectMatch("/I/method/", 8, 1);
+  Visitor.DisallowMatch("/anotherNestedFunction/", 10, 1);
+  Visitor.ExpectMatch("/I/anotherNestedFunction/", 10, 1);
+  Visitor.DisallowMatch("/innerVariable", 14, 1);
+  Visitor.ExpectMatch("/I/innerVariable", 14, 1);
+  Visitor.ExpectMatch("/outerVariable", 18, 1);
+  Visitor.DisallowMatch("/catF/", 22, 1);
+  Visitor.ExpectMatch("/Cat/catF/", 22, 1);
+  Visitor.ExpectMatch("/outerFunction/", 26, 1);
+  EXPECT_TRUE(Visitor.runOver(Source, DummyMatchVisitor::Lang_OBJC));
+}
+
+TEST(LexicallyOrderedRecursiveASTVisitor, VisitMacroDeclsInImplementation) {
+  StringRef Source = R"(
+@interface I
+@end
+
+void outerFunction() { }
+
+#define MACRO_F(x) void nestedFunction##x() { }
+
+@implementation I
+
+MACRO_F(1)
+
+@end
+
+MACRO_F(2)
+)";
+  DummyMatchVisitor Visitor;
+  Visitor.ExpectMatch("/outerFunction/", 5, 1);
+  Visitor.ExpectMatch("/I/nestedFunction1/", 7, 20);
+  Visitor.ExpectMatch("/nestedFunction2/", 7, 20);
+  EXPECT_TRUE(Visitor.runOver(Source, DummyMatchVisitor::Lang_OBJC));
+}
+
+} // end anonymous namespace
Index: unittests/Tooling/CMakeLists.txt
===================================================================
--- unittests/Tooling/CMakeLists.txt
+++ unittests/Tooling/CMakeLists.txt
@@ -11,11 +11,13 @@
 endif()
 
 add_clang_unittest(ToolingTests
+  ASTSelectionTest.cpp
   CastExprTest.cpp
   CommentHandlerTest.cpp
   CompilationDatabaseTest.cpp
   DiagnosticsYamlTest.cpp
   FixItTest.cpp
+  LexicallyOrderedRecursiveASTVisitorTest.cpp
   LookupTest.cpp
   QualTypeNamesTest.cpp
   RecursiveASTVisitorTest.cpp
Index: unittests/Tooling/ASTSelectionTest.cpp
===================================================================
--- /dev/null
+++ unittests/Tooling/ASTSelectionTest.cpp
@@ -0,0 +1,456 @@
+//===- unittest/Tooling/ASTSelectionTest.cpp ------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestVisitor.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Tooling/Refactoring/ASTSelection.h"
+
+using namespace clang;
+using namespace tooling;
+
+namespace {
+
+struct FileLocation {
+  unsigned Line, Column;
+
+  SourceLocation translate(const SourceManager &SM) {
+    return SM.translateLineCol(SM.getMainFileID(), Line, Column);
+  }
+};
+
+using FileRange = std::pair<FileLocation, FileLocation>;
+
+class SelectionFinderVisitor : public TestVisitor<SelectionFinderVisitor> {
+  FileLocation Location;
+  Optional<FileRange> SelectionRange;
+
+public:
+  Optional<SelectedASTNode> Selection;
+
+  SelectionFinderVisitor(FileLocation Location,
+                         Optional<FileRange> SelectionRange)
+      : Location(Location), SelectionRange(SelectionRange) {}
+
+  bool VisitTranslationUnitDecl(const TranslationUnitDecl *TU) {
+    const ASTContext &Context = TU->getASTContext();
+    const SourceManager &SM = Context.getSourceManager();
+
+    SourceRange SelRange;
+    if (SelectionRange) {
+      SelRange = SourceRange(SelectionRange->first.translate(SM),
+                             SelectionRange->second.translate(SM));
+    } else {
+      SourceLocation Loc = Location.translate(SM);
+      SelRange = SourceRange(Loc, Loc);
+    }
+    Selection = findSelectedASTNodes(Context, SelRange);
+    return false;
+  }
+};
+
+Optional<SelectedASTNode>
+findSelectedASTNodes(StringRef Source, FileLocation Location,
+                     Optional<FileRange> SelectionRange,
+                     SelectionFinderVisitor::Language Language =
+                         SelectionFinderVisitor::Lang_CXX11) {
+  SelectionFinderVisitor Visitor(Location, SelectionRange);
+  EXPECT_TRUE(Visitor.runOver(Source, Language));
+  return std::move(Visitor.Selection);
+}
+
+void checkNodeImpl(bool IsTypeMatched, const SelectedASTNode &Node,
+                   SourceSelectionKind SelectionKind, unsigned NumChildren) {
+  ASSERT_TRUE(IsTypeMatched);
+  EXPECT_EQ(Node.Children.size(), NumChildren);
+  ASSERT_EQ(Node.SelectionKind, SelectionKind);
+}
+
+void checkDeclName(const SelectedASTNode &Node, StringRef Name) {
+  const auto *ND = Node.Node.get<NamedDecl>();
+  EXPECT_TRUE(!!ND);
+  ASSERT_EQ(ND->getName(), Name);
+}
+
+template <typename T>
+const SelectedASTNode &
+checkNode(const SelectedASTNode &StmtNode, SourceSelectionKind SelectionKind,
+          unsigned NumChildren = 0,
+          typename std::enable_if<std::is_base_of<Stmt, T>::value, T>::type
+              *StmtOverloadChecker = nullptr) {
+  checkNodeImpl(isa<T>(StmtNode.Node.get<Stmt>()), StmtNode, SelectionKind,
+                NumChildren);
+  return StmtNode;
+}
+
+template <typename T>
+const SelectedASTNode &
+checkNode(const SelectedASTNode &DeclNode, SourceSelectionKind SelectionKind,
+          unsigned NumChildren = 0, StringRef Name = "",
+          typename std::enable_if<std::is_base_of<Decl, T>::value, T>::type
+              *DeclOverloadChecker = nullptr) {
+  checkNodeImpl(isa<T>(DeclNode.Node.get<Decl>()), DeclNode, SelectionKind,
+                NumChildren);
+  if (!Name.empty())
+    checkDeclName(DeclNode, Name);
+  return DeclNode;
+}
+
+struct ForAllChildrenOf {
+  const SelectedASTNode &Node;
+
+  static void childKindVerifier(const SelectedASTNode &Node,
+                                SourceSelectionKind SelectionKind) {
+    for (const SelectedASTNode &Child : Node.Children) {
+      ASSERT_EQ(Node.SelectionKind, SelectionKind);
+      childKindVerifier(Child, SelectionKind);
+    }
+  }
+
+public:
+  ForAllChildrenOf(const SelectedASTNode &Node) : Node(Node) {}
+
+  void shouldHaveSelectionKind(SourceSelectionKind Kind) {
+    childKindVerifier(Node, Kind);
+  }
+};
+
+ForAllChildrenOf allChildrenOf(const SelectedASTNode &Node) {
+  return ForAllChildrenOf(Node);
+}
+
+TEST(ASTSelectionFinder, CursorNoSelection) {
+  Optional<SelectedASTNode> Node =
+      findSelectedASTNodes(" void f() { }", {1, 1}, None);
+  EXPECT_FALSE(Node);
+}
+
+TEST(ASTSelectionFinder, CursorAtStartOfFunction) {
+  Optional<SelectedASTNode> Node =
+      findSelectedASTNodes("void f() { }", {1, 1}, None);
+  EXPECT_TRUE(Node);
+  checkNode<TranslationUnitDecl>(*Node, SourceSelectionKind::None,
+                                 /*NumChildren=*/1);
+  checkNode<FunctionDecl>(Node->Children[0],
+                          SourceSelectionKind::ContainsSelection,
+                          /*NumChildren=*/0, /*Name=*/"f");
+
+  // Check that the dumping works.
+  std::string DumpValue;
+  llvm::raw_string_ostream OS(DumpValue);
+  Node->Children[0].dump(OS);
+  ASSERT_EQ(OS.str(), "FunctionDecl \"f\" contains-selection\n");
+}
+
+TEST(ASTSelectionFinder, RangeNoSelection) {
+  {
+    Optional<SelectedASTNode> Node = findSelectedASTNodes(
+        " void f() { }", {1, 1}, FileRange{{1, 1}, {1, 1}});
+    EXPECT_FALSE(Node);
+  }
+  {
+    Optional<SelectedASTNode> Node = findSelectedASTNodes(
+        "  void f() { }", {1, 1}, FileRange{{1, 1}, {1, 2}});
+    EXPECT_FALSE(Node);
+  }
+}
+
+TEST(ASTSelectionFinder, EmptyRangeFallbackToCursor) {
+  Optional<SelectedASTNode> Node =
+      findSelectedASTNodes("void f() { }", {1, 1}, FileRange{{1, 1}, {1, 1}});
+  EXPECT_TRUE(Node);
+  checkNode<FunctionDecl>(Node->Children[0],
+                          SourceSelectionKind::ContainsSelection,
+                          /*NumChildren=*/0, /*Name=*/"f");
+}
+
+TEST(ASTSelectionFinder, WholeFunctionSelection) {
+  StringRef Source = "int f(int x) { return x;\n}\nvoid f2() { }";
+  // From 'int' until just after '}':
+  {
+    auto Node = findSelectedASTNodes(Source, {1, 1}, FileRange{{1, 1}, {2, 2}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2, /*Name=*/"f");
+    checkNode<ParmVarDecl>(Fn.Children[0],
+                           SourceSelectionKind::InsideSelection);
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[1], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1);
+    const auto &Return = checkNode<ReturnStmt>(
+        Body.Children[0], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1);
+    checkNode<ImplicitCastExpr>(Return.Children[0],
+                                SourceSelectionKind::InsideSelection,
+                                /*NumChildren=*/1);
+    checkNode<DeclRefExpr>(Return.Children[0].Children[0],
+                           SourceSelectionKind::InsideSelection);
+  }
+  // From 'int' until just before '}':
+  {
+    auto Node = findSelectedASTNodes(Source, {2, 1}, FileRange{{1, 1}, {2, 1}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[1], SourceSelectionKind::ContainsSelectionEnd,
+        /*NumChildren=*/1);
+    checkNode<ReturnStmt>(Body.Children[0],
+                          SourceSelectionKind::InsideSelection,
+                          /*NumChildren=*/1);
+  }
+  // From '{' until just after '}':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {1, 14}, FileRange{{1, 14}, {2, 2}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1);
+    checkNode<ReturnStmt>(Body.Children[0],
+                          SourceSelectionKind::InsideSelection,
+                          /*NumChildren=*/1);
+  }
+  // From 'x' until just after '}':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {2, 2}, FileRange{{1, 11}, {2, 2}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2, /*Name=*/"f");
+    checkNode<ParmVarDecl>(Fn.Children[0],
+                           SourceSelectionKind::ContainsSelectionStart);
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[1], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1);
+    checkNode<ReturnStmt>(Body.Children[0],
+                          SourceSelectionKind::InsideSelection,
+                          /*NumChildren=*/1);
+  }
+}
+
+TEST(ASTSelectionFinder, MultipleFunctionSelection) {
+  StringRef Source = R"(void f0() {
+}
+void f1() { }
+void f2() { }
+void f3() { }
+)";
+  auto SelectedF1F2 = [](Optional<SelectedASTNode> Node) {
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 2u);
+    checkNode<FunctionDecl>(Node->Children[0],
+                            SourceSelectionKind::InsideSelection,
+                            /*NumChildren=*/1, /*Name=*/"f1");
+    checkNode<FunctionDecl>(Node->Children[1],
+                            SourceSelectionKind::InsideSelection,
+                            /*NumChildren=*/1, /*Name=*/"f2");
+  };
+  // Just after '}' of f0 and just before 'void' of f3:
+  SelectedF1F2(findSelectedASTNodes(Source, {2, 2}, FileRange{{2, 2}, {5, 1}}));
+  // Just before 'void' of f1 and just after '}' of f2:
+  SelectedF1F2(
+      findSelectedASTNodes(Source, {3, 1}, FileRange{{3, 1}, {4, 14}}));
+}
+
+TEST(ASTSelectionFinder, MultipleStatementSelection) {
+  StringRef Source = R"(void f(int x, int y) {
+  int z = x;
+  f(2, 3);
+  if (x == 0) {
+    return;
+  }
+  x = 1;
+  return;
+})";
+  // From 'f(2,3)' until just before 'x = 1;':
+  {
+    auto Node = findSelectedASTNodes(Source, {3, 2}, FileRange{{3, 2}, {7, 1}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2);
+    allChildrenOf(checkNode<CallExpr>(Body.Children[0],
+                                      SourceSelectionKind::InsideSelection,
+                                      /*NumChildren=*/3))
+        .shouldHaveSelectionKind(SourceSelectionKind::InsideSelection);
+    allChildrenOf(checkNode<IfStmt>(Body.Children[1],
+                                    SourceSelectionKind::InsideSelection,
+                                    /*NumChildren=*/2))
+        .shouldHaveSelectionKind(SourceSelectionKind::InsideSelection);
+  }
+  // From 'f(2,3)' until just before ';' in 'x = 1;':
+  {
+    auto Node = findSelectedASTNodes(Source, {3, 2}, FileRange{{3, 2}, {7, 8}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/3);
+    checkNode<CallExpr>(Body.Children[0], SourceSelectionKind::InsideSelection,
+                        /*NumChildren=*/3);
+    checkNode<IfStmt>(Body.Children[1], SourceSelectionKind::InsideSelection,
+                      /*NumChildren=*/2);
+    checkNode<BinaryOperator>(Body.Children[2],
+                              SourceSelectionKind::InsideSelection,
+                              /*NumChildren=*/2);
+  }
+  // From the middle of 'int z = 3' until the middle of 'x = 1;':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {2, 10}, FileRange{{2, 10}, {7, 5}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/4);
+    checkNode<DeclStmt>(Body.Children[0],
+                        SourceSelectionKind::ContainsSelectionStart,
+                        /*NumChildren=*/1);
+    checkNode<CallExpr>(Body.Children[1], SourceSelectionKind::InsideSelection,
+                        /*NumChildren=*/3);
+    checkNode<IfStmt>(Body.Children[2], SourceSelectionKind::InsideSelection,
+                      /*NumChildren=*/2);
+    checkNode<BinaryOperator>(Body.Children[3],
+                              SourceSelectionKind::ContainsSelectionEnd,
+                              /*NumChildren=*/1);
+  }
+}
+
+TEST(ASTSelectionFinder, SelectionInFunctionInObjCImplementation) {
+  StringRef Source = R"(
+@interface I
+@end
+@implementation I
+
+int notSelected() { }
+
+int selected(int x) {
+  return x;
+}
+
+@end
+@implementation I(Cat)
+
+void catF() { }
+
+@end
+
+void outerFunction() { }
+)";
+  // Just the 'x' expression in 'selected':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {9, 10}, FileRange{{9, 10}, {9, 11}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Impl = checkNode<ObjCImplementationDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"I");
+    const auto &Fn = checkNode<FunctionDecl>(
+        Impl.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"selected");
+    allChildrenOf(Fn).shouldHaveSelectionKind(
+        SourceSelectionKind::ContainsSelection);
+  }
+  // The entire 'catF':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {15, 1}, FileRange{{15, 1}, {15, 16}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Impl = checkNode<ObjCCategoryImplDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"Cat");
+    const auto &Fn = checkNode<FunctionDecl>(
+        Impl.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"catF");
+    allChildrenOf(Fn).shouldHaveSelectionKind(
+        SourceSelectionKind::ContainsSelection);
+  }
+  // From the line before 'selected' to the line after 'catF':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {16, 1}, FileRange{{7, 1}, {16, 1}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 2u);
+    const auto &Impl = checkNode<ObjCImplementationDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelectionStart,
+        /*NumChildren=*/1, /*Name=*/"I");
+    const auto &Selected = checkNode<FunctionDecl>(
+        Impl.Children[0], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/2, /*Name=*/"selected");
+    allChildrenOf(Selected).shouldHaveSelectionKind(
+        SourceSelectionKind::InsideSelection);
+    const auto &Cat = checkNode<ObjCCategoryImplDecl>(
+        Node->Children[1], SourceSelectionKind::ContainsSelectionEnd,
+        /*NumChildren=*/1, /*Name=*/"Cat");
+    const auto &CatF = checkNode<FunctionDecl>(
+        Cat.Children[0], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1, /*Name=*/"catF");
+    allChildrenOf(CatF).shouldHaveSelectionKind(
+        SourceSelectionKind::InsideSelection);
+  }
+  // Just the 'outer' function:
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {19, 1}, FileRange{{19, 1}, {19, 25}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    checkNode<FunctionDecl>(Node->Children[0],
+                            SourceSelectionKind::ContainsSelection,
+                            /*NumChildren=*/1, /*Name=*/"outerFunction");
+  }
+}
+
+TEST(ASTSelectionFinder, AvoidImplicitDeclarations) {
+  StringRef Source = R"(
+struct Copy {
+  int x;
+};
+void foo() {
+  Copy x;
+  Copy y = x;
+}
+)";
+  // The entire struct 'Copy':
+  auto Node = findSelectedASTNodes(Source, {2, 1}, FileRange{{2, 1}, {4, 3}});
+  EXPECT_TRUE(Node);
+  EXPECT_EQ(Node->Children.size(), 1u);
+  const auto &Record = checkNode<CXXRecordDecl>(
+      Node->Children[0], SourceSelectionKind::InsideSelection,
+      /*NumChildren=*/1, /*Name=*/"Copy");
+  checkNode<FieldDecl>(Record.Children[0],
+                       SourceSelectionKind::InsideSelection);
+}
+
+} // end anonymous namespace
Index: lib/Tooling/Refactoring/CMakeLists.txt
===================================================================
--- lib/Tooling/Refactoring/CMakeLists.txt
+++ lib/Tooling/Refactoring/CMakeLists.txt
@@ -4,6 +4,7 @@
   )
 
 add_clang_library(clangToolingRefactor
+  ASTSelection.cpp
   AtomicChange.cpp
   Rename/RenamingAction.cpp
   Rename/USRFinder.cpp
Index: lib/Tooling/Refactoring/ASTSelection.cpp
===================================================================
--- /dev/null
+++ lib/Tooling/Refactoring/ASTSelection.cpp
@@ -0,0 +1,192 @@
+//===--- ASTSelection.cpp - Clang refactoring library ---------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/Refactoring/ASTSelection.h"
+#include "clang/AST/LexicallyOrderedRecursiveASTVisitor.h"
+#include "clang/Lex/Lexer.h"
+
+using namespace clang;
+using namespace tooling;
+using ast_type_traits::DynTypedNode;
+
+namespace {
+
+/// Constructs the tree of selected AST nodes that either contain the location
+/// of the cursor or overlap with the selection range.
+class ASTSelectionFinder
+    : public LexicallyOrderedRecursiveASTVisitor<ASTSelectionFinder> {
+public:
+  ASTSelectionFinder(SourceRange Selection, FileID TargetFile,
+                     const ASTContext &Context)
+      : LexicallyOrderedRecursiveASTVisitor(Context.getSourceManager()),
+        SelectionBegin(Selection.getBegin()),
+        SelectionEnd(Selection.getBegin() == Selection.getEnd()
+                         ? SourceLocation()
+                         : Selection.getEnd()),
+        TargetFile(TargetFile), Context(Context) {
+    // The TU decl is the root of the selected node tree.
+    SelectionStack.push_back(
+        SelectedASTNode(DynTypedNode::create(*Context.getTranslationUnitDecl()),
+                        SourceSelectionKind::None));
+  }
+
+  Optional<SelectedASTNode> getSelectedASTNode() {
+    assert(SelectionStack.size() == 1 && "stack was not popped");
+    SelectedASTNode Result = std::move(SelectionStack.back());
+    SelectionStack.pop_back();
+    if (Result.Children.empty())
+      return None;
+    return Result;
+  }
+
+  bool TraverseLexicallyOrderedDecl(Decl *D) {
+    if (isa<TranslationUnitDecl>(D))
+      return LexicallyOrderedRecursiveASTVisitor::TraverseLexicallyOrderedDecl(
+          D);
+    if (D->isImplicit())
+      return true;
+
+    // Check if this declaration is written in the file of interest.
+    const SourceRange DeclRange = D->getSourceRange();
+    const SourceManager &SM = Context.getSourceManager();
+    SourceLocation FileLoc;
+    if (DeclRange.getBegin().isMacroID() && !DeclRange.getEnd().isMacroID())
+      FileLoc = DeclRange.getEnd();
+    else
+      FileLoc = SM.getSpellingLoc(DeclRange.getBegin());
+    if (SM.getFileID(FileLoc) != TargetFile)
+      return true;
+
+    // FIXME (Alex Lorenz): Add location adjustment for ObjCImplDecls.
+    SourceSelectionKind SelectionKind =
+        selectionKindFor(CharSourceRange::getTokenRange(D->getSourceRange()));
+    SelectionStack.push_back(
+        SelectedASTNode(DynTypedNode::create(*D), SelectionKind));
+    LexicallyOrderedRecursiveASTVisitor::TraverseLexicallyOrderedDecl(D);
+    popAndAddToSelectionIfSelected(SelectionKind);
+
+    if (DeclRange.getEnd().isValid() &&
+        SM.isBeforeInTranslationUnit(SelectionEnd.isValid() ? SelectionEnd
+                                                            : SelectionBegin,
+                                     DeclRange.getEnd())) {
+      // Stop early when we've reached a declaration after the selection.
+      return false;
+    }
+    return true;
+  }
+
+  bool TraverseStmt(Stmt *S) {
+    if (!S)
+      return true;
+    // FIXME (Alex Lorenz): Improve handling for macro locations.
+    SourceSelectionKind SelectionKind =
+        selectionKindFor(CharSourceRange::getTokenRange(S->getSourceRange()));
+    SelectionStack.push_back(
+        SelectedASTNode(DynTypedNode::create(*S), SelectionKind));
+    LexicallyOrderedRecursiveASTVisitor::TraverseStmt(S);
+    popAndAddToSelectionIfSelected(SelectionKind);
+    return true;
+  }
+
+private:
+  void popAndAddToSelectionIfSelected(SourceSelectionKind SelectionKind) {
+    SelectedASTNode Node = std::move(SelectionStack.back());
+    SelectionStack.pop_back();
+    if (SelectionKind != SourceSelectionKind::None || !Node.Children.empty())
+      SelectionStack.back().Children.push_back(std::move(Node));
+  }
+
+  SourceSelectionKind selectionKindFor(CharSourceRange Range) {
+    SourceLocation End = Range.getEnd();
+    const SourceManager &SM = Context.getSourceManager();
+    if (Range.isTokenRange())
+      End = Lexer::getLocForEndOfToken(End, 0, SM, Context.getLangOpts());
+    if (!SourceLocation::isPairOfFileLocations(Range.getBegin(), End))
+      return SourceSelectionKind::None;
+    if (!SelectionEnd.isValid()) {
+      // Do a quick check when the selection is of length 0.
+      if (SM.isPointWithin(SelectionBegin, Range.getBegin(), End))
+        return SourceSelectionKind::ContainsSelection;
+      return SourceSelectionKind::None;
+    }
+    bool HasStart = SM.isPointWithin(SelectionBegin, Range.getBegin(), End);
+    bool HasEnd = SM.isPointWithin(SelectionEnd, Range.getBegin(), End);
+    if (HasStart && HasEnd)
+      return SourceSelectionKind::ContainsSelection;
+    if (SM.isPointWithin(Range.getBegin(), SelectionBegin, SelectionEnd) &&
+        SM.isPointWithin(End, SelectionBegin, SelectionEnd))
+      return SourceSelectionKind::InsideSelection;
+    // Ensure there's at least some overlap with the 'start'/'end' selection
+    // types.
+    if (HasStart && SelectionBegin != End)
+      return SourceSelectionKind::ContainsSelectionStart;
+    if (HasEnd && SelectionEnd != Range.getBegin())
+      return SourceSelectionKind::ContainsSelectionEnd;
+
+    return SourceSelectionKind::None;
+  }
+
+  const SourceLocation SelectionBegin, SelectionEnd;
+  FileID TargetFile;
+  const ASTContext &Context;
+  std::vector<SelectedASTNode> SelectionStack;
+};
+
+} // end anonymous namespace
+
+Optional<SelectedASTNode>
+clang::tooling::findSelectedASTNodes(const ASTContext &Context,
+                                     SourceRange SelectionRange) {
+  assert(SelectionRange.isValid() &&
+         SourceLocation::isPairOfFileLocations(SelectionRange.getBegin(),
+                                               SelectionRange.getEnd()) &&
+         "Expected a file range");
+  FileID TargetFile =
+      Context.getSourceManager().getFileID(SelectionRange.getBegin());
+  assert(Context.getSourceManager().getFileID(SelectionRange.getEnd()) ==
+             TargetFile &&
+         "selection range must span one file");
+
+  ASTSelectionFinder Visitor(SelectionRange, TargetFile, Context);
+  Visitor.TraverseDecl(Context.getTranslationUnitDecl());
+  return Visitor.getSelectedASTNode();
+}
+
+static const char *selectionKindToString(SourceSelectionKind Kind) {
+  switch (Kind) {
+  case SourceSelectionKind::None:
+    return "none";
+  case SourceSelectionKind::ContainsSelection:
+    return "contains-selection";
+  case SourceSelectionKind::ContainsSelectionStart:
+    return "contains-selection-start";
+  case SourceSelectionKind::ContainsSelectionEnd:
+    return "contains-selection-end";
+  case SourceSelectionKind::InsideSelection:
+    return "inside";
+  }
+  llvm_unreachable("invalid selection kind");
+}
+
+static void dump(const SelectedASTNode &Node, llvm::raw_ostream &OS,
+                 unsigned Indent = 0) {
+  OS.indent(Indent * 2);
+  if (const Decl *D = Node.Node.get<Decl>()) {
+    OS << D->getDeclKindName() << "Decl";
+    if (const auto *ND = dyn_cast<NamedDecl>(D))
+      OS << " \"" << ND->getNameAsString() << '"';
+  } else if (const Stmt *S = Node.Node.get<Stmt>()) {
+    OS << S->getStmtClassName();
+  }
+  OS << ' ' << selectionKindToString(Node.SelectionKind) << "\n";
+  for (const auto &Child : Node.Children)
+    dump(Child, OS, Indent + 1);
+}
+
+void SelectedASTNode::dump(llvm::raw_ostream &OS) const { ::dump(*this, OS); }
Index: include/clang/Tooling/Refactoring/ASTSelection.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/ASTSelection.h
@@ -0,0 +1,74 @@
+//===--- ASTSelection.h - Clang refactoring library -----------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_AST_SELECTION_H
+#define LLVM_CLANG_TOOLING_REFACTOR_AST_SELECTION_H
+
+#include "clang/AST/ASTTypeTraits.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
+#include <vector>
+
+namespace clang {
+
+class ASTContext;
+
+namespace tooling {
+
+enum class SourceSelectionKind {
+  /// A node that's not selected.
+  None,
+
+  /// A node that's considered to be selected because the whole selection range
+  /// is inside of its source range.
+  ContainsSelection,
+  /// A node that's considered to be selected because the start of the selection
+  /// range is inside its source range.
+  ContainsSelectionStart,
+  /// A node that's considered to be selected because the end of the selection
+  /// range is inside its source range.
+  ContainsSelectionEnd,
+
+  /// A node that's considered to be selected because the node is entirely in
+  /// the selection range.
+  InsideSelection,
+};
+
+/// Represents a selected AST node.
+///
+/// AST selection is represented using a tree of \c SelectedASTNode. The tree
+/// follows the top-down shape of the actual AST. Each selected node has
+/// a selection kind. The kind might be none as the node itself might not
+/// actually be selected, e.g. a statement in macro whose child is in a macro
+/// argument.
+struct SelectedASTNode {
+  ast_type_traits::DynTypedNode Node;
+  SourceSelectionKind SelectionKind;
+  std::vector<SelectedASTNode> Children;
+
+  SelectedASTNode(const ast_type_traits::DynTypedNode &Node,
+                  SourceSelectionKind SelectionKind)
+      : Node(Node), SelectionKind(SelectionKind) {}
+  SelectedASTNode(SelectedASTNode &&) = default;
+  SelectedASTNode &operator=(SelectedASTNode &&) = default;
+
+  void dump(llvm::raw_ostream &OS = llvm::errs()) const;
+};
+
+/// Traverses the given ASTContext and creates a tree of selected AST nodes.
+///
+/// \returns None if no nodes are selected in the AST, or a selected AST node
+/// that corresponds to the TranslationUnitDecl otherwise.
+Optional<SelectedASTNode> findSelectedASTNodes(const ASTContext &Context,
+                                               SourceRange SelectionRange);
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_AST_SELECTION_H
Index: include/clang/Basic/SourceManager.h
===================================================================
--- include/clang/Basic/SourceManager.h
+++ include/clang/Basic/SourceManager.h
@@ -1520,6 +1520,14 @@
     return LHSLoaded;
   }
 
+  /// Return true if the Point is within Start and End.
+  bool isPointWithin(SourceLocation Location, SourceLocation Start,
+                     SourceLocation End) const {
+    return Location == Start || Location == End ||
+           (isBeforeInTranslationUnit(Start, Location) &&
+            isBeforeInTranslationUnit(Location, End));
+  }
+
   // Iterators over FileInfos.
   typedef llvm::DenseMap<const FileEntry*, SrcMgr::ContentCache*>
       ::const_iterator fileinfo_iterator;
Index: include/clang/Basic/SourceLocation.h
===================================================================
--- include/clang/Basic/SourceLocation.h
+++ include/clang/Basic/SourceLocation.h
@@ -172,6 +172,11 @@
     return getFromRawEncoding((unsigned)(uintptr_t)Encoding);
   }
 
+  static bool isPairOfFileLocations(SourceLocation Start, SourceLocation End) {
+    return Start.isValid() && Start.isFileID() && End.isValid() &&
+           End.isFileID();
+  }
+
   void print(raw_ostream &OS, const SourceManager &SM) const;
   std::string printToString(const SourceManager &SM) const;
   void dump(const SourceManager &SM) const;
Index: include/clang/AST/LexicallyOrderedRecursiveASTVisitor.h
===================================================================
--- /dev/null
+++ include/clang/AST/LexicallyOrderedRecursiveASTVisitor.h
@@ -0,0 +1,141 @@
+//===--- LexicallyOrderedRecursiveASTVisitor.h - ----------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the LexicallyOrderedRecursiveASTVisitor interface, which
+//  recursively traverses the entire AST in a lexical order.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_LEXICALLY_ORDERED_RECURSIVEASTVISITOR_H
+#define LLVM_CLANG_AST_LEXICALLY_ORDERED_RECURSIVEASTVISITOR_H
+
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceManager.h"
+#include "llvm/Support/SaveAndRestore.h"
+
+namespace clang {
+
+/// A RecursiveASTVisitor subclass that guarantees that AST traversal is
+/// performed in a lexical order (i.e. the order in which declarations are
+/// written in the source).
+///
+/// RecursiveASTVisitor doesn't guarantee lexical ordering because there are
+/// some declarations, like Objective-C @implementation declarations
+/// that might be represented in the AST differently to how they were written
+/// in the source.
+/// In particular, Objective-C @implementation declarations may contain
+/// non-Objective-C declarations, like functions:
+///
+///   @implementation MyClass
+///
+///   - (void) method { }
+///   void normalFunction() { }
+///
+///   @end
+///
+/// Clang's AST stores these declarations outside of the @implementation
+/// declaration, so the example above would be represented using the following
+/// AST:
+///   |-ObjCImplementationDecl ... MyClass
+///   | `-ObjCMethodDecl ... method
+///   |    ...
+///   `-FunctionDecl ... normalFunction
+///       ...
+///
+/// This class ensures that these declarations are traversed before the
+/// corresponding TraverseDecl for the @implementation returns. This ensures
+/// that the lexical parent relationship between these declarations and the
+/// @implementation is preserved while traversing the AST. Note that the
+/// current implementation doesn't mix these declarations with the declarations
+/// contained in the @implementation, so the traversal of all of the
+/// declarations in the @implementation still doesn't follow the lexical order.
+///
+/// Subclasses should declare TraverseLexicallyOrderedDecl instead of
+/// TraverseDecl if they need to override the behaviour of TraverseDecl.
+template <typename Derived>
+class LexicallyOrderedRecursiveASTVisitor
+    : public RecursiveASTVisitor<Derived> {
+  using BaseType = RecursiveASTVisitor<Derived>;
+
+public:
+  LexicallyOrderedRecursiveASTVisitor(const SourceManager &SM) : SM(SM) {}
+
+  bool TraverseLexicallyOrderedDecl(Decl *D) {
+    return BaseType::TraverseDecl(D);
+  }
+
+  bool TraverseDecl(Decl *D) { // final
+    if (!D)
+      return true;
+    if (Traversed.count(D))
+      return true;
+    llvm::SaveAndRestore<std::pair<const Decl *, Decl *>> ParentTracker(
+        ParentAndCurrentDecl, std::make_pair(ParentAndCurrentDecl.second, D));
+    return (static_cast<Derived *>(this))->TraverseLexicallyOrderedDecl(D);
+  }
+
+  bool TraverseObjCImplementationDecl(ObjCImplementationDecl *D) {
+    GatherAdditionalLexicallyNestedDeclarations(D);
+    if (!BaseType::TraverseObjCImplementationDecl(D))
+      return false;
+    return TraverseAdditionalLexicallyNestedDeclarations();
+  }
+
+  bool TraverseObjCCategoryImplDecl(ObjCCategoryImplDecl *D) {
+    GatherAdditionalLexicallyNestedDeclarations(D);
+    if (!BaseType::TraverseObjCCategoryImplDecl(D))
+      return false;
+    return TraverseAdditionalLexicallyNestedDeclarations();
+  }
+
+private:
+  void GatherAdditionalLexicallyNestedDeclarations(ObjCImplDecl *D) {
+    const Decl *Parent = ParentAndCurrentDecl.first;
+    assert(Parent && "Lexically ordered traversal should start with the "
+                     "TranslationUnitDecl");
+    auto SelfIterator =
+        llvm::find_if(cast<DeclContext>(Parent)->decls(),
+                      [=](Decl *Sibling) { return Sibling == D; });
+    ++SelfIterator;
+    LexicallyNestedDeclarations.clear();
+    for (Decl *Sibling : DeclContext::decl_range(
+             SelfIterator, cast<DeclContext>(Parent)->decls_end())) {
+      if (!SM.isBeforeInTranslationUnit(Sibling->getLocStart(), D->getLocEnd()))
+        break;
+      LexicallyNestedDeclarations.push_back(Sibling);
+    }
+  }
+
+  bool TraverseAdditionalLexicallyNestedDeclarations() {
+    // FIXME: Ideally the gathered declarations and the declarations in the
+    // @implementation should be mixed and sorted to get a true lexical order,
+    // but right now we only care about getting the correct lexical parent, so
+    // we can traverse the gathered nested declarations after the declarations
+    // in the decl context.
+    assert(!BaseType::getDerived().shouldTraversePostOrder() &&
+           "post-order traversal is not supported for lexically ordered "
+           "recursive ast visitor");
+    for (Decl *D : LexicallyNestedDeclarations) {
+      if (!BaseType::getDerived().TraverseDecl(D))
+        return false;
+      Traversed.insert(D);
+    }
+    return true;
+  }
+
+  const SourceManager &SM;
+  llvm::DenseSet<const Decl *> Traversed;
+  std::pair<const Decl *, Decl *> ParentAndCurrentDecl = {nullptr, nullptr};
+  llvm::SmallVector<Decl *, 8> LexicallyNestedDeclarations;
+};
+
+} // end namespace clang
+
+#endif // LLVM_CLANG_AST_LEXICALLY_ORDERED_RECURSIVEASTVISITOR_H
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to