This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f8bf9f  Introduce Device Protocol (#179)
0f8bf9f is described below

commit 0f8bf9fc582fff89e838cadf2aacbfb2a5724ddf
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Oct 20 21:07:30 2025 -0700

    Introduce Device Protocol (#179)
    
    This PR introduces `__dlpack_device__` protocol for Device class that
    can be used to bring in device classes into ffi calls.
---
 python/tvm_ffi/cython/function.pxi | 18 ++++++++++++++++++
 tests/python/test_function.py      | 16 ++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 1138476..bf84091 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -324,6 +324,19 @@ cdef int TVMFFIPyArgSetterDevice_(
     out.v_device = (<Device>arg).cdevice
     return 0
 
+cdef int TVMFFIPyArgSetterDLPackDeviceProtocol_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for dlpack device protocol"""
+    cdef object arg = <object>py_arg
+    cdef tuple dlpack_device = arg.__dlpack_device__()
+    out.type_index = kTVMFFIDevice
+    out.v_device = TVMFFIDLDeviceFromIntPair(
+        <int32_t>dlpack_device[0],
+        <int32_t>dlpack_device[1]
+    )
+    return 0
 
 cdef int TVMFFIPyArgSetterStr_(
     TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
@@ -716,6 +729,11 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         # prefer dlpack as it covers all DLDataType struct
         out.func = TVMFFIPyArgSetterDLPackDataTypeProtocol_
         return 0
+    if hasattr(arg_class, "__dlpack_device__") and not hasattr(arg_class, 
"__dlpack__"):
+        # if a class have __dlpack_device__ but not __dlpack__
+        # then it is a DLPack device protocol
+        out.func = TVMFFIPyArgSetterDLPackDeviceProtocol_
+        return 0
     if isinstance(arg, Exception):
         out.func = TVMFFIPyArgSetterException_
         return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 6123a8e..8d6dd34 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -328,3 +328,19 @@ def test_function_with_dlpack_data_type_protocol() -> None:
     x = DLPackDataTypeProtocol((dtype.type_code, dtype.bits, dtype.lanes))
     y = fecho(x)
     assert y == dtype
+
+
+def test_function_with_dlpack_device_protocol() -> None:
+    device = tvm_ffi.device("cuda:1")
+
+    class DLPackDeviceProtocol:
+        def __init__(self, device: tvm_ffi.Device) -> None:
+            self.device = device
+
+        def __dlpack_device__(self) -> tuple[int, int]:
+            return (self.device.dlpack_device_type(), self.device.index)
+
+    fecho = tvm_ffi.get_global_func("testing.echo")
+    x = DLPackDeviceProtocol(device)
+    y = fecho(x)
+    assert y == device

Reply via email to