This revision was automatically updated to reflect the committed changes.
Closed by commit rGe4d5f00093be: [ASTMatchers] Fix hasParent while ignoring 
unwritten nodes (authored by stephenkelly).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D96113/new/

https://reviews.llvm.org/D96113

Files:
  clang/include/clang/AST/ParentMapContext.h
  clang/lib/AST/ParentMapContext.cpp
  clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

Index: clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
===================================================================
--- clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -2933,6 +2933,37 @@
     EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
     EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
   }
+  {
+    auto M = ifStmt(hasParent(compoundStmt(hasParent(cxxForRangeStmt()))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = cxxForRangeStmt(
+        has(varDecl(hasName("i"), hasParent(cxxForRangeStmt()))));
+    EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = cxxForRangeStmt(hasDescendant(varDecl(
+        hasName("i"), hasParent(declStmt(hasParent(cxxForRangeStmt()))))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = cxxForRangeStmt(hasRangeInit(declRefExpr(
+        to(varDecl(hasName("arr"))), hasParent(cxxForRangeStmt()))));
+    EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+
+  {
+    auto M = cxxForRangeStmt(hasRangeInit(declRefExpr(
+        to(varDecl(hasName("arr"))), hasParent(varDecl(hasParent(declStmt(
+                                         hasParent(cxxForRangeStmt()))))))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
 
   Code = R"cpp(
   struct Range {
@@ -3035,6 +3066,15 @@
         matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
                              true, {"-std=c++20"}));
   }
+  {
+    auto M = cxxForRangeStmt(hasInitStatement(declStmt(
+        hasSingleDecl(varDecl(hasName("a"))), hasParent(cxxForRangeStmt()))));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
 
   Code = R"cpp(
   struct Range {
@@ -3511,6 +3551,20 @@
                            forFunction(functionDecl(hasName("func13"))))))),
       langCxx20OrLater()));
 
+  EXPECT_TRUE(matches(Code,
+                      traverse(TK_IgnoreUnlessSpelledInSource,
+                               compoundStmt(hasParent(lambdaExpr(forFunction(
+                                   functionDecl(hasName("func13"))))))),
+                      langCxx20OrLater()));
+
+  EXPECT_TRUE(matches(
+      Code,
+      traverse(TK_IgnoreUnlessSpelledInSource,
+               templateTypeParmDecl(hasName("TemplateType"),
+                                    hasParent(lambdaExpr(forFunction(
+                                        functionDecl(hasName("func14"))))))),
+      langCxx20OrLater()));
+
   EXPECT_TRUE(matches(
       Code,
       traverse(TK_IgnoreUnlessSpelledInSource,
@@ -3635,6 +3689,16 @@
         matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
                              true, {"-std=c++20"}));
   }
+  {
+    auto M = cxxRewrittenBinaryOperator(
+        hasLHS(expr(hasParent(cxxRewrittenBinaryOperator()))),
+        hasRHS(expr(hasParent(cxxRewrittenBinaryOperator()))));
+    EXPECT_FALSE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
   {
     EXPECT_TRUE(matchesConditionally(
         Code,
Index: clang/lib/AST/ParentMapContext.cpp
===================================================================
--- clang/lib/AST/ParentMapContext.cpp
+++ clang/lib/AST/ParentMapContext.cpp
@@ -49,7 +49,17 @@
   return N;
 }
 
+template <typename T, typename... U>
+std::tuple<bool, DynTypedNodeList, const T *, const U *...>
+matchParents(const DynTypedNodeList &NodeList,
+             ParentMapContext::ParentMap *ParentMap);
+
+template <typename, typename...> struct MatchParents;
+
 class ParentMapContext::ParentMap {
+
+  template <typename, typename...> friend struct ::MatchParents;
+
   /// Contains parents of a node.
   using ParentVector = llvm::SmallVector<DynTypedNode, 2>;
 
@@ -117,11 +127,72 @@
     if (Node.getNodeKind().hasPointerIdentity()) {
       auto ParentList =
           getDynNodeFromMap(Node.getMemoizationData(), PointerParents);
-      if (ParentList.size() == 1 && TK == TK_IgnoreUnlessSpelledInSource) {
-        const auto *E = ParentList[0].get<Expr>();
-        const auto *Child = Node.get<Expr>();
-        if (E && Child)
-          return AscendIgnoreUnlessSpelledInSource(E, Child);
+      if (ParentList.size() > 0 && TK == TK_IgnoreUnlessSpelledInSource) {
+
+        const auto *ChildExpr = Node.get<Expr>();
+
+        {
+          // Don't match explicit node types because different stdlib
+          // implementations implement this in different ways and have
+          // different intermediate nodes.
+          // Look up 4 levels for a cxxRewrittenBinaryOperator as that is
+          // enough for the major stdlib implementations.
+          auto RewrittenBinOpParentsList = ParentList;
+          int I = 0;
+          while (ChildExpr && RewrittenBinOpParentsList.size() == 1 &&
+                 I++ < 4) {
+            const auto *S = RewrittenBinOpParentsList[0].get<Stmt>();
+            if (!S)
+              break;
+
+            const auto *RWBO = dyn_cast<CXXRewrittenBinaryOperator>(S);
+            if (!RWBO) {
+              RewrittenBinOpParentsList = getDynNodeFromMap(S, PointerParents);
+              continue;
+            }
+            if (RWBO->getLHS()->IgnoreUnlessSpelledInSource() != ChildExpr &&
+                RWBO->getRHS()->IgnoreUnlessSpelledInSource() != ChildExpr)
+              break;
+            return DynTypedNode::create(*RWBO);
+          }
+        }
+
+        const auto *ParentExpr = ParentList[0].get<Expr>();
+        if (ParentExpr && ChildExpr)
+          return AscendIgnoreUnlessSpelledInSource(ParentExpr, ChildExpr);
+
+        {
+          auto AncestorNodes =
+              matchParents<DeclStmt, CXXForRangeStmt>(ParentList, this);
+          if (std::get<bool>(AncestorNodes) &&
+              std::get<const CXXForRangeStmt *>(AncestorNodes)
+                      ->getLoopVarStmt() ==
+                  std::get<const DeclStmt *>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
+        {
+          auto AncestorNodes = matchParents<VarDecl, DeclStmt, CXXForRangeStmt>(
+              ParentList, this);
+          if (std::get<bool>(AncestorNodes) &&
+              std::get<const CXXForRangeStmt *>(AncestorNodes)
+                      ->getRangeStmt() ==
+                  std::get<const DeclStmt *>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
+        {
+          auto AncestorNodes =
+              matchParents<CXXMethodDecl, CXXRecordDecl, LambdaExpr>(ParentList,
+                                                                     this);
+          if (std::get<bool>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
+        {
+          auto AncestorNodes =
+              matchParents<FunctionTemplateDecl, CXXRecordDecl, LambdaExpr>(
+                  ParentList, this);
+          if (std::get<bool>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
       }
       return ParentList;
     }
@@ -194,6 +265,59 @@
   }
 };
 
+template <typename Tuple, std::size_t... Is>
+auto tuple_pop_front_impl(const Tuple &tuple, std::index_sequence<Is...>) {
+  return std::make_tuple(std::get<1 + Is>(tuple)...);
+}
+
+template <typename Tuple> auto tuple_pop_front(const Tuple &tuple) {
+  return tuple_pop_front_impl(
+      tuple, std::make_index_sequence<std::tuple_size<Tuple>::value - 1>());
+}
+
+template <typename T, typename... U> struct MatchParents {
+  static std::tuple<bool, DynTypedNodeList, const T *, const U *...>
+  match(const DynTypedNodeList &NodeList,
+        ParentMapContext::ParentMap *ParentMap) {
+    if (const auto *TypedNode = NodeList[0].get<T>()) {
+      auto NextParentList =
+          ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
+      if (NextParentList.size() == 1) {
+        auto TailTuple = MatchParents<U...>::match(NextParentList, ParentMap);
+        if (std::get<bool>(TailTuple)) {
+          return std::tuple_cat(
+              std::make_tuple(true, std::get<DynTypedNodeList>(TailTuple),
+                              TypedNode),
+              tuple_pop_front(tuple_pop_front(TailTuple)));
+        }
+      }
+    }
+    return std::tuple_cat(std::make_tuple(false, NodeList),
+                          std::tuple<const T *, const U *...>());
+  }
+};
+
+template <typename T> struct MatchParents<T> {
+  static std::tuple<bool, DynTypedNodeList, const T *>
+  match(const DynTypedNodeList &NodeList,
+        ParentMapContext::ParentMap *ParentMap) {
+    if (const auto *TypedNode = NodeList[0].get<T>()) {
+      auto NextParentList =
+          ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
+      if (NextParentList.size() == 1)
+        return std::make_tuple(true, NodeList, TypedNode);
+    }
+    return std::make_tuple(false, NodeList, nullptr);
+  }
+};
+
+template <typename T, typename... U>
+std::tuple<bool, DynTypedNodeList, const T *, const U *...>
+matchParents(const DynTypedNodeList &NodeList,
+             ParentMapContext::ParentMap *ParentMap) {
+  return MatchParents<T, U...>::match(NodeList, ParentMap);
+}
+
 /// Template specializations to abstract away from pointers and TypeLocs.
 /// @{
 template <typename T> static DynTypedNode createDynTypedNode(const T &Node) {
Index: clang/include/clang/AST/ParentMapContext.h
===================================================================
--- clang/include/clang/AST/ParentMapContext.h
+++ clang/include/clang/AST/ParentMapContext.h
@@ -64,9 +64,10 @@
   Expr *traverseIgnored(Expr *E) const;
   DynTypedNode traverseIgnored(const DynTypedNode &N) const;
 
+  class ParentMap;
+
 private:
   ASTContext &ASTCtx;
-  class ParentMap;
   TraversalKind Traversal = TK_AsIs;
   std::unique_ptr<ParentMap> Parents;
 };
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to