This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 5393647  [DLPack] Further compatibility (#315)
5393647 is described below

commit 539364726ea51d5ea4c7695bf9a1a6cfc37ca8cc
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Dec 5 12:23:26 2025 -0500

    [DLPack] Further compatibility (#315)
    
    This PR updates the DLPack support to be compatible to name
    c_dlpack_exchange_api and dlpack_c_exchange_api in case of future DLPack
    naming changes. Backward compatibility is kept.
---
 addons/torch_c_dlpack_ext/build_backend.py         |  6 ++--
 .../torch_c_dlpack_ext/torch_c_dlpack_ext/core.py  | 25 +++++++++++++++-
 python/tvm_ffi/_optional_torch_c_dlpack.py         | 33 +++++++++++++++-------
 python/tvm_ffi/core.pyi                            |  2 +-
 python/tvm_ffi/cython/base.pxi                     |  4 +--
 python/tvm_ffi/cython/function.pxi                 | 12 ++++----
 python/tvm_ffi/cython/tensor.pxi                   | 17 ++++++-----
 python/tvm_ffi/cython/tvm_ffi_python_helpers.h     | 22 +++++++--------
 tests/python/test_dlpack_exchange_api.py           | 32 ++++++++-------------
 tests/python/test_load_inline.py                   |  4 +--
 10 files changed, 93 insertions(+), 64 deletions(-)

diff --git a/addons/torch_c_dlpack_ext/build_backend.py 
b/addons/torch_c_dlpack_ext/build_backend.py
index e4504fa..639977e 100644
--- a/addons/torch_c_dlpack_ext/build_backend.py
+++ b/addons/torch_c_dlpack_ext/build_backend.py
@@ -66,9 +66,11 @@ def build_wheel(
         # build wheel from sdist package, compile the torch c dlpack ext 
library locally.
         import torch  # noqa: PLC0415
 
-        if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+        if hasattr(torch.Tensor, "__dlpack_c_exchange_api__") or hasattr(
+            torch.Tensor, "__c_dlpack_exchange_api__"
+        ):
             print(
-                "torch.Tensor already has attribute __c_dlpack_exchange_api__. 
"
+                "torch.Tensor already has attribute __dlpack_c_exchange_api__. 
"
                 "No need to build any torch c dlpackc libs."
             )
         else:
diff --git a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py 
b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
index d0313a7..6b63e0a 100644
--- a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
+++ b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
@@ -19,14 +19,31 @@
 import ctypes
 import sys
 from pathlib import Path
+from typing import Any
 
 import torch
 from packaging.version import Version
 
 
+def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
+    """Create a PyCapsule wrapping the DLPack exchange API pointer."""
+    capsule_name = b"dlpack_exchange_api"
+    pythonapi = ctypes.pythonapi
+    pythonapi.PyCapsule_New.restype = ctypes.py_object
+    pythonapi.PyCapsule_New.argtypes = [
+        ctypes.c_void_p,
+        ctypes.c_char_p,
+        ctypes.c_void_p,
+    ]
+    capsule = pythonapi.PyCapsule_New(ctypes.c_void_p(ptr_as_int), 
capsule_name, None)
+    return capsule
+
+
 def load_torch_c_dlpack_extension() -> None:
     """Load the torch c dlpack extension based on torch version."""
-    if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+    if hasattr(torch.Tensor, "__dlpack_c_exchange_api__") or hasattr(
+        torch.Tensor, "__c_dlpack_exchange_api__"
+    ):
         return None
     version = Version(torch.__version__)
     if sys.platform.startswith("win32"):
@@ -52,6 +69,12 @@ def load_torch_c_dlpack_extension() -> None:
     # We will do eager upgrade to PyCapsule in the tvm-ffi side instead.
     dlpack_exchange_api_ptr_as_int = func()
     setattr(torch.Tensor, "__c_dlpack_exchange_api__", 
dlpack_exchange_api_ptr_as_int)
+    setattr(
+        torch.Tensor,
+        "__dlpack_c_exchange_api__",
+        _create_dlpack_exchange_api_capsule(dlpack_exchange_api_ptr_as_int),
+    )
+
     return lib
 
 
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index f3d0119..dde0550 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -70,12 +70,31 @@ def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> 
Any:
     return capsule
 
 
+def _check_and_update_dlpack_c_exchange_api(tensor_cls: object) -> bool:
+    """Check if the DLPack exchange API is available and update the 
__dlpack_c_exchange_api__ attribute."""
+    if hasattr(tensor_cls, "__dlpack_c_exchange_api__"):
+        return True
+    # legacy path compactibility handling
+    if hasattr(tensor_cls, "__c_dlpack_exchange_api__"):
+        c_dlpack_attribute = tensor_cls.__c_dlpack_exchange_api__
+        if isinstance(c_dlpack_attribute, int):
+            setattr(
+                tensor_cls,
+                "__dlpack_c_exchange_api__",
+                _create_dlpack_exchange_api_capsule(c_dlpack_attribute),
+            )
+        else:
+            setattr(tensor_cls, "__dlpack_c_exchange_api__", 
c_dlpack_attribute)
+        return True
+    return False
+
+
 def load_torch_c_dlpack_extension() -> Any:  # noqa: PLR0912, PLR0915
     try:
         import torch  # noqa: PLC0415
 
-        if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
-            # skip loading the extension if the __c_dlpack_exchange_api__
+        if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
+            # skip loading the extension if the __dlpack_c_exchange_api__
             # attribute is already set so we don't have to do it in
             # newer version of PyTorch
             return None
@@ -86,12 +105,7 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
     try:
         import torch_c_dlpack_ext  # type: ignore  # noqa: PLC0415, F401
 
-        if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
-            if isinstance(torch.Tensor.__c_dlpack_exchange_api__, int):
-                # Brings up to speed with the new PyCapsule behavior
-                torch.Tensor.__c_dlpack_exchange_api__ = 
_create_dlpack_exchange_api_capsule(
-                    torch.Tensor.__c_dlpack_exchange_api__
-                )
+        if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
             return None
     except ImportError:
         pass
@@ -152,8 +166,7 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
         # Create a PyCapsule from the pointer
         capsule = _create_dlpack_exchange_api_capsule(func())
         # Set the DLPackExchangeAPI pointer on the class
-        setattr(torch.Tensor, "__c_dlpack_exchange_api__", capsule)
-
+        setattr(torch.Tensor, "__dlpack_c_exchange_api__", capsule)
         return lib
     except ImportError:
         pass
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 06a78e6..2cee79c 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -175,7 +175,7 @@ def from_dlpack(
 ) -> Tensor: ...
 
 class DLTensorTestWrapper:
-    __c_dlpack_exchange_api__: int
+    __dlpack_c_exchange_api__: int
     def __init__(self, tensor: Tensor) -> None: ...
     def __tvm_ffi_env_stream__(self) -> int: ...
     def __dlpack_device__(self) -> tuple[int, int]: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 933bc86..9e7ff87 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -322,11 +322,11 @@ cdef extern from "tvm_ffi_python_helpers.h":
         int device_type
         int device_id
         TVMFFIStreamHandle stream
-        const DLPackExchangeAPI* c_dlpack_exchange_api
+        const DLPackExchangeAPI* dlpack_c_exchange_api
 
     ctypedef struct TVMFFIPyArgSetter:
         int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,  
PyObject* py_arg, TVMFFIAny* out) except -1
-        const DLPackExchangeAPI* c_dlpack_exchange_api
+        const DLPackExchangeAPI* dlpack_c_exchange_api
 
     ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, 
TVMFFIPyArgSetter* out) except -1
     # The main call function
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 29af699..01a3366 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -156,10 +156,10 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
     cdef DLManagedTensorVersioned* temp_managed_tensor
     cdef TVMFFIObjectHandle temp_chandle
     cdef void* current_stream = NULL
-    cdef const DLPackExchangeAPI* exchange_api = this.c_dlpack_exchange_api
+    cdef const DLPackExchangeAPI* exchange_api = this.dlpack_c_exchange_api
 
     # Set the exchange API in context
-    ctx.c_dlpack_exchange_api = exchange_api
+    ctx.dlpack_c_exchange_api = exchange_api
 
     # Convert PyObject to DLPack using the struct's function pointer
     if exchange_api.managed_tensor_from_py_object_no_sync(arg, 
&temp_managed_tensor) != 0:
@@ -239,7 +239,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
     out.type_index = kTVMFFITensor
     out.v_ptr = (<Tensor>arg).chandle
     temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
-    ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
+    ctx.dlpack_c_exchange_api = GetTorchFallbackExchangeAPI()
     # record the stream and device for torch context
     if is_cuda and ctx.device_type == -1:
         ctx.device_type = temp_dltensor.device.device_type
@@ -740,12 +740,12 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         # as a member variable
         out.func = TVMFFIPyArgSetterFFIObjectProtocol_
         return 0
-    if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
+    if os.environ.get("TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API", "0") != "1":
         # Check for DLPackExchangeAPI struct (new approach)
         # This is checked on the CLASS, not the instance
-        if hasattr(arg_class, "__c_dlpack_exchange_api__"):
+        if hasattr(arg_class, "__dlpack_c_exchange_api__"):
             out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
-            _get_dlpack_exchange_api(arg_class.__c_dlpack_exchange_api__, 
&(out.c_dlpack_exchange_api))
+            _get_dlpack_exchange_api(arg_class.__dlpack_c_exchange_api__, 
&(out.dlpack_c_exchange_api))
             return 0
     if hasattr(arg_class, "__cuda_stream__"):
         # cuda stream protocol
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 841687a..1f4973d 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -133,9 +133,9 @@ cdef inline int _from_dlpack_universal(
     cdef int favor_legacy_dlpack = True
     cdef const DLPackExchangeAPI* exchange_api = NULL
 
-    if hasattr(ext_tensor, "__c_dlpack_exchange_api__"):
+    if hasattr(ext_tensor, "__dlpack_c_exchange_api__"):
         try:
-            _get_dlpack_exchange_api(ext_tensor.__c_dlpack_exchange_api__, 
&exchange_api)
+            _get_dlpack_exchange_api(ext_tensor.__dlpack_c_exchange_api__, 
&exchange_api)
             return _from_dlpack_exchange_api(
                 ext_tensor,
                 exchange_api,
@@ -405,7 +405,7 @@ cdef int _dltensor_test_wrapper_current_work_stream(
 # Module-level static DLPackExchangeAPI for DLTensorTestWrapper
 cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api
 
-cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api() 
noexcept:
+cdef DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api() noexcept:
     """Get the static DLPackExchangeAPI instance for DLTensorTestWrapper."""
     global _dltensor_test_wrapper_static_api
 
@@ -430,15 +430,14 @@ cdef const DLPackExchangeAPI* 
_dltensor_test_wrapper_get_exchange_api() noexcept
     return &_dltensor_test_wrapper_static_api
 
 
-def _dltensor_test_wrapper_exchange_api_ptr():
-    """Return the pointer to the DLPackExchangeAPI struct as an integer."""
-    return <long long>_dltensor_test_wrapper_get_exchange_api()
-
-
 cdef class DLTensorTestWrapper:
     """Wrapper of a Tensor that exposes DLPack protocol, only for testing 
purpose.
     """
-    __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr()
+    __dlpack_c_exchange_api__ = pycapsule.PyCapsule_New(
+        _dltensor_test_wrapper_get_exchange_api(),
+        b"dlpack_exchange_api",
+        NULL
+    )
 
     cdef Tensor tensor
     cdef dict __dict__
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 88c27d7..93d6540 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -86,7 +86,7 @@ class TVMFFIPyCallContext {
   /*! \brief Detected stream, if any */
   void* stream = nullptr;
   /*! \brief the DLPack exchange API, if any */
-  const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+  const DLPackExchangeAPI* dlpack_c_exchange_api{nullptr};
   /*! \brief pointer to the call stack space */
   TVMFFIPyCallStack* call_stack = nullptr;
   /*! \brief the temporary arguments to be recycled */
@@ -174,7 +174,7 @@ struct TVMFFIPyArgSetter {
    * \brief Optional DLPackExchangeAPI struct pointer.
    * This is the new struct-based approach that bundles all DLPack exchange 
functions.
    */
-  const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+  const DLPackExchangeAPI* dlpack_c_exchange_api{nullptr};
   /*!
    * \brief Invoke the setter.
    * \param call_ctx The call context.
@@ -297,10 +297,10 @@ class TVMFFIPyCallManager {
         // setting failed, directly return
         if (c_api_ret_code[0] != 0) return 0;
       }
-      if (ctx.c_dlpack_exchange_api != nullptr &&
-          ctx.c_dlpack_exchange_api->managed_tensor_allocator != nullptr) {
+      if (ctx.dlpack_c_exchange_api != nullptr &&
+          ctx.dlpack_c_exchange_api->managed_tensor_allocator != nullptr) {
         c_api_ret_code[0] = TVMFFIEnvSetDLPackManagedTensorAllocator(
-            ctx.c_dlpack_exchange_api->managed_tensor_allocator, 0, 
&prev_tensor_allocator);
+            ctx.dlpack_c_exchange_api->managed_tensor_allocator, 0, 
&prev_tensor_allocator);
         if (c_api_ret_code[0] != 0) return 0;
       }
       // call the function
@@ -321,14 +321,14 @@ class TVMFFIPyCallManager {
           return -1;
         }
       }
-      if (ctx.c_dlpack_exchange_api != nullptr &&
-          prev_tensor_allocator != 
ctx.c_dlpack_exchange_api->managed_tensor_allocator) {
+      if (ctx.dlpack_c_exchange_api != nullptr &&
+          prev_tensor_allocator != 
ctx.dlpack_c_exchange_api->managed_tensor_allocator) {
         c_api_ret_code[0] =
             TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, 
nullptr);
         if (c_api_ret_code[0] != 0) return 0;
       }
-      if (optional_out_ctx_dlpack_api != nullptr && ctx.c_dlpack_exchange_api 
!= nullptr) {
-        *optional_out_ctx_dlpack_api = ctx.c_dlpack_exchange_api;
+      if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api 
!= nullptr) {
+        *optional_out_ctx_dlpack_api = ctx.dlpack_c_exchange_api;
       }
       return 0;
     } catch (const std::exception& ex) {
@@ -380,8 +380,8 @@ class TVMFFIPyCallManager {
           parent_ctx->stream = ctx.stream;
         }
         // DLPack exchange API
-        if (parent_ctx->c_dlpack_exchange_api == nullptr) {
-          parent_ctx->c_dlpack_exchange_api = ctx.c_dlpack_exchange_api;
+        if (parent_ctx->dlpack_c_exchange_api == nullptr) {
+          parent_ctx->dlpack_c_exchange_api = ctx.dlpack_c_exchange_api;
         }
       }
       return 0;
diff --git a/tests/python/test_dlpack_exchange_api.py 
b/tests/python/test_dlpack_exchange_api.py
index 7df8dcf..70e7586 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -27,7 +27,7 @@ try:
     import torch  # type: ignore[no-redef]
 
     # Import tvm_ffi to load the DLPack exchange API extension
-    # This sets torch.Tensor.__c_dlpack_exchange_api__
+    # This sets torch.Tensor.__dlpack_c_exchange_api__
     import tvm_ffi
     from torch.utils import cpp_extension  # type: ignore
     from tvm_ffi import libinfo
@@ -35,7 +35,7 @@ except ImportError:
     torch = None
 
 # Check if DLPack Exchange API is available
-_has_dlpack_api = torch is not None and hasattr(torch.Tensor, 
"__c_dlpack_exchange_api__")
+_has_dlpack_api = torch is not None and hasattr(torch.Tensor, 
"__dlpack_c_exchange_api__")
 
 
 @pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API 
not available")
@@ -45,24 +45,16 @@ def test_dlpack_exchange_api() -> None:
         pytest.xfail("DLPack Exchange API test is known to fail on Windows 
platform")
 
     assert torch is not None
-
-    assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__")
-    api_attr = torch.Tensor.__c_dlpack_exchange_api__
-
-    # Handle both PyCapsule and integer types
-    if isinstance(api_attr, int):
-        # Direct integer pointer
-        api_ptr = api_attr
-        assert api_ptr != 0, "API pointer should not be NULL"
-    else:
-        # PyCapsule - extract the pointer as integer
-        pythonapi = ctypes.pythonapi
-        # Set restype to c_size_t to get integer directly (avoids c_void_p 
quirks)
-        pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
-        pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, 
ctypes.c_char_p]
-        capsule_name = b"dlpack_exchange_api"
-        api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, capsule_name)
-        assert api_ptr != 0, "API pointer from PyCapsule should not be NULL"
+    assert hasattr(torch.Tensor, "__dlpack_c_exchange_api__")
+    api_attr = torch.Tensor.__dlpack_c_exchange_api__
+    # PyCapsule - extract the pointer as integer
+    pythonapi = ctypes.pythonapi
+    # Set restype to c_size_t to get integer directly (avoids c_void_p quirks)
+    pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
+    pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, 
ctypes.c_char_p]
+    capsule_name = b"dlpack_exchange_api"
+    api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, capsule_name)
+    assert api_ptr != 0, "API pointer from PyCapsule should not be NULL"
 
     tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
 
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index f39c485..77b9f8b 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -213,8 +213,8 @@ def test_load_inline_cuda() -> None:
 @pytest.mark.skipif(torch is None, reason="Requires torch")
 def test_load_inline_with_env_tensor_allocator() -> None:
     assert torch is not None
-    if not hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
-        pytest.skip("Torch does not support __c_dlpack_exchange_api__")
+    if not hasattr(torch.Tensor, "__dlpack_c_exchange_api__"):
+        pytest.skip("Torch does not support __dlpack_c_exchange_api__")
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""

Reply via email to