alexander-shaposhnikov updated this revision to Diff 511854.
alexander-shaposhnikov added a comment.

This version is partially based on https://reviews.llvm.org/D147722.
The case with self-friendship is problematic (had to add a workaround, any 
suggestions would be greatly appreciated).
Added more tests.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D146178/new/

https://reviews.llvm.org/D146178

Files:
  clang/include/clang/Sema/Template.h
  clang/lib/Sema/SemaConcept.cpp
  clang/lib/Sema/SemaOverload.cpp
  clang/lib/Sema/SemaTemplateDeduction.cpp
  clang/lib/Sema/SemaTemplateInstantiate.cpp
  clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
  clang/test/SemaTemplate/concepts-out-of-line-def.cpp
  clang/test/SemaTemplate/concepts.cpp

Index: clang/test/SemaTemplate/concepts.cpp
===================================================================
--- clang/test/SemaTemplate/concepts.cpp
+++ clang/test/SemaTemplate/concepts.cpp
@@ -816,3 +816,62 @@
 static_assert(Parent<int>::TakesBinary<int, 0ULL>::i == 0);
 }
 
+namespace TemplateInsideNonTemplateClass {
+template<typename T, typename U> concept C = true;
+
+template<typename T> auto L = []<C<T> U>() {};
+
+struct Q {
+  template<C<int> U> friend constexpr auto decltype(L<int>)::operator()() const;
+};
+} // namespace TemplateInsideNonTemplateClass
+
+namespace GH61959 {
+template <typename T0>
+concept C = (sizeof(T0) >= 4);
+
+template<typename...>
+struct Orig { };
+
+template<typename T>
+struct Orig<T> {
+  template<typename> requires C<T>
+  void f() { }
+
+  template<typename> requires true
+  void f() { }
+};
+
+template <typename...> struct Mod {};
+
+template <typename T1, typename T2>
+struct Mod<T1, T2> {
+  template <typename> requires C<T1>
+  constexpr static int f() { return 1; }
+
+  template <typename> requires C<T2>
+  constexpr static int f() { return 2; }
+};
+
+static_assert(Mod<int, char>::f<double>() == 1);
+static_assert(Mod<char, int>::f<double>() == 2);
+
+template<typename T>
+struct Outer {
+  template<typename ...>
+  struct Inner {};
+
+  template<typename U>
+  struct Inner<U> {
+    template<typename V>
+    void foo()  requires C<U> && C<T> && C<V>{}
+    template<typename V>
+    void foo()  requires true{}
+  };
+};
+
+void bar() {
+  Outer<int>::Inner<float> I;
+  I.foo<char>();
+}
+}
Index: clang/test/SemaTemplate/concepts-out-of-line-def.cpp
===================================================================
--- clang/test/SemaTemplate/concepts-out-of-line-def.cpp
+++ clang/test/SemaTemplate/concepts-out-of-line-def.cpp
@@ -127,3 +127,220 @@
 static_assert(S<int>::specialization("str") == SPECIALIZATION_REQUIRES);
 
 } // namespace multiple_template_parameter_lists
