steveire updated this revision to Diff 199181.
steveire added a comment.

Format


Repository:
  rC Clang

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D61837/new/

https://reviews.llvm.org/D61837

Files:
  include/clang/AST/ASTContext.h
  include/clang/AST/ASTNodeTraverser.h
  include/clang/ASTMatchers/ASTMatchers.h
  include/clang/ASTMatchers/ASTMatchersInternal.h
  lib/AST/ASTContext.cpp
  lib/ASTMatchers/ASTMatchFinder.cpp
  lib/ASTMatchers/ASTMatchersInternal.cpp
  unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

Index: unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
===================================================================
--- unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -1510,6 +1510,72 @@
       notMatches("class C {}; C a = C();", varDecl(has(cxxConstructExpr()))));
 }
 
+TEST(Traversal, traverseMatcher) {
+
+  StringRef VarDeclCode = R"cpp(
+void foo()
+{
+  int i = 3.0;
+}
+)cpp";
+
+  auto Matcher = varDecl(hasInitializer(floatLiteral()));
+
+  EXPECT_TRUE(
+      notMatches(VarDeclCode, traverse(ast_type_traits::TK_AsIs, Matcher)));
+  EXPECT_TRUE(
+      matches(VarDeclCode,
+              traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+                       Matcher)));
+}
+
+TEST(Traversal, traverseMatcherNesting) {
+
+  StringRef Code = R"cpp(
+float bar(int i)
+{
+  return i;
+}
+
+void foo()
+{
+  bar(bar(3.0));
+}
+)cpp";
+
+  EXPECT_TRUE(matches(
+      Code,
+      traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+               callExpr(has(callExpr(traverse(
+                   ast_type_traits::TK_AsIs,
+                   callExpr(has(implicitCastExpr(has(floatLiteral())))))))))));
+}
+
+TEST(Traversal, traverseMatcherThroughMemoization) {
+
+  StringRef Code = R"cpp(
+void foo()
+{
+  int i = 3.0;
+}
+  )cpp";
+
+  auto Matcher = varDecl(hasInitializer(floatLiteral()));
+
+  // Matchers such as hasDescendant memoize their result regarding AST
+  // nodes. In the matcher below, the first use of hasDescendant(Matcher)
+  // fails, and the use of it inside the traverse() matcher should pass
+  // causing the overall matcher to be a true match.
+  // This test verifies that the first false result is not re-used, which
+  // would cause the overall matcher to be incorrectly false.
+
+  EXPECT_TRUE(matches(
+      Code, functionDecl(anyOf(
+                hasDescendant(Matcher),
+                traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+                         functionDecl(hasDescendant(Matcher)))))));
+}
+
 TEST(IgnoringImpCasts, MatchesImpCasts) {
   // This test checks that ignoringImpCasts matches when implicit casts are
   // present and its inner matcher alone does not match.
Index: lib/ASTMatchers/ASTMatchersInternal.cpp
===================================================================
--- lib/ASTMatchers/ASTMatchersInternal.cpp
+++ lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -211,10 +211,17 @@
 bool DynTypedMatcher::matches(const ast_type_traits::DynTypedNode &DynNode,
                               ASTMatchFinder *Finder,
                               BoundNodesTreeBuilder *Builder) const {
-  if (RestrictKind.isBaseOf(DynNode.getNodeKind()) &&
-      Implementation->dynMatches(DynNode, Finder, Builder)) {
+  auto PreviousTraversalKind = Finder->getASTContext().GetTraversalKind();
+  Finder->getASTContext().SetTraversalKind(Implementation->TraversalKind());
+  auto N = Finder->getASTContext().TraverseIgnored(DynNode);
+  auto NodeKind = N.getNodeKind();
+
+  if (RestrictKind.isBaseOf(NodeKind) &&
+      Implementation->dynMatches(N, Finder, Builder)) {
+    Finder->getASTContext().SetTraversalKind(PreviousTraversalKind);
     return true;
   }
+  Finder->getASTContext().SetTraversalKind(PreviousTraversalKind);
   // Delete all bindings when a matcher does not match.
   // This prevents unexpected exposure of bound nodes in unmatches
   // branches of the match tree.
@@ -225,8 +232,11 @@
 bool DynTypedMatcher::matchesNoKindCheck(
     const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder,
     BoundNodesTreeBuilder *Builder) const {
-  assert(RestrictKind.isBaseOf(DynNode.getNodeKind()));
-  if (Implementation->dynMatches(DynNode, Finder, Builder)) {
+  auto N = Finder->getASTContext().TraverseIgnored(DynNode);
+  auto NodeKind = N.getNodeKind();
+
+  assert(RestrictKind.isBaseOf(NodeKind));
+  if (Implementation->dynMatches(N, Finder, Builder)) {
     return true;
   }
   // Delete all bindings when a matcher does not match.
Index: lib/ASTMatchers/ASTMatchFinder.cpp
===================================================================
--- lib/ASTMatchers/ASTMatchFinder.cpp
+++ lib/ASTMatchers/ASTMatchFinder.cpp
@@ -59,10 +59,12 @@
   DynTypedMatcher::MatcherIDType MatcherID;
   ast_type_traits::DynTypedNode Node;
   BoundNodesTreeBuilder BoundNodes;
+  ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs;
 
   bool operator<(const MatchKey &Other) const {
-    return std::tie(MatcherID, Node, BoundNodes) <
-           std::tie(Other.MatcherID, Other.Node, Other.BoundNodes);
+    return std::tie(MatcherID, Node, BoundNodes, Traversal) <
+           std::tie(Other.MatcherID, Other.Node, Other.BoundNodes,
+                    Other.Traversal);
   }
 };
 
@@ -143,6 +145,8 @@
 
     ScopedIncrement ScopedDepth(&CurrentDepth);
     Stmt *StmtToTraverse = StmtNode;
+    if (Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode))
+      StmtToTraverse = Finder->getASTContext().TraverseIgnored(ExprNode);
     if (Traversal ==
         ast_type_traits::TraversalKind::TK_IgnoreImplicitCastsAndParentheses) {
       if (Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode))
@@ -384,6 +388,7 @@
 
   // Matches children or descendants of 'Node' with 'BaseMatcher'.
   bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node,
+                                  ASTContext &Ctx,
                                   const DynTypedMatcher &Matcher,
                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
                                   ast_type_traits::TraversalKind Traversal,
@@ -398,6 +403,7 @@
     Key.Node = Node;
     // Note that we key on the bindings *before* the match.
     Key.BoundNodes = *Builder;
+    Key.Traversal = Ctx.GetTraversalKind();
 
     MemoizationMap::iterator I = ResultCache.find(Key);
     if (I != ResultCache.end()) {
@@ -434,36 +440,36 @@
 
   // Implements ASTMatchFinder::matchesChildOf.
   bool matchesChildOf(const ast_type_traits::DynTypedNode &Node,
-                      const DynTypedMatcher &Matcher,
+                      ASTContext &Ctx, const DynTypedMatcher &Matcher,
                       BoundNodesTreeBuilder *Builder,
                       ast_type_traits::TraversalKind Traversal,
                       BindKind Bind) override {
     if (ResultCache.size() > MaxMemoizationEntries)
       ResultCache.clear();
-    return memoizedMatchesRecursively(Node, Matcher, Builder, 1, Traversal,
+    return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Traversal,
                                       Bind);
   }
   // Implements ASTMatchFinder::matchesDescendantOf.
   bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node,
-                           const DynTypedMatcher &Matcher,
+                           ASTContext &Ctx, const DynTypedMatcher &Matcher,
                            BoundNodesTreeBuilder *Builder,
                            BindKind Bind) override {
     if (ResultCache.size() > MaxMemoizationEntries)
       ResultCache.clear();
-    return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX,
+    return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
                                       ast_type_traits::TraversalKind::TK_AsIs,
                                       Bind);
   }
   // Implements ASTMatchFinder::matchesAncestorOf.
   bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node,
-                         const DynTypedMatcher &Matcher,
+                         ASTContext &Ctx, const DynTypedMatcher &Matcher,
                          BoundNodesTreeBuilder *Builder,
                          AncestorMatchMode MatchMode) override {
     // Reset the cache outside of the recursive call to make sure we
     // don't invalidate any iterators.
     if (ResultCache.size() > MaxMemoizationEntries)
       ResultCache.clear();
-    return memoizedMatchesAncestorOfRecursively(Node, Matcher, Builder,
+    return memoizedMatchesAncestorOfRecursively(Node, Ctx, Matcher, Builder,
                                                 MatchMode);
   }
 
@@ -628,16 +634,19 @@
   // allow simple memoization on the ancestors. Thus, we only memoize as long
   // as there is a single parent.
   bool memoizedMatchesAncestorOfRecursively(
-      const ast_type_traits::DynTypedNode &Node, const DynTypedMatcher &Matcher,
-      BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) {
+      const ast_type_traits::DynTypedNode &Node, ASTContext &Ctx,
+      const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder,
+      AncestorMatchMode MatchMode) {
     // For AST-nodes that don't have an identity, we can't memoize.
     if (!Builder->isComparable())
-      return matchesAncestorOfRecursively(Node, Matcher, Builder, MatchMode);
+      return matchesAncestorOfRecursively(Node, Ctx, Matcher, Builder,
+                                          MatchMode);
 
     MatchKey Key;
     Key.MatcherID = Matcher.getID();
     Key.Node = Node;
     Key.BoundNodes = *Builder;
+    Key.Traversal = Ctx.GetTraversalKind();
 
     // Note that we cannot use insert and reuse the iterator, as recursive
     // calls to match might invalidate the result cache iterators.
@@ -649,8 +658,8 @@
 
     MemoizedMatchResult Result;
     Result.Nodes = *Builder;
-    Result.ResultOfMatch =
-        matchesAncestorOfRecursively(Node, Matcher, &Result.Nodes, MatchMode);
+    Result.ResultOfMatch = matchesAncestorOfRecursively(
+        Node, Ctx, Matcher, &Result.Nodes, MatchMode);
 
     MemoizedMatchResult &CachedResult = ResultCache[Key];
     CachedResult = std::move(Result);
@@ -660,6 +669,7 @@
   }
 
   bool matchesAncestorOfRecursively(const ast_type_traits::DynTypedNode &Node,
+                                    ASTContext &Ctx,
                                     const DynTypedMatcher &Matcher,
                                     BoundNodesTreeBuilder *Builder,
                                     AncestorMatchMode MatchMode) {
@@ -693,8 +703,8 @@
         return true;
       }
       if (MatchMode != ASTMatchFinder::AMM_ParentOnly) {
-        return memoizedMatchesAncestorOfRecursively(Parent, Matcher, Builder,
-                                                    MatchMode);
+        return memoizedMatchesAncestorOfRecursively(Parent, Ctx, Matcher,
+                                                    Builder, MatchMode);
         // Once we get back from the recursive call, the result will be the
         // same as the parent's result.
       }
Index: lib/AST/ASTContext.cpp
===================================================================
--- lib/AST/ASTContext.cpp
+++ lib/AST/ASTContext.cpp
@@ -98,6 +98,31 @@
   Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank
 };
 
+const Expr *ASTContext::TraverseIgnored(const Expr *E) {
+  return TraverseIgnored(const_cast<Expr *>(E));
+}
+
+Expr *ASTContext::TraverseIgnored(Expr *E) {
+  if (!E)
+    return nullptr;
+
+  switch (Traversal) {
+  case ast_type_traits::TK_AsIs:
+    return E;
+  case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses:
+    return E->IgnoreParenImpCasts();
+  }
+  llvm_unreachable("Invalid Traversal type!");
+}
+
+ast_type_traits::DynTypedNode
+ASTContext::TraverseIgnored(const ast_type_traits::DynTypedNode &N) {
+  if (auto E = N.get<Expr>()) {
+    return ast_type_traits::DynTypedNode::create(*TraverseIgnored(E));
+  }
+  return N;
+}
+
 RawComment *ASTContext::getRawCommentForDeclNoCache(const Decl *D) const {
   assert(D);
 
@@ -900,7 +925,7 @@
 
 void ASTContext::setTraversalScope(const std::vector<Decl *> &TopLevelDecls) {
   TraversalScope = TopLevelDecls;
-  Parents.reset();
+  Parents.clear();
 }
 
 void ASTContext::AddDeallocation(void (*Callback)(void*), void *Data) {
@@ -10259,7 +10284,8 @@
 class ASTContext::ParentMap::ASTVisitor
     : public RecursiveASTVisitor<ASTVisitor> {
 public:
-  ASTVisitor(ParentMap &Map) : Map(Map) {}
+  ASTVisitor(ParentMap &Map, ASTContext &Context)
+      : Map(Map), Context(Context) {}
 
 private:
   friend class RecursiveASTVisitor<ASTVisitor>;
@@ -10329,8 +10355,12 @@
   }
 
   bool TraverseStmt(Stmt *StmtNode) {
+    auto FilteredNode = StmtNode;
+    if (auto *ExprNode = dyn_cast_or_null<Expr>(FilteredNode))
+      FilteredNode = Context.TraverseIgnored(ExprNode);
     return TraverseNode(
-        StmtNode, StmtNode, [&] { return VisitorBase::TraverseStmt(StmtNode); },
+        FilteredNode, FilteredNode,
+        [&] { return VisitorBase::TraverseStmt(FilteredNode); },
         &Map.PointerParents);
   }
 
@@ -10349,20 +10379,22 @@
   }
 
   ParentMap &Map;
+  ASTContext &Context;
   llvm::SmallVector<ast_type_traits::DynTypedNode, 16> ParentStack;
 };
 
 ASTContext::ParentMap::ParentMap(ASTContext &Ctx) {
-  ASTVisitor(*this).TraverseAST(Ctx);
+  ASTVisitor(*this, Ctx).TraverseAST(Ctx);
 }
 
 ASTContext::DynTypedNodeList
 ASTContext::getParents(const ast_type_traits::DynTypedNode &Node) {
-  if (!Parents)
+  std::unique_ptr<ParentMap> &P = Parents[Traversal];
+  if (!P)
     // We build the parent map for the traversal scope (usually whole TU), as
     // hasAncestor can escape any subtree.
-    Parents = llvm::make_unique<ParentMap>(*this);
-  return Parents->getParents(Node);
+    P = llvm::make_unique<ParentMap>(*this);
+  return P->getParents(Node);
 }
 
 bool
Index: include/clang/ASTMatchers/ASTMatchersInternal.h
===================================================================
--- include/clang/ASTMatchers/ASTMatchersInternal.h
+++ include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -282,6 +282,10 @@
   virtual bool dynMatches(const ast_type_traits::DynTypedNode &DynNode,
                           ASTMatchFinder *Finder,
                           BoundNodesTreeBuilder *Builder) const = 0;
+
+  virtual ast_type_traits::TraversalKind TraversalKind() const {
+    return ast_type_traits::TK_AsIs;
+  }
 };
 
 /// Generic interface for matchers on an AST node of type T.
