This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new d153a2c9 feat(python): auto-register __ffi_* dunder methods in 
@py_class as TypeMethod (#508)
d153a2c9 is described below

commit d153a2c9407ee04ec3dc1ccbdf6a0664854337b7
Author: Junru Shao <[email protected]>
AuthorDate: Sun Mar 22 14:24:35 2026 -0700

    feat(python): auto-register __ffi_* dunder methods in @py_class as 
TypeMethod (#508)
    
    ## Summary
    
    - Extend `@py_class` to detect and register recognized FFI dunder
    methods (`__ffi_repr__`, `__ffi_eq__`, `__s_equal__`, etc.) defined on
    Python dataclasses, making them callable from C++ and other FFI
    languages
    - `_FFI_RECOGNIZED_METHODS` allowlist in `py_class.py` gates which
    dunders are collected; `_collect_py_methods()` scans class dict and
    `TypeInfo._register_py_methods()` registers via the C API
    (`TVMFFITypeRegisterMethod` + `TVMFFITypeRegisterAttr`)
    - No C++ changes — `dataclass.cc` is unmodified from upstream
    
    ## Changes since v1
    
    - Rebased onto upstream/main (`d8bd1890`) which adds structural
    equality/hashing support (#507)
    - Resolved merge conflicts: preserved `structure_kind` /
    `STRUCTURE_KIND_MAP` from upstream
    - Applied Gemini Code Assist review feedback:
    - **HIGH**: Fixed `func_any` memory leak in `_register_py_methods` by
    adding `TVMFFIObjectDecRef` in a `try/finally` block
    - **MEDIUM**: Changed `_collect_py_methods` to iterate
    `cls.__dict__.items()` instead of `list(cls.__dict__)` + second lookup
    
    ## Test plan
    
    - [x] 12 new dedicated tests covering allowlist filtering, registration,
    round-trip dispatch (repr, eq, compare, hash)
    - [x] All 693 tests pass (including structural equality tests that now
    work with the `structure_kind` base from #507)
    - [ ] CI: lint, C++ tests, Python tests, Rust tests
    
    🤖 Generated with [Claude Code](https://claude.com/claude-code)
---
 python/tvm_ffi/cython/base.pxi          |  1 +
 python/tvm_ffi/cython/type_info.pxi     | 86 +++++++++++++++++++++++++++++++--
 python/tvm_ffi/dataclasses/py_class.py  | 49 +++++++++++++++++++
 tests/python/test_dataclass_compare.py  | 78 ++++++++++++++++++++++++++++++
 tests/python/test_dataclass_hash.py     | 38 +++++++++++++++
 tests/python/test_dataclass_py_class.py | 64 ++++++++++++++++++++++++
 tests/python/test_dataclass_repr.py     | 41 ++++++++++++++++
 7 files changed, 354 insertions(+), 3 deletions(-)

diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 007aa0a9..b851e1e2 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -323,6 +323,7 @@ cdef extern from "tvm/ffi/c_api.h":
         int32_t parent_type_index
     ) nogil
     int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* 
info) nogil
+    int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* 
info) nogil
     int TVMFFITypeRegisterMetadata(int32_t type_index, const 
TVMFFITypeMetadata* metadata) nogil
     int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* 
attr_name,
                                const TVMFFIAny* attr_value) nogil
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 98c5f8ca..128efa70 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -709,7 +709,7 @@ class TypeInfo:
 
         Delegates to the module-level _register_fields function,
         stores the resulting list[TypeField] on self.fields,
-        then reads back methods registered by C++ via _register_methods.
+        then reads back methods registered by C++ via _read_back_methods.
 
         Can only be called once (fields must be None beforehand).
 
@@ -725,9 +725,27 @@ class TypeInfo:
             f"_register_fields already called for {self.type_key!r}"
         )
         self.fields = _register_fields(self, fields, structure_kind)
-        self._register_methods()
+        self._read_back_methods()
 