+
+static constexpr int CONSTRAINED_METHOD_1 = 1;
+static constexpr int CONSTRAINED_METHOD_2 = 2;
+
+namespace constrained_members {
+
+template <int>
+struct S {
+  template <Concept C>
+  static constexpr int constrained_method();
+};
+
+template <>
+template <Concept C>
+constexpr int S<1>::constrained_method() { return CONSTRAINED_METHOD_1; }
+
+template <>
+template <Concept C>
+constexpr int S<2>::constrained_method() { return CONSTRAINED_METHOD_2; }
+
+static_assert(S<1>::constrained_method<XY>() == CONSTRAINED_METHOD_1);
+static_assert(S<2>::constrained_method<XY>() == CONSTRAINED_METHOD_2);
+
+
+template <class T1, class T2>
+concept ConceptT1T2 = true;
+
+template<typename T3>
+struct S12 {
+  template<ConceptT1T2<T3> T4>
+  static constexpr int constrained_method();
+};
+
+template<>
+template<ConceptT1T2<int> T5>
+constexpr int S12<int>::constrained_method() { return CONSTRAINED_METHOD_1; }
+
+template<>
+template<ConceptT1T2<double> T5>
+constexpr int S12<double>::constrained_method() { return CONSTRAINED_METHOD_2; }
+
+static_assert(S12<int>::constrained_method<XY>() == CONSTRAINED_METHOD_1);
+static_assert(S12<double>::constrained_method<XY>() == CONSTRAINED_METHOD_2);
+
+} // namespace constrained members
+
+namespace constrained_members_of_nested_types {
+
+template <int>
+struct S {
+  struct Inner0 {
+    struct Inner1 {
+      template <Concept C>
+      static constexpr int constrained_method();
+    };
+  };
+};
+
+template <>
+template <Concept C>
+constexpr int S<1>::Inner0::Inner1::constrained_method() { return CONSTRAINED_METHOD_1; }
+
+template <>
+template <Concept C>
+constexpr int S<2>::Inner0::Inner1::constrained_method() { return CONSTRAINED_METHOD_2; }
+
+static_assert(S<1>::Inner0::Inner1::constrained_method<XY>() == CONSTRAINED_METHOD_1);
+static_assert(S<2>::Inner0::Inner1::constrained_method<XY>() == CONSTRAINED_METHOD_2);
+
+
+template <class T1, class T2>
+concept ConceptT1T2 = true;
+
+template<typename T3>
+struct S12 {
+  struct Inner0 {
+    struct Inner1 {
+      template<ConceptT1T2<T3> T4>
+      static constexpr int constrained_method();
+    };
+  };
+};
+
+template<>
+template<ConceptT1T2<int> T5>
+constexpr int S12<int>::Inner0::Inner1::constrained_method() { return CONSTRAINED_METHOD_1; }
+
+template<>
+template<ConceptT1T2<double> T5>
+constexpr int S12<double>::Inner0::Inner1::constrained_method() { return CONSTRAINED_METHOD_2; }
+
+static_assert(S12<int>::Inner0::Inner1::constrained_method<XY>() == CONSTRAINED_METHOD_1);
+static_assert(S12<double>::Inner0::Inner1::constrained_method<XY>() == CONSTRAINED_METHOD_2);
+
+} // namespace constrained_members_of_nested_types
+
+namespace constrained_member_sfinae {
+
+template<int N> struct S {
+  template<class T>
+  static constexpr int constrained_method() requires (sizeof(int[N * 1073741824 + 4]) == 16) {
+    return CONSTRAINED_METHOD_1;
+  }
+
+  template<class T>
+  static constexpr int constrained_method() requires (sizeof(int[N]) == 16);
+};
+
+template<>
+template<typename T>
+constexpr int S<4>::constrained_method() requires (sizeof(int[4]) == 16) {
+  return CONSTRAINED_METHOD_2;
+}
+
+// Verify that there is no amiguity in this case.
+static_assert(S<4>::constrained_method<double>() == CONSTRAINED_METHOD_2);
+
+} // namespace constrained_member_sfinae
+
+namespace requires_expression_references_members {
+
+void accept1(int x);
+void accept2(XY xy);
+
+template <class T> struct S {
+  T Field = T();
+
+  constexpr int constrained_method()
+      requires requires { accept1(Field); };
+
+  constexpr int constrained_method()
+      requires requires { accept2(Field); };
+};
+
+template <class T>
+constexpr int S<T>::constrained_method()
+  requires requires { accept1(Field); } {
+  return CONSTRAINED_METHOD_1;
+}
+
+template <class T>
+constexpr int S<T>::constrained_method()
+  requires requires { accept2(Field); } {
+  return CONSTRAINED_METHOD_2;
+}
+
+static_assert(S<int>().constrained_method() == CONSTRAINED_METHOD_1);
+static_assert(S<XY>().constrained_method() == CONSTRAINED_METHOD_2);
+
+} // namespace requires_expression_references_members
+
+namespace GH60231 {
+
+template<typename T0> concept C = true;
+
+template <typename T1>
+struct S {
+  template <typename F1> requires C<S<T1>>
+  void foo1(F1 f);
+
+  template <typename F2>
+  void foo2(F2 f) requires C<S<T1>>;
+
+  template <typename F3> requires C<F3>
+  void foo3(F3 f);
+};
+
+template <typename T2>
+template <typename F4> requires C<S<T2>>
+void S<T2>::foo1(F4 f) {}
+
+template <typename T3>
+template <typename F5>
+void S<T3>::foo2(F5 f) requires C<S<T3>> {}
+
+template <typename T4>
+template <typename F6> requires C<F6>
+void S<T4>::foo3(F6 f) {}
+
+} // namespace GH60231
+
+namespace GH62003 {
+
+template <typename T0> concept Concept = true;
+
+template <class T1>
+struct S1 {
+  template <Concept C1>
+  static constexpr int foo();
+};
+template <class T2>
+template <Concept C2>
+constexpr int S1<T2>::foo() { return 1; }
+
+template <Concept C3>
+struct S2 {
+  template <class T3>
+  static constexpr int foo();
+};
+template <Concept C4>
+template <class T4>
+constexpr int S2<C4>::foo() { return 2; }
+
+template <Concept C5>
+struct S3 {
+  template <Concept C6>
+  static constexpr int foo();
+};
+template <Concept C7>
+template <Concept C8>
+constexpr int S3<C7>::foo() { return 3; }
+
+static_assert(S1<int>::foo<int>() == 1);
+static_assert(S2<int>::foo<int>() == 2);
+static_assert(S3<int>::foo<int>() == 3);
+
+} // namespace GH62003
Index: clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -1653,33 +1653,12 @@
         << QualifierLoc.getSourceRange();
       return nullptr;
     }
