arphaman updated this revision to Diff 107067.
arphaman marked 7 inline comments as done.
arphaman added a comment.

- Address review comments.
- Remove the `Location` parameter and `ContainsSelectionPoint` enum value.
- Stop traversing early when a declaration that ends after the selection range 
was reached.


Repository:
  rL LLVM

https://reviews.llvm.org/D35012

Files:
  include/clang/Basic/SourceLocation.h
  include/clang/Basic/SourceManager.h
  include/clang/Tooling/Refactoring/ASTSelection.h
  lib/Tooling/Refactoring/ASTSelection.cpp
  lib/Tooling/Refactoring/CMakeLists.txt
  unittests/Tooling/ASTSelectionTest.cpp
  unittests/Tooling/CMakeLists.txt

Index: unittests/Tooling/CMakeLists.txt
===================================================================
--- unittests/Tooling/CMakeLists.txt
+++ unittests/Tooling/CMakeLists.txt
@@ -11,6 +11,7 @@
 endif()
 
 add_clang_unittest(ToolingTests
+  ASTSelectionTest.cpp
   CastExprTest.cpp
   CommentHandlerTest.cpp
   CompilationDatabaseTest.cpp
Index: unittests/Tooling/ASTSelectionTest.cpp
===================================================================
--- /dev/null
+++ unittests/Tooling/ASTSelectionTest.cpp
@@ -0,0 +1,456 @@
+//===- unittest/Tooling/ASTSelectionTest.cpp ------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestVisitor.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Tooling/Refactoring/ASTSelection.h"
+
+using namespace clang;
+using namespace tooling;
+
+namespace {
+
+struct FileLocation {
+  unsigned Line, Column;
+
+  SourceLocation translate(const SourceManager &SM) {
+    return SM.translateLineCol(SM.getMainFileID(), Line, Column);
+  }
+};
+
+using FileRange = std::pair<FileLocation, FileLocation>;
+
+class SelectionFinderVisitor : public TestVisitor<SelectionFinderVisitor> {
+  FileLocation Location;
+  Optional<FileRange> SelectionRange;
+
+public:
+  Optional<SelectedASTNode> Selection;
+
+  SelectionFinderVisitor(FileLocation Location,
+                         Optional<FileRange> SelectionRange)
+      : Location(Location), SelectionRange(SelectionRange) {}
+
+  bool VisitTranslationUnitDecl(const TranslationUnitDecl *TU) {
+    const ASTContext &Context = TU->getASTContext();
+    const SourceManager &SM = Context.getSourceManager();
+
+    SourceRange SelRange;
+    if (SelectionRange) {
+      SelRange = SourceRange(SelectionRange->first.translate(SM),
+                             SelectionRange->second.translate(SM));
+    } else {
+      SourceLocation Loc = Location.translate(SM);
+      SelRange = SourceRange(Loc, Loc);
+    }
+    Selection = findSelectedASTNodes(Context, SelRange);
+    return false;
+  }
+};
+
+Optional<SelectedASTNode>
+findSelectedASTNodes(StringRef Source, FileLocation Location,
+                     Optional<FileRange> SelectionRange,
+                     SelectionFinderVisitor::Language Language =
+                         SelectionFinderVisitor::Lang_CXX11) {
+  SelectionFinderVisitor Visitor(Location, SelectionRange);
+  EXPECT_TRUE(Visitor.runOver(Source, Language));
+  return std::move(Visitor.Selection);
+}
+
+void checkNodeImpl(bool IsTypeMatched, const SelectedASTNode &Node,
+                   SourceSelectionKind SelectionKind, unsigned NumChildren) {
+  ASSERT_TRUE(IsTypeMatched);
+  EXPECT_EQ(Node.Children.size(), NumChildren);
+  ASSERT_EQ(Node.SelectionKind, SelectionKind);
+}
+
+void checkDeclName(const SelectedASTNode &Node, StringRef Name) {
+  const auto *ND = Node.Node.get<NamedDecl>();
+  EXPECT_TRUE(!!ND);
+  ASSERT_EQ(ND->getName(), Name);
+}
+
+template <typename T>
+const SelectedASTNode &
+checkNode(const SelectedASTNode &StmtNode, SourceSelectionKind SelectionKind,
+          unsigned NumChildren = 0,
+          typename std::enable_if<std::is_base_of<Stmt, T>::value, T>::type
+              *StmtOverloadChecker = nullptr) {
+  checkNodeImpl(isa<T>(StmtNode.Node.get<Stmt>()), StmtNode, SelectionKind,
+                NumChildren);
+  return StmtNode;
+}
+
+template <typename T>
+const SelectedASTNode &
+checkNode(const SelectedASTNode &DeclNode, SourceSelectionKind SelectionKind,
+          unsigned NumChildren = 0, StringRef Name = "",
+          typename std::enable_if<std::is_base_of<Decl, T>::value, T>::type
+              *DeclOverloadChecker = nullptr) {
+  checkNodeImpl(isa<T>(DeclNode.Node.get<Decl>()), DeclNode, SelectionKind,
+                NumChildren);
+  if (!Name.empty())
+    checkDeclName(DeclNode, Name);
+  return DeclNode;
+}
+
+struct ForAllChildrenOf {
+  const SelectedASTNode &Node;
+
+  static void childKindVerifier(const SelectedASTNode &Node,
+                                SourceSelectionKind SelectionKind) {
+    for (const SelectedASTNode &Child : Node.Children) {
+      ASSERT_EQ(Node.SelectionKind, SelectionKind);
+      childKindVerifier(Child, SelectionKind);
+    }
+  }
+
+public:
+  ForAllChildrenOf(const SelectedASTNode &Node) : Node(Node) {}
+
+  void shouldHaveSelectionKind(SourceSelectionKind Kind) {
+    childKindVerifier(Node, Kind);
+  }
+};
+
+ForAllChildrenOf allChildrenOf(const SelectedASTNode &Node) {
+  return ForAllChildrenOf(Node);
+}
+
+TEST(ASTSelectionFinder, CursorNoSelection) {
+  Optional<SelectedASTNode> Node =
+      findSelectedASTNodes(" void f() { }", {1, 1}, None);
+  EXPECT_FALSE(Node);
+}
+
+TEST(ASTSelectionFinder, CursorAtStartOfFunction) {
+  Optional<SelectedASTNode> Node =
+      findSelectedASTNodes("void f() { }", {1, 1}, None);
+  EXPECT_TRUE(Node);
+  checkNode<TranslationUnitDecl>(*Node, SourceSelectionKind::None,
+                                 /*NumChildren=*/1);
+  checkNode<FunctionDecl>(Node->Children[0],
+                          SourceSelectionKind::ContainsSelection,
+                          /*NumChildren=*/0, /*Name=*/"f");
+
+  // Check that the dumping works.
+  std::string DumpValue;
+  llvm::raw_string_ostream OS(DumpValue);
+  Node->Children[0].dump(OS);
+  ASSERT_EQ(OS.str(), "FunctionDecl \"f\" contains-selection\n");
+}
+
+TEST(ASTSelectionFinder, RangeNoSelection) {
+  {
+    Optional<SelectedASTNode> Node = findSelectedASTNodes(
+        " void f() { }", {1, 1}, FileRange{{1, 1}, {1, 1}});
+    EXPECT_FALSE(Node);
+  }
+  {
+    Optional<SelectedASTNode> Node = findSelectedASTNodes(
+        "  void f() { }", {1, 1}, FileRange{{1, 1}, {1, 2}});
+    EXPECT_FALSE(Node);
+  }
+}
+
+TEST(ASTSelectionFinder, EmptyRangeFallbackToCursor) {
+  Optional<SelectedASTNode> Node =
+      findSelectedASTNodes("void f() { }", {1, 1}, FileRange{{1, 1}, {1, 1}});
+  EXPECT_TRUE(Node);
+  checkNode<FunctionDecl>(Node->Children[0],
+                          SourceSelectionKind::ContainsSelection,
+                          /*NumChildren=*/0, /*Name=*/"f");
+}
+
+TEST(ASTSelectionFinder, WholeFunctionSelection) {
+  StringRef Source = "int f(int x) { return x;\n}\nvoid f2() { }";
+  // From 'int' until just after '}':
+  {
+    auto Node = findSelectedASTNodes(Source, {1, 1}, FileRange{{1, 1}, {2, 2}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2, /*Name=*/"f");
+    checkNode<ParmVarDecl>(Fn.Children[0],
+                           SourceSelectionKind::InsideSelection);
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[1], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1);
+    const auto &Return = checkNode<ReturnStmt>(
+        Body.Children[0], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1);
+    checkNode<ImplicitCastExpr>(Return.Children[0],
+                                SourceSelectionKind::InsideSelection,
+                                /*NumChildren=*/1);
+    checkNode<DeclRefExpr>(Return.Children[0].Children[0],
+                           SourceSelectionKind::InsideSelection);
+  }
+  // From 'int' until just before '}':
+  {
+    auto Node = findSelectedASTNodes(Source, {2, 1}, FileRange{{1, 1}, {2, 1}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[1], SourceSelectionKind::ContainsSelectionEnd,
+        /*NumChildren=*/1);
+    checkNode<ReturnStmt>(Body.Children[0],
+                          SourceSelectionKind::InsideSelection,
+                          /*NumChildren=*/1);
+  }
+  // From '{' until just after '}':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {1, 14}, FileRange{{1, 14}, {2, 2}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1);
+    checkNode<ReturnStmt>(Body.Children[0],
+                          SourceSelectionKind::InsideSelection,
+                          /*NumChildren=*/1);
+  }
+  // From 'x' until just after '}':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {2, 2}, FileRange{{1, 11}, {2, 2}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2, /*Name=*/"f");
+    checkNode<ParmVarDecl>(Fn.Children[0],
+                           SourceSelectionKind::ContainsSelectionStart);
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[1], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1);
+    checkNode<ReturnStmt>(Body.Children[0],
+                          SourceSelectionKind::InsideSelection,
+                          /*NumChildren=*/1);
+  }
+}
+
+TEST(ASTSelectionFinder, MultipleFunctionSelection) {
+  StringRef Source = R"(void f0() {
+}
+void f1() { }
+void f2() { }
+void f3() { }
+)";
+  auto SelectedF1F2 = [](Optional<SelectedASTNode> Node) {
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 2u);
+    checkNode<FunctionDecl>(Node->Children[0],
+                            SourceSelectionKind::InsideSelection,
+                            /*NumChildren=*/1, /*Name=*/"f1");
+    checkNode<FunctionDecl>(Node->Children[1],
+                            SourceSelectionKind::InsideSelection,
+                            /*NumChildren=*/1, /*Name=*/"f2");
+  };
+  // Just after '}' of f0 and just before 'void' of f3:
+  SelectedF1F2(findSelectedASTNodes(Source, {2, 2}, FileRange{{2, 2}, {5, 1}}));
+  // Just before 'void' of f1 and just after '}' of f2:
+  SelectedF1F2(
+      findSelectedASTNodes(Source, {3, 1}, FileRange{{3, 1}, {4, 14}}));
+}
+
+TEST(ASTSelectionFinder, MultipleStatementSelection) {
+  StringRef Source = R"(void f(int x, int y) {
+  int z = x;
+  f(2, 3);
+  if (x == 0) {
+    return;
+  }
+  x = 1;
+  return;
+})";
+  // From 'f(2,3)' until just before 'x = 1;':
+  {
+    auto Node = findSelectedASTNodes(Source, {3, 2}, FileRange{{3, 2}, {7, 1}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/2);
+    allChildrenOf(checkNode<CallExpr>(Body.Children[0],
+                                      SourceSelectionKind::InsideSelection,
+                                      /*NumChildren=*/3))
+        .shouldHaveSelectionKind(SourceSelectionKind::InsideSelection);
+    allChildrenOf(checkNode<IfStmt>(Body.Children[1],
+                                    SourceSelectionKind::InsideSelection,
+                                    /*NumChildren=*/2))
+        .shouldHaveSelectionKind(SourceSelectionKind::InsideSelection);
+  }
+  // From 'f(2,3)' until just before ';' in 'x = 1;':
+  {
+    auto Node = findSelectedASTNodes(Source, {3, 2}, FileRange{{3, 2}, {7, 8}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/3);
+    checkNode<CallExpr>(Body.Children[0], SourceSelectionKind::InsideSelection,
+                        /*NumChildren=*/3);
+    checkNode<IfStmt>(Body.Children[1], SourceSelectionKind::InsideSelection,
+                      /*NumChildren=*/2);
+    checkNode<BinaryOperator>(Body.Children[2],
+                              SourceSelectionKind::InsideSelection,
+                              /*NumChildren=*/2);
+  }
+  // From the middle of 'int z = 3' until the middle of 'x = 1;':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {2, 10}, FileRange{{2, 10}, {7, 5}});
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Fn = checkNode<FunctionDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"f");
+    const auto &Body = checkNode<CompoundStmt>(
+        Fn.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/4);
+    checkNode<DeclStmt>(Body.Children[0],
+                        SourceSelectionKind::ContainsSelectionStart,
+                        /*NumChildren=*/1);
+    checkNode<CallExpr>(Body.Children[1], SourceSelectionKind::InsideSelection,
+                        /*NumChildren=*/3);
+    checkNode<IfStmt>(Body.Children[2], SourceSelectionKind::InsideSelection,
+                      /*NumChildren=*/2);
+    checkNode<BinaryOperator>(Body.Children[3],
+                              SourceSelectionKind::ContainsSelectionEnd,
+                              /*NumChildren=*/1);
+  }
+}
+
+TEST(ASTSelectionFinder, SelectionInFunctionInObjCImplementation) {
+  StringRef Source = R"(
+@interface I
+@end
+@implementation I
+
+int notSelected() { }
+
+int selected(int x) {
+  return x;
+}
+
+@end
+@implementation I(Cat)
+
+void catF() { }
+
+@end
+
+void outerFunction() { }
+)";
+  // Just the 'x' expression in 'selected':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {9, 10}, FileRange{{9, 10}, {9, 11}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Impl = checkNode<ObjCImplementationDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"I");
+    const auto &Fn = checkNode<FunctionDecl>(
+        Impl.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"selected");
+    allChildrenOf(Fn).shouldHaveSelectionKind(
+        SourceSelectionKind::ContainsSelection);
+  }
+  // The entire 'catF':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {15, 1}, FileRange{{15, 1}, {15, 16}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    const auto &Impl = checkNode<ObjCCategoryImplDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"Cat");
+    const auto &Fn = checkNode<FunctionDecl>(
+        Impl.Children[0], SourceSelectionKind::ContainsSelection,
+        /*NumChildren=*/1, /*Name=*/"catF");
+    allChildrenOf(Fn).shouldHaveSelectionKind(
+        SourceSelectionKind::ContainsSelection);
+  }
+  // From the line before 'selected' to the line after 'catF':
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {16, 1}, FileRange{{7, 1}, {16, 1}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 2u);
+    const auto &Impl = checkNode<ObjCImplementationDecl>(
+        Node->Children[0], SourceSelectionKind::ContainsSelectionStart,
+        /*NumChildren=*/1, /*Name=*/"I");
+    const auto &Selected = checkNode<FunctionDecl>(
+        Impl.Children[0], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/2, /*Name=*/"selected");
+    allChildrenOf(Selected).shouldHaveSelectionKind(
+        SourceSelectionKind::InsideSelection);
+    const auto &Cat = checkNode<ObjCCategoryImplDecl>(
+        Node->Children[1], SourceSelectionKind::ContainsSelectionEnd,
+        /*NumChildren=*/1, /*Name=*/"Cat");
+    const auto &CatF = checkNode<FunctionDecl>(
+        Cat.Children[0], SourceSelectionKind::InsideSelection,
+        /*NumChildren=*/1, /*Name=*/"catF");
+    allChildrenOf(CatF).shouldHaveSelectionKind(
+        SourceSelectionKind::InsideSelection);
+  }
+  // Just the 'outer' function:
+  {
+    auto Node =
+        findSelectedASTNodes(Source, {19, 1}, FileRange{{19, 1}, {19, 25}},
+                             SelectionFinderVisitor::Lang_OBJC);
+    EXPECT_TRUE(Node);
+    EXPECT_EQ(Node->Children.size(), 1u);
+    checkNode<FunctionDecl>(Node->Children[0],
+                            SourceSelectionKind::ContainsSelection,
+                            /*NumChildren=*/1, /*Name=*/"outerFunction");
+  }
+}
+
+TEST(ASTSelectionFinder, AvoidImplicitDeclarations) {
+  StringRef Source = R"(
+struct Copy {
+  int x;
+};
+void foo() {
+  Copy x;
+  Copy y = x;
+}
+)";
+  // The entire struct 'Copy':
+  auto Node = findSelectedASTNodes(Source, {2, 1}, FileRange{{2, 1}, {4, 3}});
+  EXPECT_TRUE(Node);
+  EXPECT_EQ(Node->Children.size(), 1u);
+  const auto &Record = checkNode<CXXRecordDecl>(
+      Node->Children[0], SourceSelectionKind::InsideSelection,
+      /*NumChildren=*/1, /*Name=*/"Copy");
+  checkNode<FieldDecl>(Record.Children[0],
+                       SourceSelectionKind::InsideSelection);
+}
+
+} // end anonymous namespace
Index: lib/Tooling/Refactoring/CMakeLists.txt
===================================================================
--- lib/Tooling/Refactoring/CMakeLists.txt
+++ lib/Tooling/Refactoring/CMakeLists.txt
@@ -4,6 +4,7 @@
   )
 
 add_clang_library(clangToolingRefactor
+  ASTSelection.cpp
   AtomicChange.cpp
   Rename/RenamingAction.cpp
   Rename/USRFinder.cpp
Index: lib/Tooling/Refactoring/ASTSelection.cpp
===================================================================
--- /dev/null
+++ lib/Tooling/Refactoring/ASTSelection.cpp
@@ -0,0 +1,236 @@
+//===--- ASTSelection.cpp - Clang refactoring library ---------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/Refactoring/ASTSelection.h"
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Lex/Lexer.h"
+
+using namespace clang;
+using namespace tooling;
+using ast_type_traits::DynTypedNode;
+
+namespace {
+
+/// Constructs the tree of selected AST nodes that either contain the location
+/// of the cursor or overlap with the selection range.
+class ASTSelectionFinder : public RecursiveASTVisitor<ASTSelectionFinder> {
+public:
+  ASTSelectionFinder(SourceRange Selection, const ASTContext &Context)
+      : SelectionBegin(Selection.getBegin()),
+        SelectionEnd(Selection.getBegin() == Selection.getEnd()
+                         ? SourceLocation()
+                         : Selection.getEnd()),
+        Context(Context) {
+    // The TU decl is the root of the selected node tree.
+    SelectionStack.push_back(
+        SelectedASTNode(DynTypedNode::create(*Context.getTranslationUnitDecl()),
+                        SourceSelectionKind::None));
+  }
+
+  unsigned getNumTopLevelMatches() const {
+    return SelectionStack[0].Children.size();
+  }
+
+  Optional<SelectedASTNode> getSelectedASTNode() {
+    assert(SelectionStack.size() == 1 && "stack was not popped");
+    SelectedASTNode Result = std::move(SelectionStack.back());
+    SelectionStack.pop_back();
+    if (Result.Children.empty())
+      return None;
+    return Result;
+  }
+
+  bool TraverseDecl(Decl *D) {
+    if (!D)
+      return true;
+    if (D->isImplicit())
+      return true;
+    // FIXME (Alex Lorenz): Add location adjustment for ObjCImplDecls.
+    SourceSelectionKind SelectionKind =
+        selectionKindFor(CharSourceRange::getTokenRange(D->getSourceRange()));
+    SelectionStack.push_back(
+        SelectedASTNode(DynTypedNode::create(*D), SelectionKind));
+    RecursiveASTVisitor::TraverseDecl(D);
+    popAndAddToSelectionIfSelected(SelectionKind);
+    return true;
+  }
+
+  /// Traverses the given declaration and, if it's selected, adds it to the
+  /// children of the last child of the node on top of the selection stack
+  /// instead of children of the node on top of the selection stack.
+  ///
+  /// This is useful when traversing the AST nodes that are not ordered in the
+  /// lexical order, e.g. declarations that are located in Objective-C
+  /// @implementation declarations are stored after the @implementation in the
+  /// AST. The tree that models the selected AST puts those declarations into
+  /// the @implementation to have a more accurate lexical representation of
+  /// the source.
+  void TraverseDeclInPrevious(Decl *D) {
+    assert(!SelectionStack.back().Children.empty() &&
+           "No previous declaration");
+    SelectedASTNode &Previous = SelectionStack.back().Children.back();
+    SelectionStack.push_back(
+        SelectedASTNode(Previous.Node, Previous.SelectionKind));
+    TraverseDecl(D);
+    std::vector<SelectedASTNode> Children =
+        std::move(SelectionStack.back().Children);
+    SelectionStack.pop_back();
+    for (auto &&Child : Children)
+      Previous.Children.push_back(std::move(Child));
+  }
+
+  bool TraverseStmt(Stmt *S) {
+    if (!S)
+      return true;
+    // FIXME (Alex Lorenz): Improve handling for macro locations.
+    SourceSelectionKind SelectionKind =
+        selectionKindFor(CharSourceRange::getTokenRange(S->getSourceRange()));
+    SelectionStack.push_back(
+        SelectedASTNode(DynTypedNode::create(*S), SelectionKind));
+    RecursiveASTVisitor::TraverseStmt(S);
+    popAndAddToSelectionIfSelected(SelectionKind);
+    return true;
+  }
+
+private:
+  void popAndAddToSelectionIfSelected(SourceSelectionKind SelectionKind) {
+    SelectedASTNode Node = std::move(SelectionStack.back());
+    SelectionStack.pop_back();
+    if (SelectionKind != SourceSelectionKind::None || !Node.Children.empty())
+      SelectionStack.back().Children.push_back(std::move(Node));
+  }
+
+  SourceSelectionKind selectionKindFor(CharSourceRange Range) {
+    SourceLocation End = Range.getEnd();
+    const SourceManager &SM = Context.getSourceManager();
+    if (Range.isTokenRange())
+      End = Lexer::getLocForEndOfToken(End, 0, SM, Context.getLangOpts());
+    if (!SourceLocation::isPairOfFileLocations(Range.getBegin(), End))
+      return SourceSelectionKind::None;
+    if (!SelectionEnd.isValid()) {
+      // Do a quick check when the selection is of length 0.
+      if (SM.isPointWithin(SelectionBegin, Range.getBegin(), End))
+        return SourceSelectionKind::ContainsSelection;
+      return SourceSelectionKind::None;
+    }
+    bool HasStart = SM.isPointWithin(SelectionBegin, Range.getBegin(), End);
+    bool HasEnd = SM.isPointWithin(SelectionEnd, Range.getBegin(), End);
+    if (HasStart && HasEnd)
+      return SourceSelectionKind::ContainsSelection;
+    if (SM.isPointWithin(Range.getBegin(), SelectionBegin, SelectionEnd) &&
+        SM.isPointWithin(End, SelectionBegin, SelectionEnd))
+      return SourceSelectionKind::InsideSelection;
+    // Ensure there's at least some overlap with the 'start'/'end' selection
+    // types.
+    if (HasStart && SelectionBegin != End)
+      return SourceSelectionKind::ContainsSelectionStart;
+    if (HasEnd && SelectionEnd != Range.getBegin())
+      return SourceSelectionKind::ContainsSelectionEnd;
+
+    return SourceSelectionKind::None;
+  }
+
+  const SourceLocation SelectionBegin, SelectionEnd;
+  const ASTContext &Context;
+  std::vector<SelectedASTNode> SelectionStack;
+};
+
+} // end anonymous namespace
+
+Optional<SelectedASTNode>
+clang::tooling::findSelectedASTNodes(const ASTContext &Context,
+                                     SourceRange SelectionRange) {
+  assert(SelectionRange.isValid() &&
+         SourceLocation::isPairOfFileLocations(SelectionRange.getBegin(),
+                                               SelectionRange.getEnd()) &&
+         "Expected a file range");
+  FileID TargetFile =
+      Context.getSourceManager().getFileID(SelectionRange.getBegin());
+  assert(Context.getSourceManager().getFileID(SelectionRange.getEnd()) ==
+             TargetFile &&
+         "selection range must span one file");
+
+  ASTSelectionFinder Visitor(SelectionRange, Context);
+  const SourceManager &SM = Context.getSourceManager();
+  SourceLocation ObjCImplEndLoc;
+  unsigned NumMatches = 0;
+  for (Decl *D : Context.getTranslationUnitDecl()->decls()) {
+    if (ObjCImplEndLoc.isValid() &&
+        !SM.isBeforeInTranslationUnit(D->getLocStart(), ObjCImplEndLoc))
+      ObjCImplEndLoc = SourceLocation();
+
+    // Check if this declaration is written in the file of interest.
+    const SourceRange DeclRange = D->getSourceRange();
+    SourceLocation FileLoc;
+    if (DeclRange.getBegin().isMacroID() && !DeclRange.getEnd().isMacroID())
+      FileLoc = DeclRange.getEnd();
+    else
+      FileLoc = SM.getSpellingLoc(DeclRange.getBegin());
+    if (SM.getFileID(FileLoc) == TargetFile) {
+      if (ObjCImplEndLoc.isValid())
+        Visitor.TraverseDeclInPrevious(D);
+      else
+        Visitor.TraverseDecl(D);
+    }
+
+    unsigned PrevNumMatches = NumMatches;
+    NumMatches = Visitor.getNumTopLevelMatches();
+    // Objective-C @implementation declarations might have trailing declarations
+    // that are written in the @implementation, but stored outside of it in the
+    // AST. Pretend that we are still traversing the ObjCImplDecl until we
+    // reach a declaration that's outside of the @implementation.
+    if (NumMatches != PrevNumMatches && isa<ObjCImplDecl>(D)) {
+      ObjCImplEndLoc = DeclRange.getEnd();
+      continue;
+    }
+    if (ObjCImplEndLoc.isInvalid() &&
+        ((DeclRange.getEnd().isFileID() && DeclRange.getEnd().isValid() &&
+          SM.isBeforeInTranslationUnit(SelectionRange.getEnd(),
+                                       DeclRange.getEnd())) ||
+         (PrevNumMatches && PrevNumMatches == NumMatches))) {
+      // Stop early when we've reached a declaration after the selection
+      // range or when we've stopped finding matches.
+      break;
+    }
+  }
+  return Visitor.getSelectedASTNode();
+}
+
+static const char *selectionKindToString(SourceSelectionKind Kind) {
+  switch (Kind) {
+  case SourceSelectionKind::None:
+    return "none";
+  case SourceSelectionKind::ContainsSelection:
+    return "contains-selection";
+  case SourceSelectionKind::ContainsSelectionStart:
+    return "contains-selection-start";
+  case SourceSelectionKind::ContainsSelectionEnd:
+    return "contains-selection-end";
+  case SourceSelectionKind::InsideSelection:
+    return "inside";
+  }
+  llvm_unreachable("invalid selection kind");
+}
+
+static void dump(const SelectedASTNode &Node, llvm::raw_ostream &OS,
+                 unsigned Indent = 0) {
+  OS.indent(Indent * 2);
+  if (const Decl *D = Node.Node.get<Decl>()) {
+    OS << D->getDeclKindName() << "Decl";
+    if (const auto *ND = dyn_cast<NamedDecl>(D))
+      OS << " \"" << ND->getNameAsString() << '"';
+  } else if (const Stmt *S = Node.Node.get<Stmt>()) {
+    OS << S->getStmtClassName();
+  }
+  OS << ' ' << selectionKindToString(Node.SelectionKind) << "\n";
+  for (const auto &Child : Node.Children)
+    dump(Child, OS, Indent + 1);
+}
+
+void SelectedASTNode::dump(llvm::raw_ostream &OS) const { ::dump(*this, OS); }
Index: include/clang/Tooling/Refactoring/ASTSelection.h
===================================================================
--- /dev/null
+++ include/clang/Tooling/Refactoring/ASTSelection.h
@@ -0,0 +1,74 @@
+//===--- ASTSelection.h - Clang refactoring library -----------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTOR_AST_SELECTION_H
+#define LLVM_CLANG_TOOLING_REFACTOR_AST_SELECTION_H
+
+#include "clang/AST/ASTTypeTraits.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
+#include <vector>
+
+namespace clang {
+
+class ASTContext;
+
+namespace tooling {
+
+enum class SourceSelectionKind {
+  /// A node that's not selected.
+  None,
+
+  /// A node that's considered to be selected because the whole selection range
+  /// is inside of its source range.
+  ContainsSelection,
+  /// A node that's considered to be selected because the start of the selection
+  /// range is inside its source range.
+  ContainsSelectionStart,
+  /// A node that's considered to be selected because the end of the selection
+  /// range is inside its source range.
+  ContainsSelectionEnd,
+
+  /// A node that's considered to be selected because the node is entirely in
+  /// the selection range.
+  InsideSelection,
+};
+
+/// Represents a selected AST node.
+///
+/// AST selection is represented using a tree of \c SelectedASTNode. The tree
+/// follows the top-down shape of the actual AST. Each selected node has
+/// a selection kind. The kind might be none as the node itself might not
+/// actually be selected, e.g. a statement in macro whose child is in a macro
+/// argument.
+struct SelectedASTNode {
+  ast_type_traits::DynTypedNode Node;
+  SourceSelectionKind SelectionKind;
+  std::vector<SelectedASTNode> Children;
+
+  SelectedASTNode(const ast_type_traits::DynTypedNode &Node,
+                  SourceSelectionKind SelectionKind)
+      : Node(Node), SelectionKind(SelectionKind) {}
+  SelectedASTNode(SelectedASTNode &&) = default;
+  SelectedASTNode &operator=(SelectedASTNode &&) = default;
+
+  void dump(llvm::raw_ostream &OS = llvm::errs()) const;
+};
+
+/// Traverses the given ASTContext and creates a tree of selected AST nodes.
+///
+/// \returns None if no nodes are selected in the AST, or a selected AST node
+/// that corresponds to the TranslationUnitDecl otherwise.
+Optional<SelectedASTNode> findSelectedASTNodes(const ASTContext &Context,
+                                               SourceRange SelectionRange);
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTOR_AST_SELECTION_H
Index: include/clang/Basic/SourceManager.h
===================================================================
--- include/clang/Basic/SourceManager.h
+++ include/clang/Basic/SourceManager.h
@@ -1520,6 +1520,14 @@
     return LHSLoaded;
   }
 
+  /// Return true if the Point is within Start and End.
+  bool isPointWithin(SourceLocation Location, SourceLocation Start,
+                     SourceLocation End) const {
+    return Location == Start || Location == End ||
+           (isBeforeInTranslationUnit(Start, Location) &&
+            isBeforeInTranslationUnit(Location, End));
+  }
+
   // Iterators over FileInfos.
   typedef llvm::DenseMap<const FileEntry*, SrcMgr::ContentCache*>
       ::const_iterator fileinfo_iterator;
Index: include/clang/Basic/SourceLocation.h
===================================================================
--- include/clang/Basic/SourceLocation.h
+++ include/clang/Basic/SourceLocation.h
@@ -172,6 +172,11 @@
     return getFromRawEncoding((unsigned)(uintptr_t)Encoding);
   }
 
+  static bool isPairOfFileLocations(SourceLocation Start, SourceLocation End) {
+    return Start.isValid() && Start.isFileID() && End.isValid() &&
+           End.isFileID();
+  }
+
   void print(raw_ostream &OS, const SourceManager &SM) const;
   std::string printToString(const SourceManager &SM) const;
   void dump(const SourceManager &SM) const;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to