This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f3f872  [DLPack] Leverage exchange api when possible (#260)
7f3f872 is described below

commit 7f3f8726156ab6e33f781562afafd9c6f219551f
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Nov 12 14:53:20 2025 -0500

    [DLPack] Leverage exchange api when possible (#260)
    
    This PR updates the from_dlpack to leverage the exchange api when
    possible.
---
 python/tvm_ffi/cython/function.pxi | 13 +++++----
 python/tvm_ffi/cython/tensor.pxi   | 56 +++++++++++++++++++++++++-------------
 2 files changed, 45 insertions(+), 24 deletions(-)

diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index c1fb6a2..800427b 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -156,20 +156,20 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
     cdef DLManagedTensorVersioned* temp_managed_tensor
     cdef TVMFFIObjectHandle temp_chandle
     cdef void* current_stream = NULL
-    cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api
+    cdef const DLPackExchangeAPI* exchange_api = this.c_dlpack_exchange_api
 
     # Set the exchange API in context
-    ctx.c_dlpack_exchange_api = api
+    ctx.c_dlpack_exchange_api = exchange_api
 
     # Convert PyObject to DLPack using the struct's function pointer
-    if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) != 
0:
+    if exchange_api.managed_tensor_from_py_object_no_sync(arg, 
&temp_managed_tensor) != 0:
         return -1
 
     # Query current stream from producer if device is not CPU
     if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
-        if ctx.device_type == -1 and api.current_work_stream != NULL:
+        if ctx.device_type == -1 and exchange_api.current_work_stream != NULL:
             # First time seeing a device, query the stream
-            if api.current_work_stream(
+            if exchange_api.current_work_stream(
                 temp_managed_tensor.dl_tensor.device.device_type,
                 temp_managed_tensor.dl_tensor.device.device_id,
                 &current_stream
@@ -180,6 +180,9 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
 
     # Convert to TVM Tensor
     if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, 
&temp_chandle) != 0:
+        # recycle the managed tensor to avoid leak
+        if temp_managed_tensor.deleter != NULL:
+            temp_managed_tensor.deleter(temp_managed_tensor)
         raise BufferError("Failed to convert DLManagedTensorVersioned to 
ffi.Tensor")
 
     out.type_index = kTVMFFITensor
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 6a9cf9e..614e487 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -45,20 +45,6 @@ cdef void _c_dlpack_versioned_deleter(object pycaps):
         dltensor.deleter(dltensor)
 
 
-cdef inline object _from_dlpack_intptr(
-    void* dlpack
-):
-    cdef TVMFFIObjectHandle chandle
-    cdef DLManagedTensor* ptr = <DLManagedTensor*>dlpack
-    cdef int c_api_ret_code
-    cdef int c_req_alignment = 0
-    cdef int c_req_contiguous = 0
-    c_api_ret_code = TVMFFITensorFromDLPack(
-        ptr, c_req_alignment, c_req_contiguous, &chandle)
-    CHECK_CALL(c_api_ret_code)
-    return make_tensor_from_chandle(chandle)
-
-
 cdef inline int _from_dlpack(
     object dltensor, int require_alignment,
     int require_contiguous, TVMFFIObjectHandle* out
@@ -100,6 +86,26 @@ cdef inline int _from_dlpack_versioned(
     raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be 
consumed once")
 
 
+cdef inline int _from_dlpack_exchange_api(
+    object ext_tensor, DLPackExchangeAPI* exchange_api, int require_alignment,
+    int require_contiguous, TVMFFIObjectHandle* out
+) except -1:
+    cdef DLManagedTensorVersioned* temp_managed_tensor
+    cdef PyObject* ext_tensor_pyobj = <PyObject*>ext_tensor
+    if exchange_api.managed_tensor_from_py_object_no_sync(ext_tensor_pyobj, 
&temp_managed_tensor) != 0:
+        return -1
+
+    # Convert to TVM Tensor
+    if TVMFFITensorFromDLPackVersioned(
+        temp_managed_tensor, require_alignment, require_contiguous, out
+    ) != 0:
+        # recycle the managed tensor to avoid leak
+        if temp_managed_tensor.deleter != NULL:
+            temp_managed_tensor.deleter(temp_managed_tensor)
+        raise BufferError("Failed to convert DLManagedTensorVersioned to 
ffi.Tensor")
+
+    return 0
+
 cdef inline int _from_dlpack_universal(
     object ext_tensor, int require_alignment,
     int require_contiguous, TVMFFIObjectHandle* out
@@ -108,9 +114,21 @@ cdef inline int _from_dlpack_universal(
     # move to false as most frameworks get upgraded.
     cdef int favor_legacy_dlpack = True
 
+    if hasattr(ext_tensor, "__c_dlpack_exchange_api__"):
+        try:
+            return _from_dlpack_exchange_api(
+                ext_tensor,
+                <DLPackExchangeAPI*><long 
long>(ext_tensor.__c_dlpack_exchange_api__),
+                require_alignment,
+                require_contiguous,
+                out
+            )
+        except BufferError:
+            pass
+
     if hasattr(ext_tensor, "__dlpack__"):
         if favor_legacy_dlpack:
-            _from_dlpack(
+            return _from_dlpack(
                 ext_tensor.__dlpack__(),
                 require_alignment,
                 require_contiguous,
@@ -118,14 +136,14 @@ cdef inline int _from_dlpack_universal(
             )
         else:
             try:
-                _from_dlpack_versioned(
+                return _from_dlpack_versioned(
                     ext_tensor.__dlpack__(max_version=__dlpack_version__),
                     require_alignment,
                     require_contiguous,
                     out
                 )
             except TypeError:
-                _from_dlpack(
+                return _from_dlpack(
                     ext_tensor.__dlpack__(),
                     require_alignment,
                     require_contiguous,
@@ -133,14 +151,14 @@ cdef inline int _from_dlpack_universal(
                 )
     else:
         if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned):
-            _from_dlpack_versioned(
+            return _from_dlpack_versioned(
                 ext_tensor,
                 require_alignment,
                 require_contiguous,
                 out
             )
         elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor):
-            _from_dlpack(
+            return _from_dlpack(
                 ext_tensor,
                 require_alignment,
                 require_contiguous,

Reply via email to