steveire updated this revision to Diff 300629.
steveire added a comment.

Rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D80961/new/

https://reviews.llvm.org/D80961

Files:
  clang/include/clang/AST/ASTNodeTraverser.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/ASTMatchers/ASTMatchersInternal.h
  clang/lib/AST/ASTDumper.cpp
  clang/lib/ASTMatchers/ASTMatchFinder.cpp
  clang/lib/ASTMatchers/ASTMatchersInternal.cpp
  clang/unittests/AST/ASTTraverserTest.cpp
  clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
  clang/unittests/Tooling/TransformerTest.cpp

Index: clang/unittests/Tooling/TransformerTest.cpp
===================================================================
--- clang/unittests/Tooling/TransformerTest.cpp
+++ clang/unittests/Tooling/TransformerTest.cpp
@@ -1064,6 +1064,70 @@
   EXPECT_EQ(ErrorCount, 0);
 }
 
+TEST_F(TransformerTest, TemplateInstantiation) {
+
+  std::string NonTemplatesInput = R"cpp(
+struct S {
+  int m_i;
+};
+)cpp";
+  std::string NonTemplatesExpected = R"cpp(
+struct S {
+  safe_int m_i;
+};
+)cpp";
+
+  std::string TemplatesInput = R"cpp(
+template<typename T>
+struct TemplStruct {
+  TemplStruct() {}
+  ~TemplStruct() {}
+
+private:
+  T m_t;
+};
+
+void instantiate()
+{
+  TemplStruct<int> ti;
+}
+)cpp";
+
+  auto MatchedField = fieldDecl(hasType(asString("int"))).bind("theField");
+
+  // Changes the 'int' in 'S', but not the 'T' in 'TemplStruct':
+  testRule(makeRule(traverse(TK_IgnoreUnlessSpelledInSource, MatchedField),
+                    changeTo(cat("safe_int ", name("theField")))),
+           NonTemplatesInput + TemplatesInput,
+           NonTemplatesExpected + TemplatesInput);
+
+  // In AsIs mode, template instantiations are modified, which is
+  // often not desired:
+
+  std::string IncorrectTemplatesExpected = R"cpp(
+template<typename T>
+struct TemplStruct {
+  TemplStruct() {}
+  ~TemplStruct() {}
+
+private:
+  safe_int m_t;
+};
+
+void instantiate()
+{
+  TemplStruct<int> ti;
+}
+)cpp";
+
+  // Changes the 'int' in 'S', and (incorrectly) the 'T' in 'TemplStruct':
+  testRule(makeRule(traverse(TK_AsIs, MatchedField),
+                    changeTo(cat("safe_int ", name("theField")))),
+
+           NonTemplatesInput + TemplatesInput,
+           NonTemplatesExpected + IncorrectTemplatesExpected);
+}
+
 // Transformation of macro source text when the change encompasses the entirety
 // of the expanded text.
 TEST_F(TransformerTest, SimpleMacro) {
Index: clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
===================================================================
--- clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -2085,9 +2085,17 @@
       traverse(TK_AsIs,
                staticAssertDecl(has(implicitCastExpr(has(
                    substNonTypeTemplateParmExpr(has(integerLiteral())))))))));
+  EXPECT_TRUE(matches(
+      Code, traverse(TK_IgnoreUnlessSpelledInSource,
+                     staticAssertDecl(has(declRefExpr(
+                         to(nonTypeTemplateParmDecl(hasName("alignment"))),
+                         hasType(asString("unsigned int"))))))));
 
-  EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource,
-                                     staticAssertDecl(has(integerLiteral())))));
+  EXPECT_TRUE(matches(Code, traverse(TK_AsIs, staticAssertDecl(hasDescendant(
+                                                  integerLiteral())))));
+  EXPECT_FALSE(matches(
+      Code, traverse(TK_IgnoreUnlessSpelledInSource,
+                     staticAssertDecl(hasDescendant(integerLiteral())))));
 
   Code = R"cpp(
 
Index: clang/unittests/AST/ASTTraverserTest.cpp
===================================================================
--- clang/unittests/AST/ASTTraverserTest.cpp
+++ clang/unittests/AST/ASTTraverserTest.cpp
@@ -68,6 +68,14 @@
   void Visit(const TemplateArgument &A, SourceRange R = {},
              const Decl *From = nullptr, const char *Label = nullptr) {
     OS << "TemplateArgument";
+    switch (A.getKind()) {
+    case TemplateArgument::Type: {
+      OS << " type " << A.getAsType().getAsString();
+      break;
+    }
+    default:
+      break;
+    }
   }
 
   template <typename... T> void Visit(T...) {}