-
-    if (PrevClassTemplate) {
-      const ClassTemplateDecl *MostRecentPrevCT =
-          PrevClassTemplate->getMostRecentDecl();
-      TemplateParameterList *PrevParams =
-          MostRecentPrevCT->getTemplateParameters();
-
-      // Make sure the parameter lists match.
-      if (!SemaRef.TemplateParameterListsAreEqual(
-              D->getTemplatedDecl(), InstParams,
-              MostRecentPrevCT->getTemplatedDecl(), PrevParams, true,
-              Sema::TPL_TemplateMatch))
-        return nullptr;
-
-      // Do some additional validation, then merge default arguments
-      // from the existing declarations.
-      if (SemaRef.CheckTemplateParameterList(InstParams, PrevParams,
-                                             Sema::TPC_ClassTemplate))
-        return nullptr;
-    }
   }
 
   CXXRecordDecl *RecordInst = CXXRecordDecl::Create(
       SemaRef.Context, Pattern->getTagKind(), DC, Pattern->getBeginLoc(),
       Pattern->getLocation(), Pattern->getIdentifier(), PrevDecl,
       /*DelayTypeCreation=*/true);
-
   if (QualifierLoc)
     RecordInst->setQualifierInfo(QualifierLoc);
 
@@ -1689,16 +1668,37 @@
   ClassTemplateDecl *Inst
     = ClassTemplateDecl::Create(SemaRef.Context, DC, D->getLocation(),
                                 D->getIdentifier(), InstParams, RecordInst);
