This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 7f3bb77 [DLPACK] Upgrade DLPack Exchange API to pass by capsule (#288)
7f3bb77 is described below
commit 7f3bb77155645f90f7d221889b3795704ffd7d6f
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Nov 26 17:44:26 2025 -0500
[DLPACK] Upgrade DLPack Exchange API to pass by capsule (#288)
This PR upgrade to support the case where dlpack exchange api is passed
by capsule, we still keep backward compact for cases where it is passed
by int.
---
.../torch_c_dlpack_ext/torch_c_dlpack_ext/core.py | 7 ++++-
python/tvm_ffi/_optional_torch_c_dlpack.py | 36 ++++++++++++++++++++--
python/tvm_ffi/cython/function.pxi | 4 +--
python/tvm_ffi/cython/tensor.pxi | 18 +++++++++++
tests/python/test_cubin_launcher.py | 20 ++++++++++++
tests/python/test_dlpack_exchange_api.py | 20 ++++++++++--
6 files changed, 97 insertions(+), 8 deletions(-)
diff --git a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
index a2030ea..d0313a7 100644
--- a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
+++ b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
@@ -46,7 +46,12 @@ def load_torch_c_dlpack_extension() -> None:
func = lib.TorchDLPackExchangeAPIPtr
func.restype = ctypes.c_uint64
func.argtypes = []
- setattr(torch.Tensor, "__c_dlpack_exchange_api__", func())
+ # note: we need to keep this behavior for a while
+ # to ensure backward compatibility with older versions dependencies
+ # that relies on the value being int.
+ # We will do eager upgrade to PyCapsule in the tvm-ffi side instead.
+ dlpack_exchange_api_ptr_as_int = func()
+ setattr(torch.Tensor, "__c_dlpack_exchange_api__",
dlpack_exchange_api_ptr_as_int)
return lib
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 949d79f..7756926 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -44,6 +44,32 @@ from typing import Any
logger = logging.getLogger(__name__) # type: ignore
+def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
+ """Create a PyCapsule wrapping the DLPack exchange API pointer.
+
+ Parameters
+ ----------
+ ptr_as_int : int
+ The pointer to the DLPack exchange API as an integer.
+
+ Returns
+ -------
+ capsule : PyCapsule
+ A PyCapsule object wrapping the pointer with name
"dlpack_exchange_api".
+
+ """
+ capsule_name = b"dlpack_exchange_api"
+ pythonapi = ctypes.pythonapi
+ pythonapi.PyCapsule_New.restype = ctypes.py_object
+ pythonapi.PyCapsule_New.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_char_p,
+ ctypes.c_void_p,
+ ]
+ capsule = pythonapi.PyCapsule_New(ctypes.c_void_p(ptr_as_int),
capsule_name, None)
+ return capsule
+
+
def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912, PLR0915
try:
import torch # noqa: PLC0415
@@ -61,6 +87,11 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
import torch_c_dlpack_ext # type: ignore # noqa: PLC0415, F401
if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+ if isinstance(torch.Tensor.__c_dlpack_exchange_api__, int):
+ # Brings up to speed with the new PyCapsule behavior
+ torch.Tensor.__c_dlpack_exchange_api__ =
_create_dlpack_exchange_api_capsule(
+ torch.Tensor.__c_dlpack_exchange_api__
+ )
return None
except ImportError:
pass
@@ -118,9 +149,10 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
func = lib.TorchDLPackExchangeAPIPtr
func.restype = ctypes.c_uint64
func.argtypes = []
-
+ # Create a PyCapsule from the pointer
+ capsule = _create_dlpack_exchange_api_capsule(func())
# Set the DLPackExchangeAPI pointer on the class
- setattr(torch.Tensor, "__c_dlpack_exchange_api__", func())
+ setattr(torch.Tensor, "__c_dlpack_exchange_api__", capsule)
return lib
except ImportError:
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 800427b..189a6fc 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -677,6 +677,7 @@ cdef int TVMFFIPyArgSetterFloatProtocol_(
cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
+
cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out)
except -1:
"""
Factory function that creates an argument setter for a given Python
argument type.
@@ -728,8 +729,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
# This is checked on the CLASS, not the instance
if hasattr(arg_class, "__c_dlpack_exchange_api__"):
out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
- temp_ptr = arg_class.__c_dlpack_exchange_api__
- out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long
long>temp_ptr)
+ _get_dlpack_exchange_api(arg_class.__c_dlpack_exchange_api__,
&(out.c_dlpack_exchange_api))
return 0
if hasattr(arg_class, "__cuda_stream__"):
# cuda stream protocol
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 614e487..0521dcb 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -30,6 +30,24 @@ cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"
cdef const char* _c_str_dltensor_versioned = "dltensor_versioned"
cdef const char* _c_str_used_dltensor_versioned = "used_dltensor_versioned"
+cdef const char* _c_str_dlpack_exchange_api = "dlpack_exchange_api"
+
+
+cdef int _get_dlpack_exchange_api(
+ object dlpack_exchange_api_obj,
+ const DLPackExchangeAPI** out_ptr
+) except -1:
+ if isinstance(dlpack_exchange_api_obj, int):
+ out_ptr[0] = <const DLPackExchangeAPI*>(<long
long>dlpack_exchange_api_obj)
+ return 0
+
+ if pycapsule.PyCapsule_IsValid(dlpack_exchange_api_obj,
_c_str_dlpack_exchange_api):
+ out_ptr[0] = <const DLPackExchangeAPI*>pycapsule.PyCapsule_GetPointer(
+ dlpack_exchange_api_obj, _c_str_dlpack_exchange_api
+ )
+ return 0
+ raise ValueError("Expect a dlpack_exchange_api field")
+
cdef void _c_dlpack_deleter(object pycaps):
cdef DLManagedTensor* dltensor
diff --git a/tests/python/test_cubin_launcher.py
b/tests/python/test_cubin_launcher.py
index 3c0a683..d2e4ff0 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -43,6 +43,20 @@ def _is_cuda_available() -> bool:
return torch.cuda.is_available()
+def _is_cuda_version_greater_than_13() -> bool:
+ """Check if CUDA version is greater than 13.0."""
+ if torch is None or not torch.cuda.is_available():
+ return False
+ if torch.version.cuda is None:
+ return False
+ try:
+ # Parse version string into tuple of integers (e.g., "12.1" -> (12, 1))
+ version_parts = tuple(int(x) for x in torch.version.cuda.split("."))
+ return version_parts > (13, 0)
+ except (ValueError, TypeError, AttributeError):
+ return False
+
+
def _compile_kernel_to_cubin() -> bytes:
"""Compile simple CUDA kernels to CUBIN.
@@ -88,6 +102,9 @@ def _compile_kernel_to_cubin() -> bytes:
@pytest.mark.skipif(sys.platform != "linux", reason="CUBIN launcher only
supported on Linux")
@pytest.mark.skipif(torch is None, reason="PyTorch not installed")
@pytest.mark.skipif(not _is_cuda_available(), reason="CUDA not available")
[email protected](
+ not _is_cuda_version_greater_than_13(), reason="CUDA version must be
greater than 13.0"
+)
def test_cubin_launcher_add_one() -> None:
"""Test loading and launching add_one kernel from CUBIN."""
assert torch is not None, "PyTorch is required for this test"
@@ -212,6 +229,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_mul_two,
cubin_test::LaunchMulTwo);
@pytest.mark.skipif(sys.platform != "linux", reason="CUBIN launcher only
supported on Linux")
@pytest.mark.skipif(torch is None, reason="PyTorch not installed")
@pytest.mark.skipif(not _is_cuda_available(), reason="CUDA not available")
[email protected](
+ not _is_cuda_version_greater_than_13(), reason="CUDA version must be
greater than 13.0"
+)
def test_cubin_launcher_chained() -> None:
"""Test chaining multiple kernel launches."""
assert torch is not None, "PyTorch is required for this test"
diff --git a/tests/python/test_dlpack_exchange_api.py
b/tests/python/test_dlpack_exchange_api.py
index 11f93ae..048ade5 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -18,6 +18,7 @@
from __future__ import annotations
+import ctypes
import sys
import pytest
@@ -46,9 +47,22 @@ def test_dlpack_exchange_api() -> None:
assert torch is not None
assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__")
- api_ptr = torch.Tensor.__c_dlpack_exchange_api__
- assert isinstance(api_ptr, int), "API pointer should be an integer"
- assert api_ptr != 0, "API pointer should not be NULL"
+ api_attr = torch.Tensor.__c_dlpack_exchange_api__
+
+ # Handle both PyCapsule and integer types
+ if isinstance(api_attr, int):
+ # Direct integer pointer
+ api_ptr = api_attr
+ assert api_ptr != 0, "API pointer should not be NULL"
+ else:
+ # PyCapsule - extract the pointer as integer
+ pythonapi = ctypes.pythonapi
+ # Set restype to c_size_t to get integer directly (avoids c_void_p
quirks)
+ pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
+ pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object,
ctypes.c_char_p]
+ capsule_name = b"dlpack_exchange_api"
+ api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, capsule_name)
+ assert api_ptr != 0, "API pointer from PyCapsule should not be NULL"
tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)