This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a12367c4b1 [NVSHMEM] Fix compatibility with CUDA code without nvshmem
use (#18222)
a12367c4b1 is described below
commit a12367c4b19bd4f13ccfd6c8f5acb63958f9c35f
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Aug 22 08:02:28 2025 -0400
[NVSHMEM] Fix compatibility with CUDA code without nvshmem use (#18222)
This PR fixes two bugs that cause normal TIR functions (ones that
don't use any NVSHMEM API) not being able to compile and run,
in cases where `set(USE_NVSHMEM xxx)` is enabled.
Co-authored-by: Bohan Hou <[email protected]>
---
python/tvm/contrib/nvcc.py | 9 ++++++---
src/runtime/contrib/nvshmem/init.cc | 6 ++++--
2 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index c79305a739..e9d8fac761 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -57,10 +57,13 @@ def compile_cuda(code, target_format=None, arch=None,
options=None, path_target=
"""
# Check for NVSHMEM dependency
nvshmem_include_path, nvshmem_lib_path = None, None
- use_nvshmem = (
- tvm.get_global_func("runtime.nvshmem.cumodule_init",
allow_missing=True) is not None
- )
+ use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in
code
if use_nvshmem:
+ # NOTE: we cannot check whether nvshmem is used based on whether
+ # the global function "runtime.nvshmem.cumodule_init" is defined.
+ # The reason is because that if the input code does not use any
NVSHMEM functions
+ # while the global function is defined, using cubin to compile the
+ # code may cause a compilation error.
target_format = "cubin"
nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()
diff --git a/src/runtime/contrib/nvshmem/init.cc
b/src/runtime/contrib/nvshmem/init.cc
index 1b0a65f4f1..4cb0558d61 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -114,8 +114,10 @@ void NVSHMEMXCumoduleInit(void* cuModule) {
// nvshmemx_cumodule_init. If not, we skip the cumodule initialization.
if (status == NVSHMEM_STATUS_IS_INITIALIZED || status ==
NVSHMEM_STATUS_LIMITED_MPG ||
status == NVSHMEM_STATUS_FULL_MPG) {
- int result = nvshmemx_cumodule_init(mod);
- ICHECK_EQ(result, 0) << "nvshmemx_cumodule_init failed with error code: "
<< result;
+ // NOTE: we do not check the return value of nvshmemx_cumodule_init.
+ // The reason is because that the input cuModule might not use any NVSHMEM
functions,
+ // in which case the nvshmemx_cumodule_init will fail.
+ nvshmemx_cumodule_init(mod);
}
}