@@ -243,7 +251,7 @@
 
   verifyWithDynNode(TA,
                     R"cpp(
-TemplateArgument
+TemplateArgument type int
 `-BuiltinType
 )cpp");
 
@@ -1042,4 +1050,145 @@
   }
 }
 
+TEST(Traverse, IgnoreUnlessSpelledInSourceTemplateInstantiations) {
+
+  auto AST = buildASTFromCode(R"cpp(
+template<typename T>
+struct TemplStruct {
+  TemplStruct() {}
+  ~TemplStruct() {}
+
+private:
+  T m_t;
+};
+
+template<typename T>
+T timesTwo(T input)
+{
+  return input * 2;
+}
+
+void instantiate()
+{
+  TemplStruct<int> ti;
+  TemplStruct<double> td;
+  (void)timesTwo<int>(2);
+  (void)timesTwo<double>(2);
+}
+)cpp");
+  {
+    auto BN = ast_matchers::match(
+        classTemplateDecl(hasName("TemplStruct")).bind("rec"),
+        AST->getASTContext());
+    EXPECT_EQ(BN.size(), 1u);
+
+    EXPECT_EQ(dumpASTString(TK_IgnoreUnlessSpelledInSource,
+                            BN[0].getNodeAs<Decl>("rec")),
+              R"cpp(
+ClassTemplateDecl 'TemplStruct'
+|-TemplateTypeParmDecl 'T'
+`-CXXRecordDecl 'TemplStruct'
+  |-CXXRecordDecl 'TemplStruct'
+  |-CXXConstructorDecl 'TemplStruct<T>'
+  | `-CompoundStmt
+  |-CXXDestructorDecl '~TemplStruct<T>'
+  | `-CompoundStmt
+  |-AccessSpecDecl
+  `-FieldDecl 'm_t'
+)cpp");
+
+    EXPECT_EQ(dumpASTString(TK_AsIs, BN[0].getNodeAs<Decl>("rec")),
+              R"cpp(
+ClassTemplateDecl 'TemplStruct'
+|-TemplateTypeParmDecl 'T'
+|-CXXRecordDecl 'TemplStruct'
+| |-CXXRecordDecl 'TemplStruct'
+| |-CXXConstructorDecl 'TemplStruct<T>'
+| | `-CompoundStmt
+| |-CXXDestructorDecl '~TemplStruct<T>'
+| | `-CompoundStmt
+| |-AccessSpecDecl
+| `-FieldDecl 'm_t'
+|-ClassTemplateSpecializationDecl 'TemplStruct'
+| |-TemplateArgument type int
+| | `-BuiltinType
+| |-CXXRecordDecl 'TemplStruct'
+| |-CXXConstructorDecl 'TemplStruct'
+| | `-CompoundStmt
+| |-CXXDestructorDecl '~TemplStruct'
+| | `-CompoundStmt
+| |-AccessSpecDecl
+| |-FieldDecl 'm_t'
+| `-CXXConstructorDecl 'TemplStruct'
+|   `-ParmVarDecl ''
+`-ClassTemplateSpecializationDecl 'TemplStruct'
+  |-TemplateArgument type double
+  | `-BuiltinType
+  |-CXXRecordDecl 'TemplStruct'
+  |-CXXConstructorDecl 'TemplStruct'
+  | `-CompoundStmt
+  |-CXXDestructorDecl '~TemplStruct'
+  | `-CompoundStmt
+  |-AccessSpecDecl
+  |-FieldDecl 'm_t'
+  `-CXXConstructorDecl 'TemplStruct'
+    `-ParmVarDecl ''
+)cpp");
+  }
+  {
+    auto BN = ast_matchers::match(
+        functionTemplateDecl(hasName("timesTwo")).bind("fn"),
+        AST->getASTContext());
+    EXPECT_EQ(BN.size(), 1u);
+
+    EXPECT_EQ(dumpASTString(TK_IgnoreUnlessSpelledInSource,
+                            BN[0].getNodeAs<Decl>("fn")),
+              R"cpp(
+FunctionTemplateDecl 'timesTwo'
+|-TemplateTypeParmDecl 'T'
+`-FunctionDecl 'timesTwo'
+  |-ParmVarDecl 'input'
+  `-CompoundStmt
+    `-ReturnStmt
+      `-BinaryOperator
+        |-DeclRefExpr 'input'
+        `-IntegerLiteral
+)cpp");
+
+    EXPECT_EQ(dumpASTString(TK_AsIs, BN[0].getNodeAs<Decl>("fn")),
+              R"cpp(
+FunctionTemplateDecl 'timesTwo'
+|-TemplateTypeParmDecl 'T'
+|-FunctionDecl 'timesTwo'
+| |-ParmVarDecl 'input'
+| `-CompoundStmt
+|   `-ReturnStmt
+|     `-BinaryOperator
+|       |-DeclRefExpr 'input'
+|       `-IntegerLiteral
+|-FunctionDecl 'timesTwo'
+| |-TemplateArgument type int
+| | `-BuiltinType
+| |-ParmVarDecl 'input'
+| `-CompoundStmt
+|   `-ReturnStmt
+|     `-BinaryOperator
+|       |-ImplicitCastExpr
+|       | `-DeclRefExpr 'input'
+|       `-IntegerLiteral
+`-FunctionDecl 'timesTwo'
+  |-TemplateArgument type double
+  | `-BuiltinType
+  |-ParmVarDecl 'input'
+  `-CompoundStmt
+    `-ReturnStmt
+      `-BinaryOperator
+        |-ImplicitCastExpr
+        | `-DeclRefExpr 'input'
+        `-ImplicitCastExpr
+          `-IntegerLiteral
+)cpp");
+  }
+}
+
 } // namespace clang
