This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a12367c4b1 [NVSHMEM] Fix compatibility with CUDA code without nvshmem 
use (#18222)
a12367c4b1 is described below

commit a12367c4b19bd4f13ccfd6c8f5acb63958f9c35f
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Aug 22 08:02:28 2025 -0400

    [NVSHMEM] Fix compatibility with CUDA code without nvshmem use (#18222)
    
    This PR fixes two bugs that cause normal TIR functions (ones that
    don't use any NVSHMEM API) not being able to compile and run,
    in cases where `set(USE_NVSHMEM xxx)` is enabled.
    
    Co-authored-by: Bohan Hou <[email protected]>
---
 python/tvm/contrib/nvcc.py          | 9 ++++++---
 src/runtime/contrib/nvshmem/init.cc | 6 ++++--
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index c79305a739..e9d8fac761 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -57,10 +57,13 @@ def compile_cuda(code, target_format=None, arch=None, 
options=None, path_target=
     """
     # Check for NVSHMEM dependency
     nvshmem_include_path, nvshmem_lib_path = None, None
-    use_nvshmem = (
-        tvm.get_global_func("runtime.nvshmem.cumodule_init", 
allow_missing=True) is not None
-    )
+    use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in 
code
     if use_nvshmem:
+        # NOTE: we cannot check whether nvshmem is used based on whether
+        # the global function "runtime.nvshmem.cumodule_init" is defined.
+        # The reason is because that if the input code does not use any 
NVSHMEM functions
+        # while the global function is defined, using cubin to compile the
+        # code may cause a compilation error.
         target_format = "cubin"
         nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()
 
diff --git a/src/runtime/contrib/nvshmem/init.cc 
b/src/runtime/contrib/nvshmem/init.cc
index 1b0a65f4f1..4cb0558d61 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -114,8 +114,10 @@ void NVSHMEMXCumoduleInit(void* cuModule) {
   // nvshmemx_cumodule_init. If not, we skip the cumodule initialization.
   if (status == NVSHMEM_STATUS_IS_INITIALIZED || status == 
NVSHMEM_STATUS_LIMITED_MPG ||
       status == NVSHMEM_STATUS_FULL_MPG) {
-    int result = nvshmemx_cumodule_init(mod);
-    ICHECK_EQ(result, 0) << "nvshmemx_cumodule_init failed with error code: " 
<< result;
+    // NOTE: we do not check the return value of nvshmemx_cumodule_init.
+    // The reason is because that the input cuModule might not use any NVSHMEM 
functions,
+    // in which case the nvshmemx_cumodule_init will fail.
+    nvshmemx_cumodule_init(mod);
   }
 }
 

Reply via email to