@@ -991,7 +995,7 @@
                   std::is_base_of<QualType, T>::value,
                   "unsupported type for recursive matching");
     return matchesChildOf(ast_type_traits::DynTypedNode::create(Node),
-                          Matcher, Builder, Traverse, Bind);
+                          getASTContext(), Matcher, Builder, Traverse, Bind);
   }
 
   template <typename T>
@@ -1007,7 +1011,7 @@
                   std::is_base_of<QualType, T>::value,
                   "unsupported type for recursive matching");
     return matchesDescendantOf(ast_type_traits::DynTypedNode::create(Node),
-                               Matcher, Builder, Bind);
+                               getASTContext(), Matcher, Builder, Bind);
   }
 
   // FIXME: Implement support for BindKind.
@@ -1022,24 +1026,26 @@
                       std::is_base_of<TypeLoc, T>::value,
                   "type not allowed for recursive matching");
     return matchesAncestorOf(ast_type_traits::DynTypedNode::create(Node),
-                             Matcher, Builder, MatchMode);
+                             getASTContext(), Matcher, Builder, MatchMode);
   }
 
   virtual ASTContext &getASTContext() const = 0;
 
 protected:
   virtual bool matchesChildOf(const ast_type_traits::DynTypedNode &Node,
-                              const DynTypedMatcher &Matcher,
+                              ASTContext &Ctx, const DynTypedMatcher &Matcher,
                               BoundNodesTreeBuilder *Builder,
                               ast_type_traits::TraversalKind Traverse,
                               BindKind Bind) = 0;
 
   virtual bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node,
