This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new e5f3af7b feat(python): reimplement c_class as register_object +
structural dunders (#488)
e5f3af7b is described below
commit e5f3af7bb83e6461c45d08117e4eaabe51add3b1
Author: Junru Shao <[email protected]>
AuthorDate: Sat Feb 28 12:03:37 2026 -0800
feat(python): reimplement c_class as register_object + structural dunders
(#488)
## Summary
Rewrite the `@c_class` decorator from a thin `register_object`
pass-through into
a `dataclass`-style decorator that combines FFI type registration with
structural
dunder methods derived from C++ reflection metadata.
- **`@c_class` now installs structural dunders** — `__init__`,
`__repr__`, `__eq__`/`__ne__`,
`__hash__`, and ordering operators (`__lt__`, `__le__`, `__gt__`,
`__ge__`) — all
delegating to the corresponding C++ recursive operations (`RecursiveEq`,
`RecursiveHash`,
`RecursiveLt`, etc.).
- **`@dataclass_transform` decorator** added for IDE/type-checker
support (pyright, mypy).
- **Migrated all test objects** in `tvm_ffi.testing` from
`@register_object` to `@c_class`.
## Architecture
- `c_class.py`: decorator accepts `init`, `repr`, `eq`, `order`,
`unsafe_hash`
parameters. Delegates to `register_object` +
`_install_dataclass_dunders`.
- `registry.py`: `_install_dataclass_dunders` installs dunders on class;
`_install_init` synthesizes reflection-based `__init__` or guard;
`_make_init` / `_make_init_signature` builds `inspect.Signature` from
C++ field
metadata (respecting `kw_only`, `has_default`, `c_init` traits).
`_is_comparable` centralises the bidirectional isinstance guard.
- Each installed dunder checks `cls.__dict__` before setting, preserving
user-defined
overrides.
- `__eq__`/`__ne__`/ordering return `NotImplemented` for unrelated
types, following
Python data model conventions.
## Public Interfaces
- `@c_class(type_key, *, init, repr, eq, order, unsafe_hash)` — new
keyword
arguments; old usage `@c_class("key")` continues to work with sensible
defaults
(`init=True`, `repr=True`, others off).
- No breaking changes — `eq`, `order`, `unsafe_hash` default to `False`.
## Test Plan
- [x] New `test_dataclass_c_class.py` (26 tests): custom init
preservation,
auto-generated init with defaults, structural equality (reflexive,
symmetric),
hash (dict key, set dedup), ordering (reflexive, antisymmetric),
different-type
returns `NotImplemented`, subclass equality, `kw_only` from C++
reflection,
`init_subset`, derived-derived defaults.
- [x] Renamed `test_copy.py` → `test_dataclass_copy.py` with additional
cycle/Shape
coverage and `deep_copy.cc` branch-coverage tests.
- [x] `uv run pytest -vvs tests/python` — 960 passed, 23 skipped, 1
xfailed.
- [x] `sphinx-build -W --keep-going -b html docs docs/_build/html` —
build succeeded.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---
python/tvm_ffi/cython/object.pxi | 7 +-
python/tvm_ffi/dataclasses/c_class.py | 89 +++++-
python/tvm_ffi/registry.py | 130 ++++++++-
python/tvm_ffi/testing/testing.py | 60 ++--
tests/python/test_dataclass_c_class.py | 319 +++++++++++++++++++++
.../{test_copy.py => test_dataclass_copy.py} | 202 +++++++++++++
tests/python/test_dataclass_init.py | 2 +-
7 files changed, 772 insertions(+), 37 deletions(-)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 97536fec..ba0a4906 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -203,9 +203,10 @@ class Object(CObject, metaclass=_ObjectSlotsMeta):
identity unless an overridden implementation is provided on the
concrete type. Use :py:meth:`same_as` to check whether two
references point to the same underlying object.
- - Subclasses that omit ``__slots__`` are treated as ``__slots__ = ()``.
- Subclasses that need per-instance dynamic attributes can opt in with
- ``__slots__ = ("__dict__",)``.
+ - Subclasses that omit ``__slots__`` get ``__slots__ = ()`` injected
+ automatically by the metaclass. Pass ``slots=False`` in the class
+ header (e.g. ``class Foo(Object, slots=False)``) to suppress this
+ and allow a per-instance ``__dict__``.
- Most users interact with subclasses (e.g. :class:`Tensor`,
:class:`Function`) rather than :py:class:`Object` directly.
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index d836f1a9..ee6d4329 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -14,36 +14,103 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""The ``c_class`` decorator: pass-through to ``register_object``."""
+"""The ``c_class`` decorator: register_object + structural dunders."""
from __future__ import annotations
from collections.abc import Callable
-from typing import Any, TypeVar
+from typing import TypeVar
+
+from typing_extensions import dataclass_transform
_T = TypeVar("_T", bound=type)
-def c_class(type_key: str, **kwargs: Any) -> Callable[[_T], _T]:
- """Register a C++ FFI class by type key.
+@dataclass_transform(eq_default=False, order_default=False)
+def c_class(
+ type_key: str,
+ *,
+ init: bool = True,
+ repr: bool = True,
+ eq: bool = False,
+ order: bool = False,
+ unsafe_hash: bool = False,
+) -> Callable[[_T], _T]:
+ """Register a C++ FFI class and install structural dunder methods.
- This is a thin wrapper around :func:`~tvm_ffi.register_object` that
- accepts (and currently ignores) additional keyword arguments for
- forward compatibility.
+ Combines :func:`~tvm_ffi.register_object` with structural comparison,
+ hashing, and ordering derived from the C++ reflection metadata.
+ User-defined dunders in the class body are never overwritten.
Parameters
----------
type_key
The reflection key that identifies the C++ type in the FFI registry.
- kwargs
- Reserved for future use.
+ Must match a key already registered on the C++ side via
+ ``TVM_FFI_DECLARE_OBJECT_INFO``.
+ init
+ If True (default), install ``__init__`` from C++ reflection metadata.
+ The generated ``__init__`` respects ``Init()``, ``KwOnly()``, and
+ ``Default()`` traits declared on each C++ field. If the class body
+ already defines ``__init__``, it is kept.
+ repr
+ If True (default), install ``__repr__`` using
+ :func:`~tvm_ffi.core.object_repr`, which formats the object via
+ the C++ ``ReprPrint`` visitor. Skipped if the class body already
+ defines ``__repr__``.
+ eq
+ If True, install ``__eq__`` and ``__ne__`` using the C++ recursive
+ structural comparison (``RecursiveEq``). Returns ``NotImplemented``
+ for unrelated types. Defaults to False.
+ order
+ If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``
+ using the C++ recursive comparators. Returns ``NotImplemented``
+ for unrelated types. Defaults to False.
+ unsafe_hash
+ If True, install ``__hash__`` using ``RecursiveHash``. Called
+ *unsafe* because mutable fields contribute to the hash, so mutating
+ an object while it is in a set or dict key will break invariants.
+ Defaults to False.
Returns
-------
Callable[[type], type]
A class decorator.
+ Examples
+ --------
+ Basic usage with default settings (``init`` and ``repr`` enabled):
+
+ .. code-block:: python
+
+ @c_class("my.Point")
+ class Point(Object):
+ x: float
+ y: float
+
+ Enable structural equality, hashing, and ordering:
+
+ .. code-block:: python
+
+ @c_class("my.Point", eq=True, unsafe_hash=True, order=True)
+ class Point(Object):
+ x: float
+ y: float
+
+ See Also
+ --------
+ :func:`tvm_ffi.register_object`
+ Lower-level decorator that only registers the type without
+ installing structural dunders.
+
"""
- from ..registry import register_object # noqa: PLC0415
+ from ..registry import _install_dataclass_dunders, register_object #
noqa: PLC0415
+
+ def decorator(cls: _T) -> _T:
+ cls = register_object(type_key)(cls)
+ _install_dataclass_dunders(
+ cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash
+ )
+ return cls
- return register_object(type_key)
+ return decorator
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index f2801b1f..b45b03e0 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -63,8 +63,8 @@ def register_object(type_key: str | None = None) ->
Callable[[_T], _T]:
return cls
raise ValueError(f"Cannot find object type index for
{object_name}")
info = core._register_object_by_index(type_index, cls)
- setattr(cls, "__tvm_ffi_type_info__", info)
_add_class_attrs(type_cls=cls, type_info=info)
+ setattr(cls, "__tvm_ffi_type_info__", info)
return cls
if isinstance(type_key, str):
@@ -418,7 +418,6 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
setattr(type_cls, name, method.as_callable(type_cls))
elif not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
- _install_init(type_cls, enabled=True)
is_container = type_info.type_key in (
"ffi.Array",
"ffi.Map",
@@ -456,7 +455,19 @@ def _setup_copy_methods(
def _install_init(cls: type, *, enabled: bool) -> None:
- """Install ``__init__`` from C++ reflection metadata, or a guard."""
+ """Install ``__init__`` from C++ reflection metadata, or a guard.
+
+ When *enabled* is True, looks for a ``__ffi_init__`` method in the
+ type's C++ reflection metadata. If the method has ``auto_init=True``
+ metadata (set by ``refl::init()`` in C++), a Python ``__init__`` is
+ synthesized with an ``inspect.Signature`` derived from the field
+ metadata (respecting ``Init()``, ``KwOnly()``, ``Default()`` traits).
+ Otherwise the raw ``__ffi_init__`` is exposed as ``__init__`` directly.
+
+ When *enabled* is False, installs a guard that raises ``TypeError``
+ on construction. Skipped entirely if the class body already defines
+ ``__init__``.
+ """
if "__init__" in cls.__dict__:
return
type_info: TypeInfo | None = getattr(cls, "__tvm_ffi_type_info__", None)
@@ -531,6 +542,119 @@ def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
)
+def _install_dataclass_dunders(
+ cls: type,
+ *,
+ init: bool,
+ repr: bool,
+ eq: bool,
+ order: bool,
+ unsafe_hash: bool,
+) -> None:
+ """Install structural dunder methods on *cls*.
+
+ Each dunder delegates to the corresponding C++ recursive structural
+ operation (``RecursiveEq``, ``RecursiveHash``, ``RecursiveLt``, etc.).
+ If the user already defined a dunder in the class body
+ (i.e. it exists in ``cls.__dict__``), it is left untouched.
+
+ Parameters
+ ----------
+ cls
+ The class to install dunders on. Must have been processed by
+ :func:`register_object` first (so ``__tvm_ffi_type_info__`` exists).
+ init
+ If True, install ``__init__`` from C++ reflection metadata via
+ :func:`_install_init`.
+ repr
+ If True, install :func:`~tvm_ffi.core.object_repr` as ``__repr__``.
+ eq
+ If True, install ``__eq__`` and ``__ne__`` using ``RecursiveEq``.
+ Returns ``NotImplemented`` for unrelated types so Python can
+ fall back to identity comparison.
+ order
+ If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``
+ using ``RecursiveLt``/``Le``/``Gt``/``Ge``. Returns
+ ``NotImplemented`` for unrelated types.
+ unsafe_hash
+ If True, install ``__hash__`` using ``RecursiveHash``.
+
+ """
+ _install_init(cls, enabled=init)
+
+ if repr and "__repr__" not in cls.__dict__:
+ from .core import object_repr # noqa: PLC0415
+
+ cls.__repr__ = object_repr # type: ignore[attr-defined]
+
+ from . import _ffi_api # noqa: PLC0415
+
+ def _is_comparable(self: Any, other: Any) -> bool:
+ """Return True if *self* and *other* share a type hierarchy."""
+ return isinstance(other, type(self)) or isinstance(self, type(other))
+
+ dunders: dict[str, Any] = {}
+
+ if eq:
+ recursive_eq = _ffi_api.RecursiveEq
+
+ def __eq__(self: Any, other: Any) -> bool:
+ if not _is_comparable(self, other):
+ return NotImplemented
+ return recursive_eq(self, other)
+
+ def __ne__(self: Any, other: Any) -> bool:
+ if not _is_comparable(self, other):
+ return NotImplemented
+ return not recursive_eq(self, other)
+
+ dunders["__eq__"] = __eq__
+ dunders["__ne__"] = __ne__
+
+ if unsafe_hash:
+ recursive_hash = _ffi_api.RecursiveHash
+
+ def __hash__(self: Any) -> int:
+ return recursive_hash(self)
+
+ dunders["__hash__"] = __hash__
+
+ if order:
+ recursive_lt = _ffi_api.RecursiveLt
+ recursive_le = _ffi_api.RecursiveLe
+ recursive_gt = _ffi_api.RecursiveGt
+ recursive_ge = _ffi_api.RecursiveGe
+
+ def __lt__(self: Any, other: Any) -> bool:
+ if not _is_comparable(self, other):
+ return NotImplemented
+ return recursive_lt(self, other)
+
+ def __le__(self: Any, other: Any) -> bool:
+ if not _is_comparable(self, other):
+ return NotImplemented
+ return recursive_le(self, other)
+
+ def __gt__(self: Any, other: Any) -> bool:
+ if not _is_comparable(self, other):
+ return NotImplemented
+ return recursive_gt(self, other)
+
+ def __ge__(self: Any, other: Any) -> bool:
+ if not _is_comparable(self, other):
+ return NotImplemented
+ return recursive_ge(self, other)
+
+ dunders["__lt__"] = __lt__
+ dunders["__le__"] = __le__
+ dunders["__gt__"] = __gt__
+ dunders["__ge__"] = __ge__
+
+ for name, impl in dunders.items():
+ if name not in cls.__dict__:
+ setattr(cls, name, impl)
+
+
def get_registered_type_keys() -> Sequence[str]:
"""Get the list of valid type keys registered to TVM-FFI.
diff --git a/python/tvm_ffi/testing/testing.py
b/python/tvm_ffi/testing/testing.py
index d98374d9..59f1e50d 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -35,10 +35,10 @@ from typing import ClassVar
from .. import _ffi_api
from ..core import Object
from ..dataclasses import c_class
-from ..registry import get_global_func, register_object
+from ..registry import get_global_func
-@register_object("testing.TestObjectBase")
+@c_class("testing.TestObjectBase")
class TestObjectBase(Object):
"""Test object base class."""
@@ -54,10 +54,12 @@ class TestObjectBase(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestIntPair")
+@c_class("testing.TestIntPair")
class TestIntPair(Object):
"""Test Int Pair."""
+ __test__ = False
+
# tvm-ffi-stubgen(begin): object/testing.TestIntPair
# fmt: off
a: int
@@ -71,7 +73,7 @@ class TestIntPair(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestObjectDerived")
+@c_class("testing.TestObjectDerived")
class TestObjectDerived(TestObjectBase):
"""Test object derived class."""
@@ -85,14 +87,14 @@ class TestObjectDerived(TestObjectBase):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestNonCopyable")
+@c_class("testing.TestNonCopyable")
class TestNonCopyable(Object):
"""Test object with deleted copy constructor."""
value: int
-@register_object("testing.TestCompare")
+@c_class("testing.TestCompare")
class TestCompare(Object):
"""Test object with Compare(false) on ignored_field."""
@@ -111,7 +113,7 @@ class TestCompare(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestCustomCompare")
+@c_class("testing.TestCustomCompare")
class TestCustomCompare(Object):
"""Test object with custom __ffi_eq__/__ffi_compare__ hooks (compares only
key)."""
@@ -129,7 +131,7 @@ class TestCustomCompare(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestEqWithoutHash")
+@c_class("testing.TestEqWithoutHash")
class TestEqWithoutHash(Object):
"""Test object with __ffi_eq__ but no __ffi_hash__ (exercises hash
guard)."""
@@ -147,7 +149,7 @@ class TestEqWithoutHash(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestHash")
+@c_class("testing.TestHash")
class TestHash(Object):
"""Test object with Hash(false) on hash_ignored."""
@@ -166,7 +168,7 @@ class TestHash(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.TestCustomHash")
+@c_class("testing.TestCustomHash")
class TestCustomHash(Object):
"""Test object with custom __ffi_hash__ hook (hashes only key)."""
@@ -184,7 +186,7 @@ class TestCustomHash(Object):
# tvm-ffi-stubgen(end)
-@register_object("testing.SchemaAllTypes")
+@c_class("testing.SchemaAllTypes")
class _SchemaAllTypes:
# tvm-ffi-stubgen(ty-map): testing.SchemaAllTypes ->
testing._SchemaAllTypes
# tvm-ffi-stubgen(begin): object/testing.SchemaAllTypes
@@ -265,16 +267,30 @@ class _TestCxxClassBase(Object):
self.__ffi_init__(v_i64 + 1, v_i32 + 2)
-@c_class("testing.TestCxxClassDerived")
+@c_class("testing.TestCxxClassDerived", eq=True, order=True, unsafe_hash=True)
class _TestCxxClassDerived(_TestCxxClassBase):
v_f64: float
v_f32: float
+ if TYPE_CHECKING:
+
+ def __init__(self, v_i64: int, v_i32: int, v_f64: float, v_f32: float
= ...) -> None: ...
@c_class("testing.TestCxxClassDerivedDerived")
class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
v_str: str
v_bool: bool
+ if TYPE_CHECKING:
+
+ def __init__(
+ self,
+ v_i64: int,
+ v_i32: int,
+ v_f64: float,
+ v_bool: bool,
+ v_f32: float = ...,
+ v_str: str = ...,
+ ) -> None: ...
@c_class("testing.TestCxxInitSubset")
@@ -282,6 +298,9 @@ class _TestCxxInitSubset(Object):
required_field: int
optional_field: int
note: str
+ if TYPE_CHECKING:
+
+ def __init__(self, required_field: int) -> None: ...
@c_class("testing.TestCxxKwOnly")
@@ -290,9 +309,12 @@ class _TestCxxKwOnly(Object):
y: int
z: int
w: int
+ if TYPE_CHECKING:
+
+ def __init__(self, *, x: int, y: int, z: int, w: int = ...) -> None:
...
-@register_object("testing.TestCxxAutoInit")
+@c_class("testing.TestCxxAutoInit")
class _TestCxxAutoInit(Object):
"""Test object with init(false) on b and KwOnly(true) on c."""
@@ -307,7 +329,7 @@ class _TestCxxAutoInit(Object):
def __init__(self, a: int, d: int = ..., *, c: int) -> None: ...
-@register_object("testing.TestCxxAutoInitSimple")
+@c_class("testing.TestCxxAutoInitSimple")
class _TestCxxAutoInitSimple(Object):
"""Test object with all fields positional (no init/KwOnly traits)."""
@@ -320,7 +342,7 @@ class _TestCxxAutoInitSimple(Object):
def __init__(self, x: int, y: int) -> None: ...
-@register_object("testing.TestCxxAutoInitAllInitOff")
+@c_class("testing.TestCxxAutoInitAllInitOff")
class _TestCxxAutoInitAllInitOff(Object):
"""Test object with all fields excluded from auto-init (init(false))."""
@@ -334,7 +356,7 @@ class _TestCxxAutoInitAllInitOff(Object):
def __init__(self) -> None: ...
-@register_object("testing.TestCxxAutoInitKwOnlyDefaults")
+@c_class("testing.TestCxxAutoInitKwOnlyDefaults")
class _TestCxxAutoInitKwOnlyDefaults(Object):
"""Test object with mixed positional/kw-only/default/init=False fields."""
@@ -352,7 +374,7 @@ class _TestCxxAutoInitKwOnlyDefaults(Object):
) -> None: ...
-@register_object("testing.TestCxxNoAutoInit")
+@c_class("testing.TestCxxNoAutoInit", init=False)
class _TestCxxNoAutoInit(Object):
"""Test object with init(false) at class level — no __ffi_init__
generated."""
@@ -362,7 +384,7 @@ class _TestCxxNoAutoInit(Object):
y: int
-@register_object("testing.TestCxxAutoInitParent")
+@c_class("testing.TestCxxAutoInitParent")
class _TestCxxAutoInitParent(Object):
"""Parent object for inheritance auto-init tests."""
@@ -375,7 +397,7 @@ class _TestCxxAutoInitParent(Object):
def __init__(self, parent_required: int, parent_default: int = ...) ->
None: ...
-@register_object("testing.TestCxxAutoInitChild")
+@c_class("testing.TestCxxAutoInitChild")
class _TestCxxAutoInitChild(_TestCxxAutoInitParent):
"""Child object for inheritance auto-init tests."""
diff --git a/tests/python/test_dataclass_c_class.py
b/tests/python/test_dataclass_c_class.py
new file mode 100644
index 00000000..52bf38bb
--- /dev/null
+++ b/tests/python/test_dataclass_c_class.py
@@ -0,0 +1,319 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for the c_class decorator (register_object + structural dunders)."""
+
+from __future__ import annotations
+
+import inspect
+
+import pytest
+from tvm_ffi.testing import (
+ _TestCxxClassBase,
+ _TestCxxClassDerived,
+ _TestCxxClassDerivedDerived,
+ _TestCxxInitSubset,
+ _TestCxxKwOnly,
+)
+
+# ---------------------------------------------------------------------------
+# 1. Custom __init__ preservation
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_custom_init() -> None:
+ """c_class preserves user-defined __init__."""
+ obj = _TestCxxClassBase(v_i64=10, v_i32=20)
+ assert obj.v_i64 == 11 # +1 from custom __init__
+ assert obj.v_i32 == 22 # +2 from custom __init__
+
+
+# ---------------------------------------------------------------------------
+# 2. Auto-generated __init__ with defaults
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_auto_init_defaults() -> None:
+ """Derived classes use auto-generated __init__ with C++ defaults."""
+ obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0)
+ assert obj.v_i64 == 1
+ assert obj.v_i32 == 2
+ assert obj.v_f64 == 3.0
+ assert obj.v_f32 == 8.0 # default from C++
+
+
+def test_c_class_auto_init_all_explicit() -> None:
+ """Auto-generated __init__ accepts all fields explicitly."""
+ obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=9.0)
+ assert obj.v_i64 == 123
+ assert obj.v_i32 == 456
+ assert obj.v_f64 == 4.0
+ assert obj.v_f32 == 9.0
+
+
+# ---------------------------------------------------------------------------
+# 3. Structural equality (__eq__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_eq() -> None:
+ """c_class installs __eq__ using RecursiveEq."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ assert a == b
+ assert a is not b # different objects
+ c = _TestCxxClassDerived(1, 2, 3.0, 5.0)
+ assert a != c
+
+
+def test_c_class_eq_reflexive() -> None:
+ """Equality is reflexive: an object equals itself."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = a # alias, same object
+ assert a == b
+
+
+def test_c_class_eq_symmetric() -> None:
+ """Equality is symmetric: a == b implies b == a."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ assert a == b
+ assert b == a
+
+
+# ---------------------------------------------------------------------------
+# 4. Structural hash (__hash__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_hash() -> None:
+ """c_class installs __hash__ using RecursiveHash."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ assert hash(a) == hash(b)
+
+
+def test_c_class_hash_as_dict_key() -> None:
+ """Equal objects can be used interchangeably as dict keys."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ d = {a: "value"}
+ assert d[b] == "value"
+
+
+# ---------------------------------------------------------------------------
+# 5. Ordering (__lt__, __le__, __gt__, __ge__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_ordering() -> None:
+ """c_class installs ordering operators."""
+ small = _TestCxxClassDerived(0, 0, 0.0, 0.0)
+ big = _TestCxxClassDerived(100, 100, 100.0, 100.0)
+ assert small < big # ty: ignore[unsupported-operator]
+ assert small <= big # ty: ignore[unsupported-operator]
+ assert big > small # ty: ignore[unsupported-operator]
+ assert big >= small # ty: ignore[unsupported-operator]
+ assert not (big < small) # ty: ignore[unsupported-operator]
+ assert not (small > big) # ty: ignore[unsupported-operator]
+
+
+def test_c_class_ordering_reflexive() -> None:
+ """<= and >= are reflexive."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = a # alias, same object
+ assert a <= b # ty: ignore[unsupported-operator]
+ assert a >= b # ty: ignore[unsupported-operator]
+
+
+def test_c_class_ordering_antisymmetric() -> None:
+ """If a < b then not b < a."""
+ a = _TestCxxClassDerived(0, 0, 0.0, 0.0)
+ b = _TestCxxClassDerived(100, 100, 100.0, 100.0)
+ if a < b: # ty: ignore[unsupported-operator]
+ assert not (b < a) # ty: ignore[unsupported-operator]
+ else:
+ assert not (a < b) # ty: ignore[unsupported-operator]
+
+
+# ---------------------------------------------------------------------------
+# 6. Equality with different types returns NotImplemented
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_eq_different_type() -> None:
+ """__eq__ returns NotImplemented for unrelated types."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ assert a != "hello"
+ assert a != 42
+ assert a != 3.14
+ assert a is not None
+
+
+def test_c_class_ordering_different_type() -> None:
+ """Ordering against unrelated types raises TypeError."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ with pytest.raises(TypeError):
+ a < "hello" # ty: ignore[unsupported-operator]
+ with pytest.raises(TypeError):
+ a <= 42 # ty: ignore[unsupported-operator]
+ with pytest.raises(TypeError):
+ a > 3.14 # ty: ignore[unsupported-operator]
+ with pytest.raises(TypeError):
+ a >= None # ty: ignore[unsupported-operator]
+
+
+# ---------------------------------------------------------------------------
+# 7. Subclass equality
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_subclass_eq() -> None:
+ """Subclass instances can be compared to parent instances without
crashing."""
+ derived = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ derived_derived = _TestCxxClassDerivedDerived(
+ v_i64=1, v_i32=2, v_f64=3.0, v_f32=4.0, v_str="hello", v_bool=True
+ )
+ # These are different types in the same hierarchy; comparison should
+ # return a bool (the result depends on C++ behavior).
+ result = derived == derived_derived
+ assert isinstance(result, bool)
+
+
+# ---------------------------------------------------------------------------
+# 8. KwOnly from C++ reflection
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_kw_only_signature() -> None:
+ """kw_only trait comes from C++ reflection, not Python decorator."""
+ sig = inspect.signature(_TestCxxKwOnly.__init__)
+ params = sig.parameters
+ for name in ("x", "y", "z", "w"):
+ assert params[name].kind == inspect.Parameter.KEYWORD_ONLY, (
+ f"Expected {name} to be KEYWORD_ONLY"
+ )
+
+
+def test_c_class_kw_only_call() -> None:
+ """KwOnly fields can be supplied as keyword arguments."""
+ obj = _TestCxxKwOnly(x=1, y=2, z=3, w=4)
+ assert obj.x == 1
+ assert obj.y == 2
+ assert obj.z == 3
+ assert obj.w == 4
+
+
+def test_c_class_kw_only_default() -> None:
+ """KwOnly field with a C++ default can be omitted."""
+ obj = _TestCxxKwOnly(x=1, y=2, z=3)
+ assert obj.w == 100
+
+
+def test_c_class_kw_only_rejects_positional() -> None:
+ """KwOnly fields reject positional arguments."""
+ with pytest.raises(TypeError, match="positional"):
+ _TestCxxKwOnly(1, 2, 3, 4) # ty: ignore[missing-argument,
too-many-positional-arguments]
+
+
+# ---------------------------------------------------------------------------
+# 9. Init subset from C++ reflection
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_init_subset_signature() -> None:
+ """init=False fields from C++ reflection are excluded from __init__."""
+ sig = inspect.signature(_TestCxxInitSubset.__init__)
+ params = tuple(sig.parameters)
+ assert "required_field" in params
+ assert "optional_field" not in params
+ assert "note" not in params
+
+
+def test_c_class_init_subset_defaults() -> None:
+ """init=False fields get their default values from C++."""
+ obj = _TestCxxInitSubset(required_field=42)
+ assert obj.required_field == 42
+ assert obj.optional_field == -1 # C++ default
+ assert obj.note == "default" # C++ default
+
+
+def test_c_class_init_subset_positional() -> None:
+ """Init-subset fields can be passed positionally."""
+ obj = _TestCxxInitSubset(7)
+ assert obj.required_field == 7
+ assert obj.optional_field == -1
+
+
+def test_c_class_init_subset_field_writable() -> None:
+ """Fields excluded from __init__ can still be assigned after
construction."""
+ obj = _TestCxxInitSubset(required_field=0)
+ obj.optional_field = 11
+ assert obj.optional_field == 11
+
+
+# ---------------------------------------------------------------------------
+# 10. DerivedDerived with defaults
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_derived_derived_defaults() -> None:
+ """DerivedDerived uses positional args; C++ defaults fill in omitted
fields."""
+ obj = _TestCxxClassDerivedDerived(1, 2, 3.0, True)
+ assert obj.v_i64 == 1
+ assert obj.v_i32 == 2
+ assert obj.v_f64 == 3.0
+ assert obj.v_f32 == 8.0 # C++ default
+ assert obj.v_str == "default" # C++ default
+ assert obj.v_bool is True
+
+
+def test_c_class_derived_derived_all_explicit() -> None:
+ """DerivedDerived with all fields explicitly provided."""
+ obj = _TestCxxClassDerivedDerived(
+ v_i64=123,
+ v_i32=456,
+ v_f64=4.0,
+ v_f32=9.0,
+ v_str="hello",
+ v_bool=True,
+ )
+ assert obj.v_i64 == 123
+ assert obj.v_i32 == 456
+ assert obj.v_f64 == 4.0
+ assert obj.v_f32 == 9.0
+ assert obj.v_str == "hello"
+ assert obj.v_bool is True
+
+
+# ---------------------------------------------------------------------------
+# 11. Hash / set usage
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_usable_in_set() -> None:
+ """Equal objects deduplicate in a set."""
+ a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+ c = _TestCxxClassDerived(5, 6, 7.0, 8.0)
+ s = {a, b, c}
+ assert len(s) == 2 # a and b are equal
+
+
+def test_c_class_unequal_objects_in_set() -> None:
+ """Distinct objects are separate entries in a set."""
+ objs = {_TestCxxClassDerived(i, i, float(i), float(i)) for i in range(5)}
+ assert len(objs) == 5
diff --git a/tests/python/test_copy.py b/tests/python/test_dataclass_copy.py
similarity index 80%
rename from tests/python/test_copy.py
rename to tests/python/test_dataclass_copy.py
index 3ec9a55c..df9b6367 100644
--- a/tests/python/test_copy.py
+++ b/tests/python/test_dataclass_copy.py
@@ -685,6 +685,208 @@ class TestDeepCopyBranches:
assert not pair.same_as(deep_pair)
assert deep_pair.a == 5
+ # --- Cycle preservation with immutable root containers ---
+
+ def test_cycle_list_root_map_backref_preserved(self) -> None:
+ """Control case: List root with Map back-reference should preserve
cycle."""
+ root_list = tvm_ffi.List()
+ m = tvm_ffi.Map({"list": root_list})
+ root_list.append(m)
+
+ deep_list = copy.deepcopy(root_list)
+ assert not root_list.same_as(deep_list)
+ assert deep_list[0]["list"].same_as(deep_list)
+
+ def test_cycle_map_root_list_backref_preserved(self) -> None:
+ """Map root with List child pointing back should preserve cycle to
root copy."""
+ l = tvm_ffi.List()
+ m = tvm_ffi.Map({"list": l})
+ l.append(m)
+
+ deep_map = copy.deepcopy(m)
+ assert not m.same_as(deep_map)
+ assert not l.same_as(deep_map["list"])
+ assert deep_map["list"][0].same_as(deep_map)
+
+ def test_cycle_array_root_list_backref_preserved(self) -> None:
+ """Array root with List child pointing back should preserve cycle to
root copy."""
+ l = tvm_ffi.List()
+ a = tvm_ffi.Array([l])
+ l.append(a)
+
+ deep_arr = copy.deepcopy(a)
+ assert not a.same_as(deep_arr)
+ assert not l.same_as(deep_arr[0])
+ assert deep_arr[0][0].same_as(deep_arr)
+
+ def test_cycle_array_root_dict_backref_preserved(self) -> None:
+ """Array root with Dict child pointing back should preserve cycle to
root copy."""
+ d = tvm_ffi.Dict()
+ a = tvm_ffi.Array([d])
+ d["self"] = a
+
+ deep_arr = copy.deepcopy(a)
+ assert not a.same_as(deep_arr)
+ assert not d.same_as(deep_arr[0])
+ assert deep_arr[0]["self"].same_as(deep_arr)
+
+ def test_cycle_map_root_dict_backref_preserved(self) -> None:
+ """Map root with Dict child pointing back should preserve cycle to
root copy."""
+ d = tvm_ffi.Dict()
+ m = tvm_ffi.Map({"dict": d})
+ d["self"] = m
+
+ deep_map = copy.deepcopy(m)
+ assert not m.same_as(deep_map)
+ assert not d.same_as(deep_map["dict"])
+ assert deep_map["dict"]["self"].same_as(deep_map)
+
+ def test_cycle_map_root_backref_identity_not_duplicated(self) -> None:
+ """Back-references in a map-root cycle should point to the root copied
map."""
+ shared_list = tvm_ffi.List()
+ m = tvm_ffi.Map({"l1": shared_list, "l2": shared_list})
+ shared_list.append(m)
+
+ deep_map = copy.deepcopy(m)
+ assert deep_map["l1"].same_as(deep_map["l2"])
+ assert deep_map["l1"][0].same_as(deep_map)
+
+ def test_cycle_map_root_list_key_backref_preserved(self) -> None:
+ """Map-root cycles through keys should preserve back-reference to
copied root."""
+ key_list = tvm_ffi.List()
+ m = tvm_ffi.Map({key_list: 1})
+ key_list.append(m)
+
+ deep_map = copy.deepcopy(m)
+ deep_key = next(iter(deep_map.keys()))
+ assert isinstance(deep_key, tvm_ffi.List)
+ assert deep_key[0].same_as(deep_map)
+
+ def test_cycle_map_root_dict_key_backref_preserved(self) -> None:
+ """Map-root cycles through Dict keys should preserve back-reference to
copied root."""
+ key_dict = tvm_ffi.Dict()
+ m = tvm_ffi.Map({key_dict: 1})
+ key_dict["self"] = m
+
+ deep_map = copy.deepcopy(m)
+ deep_key = next(iter(deep_map.keys()))
+ assert isinstance(deep_key, tvm_ffi.Dict)
+ assert deep_key["self"].same_as(deep_map)
+
+ def test_cycle_array_root_dict_contains_root_as_key(self) -> None:
+ """Array root with Dict child using the root as key should fix key to
copied root."""
+ d = tvm_ffi.Dict()
+ root = tvm_ffi.Array([d])
+ d[root] = 1
+
+ deep_root = copy.deepcopy(root)
+ deep_dict = deep_root[0]
+ deep_key = next(iter(deep_dict.keys()))
+
+ assert not root.same_as(deep_root)
+ assert deep_key.same_as(deep_root)
+ assert not deep_key.same_as(root)
+
+ def test_cycle_map_root_dict_contains_root_as_key(self) -> None:
+ """Map root with Dict child using the root as key should fix key to
copied root."""
+ d = tvm_ffi.Dict()
+ root = tvm_ffi.Map({"d": d})
+ d[root] = 1
+
+ deep_root = copy.deepcopy(root)
+ deep_dict = deep_root["d"]
+ deep_key = next(iter(deep_dict.keys()))
+
+ assert not root.same_as(deep_root)
+ assert deep_key.same_as(deep_root)
+ assert not deep_key.same_as(root)
+
+ # --- Python deepcopy protocol consistency for immutable Shape ---
+
+ def test_shape_root_python_deepcopy_matches_ffi_deepcopy(self) -> None:
+ """copy.deepcopy(Shape) should be consistent with ffi.DeepCopy."""
+ deep_copy_fn = tvm_ffi.get_global_func("ffi.DeepCopy")
+ s = tvm_ffi.Shape((2, 3, 4))
+ ffi_copied = deep_copy_fn(s)
+ py_copied = copy.deepcopy(s)
+ assert py_copied == ffi_copied
+ assert isinstance(py_copied, type(s))
+
+ def test_shape_inside_python_container_deepcopy(self) -> None:
+ """Python container deepcopy should handle Shape payloads."""
+ s = tvm_ffi.Shape((1, 2))
+ payload = [s, {"shape": s}]
+ copied = copy.deepcopy(payload)
+ assert copied[0] == s
+ assert copied[1]["shape"] == s # ty: ignore[invalid-argument-type]
+
+ # --- Cycle fixup: immutable container → reflected object back-reference
---
+
+ def test_cycle_array_root_object_backreference(self) -> None:
+ """Array A → Object X, X.v_array = A. Deep copy from A."""
+ obj = tvm_ffi.testing.create_object(
+ "testing.TestObjectDerived",
+ v_i64=42,
+ v_map=tvm_ffi.Map({}),
+ v_array=tvm_ffi.Array([]),
+ )
+ arr = tvm_ffi.Array([obj])
+ obj.v_array = arr # ty: ignore[unresolved-attribute]
+
+ arr_deep = _deep_copy(arr)
+
+ assert not arr.same_as(arr_deep)
+ obj_deep = arr_deep[0]
+ assert not obj.same_as(obj_deep)
+ assert obj_deep.v_i64 == 42
+ assert not obj_deep.v_array.same_as(arr)
+ assert obj_deep.v_array.same_as(arr_deep)
+
+ def test_cycle_map_root_object_backreference(self) -> None:
+ """Map M → Object X, X.v_map = M. Deep copy from M."""
+ obj = tvm_ffi.testing.create_object(
+ "testing.TestObjectDerived",
+ v_i64=7,
+ v_map=tvm_ffi.Map({}),
+ v_array=tvm_ffi.Array([]),
+ )
+ m = tvm_ffi.Map({"key": obj})
+ obj.v_map = m # ty: ignore[unresolved-attribute]
+
+ m_deep = _deep_copy(m)
+
+ assert not m.same_as(m_deep)
+ obj_deep = m_deep["key"]
+ assert not obj.same_as(obj_deep)
+ assert obj_deep.v_i64 == 7
+ assert not obj_deep.v_map.same_as(m)
+ assert obj_deep.v_map.same_as(m_deep)
+
+ def test_cycle_nested_array_object_array(self) -> None:
+ """Array → Object → Array → Object → back to root Array."""
+ inner = tvm_ffi.testing.create_object(
+ "testing.TestObjectDerived",
+ v_i64=1,
+ v_map=tvm_ffi.Map({}),
+ v_array=tvm_ffi.Array([]),
+ )
+ outer = tvm_ffi.testing.create_object(
+ "testing.TestObjectDerived",
+ v_i64=2,
+ v_map=tvm_ffi.Map({}),
+ v_array=tvm_ffi.Array([inner]),
+ )
+ root_arr = tvm_ffi.Array([outer])
+ inner.v_array = root_arr # ty: ignore[unresolved-attribute]
+
+ root_deep = _deep_copy(root_arr)
+
+ assert not root_arr.same_as(root_deep)
+ outer_deep = root_deep[0]
+ inner_deep = outer_deep.v_array[0]
+ assert not inner_deep.v_array.same_as(root_arr)
+ assert inner_deep.v_array.same_as(root_deep)
+
# --------------------------------------------------------------------------- #
# __replace__
diff --git a/tests/python/test_dataclass_init.py
b/tests/python/test_dataclass_init.py
index b957f54e..918838cf 100644
--- a/tests/python/test_dataclass_init.py
+++ b/tests/python/test_dataclass_init.py
@@ -886,7 +886,7 @@ class TestClassLevelInitFalse:
assert field_names == ["x", "y"]
def test_direct_construction_raises(self) -> None:
- with pytest.raises(TypeError):
+ with pytest.raises(TypeError, match="cannot be constructed directly"):
_TestCxxNoAutoInit(1, 2) # ty:
ignore[too-many-positional-arguments]
def test_has_shallow_copy(self) -> None: