This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 360648f feat: Add __repr__ generation support for @c_class
dataclasses (#411)
360648f is described below
commit 360648f30ccb14523ab6fbb81f37eb085b801f98
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sun Jan 18 11:30:53 2026 +0800
feat: Add __repr__ generation support for @c_class dataclasses (#411)
- Add `repr` parameter to `c_class()` decorator (default: True)
- Add `repr` parameter to `field()` function (default: True)
- Implement `method_repr()` to generate __repr__ methods
- Generated repr format: ClassName(field1=value1, field2=value2, ...)
- Fields with repr=False are excluded from the representation
This implements part of #356 dataclass feature parity.
---
python/tvm_ffi/dataclasses/_utils.py | 63 +++++++++++++++++++++++++++-----
python/tvm_ffi/dataclasses/c_class.py | 12 ++++--
python/tvm_ffi/dataclasses/field.py | 12 +++++-
tests/python/test_dataclasses_c_class.py | 35 ++++++++++++++++++
4 files changed, 108 insertions(+), 14 deletions(-)
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
index bd647a6..5ed4e96 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -58,14 +58,13 @@ def type_info_to_cls(
def _add_method(name: str, func: Callable[..., Any]) -> None:
if name == "__ffi_init__":
name = "__c_ffi_init__"
- if name in attrs: # already defined
- return
+ # Allow overriding methods (including from base classes like
Object.__repr__)
+ # by always adding to attrs, which will be used when creating the new
class
func.__module__ = cls.__module__
func.__name__ = name
func.__qualname__ = f"{cls.__qualname__}.{name}"
func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
attrs[name] = func
- setattr(cls, name, func)
for name, method_impl in methods.items():
if method_impl is not None:
@@ -98,6 +97,57 @@ def fill_dataclass_field(type_cls: type, type_field:
TypeField) -> None:
type_field.dataclass_field = rhs
+def _get_all_fields(type_info: TypeInfo) -> list[TypeField]:
+ """Collect all fields from the type hierarchy, from parents to children."""
+ fields: list[TypeField] = []
+ cur_type_info: TypeInfo | None = type_info
+ while cur_type_info is not None:
+ fields.extend(reversed(cur_type_info.fields))
+ cur_type_info = cur_type_info.parent_type_info
+ fields.reverse()
+ return fields
+
+
+def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]:
+ """Generate a ``__repr__`` method for the dataclass.
+
+ The generated representation includes all fields with ``repr=True`` in
+ the format ``ClassName(field1=value1, field2=value2, ...)``.
+ """
+ # Step 0. Collect all fields from the type hierarchy
+ fields = _get_all_fields(type_info)
+
+ # Step 1. Filter fields that should appear in repr
+ repr_fields: list[str] = []
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ if field.dataclass_field.repr:
+ repr_fields.append(field.name)
+
+ # Step 2. Generate the repr method
+ if not repr_fields:
+ # No fields to show, return a simple class name representation
+ body_lines = [f"return f'{type_cls.__name__}()'"]
+ else:
+ # Build field representations
+ fields_str = ", ".join(
+ f"{field_name}={{self.{field_name}!r}}" for field_name in
repr_fields
+ )
+ body_lines = [f"return f'{type_cls.__name__}({fields_str})'"]
+
+ source_lines = ["def __repr__(self) -> str:"]
+ source_lines.extend(f" {line}" for line in body_lines)
+ source = "\n".join(source_lines)
+
+ # Note: Code generation in this case is guaranteed to be safe,
+ # because the generated code does not contain any untrusted input.
+ namespace: dict[str, Any] = {}
+ exec(source, {}, namespace)
+ __repr__ = namespace["__repr__"]
+ return __repr__
+
+
def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.
@@ -105,12 +155,7 @@ def method_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
reflected field list, supporting default values and ``__post_init__``.
"""
# Step 0. Collect all fields from the type hierarchy
- fields: list[TypeField] = []
- cur_type_info: TypeInfo | None = type_info
- while cur_type_info is not None:
- fields.extend(reversed(cur_type_info.fields))
- cur_type_info = cur_type_info.parent_type_info
- fields.reverse()
+ fields = _get_all_fields(type_info)
# sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index 65d7c73..8171b1b 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -41,7 +41,7 @@ _InputClsType = TypeVar("_InputClsType")
@dataclass_transform(field_specifiers=(field,))
def c_class(
- type_key: str, init: bool = True
+ type_key: str, init: bool = True, repr: bool = True
) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
"""(Experimental) Create a dataclass-like proxy for a C++ class registered
with TVM FFI.
@@ -71,6 +71,10 @@ def c_class(
signature. The generated initializer calls the C++ ``__init__``
function registered with ``ObjectDef`` and invokes ``__post_init__`` if
it exists on the Python class.
+ repr
+ If ``True`` and the Python class does not define ``__repr__``, a
+ representation method is auto-generated that includes all fields with
+ ``repr=True``.
Returns
-------
@@ -118,8 +122,9 @@ def c_class(
"""
def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]:
# noqa: UP006
- nonlocal init
+ nonlocal init, repr
init = init and "__init__" not in super_type_cls.__dict__
+ repr = repr and "__repr__" not in super_type_cls.__dict__
# Step 1. Retrieve `type_info` from registry
type_info: TypeInfo =
_lookup_or_register_type_info_from_type_key(type_key)
assert type_info.parent_type_info is not None
@@ -129,10 +134,11 @@ def c_class(
_utils.fill_dataclass_field(super_type_cls, type_field)
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else
None
+ fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else
None
type_cls: Type[_InputClsType] = _utils.type_info_to_cls( # noqa: UP006
type_info=type_info,
cls=super_type_cls,
- methods={"__init__": fn_init},
+ methods={"__init__": fn_init, "__repr__": fn_repr},
)
_set_type_cls(type_info, type_cls)
return type_cls
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index d10612c..d0e27b1 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -37,7 +37,7 @@ class Field:
way the decorator understands.
"""
- __slots__ = ("default_factory", "init", "name")
+ __slots__ = ("default_factory", "init", "name", "repr")
def __init__(
self,
@@ -45,11 +45,13 @@ class Field:
name: str | None = None,
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
init: bool = True,
+ repr: bool = True,
) -> None:
"""Do not call directly; use :func:`field` instead."""
self.name = name
self.default_factory = default_factory
self.init = init
+ self.repr = repr
def field(
@@ -57,6 +59,7 @@ def field(
default: _FieldValue | _MISSING_TYPE = MISSING, # type: ignore[assignment]
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, #
type: ignore[assignment]
init: bool = True,
+ repr: bool = True,
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class`
proxy.
@@ -78,6 +81,9 @@ def field(
init
If ``True`` the field is included in the generated ``__init__``.
If ``False`` the field is omitted from input arguments of ``__init__``.
+ repr
+ If ``True`` the field is included in the generated ``__repr__``.
+ If ``False`` the field is omitted from the ``__repr__`` output.
Note
----
@@ -123,9 +129,11 @@ def field(
raise ValueError("Cannot specify both `default` and `default_factory`")
if not isinstance(init, bool):
raise TypeError("`init` must be a bool")
+ if not isinstance(repr, bool):
+ raise TypeError("`repr` must be a bool")
if default is not MISSING:
default_factory = _make_default_factory(default)
- ret = Field(default_factory=default_factory, init=init)
+ ret = Field(default_factory=default_factory, init=init, repr=repr)
return cast(_FieldValue, ret)
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
index 676bbf5..5361cb6 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -94,3 +94,38 @@ def test_cxx_class_init_subset_positional() -> None:
assert obj.optional_field == -1
obj.optional_field = 11
assert obj.optional_field == 11
+
+
+def test_cxx_class_repr() -> None:
+ obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0)
+ repr_str = repr(obj)
+ assert "_TestCxxClassDerived" in repr_str
+ if "__repr__" in _TestCxxClassDerived.__dict__:
+ assert "v_i64=123" in repr_str
+ assert "v_i32=456" in repr_str
+ assert "v_f64=4.0" in repr_str
+ assert "v_f32=8.0" in repr_str
+
+
+def test_cxx_class_repr_default() -> None:
+ obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0)
+ repr_str = repr(obj)
+ assert "_TestCxxClassDerived" in repr_str
+ if "__repr__" in _TestCxxClassDerived.__dict__:
+ assert "v_i64=123" in repr_str
+ assert "v_i32=456" in repr_str
+ assert "v_f64=4.0" in repr_str
+ assert "v_f32=8.0" in repr_str
+
+
+def test_cxx_class_repr_derived_derived() -> None:
+ obj = _TestCxxClassDerivedDerived(
+ v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0, v_str="hello", v_bool=True
+ )
+ repr_str = repr(obj)
+ assert "_TestCxxClassDerivedDerived" in repr_str
+ if "__repr__" in _TestCxxClassDerivedDerived.__dict__:
+ assert "v_i64=123" in repr_str
+ assert "v_i32=456" in repr_str
+ assert "v_str='hello'" in repr_str or 'v_str="hello"' in repr_str
+ assert "v_bool=True" in repr_str