-  assert(!(isFriend && Owner->isDependentContext()));
-  Inst->setPreviousDecl(PrevClassTemplate);
-
   RecordInst->setDescribedClassTemplate(Inst);
 
   if (isFriend) {
-    if (PrevClassTemplate)
+    assert(!Owner->isDependentContext());
+    Inst->setLexicalDeclContext(Owner);
+    RecordInst->setLexicalDeclContext(Owner);
+
+    if (PrevClassTemplate) {
+      RecordInst->setTypeForDecl(
+          PrevClassTemplate->getTemplatedDecl()->getTypeForDecl());
+      const ClassTemplateDecl *MostRecentPrevCT =
+          PrevClassTemplate->getMostRecentDecl();
+      TemplateParameterList *PrevParams =
+          MostRecentPrevCT->getTemplateParameters();
+
+      // Make sure the parameter lists match.
+      if (!SemaRef.TemplateParameterListsAreEqual(
+              RecordInst, InstParams, MostRecentPrevCT->getTemplatedDecl(),
+              PrevParams, true, Sema::TPL_TemplateMatch))
+        return nullptr;
+
+      // Do some additional validation, then merge default arguments
+      // from the existing declarations.
+      if (SemaRef.CheckTemplateParameterList(InstParams, PrevParams,
+                                             Sema::TPC_ClassTemplate))
+        return nullptr;
+
       Inst->setAccess(PrevClassTemplate->getAccess());
-    else
+    } else {
       Inst->setAccess(D->getAccess());
+    }
 
     Inst->setObjectOfFriendDecl();
     // TODO: do we want to track the instantiation progeny of this
@@ -1709,15 +1709,15 @@
       Inst->setInstantiatedFromMemberTemplate(D);
   }
 
+  Inst->setPreviousDecl(PrevClassTemplate);
+
   // Trigger creation of the type for the instantiation.
-  SemaRef.Context.getInjectedClassNameType(RecordInst,
-                                    Inst->getInjectedClassNameSpecialization());
+  SemaRef.Context.getInjectedClassNameType(
+      RecordInst, Inst->getInjectedClassNameSpecialization());
 
   // Finish handling of friends.
   if (isFriend) {
     DC->makeDeclVisibleInContext(Inst);
-    Inst->setLexicalDeclContext(Owner);
-    RecordInst->setLexicalDeclContext(Owner);
     return Inst;
   }
 
Index: clang/lib/Sema/SemaTemplateInstantiate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateInstantiate.cpp
+++ clang/lib/Sema/SemaTemplateInstantiate.cpp
@@ -131,6 +131,14 @@
   return Response::Done();
 }
 
+Response HandlePartialClassTemplateSpec(
+    const ClassTemplatePartialSpecializationDecl *PartialClassTemplSpec,
+    MultiLevelTemplateArgumentList &Result, bool SkipForSpecialization) {
+  if (!SkipForSpecialization)
+      Result.addOuterRetainedLevels(PartialClassTemplSpec->getTemplateDepth());
+  return Response::Done();
+}
+
 // Add template arguments from a class template instantiation.
 Response
 HandleClassTemplateSpec(const ClassTemplateSpecializationDecl *ClassTemplSpec,
@@ -208,6 +216,10 @@
   return Response::UseNextDecl(Function);
 }
 
