This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c85e56  [STREAM] Enable compact with cuda-python driver stream (#236)
6c85e56 is described below

commit 6c85e562c00f098743ed257ff4516a250f5145e7
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 7 14:07:05 2025 -0500

    [STREAM] Enable compact with cuda-python driver stream (#236)
    
    As of now cuda-python driver stream do not yet support cuda stream
    protocol. This PR enables a compact mode so we can take
    cuda_driver.CUstream arguments and treat them as void_p.
---
 python/tvm_ffi/_optional_torch_c_dlpack.py |  4 ++--
 python/tvm_ffi/cython/function.pxi         | 27 +++++++++++++++++++++++++--
 tests/python/test_stream.py                | 19 +++++++++++++++++++
 3 files changed, 46 insertions(+), 4 deletions(-)

diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 302a29c..7673c03 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -109,7 +109,7 @@ def load_torch_c_dlpack_extension() -> Any:
     return None
 
 
-def patch_torch_cuda_stream_protocol() -> Any:
+def patch_torch_cuda_stream_protocol() -> None:
     """Load the torch cuda stream protocol for older versions of torch."""
     try:
         import torch  # noqa: PLC0415
@@ -118,7 +118,7 @@ def patch_torch_cuda_stream_protocol() -> Any:
             return
         if not hasattr(torch.cuda.Stream, "__cuda_stream__"):
 
-            def __torch_cuda_stream__(self: torch.cuda.Stream) -> tuple[int, 
torch.cuda.Stream]:
+            def __torch_cuda_stream__(self: torch.cuda.Stream) -> tuple[int, 
int]:
                 """Return the version number and the cuda stream."""
                 return (0, self.cuda_stream)
 
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 4f9fca6..acfe3e2 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -33,9 +33,15 @@ if os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "0":
         import numpy
     except ImportError:
         numpy = None
+
+    try:
+        from cuda.bindings import driver as cuda_driver
+    except ImportError:
+        cuda_driver = None
 else:
     torch = None
     numpy = None
+    cuda_driver = None
 
 
 cdef int _RELEASE_GIL_BY_DEFAULT = int(
@@ -287,7 +293,7 @@ cdef int TVMFFIPyArgSetterFFIObjectCompatible_(
     return 0
 
 
-cdef int TVMFFIPyArgSetterCUDAStream_(
+cdef int TVMFFIPyArgSetterCUDAStreamProtocol_(
     TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
     PyObject* py_arg, TVMFFIAny* out
 ) except -1:
@@ -301,6 +307,19 @@ cdef int TVMFFIPyArgSetterCUDAStream_(
     return 0
 
 
+cdef int TVMFFIPyArgSetterCUDADriverStreamFallback_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for cuda.bindings.driver.CUstream as a fallback without 
__cuda_stream__ protocol"""
+    cdef object arg = <object>py_arg
+    # call driver stream
+    cdef long long long_ptr = int(arg)
+    out.type_index = kTVMFFIOpaquePtr
+    out.v_ptr = <void*>long_ptr
+    return 0
+
+
 cdef int TVMFFIPyArgSetterDType_(
     TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
     PyObject* py_arg, TVMFFIAny* out
@@ -658,7 +677,11 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
             return 0
     if hasattr(arg_class, "__cuda_stream__"):
         # cuda stream protocol
-        out.func = TVMFFIPyArgSetterCUDAStream_
+        out.func = TVMFFIPyArgSetterCUDAStreamProtocol_
+        return 0
+    if cuda_driver is not None and isinstance(arg, cuda_driver.CUstream):
+        # TODO(tqchen): remove this once cuda-python supports __cuda_stream__ 
protocol
+        out.func = TVMFFIPyArgSetterCUDADriverStreamFallback_
         return 0
     if torch is not None and isinstance(arg, torch.Tensor):
         out.func = TVMFFIPyArgSetterTorchFallback_
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index fbe1a0a..3b58ccb 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import ctypes
 from types import ModuleType
 
 import pytest
@@ -30,6 +31,24 @@ except ImportError:
     torch = None
 
 
+try:
+    from cuda.bindings import driver as cuda_driver  # type: 
ignore[import-not-found]
+except ImportError:
+    cuda_driver = None
+
+
[email protected](cuda_driver is None, reason="Requires cuda-python")
+def test_cuda_driver_stream() -> None:
+    assert cuda_driver is not None
+    echo = tvm_ffi.get_global_func("testing.echo")
+    stream = cuda_driver.CUstream(0)
+    y = echo(stream)
+    assert y is not None
+    z = echo(cuda_driver.CUstream(1))
+    assert isinstance(z, ctypes.c_void_p)
+    assert z.value == 1
+
+
 def gen_check_stream_mod() -> tvm_ffi.Module:
     return tvm_ffi.cpp.load_inline(
         name="check_stream",

Reply via email to