This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new c1df05f  [PYTHON] Further streamline number handling (#242)
c1df05f is described below

commit c1df05f3555d4e2a9e1a32822c0f41ccb8467251
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Nov 8 09:23:22 2025 -0500

    [PYTHON] Further streamline number handling (#242)
    
    This PR further streamlines number handling by
    introducing two custom protocols and move Integral and Real handling to
    more conservative path.
---
 python/tvm_ffi/_convert.py         |  4 +++
 python/tvm_ffi/cython/function.pxi | 71 +++++++++++++++++++++++++++++++++++---
 tests/python/test_function.py      | 33 ++++++++++++++++++
 3 files changed, 103 insertions(+), 5 deletions(-)

diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index d0a44e3..f090d28 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -129,6 +129,10 @@ def convert(value: Any) -> Any:  # noqa: PLR0911,PLR0912
         return value
     elif hasattr(value, "__dlpack_device__"):
         return value
+    elif hasattr(value, "__tvm_ffi_int__"):
+        return value
+    elif hasattr(value, "__tvm_ffi_float__"):
+        return value
     else:
         # in this case, it is an opaque python object
         return core._convert_to_opaque_object(value)
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 3ebe862..2db7c1a 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -17,7 +17,7 @@
 import ctypes
 import threading
 import os
-from numbers import Real, Integral
+from numbers import Integral, Real
 from typing import Any, Callable
 
 
@@ -276,7 +276,33 @@ cdef int TVMFFIPyArgSetterDLPack_(
     return 0
 
 
-cdef int TVMFFIPyArgSetterFFIObjectCompatible_(
+cdef int TVMFFIPyArgSetterIntegral_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for Integral"""
+    cdef object arg = <object>py_arg
+    out.type_index = kTVMFFIInt
+    # keep it in cython so it will also check for fallback cases
+    # where the arg is not exactly the int class
+    out.v_int64 = <long long>arg
+    return 0
+
+
+cdef int TVMFFIPyArgSetterReal_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for Real"""
+    cdef object arg = <object>py_arg
+    out.type_index = kTVMFFIFloat
+    # keep it in cython so it will also check for fallback cases
+    # where the arg is not exactly the float class
+    out.v_float64 = <double>arg
+    return 0
+
+
+cdef int TVMFFIPyArgSetterFFIObjectProtocol_(
     TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
     PyObject* py_arg, TVMFFIAny* out
 ) except -1:
@@ -608,6 +634,7 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
     out.v_dtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[py_obj]
     return 0
 
+
 cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
     TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
     PyObject* py_arg, TVMFFIAny* out
@@ -621,6 +648,29 @@ cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
     out.v_dtype.lanes = <long long>dltype_data_type[2]
     return 0
 
+
+cdef int TVMFFIPyArgSetterIntProtocol_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for class with __tvm_ffi_int__() method"""
+    cdef object arg = <object>py_arg
+    out.type_index = kTVMFFIInt
+    out.v_int64 = <long long>(arg.__tvm_ffi_int__())
+    return 0
+
+
+cdef int TVMFFIPyArgSetterFloatProtocol_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for class with __tvm_ffi_float__() method"""
+    cdef object arg = <object>py_arg
+    out.type_index = kTVMFFIFloat
+    out.v_float64 = <double>(arg.__tvm_ffi_float__())
+    return 0
+
+
 cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
 cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
 
@@ -668,7 +718,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         # can directly map to tvm ffi object
         # usually used for solutions that takes subclass of ffi.Object
         # as a member variable
-        out.func = TVMFFIPyArgSetterFFIObjectCompatible_
+        out.func = TVMFFIPyArgSetterFFIObjectProtocol_
         return 0
     if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
         # Check for DLPackExchangeAPI struct (new approach)
@@ -698,10 +748,15 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         out.func = TVMFFIPyArgSetterBool_
         return 0
     if isinstance(arg, Integral):
-        out.func = TVMFFIPyArgSetterInt_
+        # must occur before Real check
+        # cannot simply use TVMFFIPyArgSetterInt
+        # because Integral may not be exactly the int class
+        out.func = TVMFFIPyArgSetterIntegral_
         return 0
     if isinstance(arg, Real):
-        out.func = TVMFFIPyArgSetterFloat_
+        # cannot simply use TVMFFIPyArgSetterFloat
+        # because Real may not be exactly the float class
+        out.func = TVMFFIPyArgSetterReal_
         return 0
     # dtype is a subclass of str, so this check must occur before str
     if isinstance(arg, _CLASS_DTYPE):
@@ -760,6 +815,12 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         # then it is a DLPack device protocol
         out.func = TVMFFIPyArgSetterDLPackDeviceProtocol_
         return 0
+    if hasattr(arg_class, "__tvm_ffi_int__"):
+        out.func = TVMFFIPyArgSetterIntProtocol_
+        return 0
+    if hasattr(arg_class, "__tvm_ffi_float__"):
+        out.func = TVMFFIPyArgSetterFloatProtocol_
+        return 0
     if isinstance(arg, Exception):
         out.func = TVMFFIPyArgSetterException_
         return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index a1e68b0..a24395a 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -350,3 +350,36 @@ def test_function_with_dlpack_device_protocol() -> None:
     x = DLPackDeviceProtocol(device)
     y = fecho(x)
     assert y == device
+
+
+def test_integral_float_variants_passing() -> None:
+    fecho = tvm_ffi.get_global_func("testing.echo")
+    y = fecho(np.int32(1))
+    assert isinstance(y, int)
+    assert y == 1
+
+    y = fecho(np.float64(2.0))
+    assert isinstance(y, float)
+    assert y == 2.0
+
+    class IntProtocol:
+        def __init__(self, value: int) -> None:
+            self.value = value
+
+        def __tvm_ffi_int__(self) -> int:
+            return self.value
+
+    y = fecho(IntProtocol(10))
+    assert isinstance(y, int)
+    assert y == 10
+
+    class FloatProtocol:
+        def __init__(self, value: float) -> None:
+            self.value = value
+
+        def __tvm_ffi_float__(self) -> float:
+            return self.value
+
+    y = fecho(FloatProtocol(10))
+    assert isinstance(y, float)
+    assert y == 10

Reply via email to