This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f3bb77  [DLPACK] Upgrade DLPack Exchange API to pass by capsule (#288)
7f3bb77 is described below

commit 7f3bb77155645f90f7d221889b3795704ffd7d6f
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Nov 26 17:44:26 2025 -0500

    [DLPACK] Upgrade DLPack Exchange API to pass by capsule (#288)
    
    This PR upgrade to support the case where dlpack exchange api is passed
    by capsule, we still keep backward compact for cases where it is passed
    by int.
---
 .../torch_c_dlpack_ext/torch_c_dlpack_ext/core.py  |  7 ++++-
 python/tvm_ffi/_optional_torch_c_dlpack.py         | 36 ++++++++++++++++++++--
 python/tvm_ffi/cython/function.pxi                 |  4 +--
 python/tvm_ffi/cython/tensor.pxi                   | 18 +++++++++++
 tests/python/test_cubin_launcher.py                | 20 ++++++++++++
 tests/python/test_dlpack_exchange_api.py           | 20 ++++++++++--
 6 files changed, 97 insertions(+), 8 deletions(-)

diff --git a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py 
b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
index a2030ea..d0313a7 100644
--- a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
+++ b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
@@ -46,7 +46,12 @@ def load_torch_c_dlpack_extension() -> None:
     func = lib.TorchDLPackExchangeAPIPtr
     func.restype = ctypes.c_uint64
     func.argtypes = []
-    setattr(torch.Tensor, "__c_dlpack_exchange_api__", func())
+    # note: we need to keep this behavior for a while
+    # to ensure backward compatibility with older versions dependencies
+    # that relies on the value being int.
+    # We will do eager upgrade to PyCapsule in the tvm-ffi side instead.
+    dlpack_exchange_api_ptr_as_int = func()
+    setattr(torch.Tensor, "__c_dlpack_exchange_api__", 
dlpack_exchange_api_ptr_as_int)
     return lib
 
 
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 949d79f..7756926 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -44,6 +44,32 @@ from typing import Any
 logger = logging.getLogger(__name__)  # type: ignore
 
 
+def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
+    """Create a PyCapsule wrapping the DLPack exchange API pointer.
+
+    Parameters
+    ----------
+    ptr_as_int : int
+        The pointer to the DLPack exchange API as an integer.
+
+    Returns
+    -------
+    capsule : PyCapsule
+        A PyCapsule object wrapping the pointer with name 
"dlpack_exchange_api".
+
+    """
+    capsule_name = b"dlpack_exchange_api"
+    pythonapi = ctypes.pythonapi
+    pythonapi.PyCapsule_New.restype = ctypes.py_object
+    pythonapi.PyCapsule_New.argtypes = [
+        ctypes.c_void_p,
+        ctypes.c_char_p,
+        ctypes.c_void_p,
+    ]
+    capsule = pythonapi.PyCapsule_New(ctypes.c_void_p(ptr_as_int), 
capsule_name, None)
+    return capsule
+
+
 def load_torch_c_dlpack_extension() -> Any:  # noqa: PLR0912, PLR0915
     try:
         import torch  # noqa: PLC0415
@@ -61,6 +87,11 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
         import torch_c_dlpack_ext  # type: ignore  # noqa: PLC0415, F401
 
         if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+            if isinstance(torch.Tensor.__c_dlpack_exchange_api__, int):
+                # Brings up to speed with the new PyCapsule behavior
+                torch.Tensor.__c_dlpack_exchange_api__ = 
_create_dlpack_exchange_api_capsule(
+                    torch.Tensor.__c_dlpack_exchange_api__
+                )
             return None
     except ImportError:
         pass
@@ -118,9 +149,10 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
         func = lib.TorchDLPackExchangeAPIPtr
         func.restype = ctypes.c_uint64
         func.argtypes = []
-
+        # Create a PyCapsule from the pointer
+        capsule = _create_dlpack_exchange_api_capsule(func())
         # Set the DLPackExchangeAPI pointer on the class
-        setattr(torch.Tensor, "__c_dlpack_exchange_api__", func())
+        setattr(torch.Tensor, "__c_dlpack_exchange_api__", capsule)
 
         return lib
     except ImportError:
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 800427b..189a6fc 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -677,6 +677,7 @@ cdef int TVMFFIPyArgSetterFloatProtocol_(
 cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
 cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
 
+
 cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) 
except -1:
     """
     Factory function that creates an argument setter for a given Python 
argument type.
@@ -728,8 +729,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         # This is checked on the CLASS, not the instance
         if hasattr(arg_class, "__c_dlpack_exchange_api__"):
             out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
-            temp_ptr = arg_class.__c_dlpack_exchange_api__
-            out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long 
long>temp_ptr)
+            _get_dlpack_exchange_api(arg_class.__c_dlpack_exchange_api__, 
&(out.c_dlpack_exchange_api))
             return 0
     if hasattr(arg_class, "__cuda_stream__"):
         # cuda stream protocol
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 614e487..0521dcb 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -30,6 +30,24 @@ cdef const char* _c_str_dltensor = "dltensor"
 cdef const char* _c_str_used_dltensor = "used_dltensor"
 cdef const char* _c_str_dltensor_versioned = "dltensor_versioned"
 cdef const char* _c_str_used_dltensor_versioned = "used_dltensor_versioned"
+cdef const char* _c_str_dlpack_exchange_api = "dlpack_exchange_api"
+
+
+cdef int _get_dlpack_exchange_api(
+    object dlpack_exchange_api_obj,
+    const DLPackExchangeAPI** out_ptr
+) except -1:
+    if isinstance(dlpack_exchange_api_obj, int):
+        out_ptr[0] = <const DLPackExchangeAPI*>(<long 
long>dlpack_exchange_api_obj)
+        return 0
+
+    if pycapsule.PyCapsule_IsValid(dlpack_exchange_api_obj, 
_c_str_dlpack_exchange_api):
+        out_ptr[0] = <const DLPackExchangeAPI*>pycapsule.PyCapsule_GetPointer(
+            dlpack_exchange_api_obj, _c_str_dlpack_exchange_api
+        )
+        return 0
+    raise ValueError("Expect a dlpack_exchange_api field")
+
 
 cdef void _c_dlpack_deleter(object pycaps):
     cdef DLManagedTensor* dltensor
diff --git a/tests/python/test_cubin_launcher.py 
b/tests/python/test_cubin_launcher.py
index 3c0a683..d2e4ff0 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -43,6 +43,20 @@ def _is_cuda_available() -> bool:
     return torch.cuda.is_available()
 
 
+def _is_cuda_version_greater_than_13() -> bool:
+    """Check if CUDA version is greater than 13.0."""
+    if torch is None or not torch.cuda.is_available():
+        return False
+    if torch.version.cuda is None:
+        return False
+    try:
+        # Parse version string into tuple of integers (e.g., "12.1" -> (12, 1))
+        version_parts = tuple(int(x) for x in torch.version.cuda.split("."))
+        return version_parts > (13, 0)
+    except (ValueError, TypeError, AttributeError):
+        return False
+
+
 def _compile_kernel_to_cubin() -> bytes:
     """Compile simple CUDA kernels to CUBIN.
 
@@ -88,6 +102,9 @@ def _compile_kernel_to_cubin() -> bytes:
 @pytest.mark.skipif(sys.platform != "linux", reason="CUBIN launcher only 
supported on Linux")
 @pytest.mark.skipif(torch is None, reason="PyTorch not installed")
 @pytest.mark.skipif(not _is_cuda_available(), reason="CUDA not available")
[email protected](
+    not _is_cuda_version_greater_than_13(), reason="CUDA version must be 
greater than 13.0"
+)
 def test_cubin_launcher_add_one() -> None:
     """Test loading and launching add_one kernel from CUBIN."""
     assert torch is not None, "PyTorch is required for this test"
@@ -212,6 +229,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_mul_two, 
cubin_test::LaunchMulTwo);
 @pytest.mark.skipif(sys.platform != "linux", reason="CUBIN launcher only 
supported on Linux")
 @pytest.mark.skipif(torch is None, reason="PyTorch not installed")
 @pytest.mark.skipif(not _is_cuda_available(), reason="CUDA not available")
[email protected](
+    not _is_cuda_version_greater_than_13(), reason="CUDA version must be 
greater than 13.0"
+)
 def test_cubin_launcher_chained() -> None:
     """Test chaining multiple kernel launches."""
     assert torch is not None, "PyTorch is required for this test"
diff --git a/tests/python/test_dlpack_exchange_api.py 
b/tests/python/test_dlpack_exchange_api.py
index 11f93ae..048ade5 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -18,6 +18,7 @@
 
 from __future__ import annotations
 
+import ctypes
 import sys
 
 import pytest
@@ -46,9 +47,22 @@ def test_dlpack_exchange_api() -> None:
     assert torch is not None
 
     assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__")
-    api_ptr = torch.Tensor.__c_dlpack_exchange_api__
-    assert isinstance(api_ptr, int), "API pointer should be an integer"
-    assert api_ptr != 0, "API pointer should not be NULL"
+    api_attr = torch.Tensor.__c_dlpack_exchange_api__
+
+    # Handle both PyCapsule and integer types
+    if isinstance(api_attr, int):
+        # Direct integer pointer
+        api_ptr = api_attr
+        assert api_ptr != 0, "API pointer should not be NULL"
+    else:
+        # PyCapsule - extract the pointer as integer
+        pythonapi = ctypes.pythonapi
+        # Set restype to c_size_t to get integer directly (avoids c_void_p 
quirks)
+        pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
+        pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, 
ctypes.c_char_p]
+        capsule_name = b"dlpack_exchange_api"
+        api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, capsule_name)
+        assert api_ptr != 0, "API pointer from PyCapsule should not be NULL"
 
     tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
 

Reply via email to