-    def _register_methods(self):
+    def _register_py_methods(self, py_methods=None):
+        """Register user-defined dunder methods and re-read the method table.
+
+        When *py_methods* is non-empty, each entry is registered as both
+        TypeMethod and TypeAttr via the C API.  Regardless, the full
+        method list is always re-read from the C type table so that
+        system-generated methods (``__ffi_init__``, ``__ffi_shallow_copy__``)
+        are picked up.
+
+        Parameters
+        ----------
+        py_methods : list[tuple[str, callable, bool]] | None
+            Each entry is ``(name, func, is_static)``.
+        """
+        if py_methods:
+            _register_py_methods(self.type_index, py_methods)
+        self._read_back_methods()
+
+    def _read_back_methods(self):
         """Read methods from the C type table into self.methods.
 
         Called after C++ registers __ffi_init__, __ffi_shallow_copy__, etc.
@@ -1029,6 +1047,68 @@ cdef _register_type_metadata(int32_t type_index, int32_t 
total_size, int structu
     CHECK_CALL(TVMFFITypeRegisterMetadata(type_index, &metadata))
 
 
+cdef _register_py_methods(int32_t type_index, list py_methods):
+    """Register user-defined dunder methods as both TypeMethod and TypeAttr.
+
+    For each method in *py_methods*:
+    1. Convert the Python callable to a ``TVMFFIAny`` (``ffi::Function``).
+    2. Call ``TVMFFITypeRegisterMethod`` so the method appears in the
+       type's reflection metadata (``TypeInfo.methods``).
+    3. Ensure the type-attribute column exists (sentinel call with
+       ``type_index = kTVMFFINone``), then call ``TVMFFITypeRegisterAttr``
+       so the C++ runtime dispatch can find the hook.
+
+    Parameters
+    ----------
+    type_index : int
+        The runtime type index of the type.
+    py_methods : list[tuple[str, callable, bool]]
+        Each entry is ``(name, func, is_static)``.
+    """
+    cdef TVMFFIMethodInfo method_info
+    cdef TVMFFIAny func_any
+    cdef TVMFFIAny sentinel_any
+    cdef int c_api_ret_code
+    cdef ByteArrayArg name_arg
+
+    sentinel_any.type_index = kTVMFFINone
+    sentinel_any.v_int64 = 0
+
+    for name, func, is_static in py_methods:
+        func_any.type_index = kTVMFFINone
+        func_any.v_int64 = 0
+        try:
+            name_bytes = c_str(name)
+            name_arg = ByteArrayArg(name_bytes)
+
+            # Convert Python callable -> TVMFFIAny (creates a FunctionObj)
+            TVMFFIPyPyObjectToFFIAny(
+                TVMFFIPyArgSetterFactory_,
+                <PyObject*>func,
+                &func_any,
+                &c_api_ret_code,
+            )
+            CHECK_CALL(c_api_ret_code)
+
+            # 1. Register as TypeMethod
+            method_info.name = name_arg.cdata
+            method_info.doc.data = NULL
+            method_info.doc.size = 0
+            method_info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod if 
is_static else 0
+            method_info.method = func_any
+            method_info.metadata.data = NULL
+            method_info.metadata.size = 0
+            CHECK_CALL(TVMFFITypeRegisterMethod(type_index, &method_info))
+
+            # 2. Ensure type-attr column exists (sentinel: kTVMFFINone)
+            CHECK_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_arg.cdata, 
&sentinel_any))
+            # 3. Register as TypeAttr
+            CHECK_CALL(TVMFFITypeRegisterAttr(type_index, &name_arg.cdata, 
&func_any))
+        finally:
+            if func_any.type_index >= kTVMFFIStaticObjectBegin and 
func_any.v_obj != NULL:
+                TVMFFIObjectDecRef(<TVMFFIObjectHandle>func_any.v_obj)
+
+
 def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., 
Any]:
     def wrapper(self: Any, *args: Any) -> Any:
         return method_func(self, *args)
diff --git a/python/tvm_ffi/dataclasses/py_class.py 
b/python/tvm_ffi/dataclasses/py_class.py
index 103f2c26..07b6aa08 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -202,6 +202,30 @@ def _collect_own_fields(
     return fields
 
 
+def _collect_py_methods(cls: type) -> list[tuple[str, Any, bool]] | None:
+    """Extract recognized FFI dunder methods from the class body.
+
+    Only names listed in :data:`_FFI_RECOGNIZED_METHODS` are collected.
+
+    Returns a list of ``(name, func, is_static)`` tuples, or ``None``
+    if no eligible methods were found.
+    """
+    methods: list[tuple[str, Any, bool]] = []
+    for name, value in cls.__dict__.items():
+        if name not in _FFI_RECOGNIZED_METHODS:
+            continue
+        if isinstance(value, staticmethod):
+            func = value.__func__
+            is_static = True
+        elif callable(value):
+            func = value
+            is_static = False
+        else:
+            continue
+        methods.append((name, func, is_static))
+    return methods if methods else None
+
+
 def _phase2_register_fields(
     cls: type,
     type_info: Any,
@@ -224,10 +248,13 @@ def _phase2_register_fields(
         return False
 
     own_fields = _collect_own_fields(cls, hints, params["kw_only"])
+    py_methods = _collect_py_methods(cls)
 
     # Register fields and type-level structural eq/hash kind with the C layer.
     structure_kind = _STRUCTURE_KIND_MAP.get(params.get("structure"))
     type_info._register_fields(own_fields, structure_kind)
+    # Register user-defined dunder methods and read back system-generated ones.
+    type_info._register_py_methods(py_methods)
     _add_class_attrs(cls, type_info)
 
     # Remove deferred __init__ and restore user-defined __init__ if saved
@@ -350,6 +377,28 @@ _STRUCTURE_KIND_MAP: dict[str | None, int] = {
     "singleton": 5,  # kTVMFFISEqHashKindUniqueInstance
 }
 
+#: Allowlist of dunder method names that ``@py_class`` will auto-register
+#: as both TypeMethod (for reflection) and TypeAttr (for C++ dispatch).
+#:
+#: Only names in this set are collected from the class body.
+#: System-managed names (``__ffi_init__``, ``__ffi_shallow_copy__``, etc.)
+#: are intentionally absent because the C++ runtime generates them.
+_FFI_RECOGNIZED_METHODS: frozenset[str] = frozenset(
+    {
+        # Recursive operations (RecursiveHash, RecursiveEq, RecursiveCompare, 
ReprPrint)
+        "__ffi_repr__",
+        "__ffi_hash__",
+        "__ffi_eq__",
+        "__ffi_compare__",
+        # Structural equality/hashing (StructuralEqual, StructuralHash)
+        "__s_equal__",
+        "__s_hash__",
+        # Serialization (ToJSONGraph, FromJSONGraph)
+        "__data_to_json__",
+        "__data_from_json__",
+    }
+)
+
 
 @dataclass_transform(
     eq_default=False,
diff --git a/tests/python/test_dataclass_compare.py 
b/tests/python/test_dataclass_compare.py
index 83234abe..36f3fea6 100644
--- a/tests/python/test_dataclass_compare.py
+++ b/tests/python/test_dataclass_compare.py
@@ -1270,3 +1270,81 @@ def test_custom_compare_ordering_consistency() -> None:
     assert not RecursiveGt(a, b)
     assert RecursiveLe(a, b)
     assert RecursiveGe(a, b)
+
+
+# ---------------------------------------------------------------------------
+# Custom __ffi_eq__ / __ffi_compare__ hooks via @py_class
+# ---------------------------------------------------------------------------
+import itertools as _itertools_cmp
+from typing import Any as _Any_cmp
+from typing import Callable as _Callable_cmp
+
+from tvm_ffi._ffi_api import RecursiveHash as _RecursiveHash_cmp
+from tvm_ffi.core import Object as _Object_cmp
+from tvm_ffi.dataclasses import py_class as _py_class_cmp
+
+_counter_cmp = _itertools_cmp.count()
+
+
+def _unique_key_cmp(base: str) -> str:
+    return f"testing.cmp_pc.{base}_{next(_counter_cmp)}"
+
+
+@_py_class_cmp(_unique_key_cmp("PyEqHash"))
+class _PyEqHash(_Object_cmp):
+    key: int
+    label: str
+
+    def __ffi_hash__(self, fn_hash: _Callable_cmp[..., _Any_cmp]) -> int:
+        return fn_hash(self.key)
+
+    def __ffi_eq__(self, other: _PyEqHash, fn_eq: _Callable_cmp[..., 
_Any_cmp]) -> bool:
+        return fn_eq(self.key, other.key)
+
+
+@_py_class_cmp(_unique_key_cmp("PyCmp"))
+class _PyCmp(_Object_cmp):
+    key: int
+    label: str
+
+    def __ffi_hash__(self, fn_hash: _Callable_cmp[..., _Any_cmp]) -> int:
+        return fn_hash(self.key)
+
+    def __ffi_eq__(self, other: _PyCmp, fn_eq: _Callable_cmp[..., _Any_cmp]) 
-> bool:
+        return fn_eq(self.key, other.key)
+
+    def __ffi_compare__(self, other: _PyCmp, fn_cmp: _Callable_cmp[..., 
_Any_cmp]) -> int:
+        return fn_cmp(self.key, other.key)
+
+
+def test_py_class_custom_eq_ignores_label() -> None:
+    assert RecursiveEq(_PyEqHash(42, "alpha"), _PyEqHash(42, "beta"))
+
+
+def test_py_class_custom_eq_different_key() -> None:
+    assert not RecursiveEq(_PyEqHash(1, "same"), _PyEqHash(2, "same"))
+
+
+def test_py_class_custom_eq_hash_consistency() -> None:
+    a, b = _PyEqHash(42, "alpha"), _PyEqHash(42, "beta")
+    assert RecursiveEq(a, b)
+    assert _RecursiveHash_cmp(a) == _RecursiveHash_cmp(b)
+
+
+def test_py_class_custom_compare_ordering() -> None:
+    a = _PyCmp(1, "zzz")
+    b = _PyCmp(2, "aaa")
+    assert RecursiveLt(a, b)
+    assert RecursiveLe(a, b)
+    assert not RecursiveGt(a, b)
+    assert not RecursiveGe(a, b)
+
+
+def test_py_class_custom_compare_equal_keys() -> None:
+    a = _PyCmp(42, "alpha")
+    b = _PyCmp(42, "beta")
+    assert RecursiveEq(a, b)
+    assert RecursiveLe(a, b)
+    assert RecursiveGe(a, b)
+    assert not RecursiveLt(a, b)
+    assert not RecursiveGt(a, b)
diff --git a/tests/python/test_dataclass_hash.py 
b/tests/python/test_dataclass_hash.py
index ec58c2b7..a79c519e 100644
--- a/tests/python/test_dataclass_hash.py
+++ b/tests/python/test_dataclass_hash.py
@@ -968,3 +968,41 @@ def test_eq_without_hash_inside_container_raises() -> None:
     arr = tvm_ffi.Array([obj])
     with pytest.raises(ValueError, match="__ffi_eq__ or __ffi_compare__ but 
not __ffi_hash__"):
         RecursiveHash(arr)
+
+
+# ---------------------------------------------------------------------------
+# Custom __ffi_hash__ hook via @py_class
+# ---------------------------------------------------------------------------
+import itertools as _itertools_hash
+from typing import Any as _Any_hash
+from typing import Callable as _Callable_hash
+
+from tvm_ffi.core import Object as _Object_hash
+from tvm_ffi.dataclasses import py_class as _py_class_hash
+
+_counter_hash = _itertools_hash.count()
+
+
+def _unique_key_hash(base: str) -> str:
+    return f"testing.hash_pc.{base}_{next(_counter_hash)}"
+
+
+@_py_class_hash(_unique_key_hash("PyCustomHash"))
+class _PyCustomHash(_Object_hash):
+    key: int
+    label: str
+
+    def __ffi_hash__(self, fn_hash: _Callable_hash[..., _Any_hash]) -> int:
+        return fn_hash(self.key)
+
+
+def test_py_class_custom_hash_ignores_label() -> None:
+    a = _PyCustomHash(42, "alpha")
+    b = _PyCustomHash(42, "beta")
+    assert RecursiveHash(a) == RecursiveHash(b)
+
+
+def test_py_class_custom_hash_different_key() -> None:
+    a = _PyCustomHash(1, "same")
+    b = _PyCustomHash(2, "same")
+    assert RecursiveHash(a) != RecursiveHash(b)
diff --git a/tests/python/test_dataclass_py_class.py 
b/tests/python/test_dataclass_py_class.py
index 520589ce..f9dc5344 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -4497,3 +4497,67 @@ class TestMultiTypeCopy:
         assert not obj.dict_str_int.same_as(obj2.dict_str_int)  # 
ty:ignore[unresolved-attribute]
         assert tuple(obj2.list_int) == (1, 2, 3)
         assert obj2.dict_str_int["a"] == 1
+
+
+# ---------------------------------------------------------------------------
+# _collect_py_methods allowlist and method introspection
+# ---------------------------------------------------------------------------
+
+
+class TestPyMethodAllowlist:
+    """Only names in ``_FFI_RECOGNIZED_METHODS`` are collected by 
``_collect_py_methods``."""
+
+    def test_system_methods_not_in_allowlist(self) -> None:
+        from tvm_ffi.dataclasses.py_class import _collect_py_methods  # noqa: 
PLC0415
+
+        @py_class(_unique_key("Allow"))
+        class Allow(core.Object):
+            x: int
+
+            def __ffi_init__(self, x: int) -> None:  # ty: 
ignore[invalid-method-override]
+                pass
+
+            def __ffi_shallow_copy__(self) -> None:
+                pass
+
+            def __ffi_repr__(self, fn_repr: Any) -> str:
+                return "repr"
+
+        collected = _collect_py_methods(Allow)
+        assert collected is not None
+        names = {name for name, _, _ in collected}
+        assert "__ffi_repr__" in names
+        assert "__ffi_init__" not in names
+        assert "__ffi_shallow_copy__" not in names
+
+    def test_arbitrary_ffi_dunder_not_collected(self) -> None:
+        from tvm_ffi.dataclasses.py_class import _collect_py_methods  # noqa: 
PLC0415
+
+        @py_class(_unique_key("Arb"))
+        class Arb(core.Object):
+            x: int
+
+            def __ffi_custom_op__(self, y: int) -> int:
+                return self.x + y
+
+        collected = _collect_py_methods(Arb)
+        assert collected is None
+
+
+class TestPyMethodIntrospection:
+    """Registered __ffi_* methods appear in ``TypeInfo.methods``."""
+
+    def test_ffi_repr_in_methods(self) -> None:
+        @py_class(_unique_key("IntrRepr"))
+        class IntrRepr(core.Object):
+            x: int
+
+            def __ffi_repr__(self, fn_repr: Any) -> str:
+                return "repr"
+
+        info = getattr(IntrRepr, "__tvm_ffi_type_info__")
+        names = {m.name for m in info.methods}
+        assert "__ffi_repr__" in names
+        # system methods still present
+        assert "__ffi_init__" in names
+        assert "__ffi_shallow_copy__" in names
diff --git a/tests/python/test_dataclass_repr.py 
b/tests/python/test_dataclass_repr.py
index 62e9aff7..3f989e24 100644
--- a/tests/python/test_dataclass_repr.py
+++ b/tests/python/test_dataclass_repr.py
@@ -728,5 +728,46 @@ def test_repr_py_class_in_array() -> None:
     assert "2" in r
 
 
+# ---------------------------------------------------------------------------
+# Custom __ffi_repr__ hook via @py_class
+# ---------------------------------------------------------------------------
+from typing import Any as _Any_repr
+from typing import Callable as _Callable_repr
+
+
+def test_py_class_custom_ffi_repr() -> None:
+    """ReprPrint dispatches the user-defined __ffi_repr__ hook."""
+
+    @_py_class_repr(_unique_key_repr("CRepr"))
+    class CRepr(_Object_repr):
+        value: int
+
+        def __ffi_repr__(self, fn_repr: _Callable_repr[..., _Any_repr]) -> str:
+            return f"<CRepr:{self.value}>"
+
+    assert ReprPrint(CRepr(42)) == "<CRepr:42>"
+    assert ReprPrint(CRepr(999)) == "<CRepr:999>"
+
+
+def test_py_class_ffi_repr_with_fields_and_copy() -> None:
+    """Fields work normally and copy preserves __ffi_repr__ behaviour."""
+    import copy as _copy_repr  # noqa: PLC0415
+
+    @_py_class_repr(_unique_key_repr("FnR"))
+    class FnR(_Object_repr):
+        a: int
+        b: str
+
+        def __ffi_repr__(self, fn_repr: _Callable_repr[..., _Any_repr]) -> str:
+            return f"FnR({self.a}, {self.b!r})"
+
+    obj = FnR(10, "hi")
+    assert obj.a == 10
+    assert obj.b == "hi"
+    assert ReprPrint(obj) == "FnR(10, 'hi')"
+    obj2 = _copy_repr.copy(obj)
+    assert ReprPrint(obj2) == "FnR(10, 'hi')"
+
+
 if __name__ == "__main__":
     pytest.main([__file__, "-v"])

Reply via email to