+                                   ASTContext &Ctx,
                                    const DynTypedMatcher &Matcher,
                                    BoundNodesTreeBuilder *Builder,
                                    BindKind Bind) = 0;
 
   virtual bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node,
+                                 ASTContext &Ctx,
                                  const DynTypedMatcher &Matcher,
                                  BoundNodesTreeBuilder *Builder,
                                  AncestorMatchMode MatchMode) = 0;
@@ -1151,6 +1157,27 @@
   }
 };
 
+template <typename T>
+class TraversalMatcher : public WrapperMatcherInterface<T> {
+  ast_type_traits::TraversalKind Traversal;
+
+public:
+  explicit TraversalMatcher(ast_type_traits::TraversalKind TK,
+                            const Matcher<T> &ChildMatcher)
+      : TraversalMatcher::WrapperMatcherInterface(ChildMatcher), Traversal(TK) {
+  }
+
+  bool matches(const T &Node, ASTMatchFinder *Finder,
+               BoundNodesTreeBuilder *Builder) const override {
+    return this->InnerMatcher.matches(
+        ast_type_traits::DynTypedNode::create(Node), Finder, Builder);
+  }
+
+  ast_type_traits::TraversalKind TraversalKind() const override {
+    return Traversal;
+  }
+};
+
 /// A PolymorphicMatcherWithParamN<MatcherT, P1, ..., PN> object can be
 /// created from N parameters p1, ..., pN (of type P1, ..., PN) and
 /// used as a Matcher<T> where a MatcherT<T, P1, ..., PN>(p1, ..., pN)