Index: clang/lib/ASTMatchers/ASTMatchersInternal.cpp
===================================================================
--- clang/lib/ASTMatchers/ASTMatchersInternal.cpp
+++ clang/lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -284,6 +284,13 @@
   TraversalKindScope RAII(Finder->getASTContext(),
                           Implementation->TraversalKind());
 
+  if (Finder->getASTContext().getParentMapContext().getTraversalKind() ==
+      TK_IgnoreUnlessSpelledInSource) {
+    if (Finder->isMatchingInImplicitTemplateInstantiation()) {
+      return false;
+    }
+  }
+
   auto N =
       Finder->getASTContext().getParentMapContext().traverseIgnored(DynNode);
 
@@ -304,6 +311,13 @@
   TraversalKindScope raii(Finder->getASTContext(),
                           Implementation->TraversalKind());
 
+  if (Finder->getASTContext().getParentMapContext().getTraversalKind() ==
+      TK_IgnoreUnlessSpelledInSource) {
+    if (Finder->isMatchingInImplicitTemplateInstantiation()) {
+      return false;
+    }
+  }
+
   auto N =
       Finder->getASTContext().getParentMapContext().traverseIgnored(DynNode);
 
Index: clang/lib/ASTMatchers/ASTMatchFinder.cpp
===================================================================
--- clang/lib/ASTMatchers/ASTMatchFinder.cpp
+++ clang/lib/ASTMatchers/ASTMatchFinder.cpp
@@ -584,6 +584,44 @@
   bool shouldVisitTemplateInstantiations() const { return true; }
   bool shouldVisitImplicitCode() const { return true; }
 
