This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new d153a2c9 feat(python): auto-register __ffi_* dunder methods in
@py_class as TypeMethod (#508)
d153a2c9 is described below
commit d153a2c9407ee04ec3dc1ccbdf6a0664854337b7
Author: Junru Shao <[email protected]>
AuthorDate: Sun Mar 22 14:24:35 2026 -0700
feat(python): auto-register __ffi_* dunder methods in @py_class as
TypeMethod (#508)
## Summary
- Extend `@py_class` to detect and register recognized FFI dunder
methods (`__ffi_repr__`, `__ffi_eq__`, `__s_equal__`, etc.) defined on
Python dataclasses, making them callable from C++ and other FFI
languages
- `_FFI_RECOGNIZED_METHODS` allowlist in `py_class.py` gates which
dunders are collected; `_collect_py_methods()` scans class dict and
`TypeInfo._register_py_methods()` registers via the C API
(`TVMFFITypeRegisterMethod` + `TVMFFITypeRegisterAttr`)
- No C++ changes — `dataclass.cc` is unmodified from upstream
## Changes since v1
- Rebased onto upstream/main (`d8bd1890`) which adds structural
equality/hashing support (#507)
- Resolved merge conflicts: preserved `structure_kind` /
`STRUCTURE_KIND_MAP` from upstream
- Applied Gemini Code Assist review feedback:
- **HIGH**: Fixed `func_any` memory leak in `_register_py_methods` by
adding `TVMFFIObjectDecRef` in a `try/finally` block
- **MEDIUM**: Changed `_collect_py_methods` to iterate
`cls.__dict__.items()` instead of `list(cls.__dict__)` + second lookup
## Test plan
- [x] 12 new dedicated tests covering allowlist filtering, registration,
round-trip dispatch (repr, eq, compare, hash)
- [x] All 693 tests pass (including structural equality tests that now
work with the `structure_kind` base from #507)
- [ ] CI: lint, C++ tests, Python tests, Rust tests
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---
python/tvm_ffi/cython/base.pxi | 1 +
python/tvm_ffi/cython/type_info.pxi | 86 +++++++++++++++++++++++++++++++--
python/tvm_ffi/dataclasses/py_class.py | 49 +++++++++++++++++++
tests/python/test_dataclass_compare.py | 78 ++++++++++++++++++++++++++++++
tests/python/test_dataclass_hash.py | 38 +++++++++++++++
tests/python/test_dataclass_py_class.py | 64 ++++++++++++++++++++++++
tests/python/test_dataclass_repr.py | 41 ++++++++++++++++
7 files changed, 354 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 007aa0a9..b851e1e2 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -323,6 +323,7 @@ cdef extern from "tvm/ffi/c_api.h":
int32_t parent_type_index
) nogil
int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo*
info) nogil
+ int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo*
info) nogil
int TVMFFITypeRegisterMetadata(int32_t type_index, const
TVMFFITypeMetadata* metadata) nogil
int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray*
attr_name,
const TVMFFIAny* attr_value) nogil
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 98c5f8ca..128efa70 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -709,7 +709,7 @@ class TypeInfo:
Delegates to the module-level _register_fields function,
stores the resulting list[TypeField] on self.fields,
- then reads back methods registered by C++ via _register_methods.
+ then reads back methods registered by C++ via _read_back_methods.
Can only be called once (fields must be None beforehand).
@@ -725,9 +725,27 @@ class TypeInfo:
f"_register_fields already called for {self.type_key!r}"
)
self.fields = _register_fields(self, fields, structure_kind)
- self._register_methods()
+ self._read_back_methods()
- def _register_methods(self):
+ def _register_py_methods(self, py_methods=None):
+ """Register user-defined dunder methods and re-read the method table.
+
+ When *py_methods* is non-empty, each entry is registered as both
+ TypeMethod and TypeAttr via the C API. Regardless, the full
+ method list is always re-read from the C type table so that
+ system-generated methods (``__ffi_init__``, ``__ffi_shallow_copy__``)
+ are picked up.
+
+ Parameters
+ ----------
+ py_methods : list[tuple[str, callable, bool]] | None
+ Each entry is ``(name, func, is_static)``.
+ """
+ if py_methods:
+ _register_py_methods(self.type_index, py_methods)
+ self._read_back_methods()
+
+ def _read_back_methods(self):
"""Read methods from the C type table into self.methods.
Called after C++ registers __ffi_init__, __ffi_shallow_copy__, etc.
@@ -1029,6 +1047,68 @@ cdef _register_type_metadata(int32_t type_index, int32_t
total_size, int structu
CHECK_CALL(TVMFFITypeRegisterMetadata(type_index, &metadata))
+cdef _register_py_methods(int32_t type_index, list py_methods):
+ """Register user-defined dunder methods as both TypeMethod and TypeAttr.
+
+ For each method in *py_methods*:
+ 1. Convert the Python callable to a ``TVMFFIAny`` (``ffi::Function``).
+ 2. Call ``TVMFFITypeRegisterMethod`` so the method appears in the
+ type's reflection metadata (``TypeInfo.methods``).
+ 3. Ensure the type-attribute column exists (sentinel call with
+ ``type_index = kTVMFFINone``), then call ``TVMFFITypeRegisterAttr``
+ so the C++ runtime dispatch can find the hook.
+
+ Parameters
+ ----------
+ type_index : int
+ The runtime type index of the type.
+ py_methods : list[tuple[str, callable, bool]]
+ Each entry is ``(name, func, is_static)``.
+ """
+ cdef TVMFFIMethodInfo method_info
+ cdef TVMFFIAny func_any
+ cdef TVMFFIAny sentinel_any
+ cdef int c_api_ret_code
+ cdef ByteArrayArg name_arg
+
+ sentinel_any.type_index = kTVMFFINone
+ sentinel_any.v_int64 = 0
+
+ for name, func, is_static in py_methods:
+ func_any.type_index = kTVMFFINone
+ func_any.v_int64 = 0
+ try:
+ name_bytes = c_str(name)
+ name_arg = ByteArrayArg(name_bytes)
+
+ # Convert Python callable -> TVMFFIAny (creates a FunctionObj)
+ TVMFFIPyPyObjectToFFIAny(
+ TVMFFIPyArgSetterFactory_,
+ <PyObject*>func,
+ &func_any,
+ &c_api_ret_code,
+ )
+ CHECK_CALL(c_api_ret_code)
+
+ # 1. Register as TypeMethod
+ method_info.name = name_arg.cdata
+ method_info.doc.data = NULL
+ method_info.doc.size = 0
+ method_info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod if
is_static else 0
+ method_info.method = func_any
+ method_info.metadata.data = NULL
+ method_info.metadata.size = 0
+ CHECK_CALL(TVMFFITypeRegisterMethod(type_index, &method_info))
+
+ # 2. Ensure type-attr column exists (sentinel: kTVMFFINone)
+ CHECK_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_arg.cdata,
&sentinel_any))
+ # 3. Register as TypeAttr
+ CHECK_CALL(TVMFFITypeRegisterAttr(type_index, &name_arg.cdata,
&func_any))
+ finally:
+ if func_any.type_index >= kTVMFFIStaticObjectBegin and
func_any.v_obj != NULL:
+ TVMFFIObjectDecRef(<TVMFFIObjectHandle>func_any.v_obj)
+
+
def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
def wrapper(self: Any, *args: Any) -> Any:
return method_func(self, *args)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index 103f2c26..07b6aa08 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -202,6 +202,30 @@ def _collect_own_fields(
return fields
+def _collect_py_methods(cls: type) -> list[tuple[str, Any, bool]] | None:
+ """Extract recognized FFI dunder methods from the class body.
+
+ Only names listed in :data:`_FFI_RECOGNIZED_METHODS` are collected.
+
+ Returns a list of ``(name, func, is_static)`` tuples, or ``None``
+ if no eligible methods were found.
+ """
+ methods: list[tuple[str, Any, bool]] = []
+ for name, value in cls.__dict__.items():
+ if name not in _FFI_RECOGNIZED_METHODS:
+ continue
+ if isinstance(value, staticmethod):
+ func = value.__func__
+ is_static = True
+ elif callable(value):
+ func = value
+ is_static = False
+ else:
+ continue
+ methods.append((name, func, is_static))
+ return methods if methods else None
+
+
def _phase2_register_fields(
cls: type,
type_info: Any,
@@ -224,10 +248,13 @@ def _phase2_register_fields(
return False
own_fields = _collect_own_fields(cls, hints, params["kw_only"])
+ py_methods = _collect_py_methods(cls)
# Register fields and type-level structural eq/hash kind with the C layer.
structure_kind = _STRUCTURE_KIND_MAP.get(params.get("structure"))
type_info._register_fields(own_fields, structure_kind)
+ # Register user-defined dunder methods and read back system-generated ones.
+ type_info._register_py_methods(py_methods)
_add_class_attrs(cls, type_info)
# Remove deferred __init__ and restore user-defined __init__ if saved
@@ -350,6 +377,28 @@ _STRUCTURE_KIND_MAP: dict[str | None, int] = {
"singleton": 5, # kTVMFFISEqHashKindUniqueInstance
}
+#: Allowlist of dunder method names that ``@py_class`` will auto-register
+#: as both TypeMethod (for reflection) and TypeAttr (for C++ dispatch).
+#:
+#: Only names in this set are collected from the class body.
+#: System-managed names (``__ffi_init__``, ``__ffi_shallow_copy__``, etc.)
+#: are intentionally absent because the C++ runtime generates them.
+_FFI_RECOGNIZED_METHODS: frozenset[str] = frozenset(
+ {
+ # Recursive operations (RecursiveHash, RecursiveEq, RecursiveCompare,
ReprPrint)
+ "__ffi_repr__",
+ "__ffi_hash__",
+ "__ffi_eq__",
+ "__ffi_compare__",
+ # Structural equality/hashing (StructuralEqual, StructuralHash)
+ "__s_equal__",
+ "__s_hash__",
+ # Serialization (ToJSONGraph, FromJSONGraph)
+ "__data_to_json__",
+ "__data_from_json__",
+ }
+)
+
@dataclass_transform(
eq_default=False,
diff --git a/tests/python/test_dataclass_compare.py
b/tests/python/test_dataclass_compare.py
index 83234abe..36f3fea6 100644
--- a/tests/python/test_dataclass_compare.py
+++ b/tests/python/test_dataclass_compare.py
@@ -1270,3 +1270,81 @@ def test_custom_compare_ordering_consistency() -> None:
assert not RecursiveGt(a, b)
assert RecursiveLe(a, b)
assert RecursiveGe(a, b)
+
+
+# ---------------------------------------------------------------------------
+# Custom __ffi_eq__ / __ffi_compare__ hooks via @py_class
+# ---------------------------------------------------------------------------
+import itertools as _itertools_cmp
+from typing import Any as _Any_cmp
+from typing import Callable as _Callable_cmp
+
+from tvm_ffi._ffi_api import RecursiveHash as _RecursiveHash_cmp
+from tvm_ffi.core import Object as _Object_cmp
+from tvm_ffi.dataclasses import py_class as _py_class_cmp
+
+_counter_cmp = _itertools_cmp.count()
+
+
+def _unique_key_cmp(base: str) -> str:
+ return f"testing.cmp_pc.{base}_{next(_counter_cmp)}"
+
+
+@_py_class_cmp(_unique_key_cmp("PyEqHash"))
+class _PyEqHash(_Object_cmp):
+ key: int
+ label: str
+
+ def __ffi_hash__(self, fn_hash: _Callable_cmp[..., _Any_cmp]) -> int:
+ return fn_hash(self.key)
+
+ def __ffi_eq__(self, other: _PyEqHash, fn_eq: _Callable_cmp[...,
_Any_cmp]) -> bool:
+ return fn_eq(self.key, other.key)
+
+
+@_py_class_cmp(_unique_key_cmp("PyCmp"))
+class _PyCmp(_Object_cmp):
+ key: int
+ label: str
+
+ def __ffi_hash__(self, fn_hash: _Callable_cmp[..., _Any_cmp]) -> int:
+ return fn_hash(self.key)
+
+ def __ffi_eq__(self, other: _PyCmp, fn_eq: _Callable_cmp[..., _Any_cmp])
-> bool:
+ return fn_eq(self.key, other.key)
+
+ def __ffi_compare__(self, other: _PyCmp, fn_cmp: _Callable_cmp[...,
_Any_cmp]) -> int:
+ return fn_cmp(self.key, other.key)
+
+
+def test_py_class_custom_eq_ignores_label() -> None:
+ assert RecursiveEq(_PyEqHash(42, "alpha"), _PyEqHash(42, "beta"))
+
+
+def test_py_class_custom_eq_different_key() -> None:
+ assert not RecursiveEq(_PyEqHash(1, "same"), _PyEqHash(2, "same"))
+
+
+def test_py_class_custom_eq_hash_consistency() -> None:
+ a, b = _PyEqHash(42, "alpha"), _PyEqHash(42, "beta")
+ assert RecursiveEq(a, b)
+ assert _RecursiveHash_cmp(a) == _RecursiveHash_cmp(b)
+
+
+def test_py_class_custom_compare_ordering() -> None:
+ a = _PyCmp(1, "zzz")
+ b = _PyCmp(2, "aaa")
+ assert RecursiveLt(a, b)
+ assert RecursiveLe(a, b)
+ assert not RecursiveGt(a, b)
+ assert not RecursiveGe(a, b)
+
+
+def test_py_class_custom_compare_equal_keys() -> None:
+ a = _PyCmp(42, "alpha")
+ b = _PyCmp(42, "beta")
+ assert RecursiveEq(a, b)
+ assert RecursiveLe(a, b)
+ assert RecursiveGe(a, b)
+ assert not RecursiveLt(a, b)
+ assert not RecursiveGt(a, b)
diff --git a/tests/python/test_dataclass_hash.py
b/tests/python/test_dataclass_hash.py
index ec58c2b7..a79c519e 100644
--- a/tests/python/test_dataclass_hash.py
+++ b/tests/python/test_dataclass_hash.py
@@ -968,3 +968,41 @@ def test_eq_without_hash_inside_container_raises() -> None:
arr = tvm_ffi.Array([obj])
with pytest.raises(ValueError, match="__ffi_eq__ or __ffi_compare__ but
not __ffi_hash__"):
RecursiveHash(arr)
+
+
+# ---------------------------------------------------------------------------
+# Custom __ffi_hash__ hook via @py_class
+# ---------------------------------------------------------------------------
+import itertools as _itertools_hash
+from typing import Any as _Any_hash
+from typing import Callable as _Callable_hash
+
+from tvm_ffi.core import Object as _Object_hash
+from tvm_ffi.dataclasses import py_class as _py_class_hash
+
+_counter_hash = _itertools_hash.count()
+
+
+def _unique_key_hash(base: str) -> str:
+ return f"testing.hash_pc.{base}_{next(_counter_hash)}"
+
+
+@_py_class_hash(_unique_key_hash("PyCustomHash"))
+class _PyCustomHash(_Object_hash):
+ key: int
+ label: str
+
+ def __ffi_hash__(self, fn_hash: _Callable_hash[..., _Any_hash]) -> int:
+ return fn_hash(self.key)
+
+
+def test_py_class_custom_hash_ignores_label() -> None:
+ a = _PyCustomHash(42, "alpha")
+ b = _PyCustomHash(42, "beta")
+ assert RecursiveHash(a) == RecursiveHash(b)
+
+
+def test_py_class_custom_hash_different_key() -> None:
+ a = _PyCustomHash(1, "same")
+ b = _PyCustomHash(2, "same")
+ assert RecursiveHash(a) != RecursiveHash(b)
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index 520589ce..f9dc5344 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -4497,3 +4497,67 @@ class TestMultiTypeCopy:
assert not obj.dict_str_int.same_as(obj2.dict_str_int) #
ty:ignore[unresolved-attribute]
assert tuple(obj2.list_int) == (1, 2, 3)
assert obj2.dict_str_int["a"] == 1
+
+
+# ---------------------------------------------------------------------------
+# _collect_py_methods allowlist and method introspection
+# ---------------------------------------------------------------------------
+
+
+class TestPyMethodAllowlist:
+ """Only names in ``_FFI_RECOGNIZED_METHODS`` are collected by
``_collect_py_methods``."""
+
+ def test_system_methods_not_in_allowlist(self) -> None:
+ from tvm_ffi.dataclasses.py_class import _collect_py_methods # noqa:
PLC0415
+
+ @py_class(_unique_key("Allow"))
+ class Allow(core.Object):
+ x: int
+
+ def __ffi_init__(self, x: int) -> None: # ty:
ignore[invalid-method-override]
+ pass
+
+ def __ffi_shallow_copy__(self) -> None:
+ pass
+
+ def __ffi_repr__(self, fn_repr: Any) -> str:
+ return "repr"
+
+ collected = _collect_py_methods(Allow)
+ assert collected is not None
+ names = {name for name, _, _ in collected}
+ assert "__ffi_repr__" in names
+ assert "__ffi_init__" not in names
+ assert "__ffi_shallow_copy__" not in names
+
+ def test_arbitrary_ffi_dunder_not_collected(self) -> None:
+ from tvm_ffi.dataclasses.py_class import _collect_py_methods # noqa:
PLC0415
+
+ @py_class(_unique_key("Arb"))
+ class Arb(core.Object):
+ x: int
+
+ def __ffi_custom_op__(self, y: int) -> int:
+ return self.x + y
+
+ collected = _collect_py_methods(Arb)
+ assert collected is None
+
+
+class TestPyMethodIntrospection:
+ """Registered __ffi_* methods appear in ``TypeInfo.methods``."""
+
+ def test_ffi_repr_in_methods(self) -> None:
+ @py_class(_unique_key("IntrRepr"))
+ class IntrRepr(core.Object):
+ x: int
+
+ def __ffi_repr__(self, fn_repr: Any) -> str:
+ return "repr"
+
+ info = getattr(IntrRepr, "__tvm_ffi_type_info__")
+ names = {m.name for m in info.methods}
+ assert "__ffi_repr__" in names
+ # system methods still present
+ assert "__ffi_init__" in names
+ assert "__ffi_shallow_copy__" in names
diff --git a/tests/python/test_dataclass_repr.py
b/tests/python/test_dataclass_repr.py
index 62e9aff7..3f989e24 100644
--- a/tests/python/test_dataclass_repr.py
+++ b/tests/python/test_dataclass_repr.py
@@ -728,5 +728,46 @@ def test_repr_py_class_in_array() -> None:
assert "2" in r
+# ---------------------------------------------------------------------------
+# Custom __ffi_repr__ hook via @py_class
+# ---------------------------------------------------------------------------
+from typing import Any as _Any_repr
+from typing import Callable as _Callable_repr
+
+
+def test_py_class_custom_ffi_repr() -> None:
+ """ReprPrint dispatches the user-defined __ffi_repr__ hook."""
+
+ @_py_class_repr(_unique_key_repr("CRepr"))
+ class CRepr(_Object_repr):
+ value: int
+
+ def __ffi_repr__(self, fn_repr: _Callable_repr[..., _Any_repr]) -> str:
+ return f"<CRepr:{self.value}>"
+
+ assert ReprPrint(CRepr(42)) == "<CRepr:42>"
+ assert ReprPrint(CRepr(999)) == "<CRepr:999>"
+
+
+def test_py_class_ffi_repr_with_fields_and_copy() -> None:
+ """Fields work normally and copy preserves __ffi_repr__ behaviour."""
+ import copy as _copy_repr # noqa: PLC0415
+
+ @_py_class_repr(_unique_key_repr("FnR"))
+ class FnR(_Object_repr):
+ a: int
+ b: str
+
+ def __ffi_repr__(self, fn_repr: _Callable_repr[..., _Any_repr]) -> str:
+ return f"FnR({self.a}, {self.b!r})"
+
+ obj = FnR(10, "hi")
+ assert obj.a == 10
+ assert obj.b == "hi"
+ assert ReprPrint(obj) == "FnR(10, 'hi')"
+ obj2 = _copy_repr.copy(obj)
+ assert ReprPrint(obj2) == "FnR(10, 'hi')"
+
+
if __name__ == "__main__":
pytest.main([__file__, "-v"])