Index: include/clang/ASTMatchers/ASTMatchers.h
===================================================================
--- include/clang/ASTMatchers/ASTMatchers.h
+++ include/clang/ASTMatchers/ASTMatchers.h
@@ -698,6 +698,29 @@
                              Builder);
 }
 
+/// Causes all nested matchers to be matched with the specified traversal kind
+///
+/// Given
+/// \code
+///   void foo()
+///   {
+///       int i = 3.0;
+///   }
+/// \endcode
+/// The matcher
+/// \code
+///   traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+///     varDecl(hasInitializer(floatLiteral()))
+///   )
+/// \endcode
+/// matches the return statement with "ret" bound to "a".
+template <typename T>
+internal::Matcher<T> traverse(ast_type_traits::TraversalKind TK,
+                              const internal::Matcher<T> &InnerMatcher) {
+  return internal::Matcher<T>(
+      new internal::TraversalMatcher<T>(TK, InnerMatcher));
+}
+
 /// Matches expressions that match InnerMatcher after any implicit AST
 /// nodes are stripped off.
 ///
Index: include/clang/AST/ASTNodeTraverser.h
===================================================================
--- include/clang/AST/ASTNodeTraverser.h
+++ include/clang/AST/ASTNodeTraverser.h
@@ -65,6 +65,9 @@
   /// not already been loaded.
   bool Deserialize = false;
 
