This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c1df05f [PYTHON] Further streamline number handling (#242)
c1df05f is described below
commit c1df05f3555d4e2a9e1a32822c0f41ccb8467251
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Nov 8 09:23:22 2025 -0500
[PYTHON] Further streamline number handling (#242)
This PR further streamlines number handling by
introducing two custom protocols and move Integral and Real handling to
more conservative path.
---
python/tvm_ffi/_convert.py | 4 +++
python/tvm_ffi/cython/function.pxi | 71 +++++++++++++++++++++++++++++++++++---
tests/python/test_function.py | 33 ++++++++++++++++++
3 files changed, 103 insertions(+), 5 deletions(-)
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index d0a44e3..f090d28 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -129,6 +129,10 @@ def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
return value
elif hasattr(value, "__dlpack_device__"):
return value
+ elif hasattr(value, "__tvm_ffi_int__"):
+ return value
+ elif hasattr(value, "__tvm_ffi_float__"):
+ return value
else:
# in this case, it is an opaque python object
return core._convert_to_opaque_object(value)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 3ebe862..2db7c1a 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -17,7 +17,7 @@
import ctypes
import threading
import os
-from numbers import Real, Integral
+from numbers import Integral, Real
from typing import Any, Callable
@@ -276,7 +276,33 @@ cdef int TVMFFIPyArgSetterDLPack_(
return 0
-cdef int TVMFFIPyArgSetterFFIObjectCompatible_(
+cdef int TVMFFIPyArgSetterIntegral_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for Integral"""
+ cdef object arg = <object>py_arg
+ out.type_index = kTVMFFIInt
+ # keep it in cython so it will also check for fallback cases
+ # where the arg is not exactly the int class
+ out.v_int64 = <long long>arg
+ return 0
+
+
+cdef int TVMFFIPyArgSetterReal_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for Real"""
+ cdef object arg = <object>py_arg
+ out.type_index = kTVMFFIFloat
+ # keep it in cython so it will also check for fallback cases
+ # where the arg is not exactly the float class
+ out.v_float64 = <double>arg
+ return 0
+
+
+cdef int TVMFFIPyArgSetterFFIObjectProtocol_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
) except -1:
@@ -608,6 +634,7 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
out.v_dtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[py_obj]
return 0
+
cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
@@ -621,6 +648,29 @@ cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
out.v_dtype.lanes = <long long>dltype_data_type[2]
return 0
+
+cdef int TVMFFIPyArgSetterIntProtocol_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for class with __tvm_ffi_int__() method"""
+ cdef object arg = <object>py_arg
+ out.type_index = kTVMFFIInt
+ out.v_int64 = <long long>(arg.__tvm_ffi_int__())
+ return 0
+
+
+cdef int TVMFFIPyArgSetterFloatProtocol_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for class with __tvm_ffi_float__() method"""
+ cdef object arg = <object>py_arg
+ out.type_index = kTVMFFIFloat
+ out.v_float64 = <double>(arg.__tvm_ffi_float__())
+ return 0
+
+
cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
@@ -668,7 +718,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
# can directly map to tvm ffi object
# usually used for solutions that takes subclass of ffi.Object
# as a member variable
- out.func = TVMFFIPyArgSetterFFIObjectCompatible_
+ out.func = TVMFFIPyArgSetterFFIObjectProtocol_
return 0
if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
# Check for DLPackExchangeAPI struct (new approach)
@@ -698,10 +748,15 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
out.func = TVMFFIPyArgSetterBool_
return 0
if isinstance(arg, Integral):
- out.func = TVMFFIPyArgSetterInt_
+ # must occur before Real check
+ # cannot simply use TVMFFIPyArgSetterInt
+ # because Integral may not be exactly the int class
+ out.func = TVMFFIPyArgSetterIntegral_
return 0
if isinstance(arg, Real):
- out.func = TVMFFIPyArgSetterFloat_
+ # cannot simply use TVMFFIPyArgSetterFloat
+ # because Real may not be exactly the float class
+ out.func = TVMFFIPyArgSetterReal_
return 0
# dtype is a subclass of str, so this check must occur before str
if isinstance(arg, _CLASS_DTYPE):
@@ -760,6 +815,12 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
# then it is a DLPack device protocol
out.func = TVMFFIPyArgSetterDLPackDeviceProtocol_
return 0
+ if hasattr(arg_class, "__tvm_ffi_int__"):
+ out.func = TVMFFIPyArgSetterIntProtocol_
+ return 0
+ if hasattr(arg_class, "__tvm_ffi_float__"):
+ out.func = TVMFFIPyArgSetterFloatProtocol_
+ return 0
if isinstance(arg, Exception):
out.func = TVMFFIPyArgSetterException_
return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index a1e68b0..a24395a 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -350,3 +350,36 @@ def test_function_with_dlpack_device_protocol() -> None:
x = DLPackDeviceProtocol(device)
y = fecho(x)
assert y == device
+
+
+def test_integral_float_variants_passing() -> None:
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ y = fecho(np.int32(1))
+ assert isinstance(y, int)
+ assert y == 1
+
+ y = fecho(np.float64(2.0))
+ assert isinstance(y, float)
+ assert y == 2.0
+
+ class IntProtocol:
+ def __init__(self, value: int) -> None:
+ self.value = value
+
+ def __tvm_ffi_int__(self) -> int:
+ return self.value
+
+ y = fecho(IntProtocol(10))
+ assert isinstance(y, int)
+ assert y == 10
+
+ class FloatProtocol:
+ def __init__(self, value: float) -> None:
+ self.value = value
+
+ def __tvm_ffi_float__(self) -> float:
+ return self.value
+
+ y = fecho(FloatProtocol(10))
+ assert isinstance(y, float)
+ assert y == 10