Author: Weibo He
Date: 2026-06-17T19:38:17+08:00
New Revision: 564e83191cc5686429cefe69aff81c2aeb268f5a

URL: 
https://github.com/llvm/llvm-project/commit/564e83191cc5686429cefe69aff81c2aeb268f5a
DIFF: 
https://github.com/llvm/llvm-project/commit/564e83191cc5686429cefe69aff81c2aeb268f5a.diff

LOG: [clang][Sema][CUDA] Restrict immediate template resolution to host-device 
functions (#200662)

Since overload resolution gives higher priority to `__host__` and
`__device__` attributes, HD functions may favor template candidates even
when a non‑template candidate would be a perfect match. This patch
resolves templates eagerly only for HD functions, not for all code
compiled with `-x cuda`, thus preventing valid host code from being
rejected.

Close #200545

Added: 
    clang/test/SemaCUDA/pr200545.cu

Modified: 
    clang/docs/ReleaseNotes.rst
    clang/include/clang/Sema/Overload.h
    clang/lib/Sema/SemaOverload.cpp

Removed: 
    


################################################################################
diff  --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 89909fe27cbb9..930f26ce8f5d8 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -701,6 +701,7 @@ Bug Fixes in This Version
   an array via an element-at-a-time copy loop (#GH192026)
 - Fixed an issue where certain designated initializers would be rejected for 
constexpr variables. (#GH193373)
 - Fixed a crash when ``#embed`` is used with C++ modules (#GH195350)
+- Fixed a bug where ``-x cuda`` caused clang to immediately resolve templates 
that should not be. (#GH200545)
 - Fixed an issue where ``__typeof_unqual`` and ``__typeof_unqual__`` were 
rejected as a declaration specifier in block scope in C++.
 - Fixed crash when checking for overflow for unary operator that can't 
overflow (#GH170072)
 - Clang no longer handles a `" q-char-sequence "` header name as a string 
literal (#GH132643).

diff  --git a/clang/include/clang/Sema/Overload.h 
b/clang/include/clang/Sema/Overload.h
index d42963e325b58..1e412ff6fc9e2 100644
--- a/clang/include/clang/Sema/Overload.h
+++ b/clang/include/clang/Sema/Overload.h
@@ -1353,7 +1353,7 @@ class Sema;
     bool shouldDeferDiags(Sema &S, ArrayRef<Expr *> Args, SourceLocation 
OpLoc);
 
     // Whether the resolution of template candidates should be deferred
-    bool shouldDeferTemplateArgumentDeduction(const LangOptions &Opts) const;
+    bool shouldDeferTemplateArgumentDeduction(const Sema &S) const;
 
     /// Determine when this overload candidate will be new to the
     /// overload set.
@@ -1545,22 +1545,6 @@ class Sema;
   // good candidate as we can get, despite the fact that it takes one less
   // parameter.
   bool shouldEnforceArgLimit(bool PartialOverloading, FunctionDecl *Function);
-
-  inline bool OverloadCandidateSet::shouldDeferTemplateArgumentDeduction(
-      const LangOptions &Opts) const {
-    return
-        // For user defined conversion we need to check against 
diff erent
-        // combination of CV qualifiers and look at any explicit specifier, so
-        // always deduce template candidates.
-        Kind != CSK_InitByUserDefinedConversion
-        // When doing code completion, we want to see all the
-        // viable candidates.
-        && Kind != CSK_CodeCompletion
-        // CUDA may prefer template candidates even when a non-candidate
-        // is a perfect match
-        && !Opts.CUDA;
-  }
-
 } // namespace clang
 
 #endif // LLVM_CLANG_SEMA_OVERLOAD_H

diff  --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index a5bd32c35e758..c663765573612 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -8196,7 +8196,7 @@ void Sema::AddMethodTemplateCandidate(
     return;
 
   if (ExplicitTemplateArgs ||
-      !CandidateSet.shouldDeferTemplateArgumentDeduction(getLangOpts())) {
+      !CandidateSet.shouldDeferTemplateArgumentDeduction(*this)) {
     AddMethodTemplateCandidateImmediately(
         *this, CandidateSet, MethodTmpl, FoundDecl, ActingContext,
         ExplicitTemplateArgs, ObjectType, ObjectClassification, Args,
@@ -8326,7 +8326,7 @@ void Sema::AddTemplateOverloadCandidate(
   bool DependentExplicitSpecifier = hasDependentExplicit(FunctionTemplate);
 
   if (ExplicitTemplateArgs ||
-      !CandidateSet.shouldDeferTemplateArgumentDeduction(getLangOpts()) ||
+      !CandidateSet.shouldDeferTemplateArgumentDeduction(*this) ||
       (isa<CXXConstructorDecl>(FunctionTemplate->getTemplatedDecl()) &&
        DependentExplicitSpecifier)) {
 
@@ -8764,7 +8764,7 @@ void Sema::AddTemplateConversionCandidate(
   if (!CandidateSet.isNewCandidate(FunctionTemplate))
     return;
 
-  if (!CandidateSet.shouldDeferTemplateArgumentDeduction(getLangOpts()) ||
+  if (!CandidateSet.shouldDeferTemplateArgumentDeduction(*this) ||
       CandidateSet.getKind() ==
           OverloadCandidateSet::CSK_InitByUserDefinedConversion ||
       CandidateSet.getKind() == OverloadCandidateSet::CSK_InitByConstructor) {
@@ -11585,7 +11585,7 @@ OverloadingResult 
OverloadCandidateSet::BestViableFunction(Sema &S,
                                                            SourceLocation Loc,
                                                            iterator &Best) {
 
-  assert((shouldDeferTemplateArgumentDeduction(S.getLangOpts()) ||
+  assert((shouldDeferTemplateArgumentDeduction(S) ||
           DeferredCandidatesCount == 0) &&
          "Unexpected deferred template candidates");
 
@@ -13537,6 +13537,28 @@ void OverloadCandidateSet::NoteCandidates(Sema &S, 
ArrayRef<Expr *> Args,
   }
 }
 
+bool OverloadCandidateSet::shouldDeferTemplateArgumentDeduction(
+    const Sema &S) const {
+  if (S.getLangOpts().CUDA) {
+    auto *Caller = S.getCurFunctionDecl(true);
+    // Overloading based on __host__ and __device__ attributes takes
+    // higher priority, HD functions may favor template candidates even when a
+    // non-template candidate would be a perfect match.
+    if (Caller && Caller->hasAttr<CUDAHostAttr>() &&
+        Caller->hasAttr<CUDADeviceAttr>())
+      return false;
+  }
+
+  return
+      // For user defined conversion we need to check against 
diff erent
+      // combination of CV qualifiers and look at any explicit specifier, so
+      // always deduce template candidates.
+      Kind != CSK_InitByUserDefinedConversion
+      // When doing code completion, we want to see all the
+      // viable candidates.
+      && Kind != CSK_CodeCompletion;
+}
+
 static SourceLocation
 GetLocationForCandidate(const TemplateSpecCandidate *Cand) {
   return Cand->Specialization ? Cand->Specialization->getLocation()

diff  --git a/clang/test/SemaCUDA/pr200545.cu b/clang/test/SemaCUDA/pr200545.cu
new file mode 100644
index 0000000000000..b839cf134ed7a
--- /dev/null
+++ b/clang/test/SemaCUDA/pr200545.cu
@@ -0,0 +1,109 @@
+// Test that template argument deduction is deferred correctly.
+//
+// RUN: %clang_cc1 -std=c++20 -fsyntax-only -verify 
-verify-ignore-unexpected=note %s
+
+#include "Inputs/cuda.h"
+
+namespace h_free_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  void fn(int) {}
+  void fn(DoNotDeduct auto) {}
+
+  void call() {
+    fn(0);
+    fn(nullptr); // expected-error@-9 {{static assertion failed due to 
requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace h_member_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  struct A {
+    void operator=(int) {}
+    void operator=(DoNotDeduct auto) {}
+  };
+
+  void call(A a) {
+    a.operator=(0);
+    a.operator=(nullptr); // expected-error@-11 {{static assertion failed due 
to requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace h_conversion_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  struct A {
+    operator int();
+    template<DoNotDeduct T> operator T();
+  };
+
+  void call(A a) {
+    switch (a) {}
+    (void)float(a); // expected-error@-11 {{static assertion failed due to 
requirement 'sizeof(float) == 0'}}
+  }
+}
+
+namespace hd_free_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  __host__ __device__ void fn(int) {}
+  __host__ __device__ void fn(DoNotDeduct auto) {}
+
+  __host__ __device__ void call() {
+    fn(0); // expected-error@-8 {{static assertion failed due to requirement 
'sizeof(int) == 0'}}
+    fn(nullptr); // expected-error@-9 {{static assertion failed due to 
requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace hd_member_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  struct A {
+    __host__ __device__ void operator=(int) {}
+    __host__ __device__ void operator=(DoNotDeduct auto) {}
+  };
+
+  __host__ __device__ void call(A a) {
+    a.operator=(0); // expected-error@-10 {{static assertion failed due to 
requirement 'sizeof(int) == 0'}}
+    a.operator=(nullptr); // expected-error@-11 {{static assertion failed due 
to requirement 'sizeof(std::nullptr_t) == 0'}}
+  }
+}
+
+namespace hd_conversion_call {
+  template<class T>
+  concept DoNotDeduct = []() {
+    static_assert(sizeof(T) == 0);
+    return true;
+  }();
+
+  struct A {
+    __host__ __device__ operator int();
+    template<DoNotDeduct T> __host__ __device__ operator T();
+  };
+
+  __host__ __device__ void call(A a) {
+    switch (a) {} // expected-error@-10 {{static assertion failed due to 
requirement 'sizeof(int) == 0'}}
+    (void)float(a); // expected-error@-11 {{static assertion failed due to 
requirement 'sizeof(float) == 0'}}
+  }
+}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to