+  bool TraversingImplicitTemplateInstantiation = false;
+
+  bool isMatchingInImplicitTemplateInstantiation() const override {
+    return TraversingImplicitTemplateInstantiation;
+  }
+
+  struct ImplicitTemplateInstantiationScope {
+    ImplicitTemplateInstantiationScope(MatchASTVisitor *V, bool B)
+        : MV(V), MB(V->TraversingImplicitTemplateInstantiation) {
+      V->TraversingImplicitTemplateInstantiation = B;
+    }
+    ~ImplicitTemplateInstantiationScope() {
+      MV->TraversingImplicitTemplateInstantiation = MB;
+    }
+
+  private:
+    MatchASTVisitor *MV;
+    bool MB;
+  };
+
+  bool TraverseTemplateInstantiations(ClassTemplateDecl *D) {
+    ImplicitTemplateInstantiationScope RAII(this, true);
+    return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
+        D);
+  }
+
+  bool TraverseTemplateInstantiations(VarTemplateDecl *D) {
+    ImplicitTemplateInstantiationScope RAII(this, true);
+    return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
+        D);
+  }
+
+  bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) {
+    ImplicitTemplateInstantiationScope RAII(this, true);
+    return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
+        D);
+  }
+
 private:
   class TimeBucketRegion {
   public:
Index: clang/lib/AST/ASTDumper.cpp
===================================================================
--- clang/lib/AST/ASTDumper.cpp
+++ clang/lib/AST/ASTDumper.cpp
@@ -129,9 +129,10 @@
 
   Visit(D->getTemplatedDecl());
 
-  for (const auto *Child : D->specializations())
-    dumpTemplateDeclSpecialization(Child, DumpExplicitInst,
-                                   !D->isCanonicalDecl());
+  if (GetTraversalKind() == TK_AsIs)
+    for (const auto *Child : D->specializations())
+      dumpTemplateDeclSpecialization(Child, DumpExplicitInst,
+                                     !D->isCanonicalDecl());
 }
 
 void ASTDumper::VisitFunctionTemplateDecl(const FunctionTemplateDecl *D) {
Index: clang/include/clang/ASTMatchers/ASTMatchersInternal.h
===================================================================
--- clang/include/clang/ASTMatchers/ASTMatchersInternal.h
+++ clang/include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -586,6 +586,10 @@
       return this->InnerMatcher.matches(DynTypedNode::create(*Node), Finder,
                                         Builder);
     }
+
+    llvm::Optional<clang::TraversalKind> TraversalKind() const override {
+      return this->InnerMatcher.getTraversalKind();
+    }
   };
 
 private:
@@ -1056,6 +1060,8 @@
 
   virtual ASTContext &getASTContext() const = 0;
 
+  virtual bool isMatchingInImplicitTemplateInstantiation() const = 0;
+
 protected:
   virtual bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
                               const DynTypedMatcher &Matcher,
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -461,6 +461,13 @@
 
   bool canIgnoreChildDeclWhileTraversingDeclContext(const Decl *Child);
 
+#define DEF_TRAVERSE_TMPL_INST(TMPLDECLKIND)                                   \
+  bool TraverseTemplateInstantiations(TMPLDECLKIND##TemplateDecl *D);
+  DEF_TRAVERSE_TMPL_INST(Class)
+  DEF_TRAVERSE_TMPL_INST(Var)
+  DEF_TRAVERSE_TMPL_INST(Function)
+#undef DEF_TRAVERSE_TMPL_INST
+
 private:
   // These are helper methods used by more than one Traverse* method.
   bool TraverseTemplateParameterListHelper(TemplateParameterList *TPL);
@@ -469,12 +476,6 @@
   template <typename T>
   bool TraverseDeclTemplateParameterLists(T *D);
 
-#define DEF_TRAVERSE_TMPL_INST(TMPLDECLKIND)                                   \
-  bool TraverseTemplateInstantiations(TMPLDECLKIND##TemplateDecl *D);
-  DEF_TRAVERSE_TMPL_INST(Class)
-  DEF_TRAVERSE_TMPL_INST(Var)
-  DEF_TRAVERSE_TMPL_INST(Function)
-#undef DEF_TRAVERSE_TMPL_INST
   bool TraverseTemplateArgumentLocsHelper(const TemplateArgumentLoc *TAL,
                                           unsigned Count);
   bool TraverseArrayTypeLocHelper(ArrayTypeLoc TL);
Index: clang/include/clang/AST/ASTNodeTraverser.h
===================================================================
--- clang/include/clang/AST/ASTNodeTraverser.h
+++ clang/include/clang/AST/ASTNodeTraverser.h
@@ -82,6 +82,7 @@
   bool getDeserialize() const { return Deserialize; }
 
   void SetTraversalKind(TraversalKind TK) { Traversal = TK; }
+  TraversalKind GetTraversalKind() const { return Traversal; }
 
   void Visit(const Decl *D) {
     getNodeDelegate().AddChild([=] {
@@ -481,8 +482,9 @@
 
     Visit(D->getTemplatedDecl());
 
-    for (const auto *Child : D->specializations())
-      dumpTemplateDeclSpecialization(Child);
+    if (Traversal == TK_AsIs)
+      for (const auto *Child : D->specializations())
+        dumpTemplateDeclSpecialization(Child);
   }
 
   void VisitTypeAliasDecl(const TypeAliasDecl *D) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to