+Response HandleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) {
+  return Response::ChangeDecl(FTD->getLexicalDeclContext());
+}
+
 Response HandleRecordDecl(const CXXRecordDecl *Rec,
                           MultiLevelTemplateArgumentList &Result,
                           ASTContext &Context,
@@ -306,6 +318,10 @@
     if (const auto *VarTemplSpec =
             dyn_cast<VarTemplateSpecializationDecl>(CurDecl)) {
       R = HandleVarTemplateSpec(VarTemplSpec, Result, SkipForSpecialization);
+    } else if (const auto *PartialClassTemplSpec =
+                   dyn_cast<ClassTemplatePartialSpecializationDecl>(CurDecl)) {
+      R = HandlePartialClassTemplateSpec(PartialClassTemplSpec, Result,
+                                         SkipForSpecialization);
     } else if (const auto *ClassTemplSpec =
                    dyn_cast<ClassTemplateSpecializationDecl>(CurDecl)) {
       R = HandleClassTemplateSpec(ClassTemplSpec, Result,
@@ -318,6 +334,8 @@
     } else if (const auto *CSD =
                    dyn_cast<ImplicitConceptSpecializationDecl>(CurDecl)) {
       R = HandleImplicitConceptSpecializationDecl(CSD, Result);
+    } else if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(CurDecl)) {
+      R = HandleFunctionTemplateDecl(FTD);
     } else if (!isa<DeclContext>(CurDecl)) {
       R = Response::DontClearRelativeToPrimaryNextDecl(CurDecl);
       if (CurDecl->getDeclContext()->isTranslationUnit()) {
Index: clang/lib/Sema/SemaTemplateDeduction.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateDeduction.cpp
+++ clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2881,7 +2881,7 @@
   // not class-scope explicit specialization, so replace with Deduced Args
   // instead of adding to inner-most.
   if (NeedsReplacement)
-    MLTAL.replaceInnermostTemplateArguments(CanonicalDeducedArgs);
+    MLTAL.replaceInnermostTemplateArguments(Template, CanonicalDeducedArgs);
 
   if (S.CheckConstraintSatisfaction(Template, AssociatedConstraints, MLTAL,
                                     Info.getLocation(),
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -1294,7 +1294,7 @@
     // We check the return type and template parameter lists for function
     // templates first; the remaining checks follow.
     bool SameTemplateParameterList = TemplateParameterListsAreEqual(
-        NewTemplate->getTemplateParameters(),
+        NewTemplate, NewTemplate->getTemplateParameters(), OldTemplate,
         OldTemplate->getTemplateParameters(), false, TPL_TemplateMatch);
     bool SameReturnType = Context.hasSameType(Old->getDeclaredReturnType(),
                                               New->getDeclaredReturnType());
Index: clang/lib/Sema/SemaConcept.cpp
===================================================================
--- clang/lib/Sema/SemaConcept.cpp
+++ clang/lib/Sema/SemaConcept.cpp
@@ -260,6 +260,9 @@
     return SubstitutedAtomicExpr;
   }
 
+  if (SubstitutedAtomicExpr.get()->isValueDependent())
+    return SubstitutedAtomicExpr;
+
   EnterExpressionEvaluationContext ConstantEvaluated(
       S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
   SmallVector<PartialDiagnosticAt, 2> EvaluationDiags;
@@ -721,7 +724,7 @@
       ND, /*Final=*/false, /*Innermost=*/nullptr, /*RelativeToPrimary=*/true,
       /*Pattern=*/nullptr,
       /*ForConstraintInstantiation=*/true, SkipForSpecialization);
-  return MLTAL.getNumSubstitutedLevels();
+  return MLTAL.getNumLevels();
 }
 
 namespace {
@@ -752,27 +755,56 @@
   };
 } // namespace
 
+static const Expr *SubstituteConstraintExpression(Sema &S, const NamedDecl *ND,
+                                                  const Expr *ConstrExpr,
+                                                  const NamedDecl *OtherND) {
+  bool ForConstraintInstantiation = false;
+  // FIXME: Remove this workaround, currently it is necessary
+  // to avoid breaking the self-friendship case:
+  // template<Constraint C0>
+  // struct S {
+  //    template <Constraint C1>
+  //    friend class S;
+  // };
+  // Setting ForConstraintInstantiation = true makes
+  // getTemplateInstantiationArgs collect arguments from the
+  // enclosing template class. Substitution of these args into
+  // ConstrExpr for the inner template will properly adjust the depths.
+  if (isa<CXXRecordDecl>(ND) && isa<CXXRecordDecl>(OtherND))
+    ForConstraintInstantiation = true;
+  MultiLevelTemplateArgumentList MLTAL = S.getTemplateInstantiationArgs(
+      ND, /*Final=*/false, /*Innermost=*/nullptr,
+      /*RelativeToPrimary=*/true,
+      /*Pattern=*/nullptr, ForConstraintInstantiation,
+      /*SkipForSpecialization*/ false);
+  Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/false);
+  std::optional<Sema::CXXThisScopeRAII> ThisScope;
+  if (auto *RD = dyn_cast<CXXRecordDecl>(ND->getDeclContext()))
+    ThisScope.emplace(S, const_cast<CXXRecordDecl *>(RD), Qualifiers());
+  ExprResult SubstConstr =
+      S.SubstConstraintExpr(const_cast<clang::Expr *>(ConstrExpr), MLTAL);
+  if (SFINAE.hasErrorOccurred() || !SubstConstr.isUsable())
+    return nullptr;
+  return SubstConstr.get();
+}
+
 bool Sema::AreConstraintExpressionsEqual(const NamedDecl *Old,
                                          const Expr *OldConstr,
                                          const NamedDecl *New,
                                          const Expr *NewConstr) {
+  if (OldConstr == NewConstr)
+    return true;
   if (Old && New && Old != New) {
-    unsigned Depth1 = CalculateTemplateDepthForConstraints(
-        *this, Old);
-    unsigned Depth2 = CalculateTemplateDepthForConstraints(
-        *this, New);
-
-    // Adjust the 'shallowest' verison of this to increase the depth to match
-    // the 'other'.
-    if (Depth2 > Depth1) {
-      OldConstr = AdjustConstraintDepth(*this, Depth2 - Depth1)
-                      .TransformExpr(const_cast<Expr *>(OldConstr))
-                      .get();
-    } else if (Depth1 > Depth2) {
-      NewConstr = AdjustConstraintDepth(*this, Depth1 - Depth2)
-                      .TransformExpr(const_cast<Expr *>(NewConstr))
-                      .get();
-    }
+    if (const Expr *SubstConstr =
+            SubstituteConstraintExpression(*this, Old, OldConstr, New))
+      OldConstr = SubstConstr;
+    else
+      return false;
+    if (const Expr *SubstConstr =
+            SubstituteConstraintExpression(*this, New, NewConstr, Old))
+      NewConstr = SubstConstr;
+    else
+      return false;
   }
 
   llvm::FoldingSetNodeID ID1, ID2;
Index: clang/include/clang/Sema/Template.h
===================================================================
--- clang/include/clang/Sema/Template.h
+++ clang/include/clang/Sema/Template.h
@@ -232,9 +232,21 @@
     /// Replaces the current 'innermost' level with the provided argument list.
     /// This is useful for type deduction cases where we need to get the entire
     /// list from the AST, but then add the deduced innermost list.
-    void replaceInnermostTemplateArguments(ArgList Args) {
-      assert(TemplateArgumentLists.size() > 0 && "Replacing in an empty list?");
-      TemplateArgumentLists[0].Args = Args;
+    void replaceInnermostTemplateArguments(Decl *AssociatedDecl, ArgList Args) {
+      assert((!TemplateArgumentLists.empty() || NumRetainedOuterLevels) &&
+             "Replacing in an empty list?");
+
+      if (!TemplateArgumentLists.empty()) {
+        assert((TemplateArgumentLists[0].AssociatedDeclAndFinal.getPointer() ||
+                TemplateArgumentLists[0].AssociatedDeclAndFinal.getPointer() ==
+                    AssociatedDecl) &&
+               "Trying to change incorrect declaration?");
+        TemplateArgumentLists[0].Args = Args;
+      } else {
+        --NumRetainedOuterLevels;
+        TemplateArgumentLists.push_back(
+            {{AssociatedDecl, /*Final=*/false}, Args});
+      }
     }
 
     /// Add an outermost level that we are not substituting. We have no
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to