+  ast_type_traits::TraversalKind Traversal =
+      ast_type_traits::TraversalKind::TK_AsIs;
+
   NodeDelegateType &getNodeDelegate() {
     return getDerived().doGetNodeDelegate();
   }
@@ -74,6 +77,8 @@
   void setDeserialize(bool D) { Deserialize = D; }
   bool getDeserialize() const { return Deserialize; }
 
+  void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; }
+
   void Visit(const Decl *D) {
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(D);
@@ -97,8 +102,20 @@
     });
   }
 
-  void Visit(const Stmt *S, StringRef Label = {}) {
+  void Visit(const Stmt *S_, StringRef Label = {}) {
     getNodeDelegate().AddChild(Label, [=] {
+      const Stmt *S = S_;
+
+      if (auto *E = dyn_cast_or_null<Expr>(S)) {
+        switch (Traversal) {
+        case ast_type_traits::TK_AsIs:
+          break;
+        case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses:
+          S = E->IgnoreParenImpCasts();
+          break;
+        }
+      }
+
       getNodeDelegate().Visit(S);
 
       if (!S) {
Index: include/clang/AST/ASTContext.h
===================================================================
--- include/clang/AST/ASTContext.h
+++ include/clang/AST/ASTContext.h
@@ -556,7 +556,17 @@
   const TargetInfo *AuxTarget = nullptr;
   clang::PrintingPolicy PrintingPolicy;
 
+  ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs;
+
 public:
+  ast_type_traits::TraversalKind GetTraversalKind() const { return Traversal; }
+  void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; }
+
+  const Expr *TraverseIgnored(const Expr *E);
+  Expr *TraverseIgnored(Expr *E);
+  ast_type_traits::DynTypedNode
+  TraverseIgnored(const ast_type_traits::DynTypedNode &N);
+
   IdentifierTable &Idents;
   SelectorTable &Selectors;
   Builtin::Context &BuiltinInfo;
@@ -2942,7 +2952,7 @@
 
   std::vector<Decl *> TraversalScope;
   class ParentMap;
-  std::unique_ptr<ParentMap> Parents;
+  std::map<ast_type_traits::TraversalKind, std::unique_ptr<ParentMap>> Parents;
 
   std::unique_ptr<VTableContextBase> VTContext;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to