This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6973d225 feat(python): wire __init__ from C++ reflection in
register_object and stubgen (#491)
6973d225 is described below
commit 6973d225eb3c67a7c306e36b20a100c5e9ff46f7
Author: Junru Shao <[email protected]>
AuthorDate: Sun Mar 1 16:11:36 2026 -0800
feat(python): wire __init__ from C++ reflection in register_object and
stubgen (#491)
## Summary
Wire `__init__` from C++ reflection metadata in both `register_object`
(runtime) and `tvm-ffi-stubgen` (static stubs), so that registered
classes with `refl::init()` work out of the box without `@c_class`.
## Architecture
- **`stub/utils.py`**: New `InitFieldInfo` / `ObjectInfo.gen_init()`
walk the TypeInfo parent chain to collect `c_init`, `c_kw_only`,
`c_has_default` flags per field, emitting typed `__init__` stubs via
`_gen_auto_init` (KWARGS protocol with proper signature) or
`_gen_c_init` (positional pass-through from `__c_ffi_init__`).
- **`stub/codegen.py`**: `generate_object()` now calls `gen_init()` and
injects the result into the `TYPE_CHECKING` block before method stubs.
- **`registry.py`**: `register_object._register()` calls
`_install_init(cls, enabled=True)` after `_add_class_attrs`, wiring
`__init__` → `__ffi_init__` for any class whose C++ `ObjectDef`
registered `refl::init()`.
- **`_install_init`**: When `enabled=True` and no `__ffi_init__` exists,
returns silently instead of installing a TypeError guard —
backward-compatible for `Object()` and unregistered subclasses. The
`enabled=False` guard (used by `@c_class(init=False)`) is unchanged.
- Removed duplicate `_install_init` definition that shadowed the primary
one after rebase.
## Public Interfaces
- `register_object` now auto-wires `__init__` when C++ `__ffi_init__`
exists; previously only `@c_class` did this.
- `core.pyi`: Added `TypeField.c_init`, `c_kw_only`, `c_has_default` and
`TypeInfo.type_ancestors` stubs.
- Slots docstrings updated: recommend `__slots__ = ("__dict__",)`
instead of the removed `slots=False` metaclass keyword.
## Behavioral Changes
- Registered classes with `__ffi_init__` (non-auto-init) get `__init__ =
__ffi_init__` automatically — fixes `IntPair(1, 2)` in
`examples/python_packaging`.
- Classes without `__ffi_init__` keep default `object.__init__` behavior
(no guard installed).
- `@c_class(init=False)` still installs a `TypeError` guard as before.
- 30 unnecessary `ty: ignore` comments removed after stubgen generates
proper `__init__` signatures.
## Test Plan
- [x] Pre-commit hooks pass (ruff check, ruff format, ty check,
cython-lint, clang-format, ASF headers)
- [x] Full `pytest tests/python` (requires build)
- [x] `examples/python_packaging` end-to-end (requires build + `uv pip
install`)
- [x] C++ tests unaffected (no C++ changes)
## Untested Edge Cases
- stubgen `_gen_c_init` on packages whose shared library is not loaded
at stubgen time (the metadata would be unavailable; no crash, just no
`__init__` emitted).
- Interaction of auto-wired `__init__` with `__init_subclass__` hooks on
deeply nested Object hierarchies.
---
.../python/my_ffi_extension/_ffi_api.py | 2 +
examples/python_packaging/run_example.py | 2 +-
python/tvm_ffi/core.pyi | 4 +
python/tvm_ffi/cython/object.pxi | 5 +-
python/tvm_ffi/dataclasses/__init__.py | 2 +-
python/tvm_ffi/registry.py | 46 ++-
python/tvm_ffi/structural.py | 3 +
python/tvm_ffi/stub/codegen.py | 4 +-
python/tvm_ffi/stub/utils.py | 119 +++++++-
python/tvm_ffi/testing/testing.py | 326 +++++++++++++--------
tests/python/test_dataclass_c_class.py | 22 +-
tests/python/test_dataclass_compare.py | 142 ++++-----
tests/python/test_dataclass_copy.py | 48 +--
tests/python/test_dataclass_hash.py | 119 ++++----
tests/python/test_dataclass_repr.py | 4 +-
tests/python/test_object.py | 14 +-
tests/python/test_serialization.py | 10 +-
17 files changed, 553 insertions(+), 319 deletions(-)
diff --git a/examples/python_packaging/python/my_ffi_extension/_ffi_api.py
b/examples/python_packaging/python/my_ffi_extension/_ffi_api.py
index 0ea71b89..0b67bd8b 100644
--- a/examples/python_packaging/python/my_ffi_extension/_ffi_api.py
+++ b/examples/python_packaging/python/my_ffi_extension/_ffi_api.py
@@ -48,6 +48,8 @@ class IntPair(_ffi_Object):
a: int
b: int
if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: int, /) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
@staticmethod
def __c_ffi_init__(_0: int, _1: int, /) -> Object: ...
def sum(self, /) -> int: ...
diff --git a/examples/python_packaging/run_example.py
b/examples/python_packaging/run_example.py
index fd3c03cf..8c601e60 100644
--- a/examples/python_packaging/run_example.py
+++ b/examples/python_packaging/run_example.py
@@ -45,7 +45,7 @@ def run_raise_error() -> None:
def run_int_pair() -> None:
"""Invoke IntPair from the extension to demonstrate object handling."""
print("=========== Example 4: IntPair ===========")
- pair = my_ffi_extension.IntPair(1, 2)
+ pair = my_ffi_extension.IntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
print(f"a={pair.a}")
print(f"b={pair.b}")
print(f"sum={pair.sum()}")
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index c7b35b76..d39d83c1 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -252,6 +252,9 @@ class TypeField:
metadata: dict[str, Any]
getter: Any
setter: Any
+ c_init: bool
+ c_kw_only: bool
+ c_has_default: bool
dataclass_field: Any | None
def as_property(self, cls: type) -> property: ...
@@ -269,6 +272,7 @@ class TypeInfo:
type_cls: type | None
type_index: int
type_key: str
+ type_ancestors: list[int]
fields: list[TypeField]
methods: list[TypeMethod]
parent_type_info: TypeInfo | None
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index ba0a4906..6eb22539 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -204,9 +204,8 @@ class Object(CObject, metaclass=_ObjectSlotsMeta):
concrete type. Use :py:meth:`same_as` to check whether two
references point to the same underlying object.
- Subclasses that omit ``__slots__`` get ``__slots__ = ()`` injected
- automatically by the metaclass. Pass ``slots=False`` in the class
- header (e.g. ``class Foo(Object, slots=False)``) to suppress this
- and allow a per-instance ``__dict__``.
+ automatically by the metaclass. To allow a per-instance ``__dict__``,
+ declare ``__slots__ = ("__dict__",)`` explicitly in the class body.
- Most users interact with subclasses (e.g. :class:`Tensor`,
:class:`Function`) rather than :py:class:`Object` directly.
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
index 912d6bd1..8ea00ab3 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""C++ FFI classes registered via ``c_class`` decorator."""
+"""C++ FFI classes with structural comparison and hashing."""
from .c_class import c_class
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index b45b03e0..82540494 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -42,6 +42,17 @@ def register_object(type_key: str | None = None) ->
Callable[[_T], _T]:
The type key of the node. It requires ``type_key`` to be registered
already
on the C++ side. If not specified, the class name will be used.
+ Notes
+ -----
+ All :class:`Object` subclasses get ``__slots__ = ()`` by default via the
+ metaclass, preventing per-instance ``__dict__``. To opt out and allow
+ arbitrary instance attributes, declare ``__slots__ = ("__dict__",)``
+ explicitly in the class body::
+
+ @tvm_ffi.register_object("test.MyObject")
+ class MyObject(Object):
+ __slots__ = ("__dict__",)
+
Examples
--------
The following code registers MyObject using type key "test.MyObject", if
the
@@ -65,6 +76,7 @@ def register_object(type_key: str | None = None) ->
Callable[[_T], _T]:
info = core._register_object_by_index(type_index, cls)
_add_class_attrs(type_cls=cls, type_info=info)
setattr(cls, "__tvm_ffi_type_info__", info)
+ _install_init(cls, enabled=True)
return cls
if isinstance(type_key, str):
@@ -334,7 +346,14 @@ __SENTINEL = object()
def _make_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
- """Build a Python ``__init__`` that delegates to the C++ auto-generated
``__ffi_init__``."""
+ """Build a Python ``__init__`` that delegates to the C++ auto-generated
``__ffi_init__``.
+
+ Reads per-field ``c_init``, ``c_kw_only``, and ``c_has_default`` from the
+ TypeField bitmask fields and produces a function with matching Python
+ signature. The ``__init__`` body is a trivial adapter — all validation
+ (too many positional, duplicates, missing required, kw_only enforcement,
+ unknown kwargs) is handled by C++.
+ """
sig = _make_init_signature(type_info)
kwargs_obj = core.KWARGS
@@ -353,7 +372,12 @@ def _make_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
def _make_init_signature(type_info: TypeInfo) -> inspect.Signature:
- """Build an ``inspect.Signature`` from reflection field metadata."""
+ """Build an ``inspect.Signature`` from reflection field metadata.
+
+ Walks the parent chain (parent-first) to collect all ``init=True`` fields,
+ reorders required-before-optional within each group, and returns a
+ Signature for introspection.
+ """
positional: list[tuple[str, bool]] = [] # (name, has_default)
kw_only: list[tuple[str, bool]] = [] # (name, has_default)
@@ -480,19 +504,11 @@ def _install_init(cls: type, *, enabled: bool) -> None:
setattr(cls, "__init__", _make_init(cls, type_info))
else:
setattr(cls, "__init__", getattr(cls, "__ffi_init__"))
- return
- if issubclass(cls, core.PyNativeObject):
- return
- msg = (
- f"`{cls.__name__}` (C++ type `{type_info.type_key}`) has no
__ffi_init__ "
- f"registered. Either add `refl::init()` to its C++ ObjectDef, "
- f"or pass `init=False` to @c_class."
- )
- else:
- msg = (
- f"`{cls.__name__}` cannot be constructed directly. "
- f"Define a custom __init__ or use a factory method."
- )
+ return
+ msg = (
+ f"`{cls.__name__}` cannot be constructed directly. "
+ f"Define a custom __init__ or use a factory method."
+ )
def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
raise TypeError(msg)
diff --git a/python/tvm_ffi/structural.py b/python/tvm_ffi/structural.py
index 798e264f..02785386 100644
--- a/python/tvm_ffi/structural.py
+++ b/python/tvm_ffi/structural.py
@@ -206,7 +206,10 @@ class StructuralKey(Object):
key: Any
hash_i64: int
if TYPE_CHECKING:
+ def __init__(self, key: Any, hash_i64: int) -> None: ...
def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
# fmt: on
# tvm-ffi-stubgen(end)
diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py
index 9b9294ec..6ca36f43 100644
--- a/python/tvm_ffi/stub/codegen.py
+++ b/python/tvm_ffi/stub/codegen.py
@@ -109,7 +109,8 @@ def generate_object(
info = obj_info
method_names = {m.schema.name.rsplit(".", 1)[-1] for m in info.methods}
fn_ty_map = _type_suffix_and_record(ty_map, imports,
func_names=method_names)
- if info.methods:
+ init_lines = info.gen_init(fn_ty_map, indent=opt.indent)
+ if info.methods or init_lines:
imports.append(
ImportItem(
"typing.TYPE_CHECKING",
@@ -120,6 +121,7 @@ def generate_object(
"# fmt: off",
*info.gen_fields(fn_ty_map, indent=0),
"if TYPE_CHECKING:",
+ *init_lines,
*info.gen_methods(fn_ty_map, indent=opt.indent),
"# fmt: on",
]
diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py
index e02bbc3b..5393e7e4 100644
--- a/python/tvm_ffi/stub/utils.py
+++ b/python/tvm_ffi/stub/utils.py
@@ -20,13 +20,20 @@ from __future__ import annotations
import dataclasses
from io import StringIO
-from typing import Callable
+from typing import Any, Callable
from tvm_ffi.core import TypeInfo, TypeSchema
from . import consts as C
+def _parse_type_schema(raw: str | dict[str, Any]) -> TypeSchema:
+ """Parse a type schema from either a JSON string or an already-parsed
dict."""
+ if isinstance(raw, dict):
+ return TypeSchema.from_json_obj(raw)
+ return TypeSchema.from_json_str(raw)
+
+
@dataclasses.dataclass
class InitConfig:
"""Configuration for generating new stubs.
@@ -178,6 +185,16 @@ class FuncInfo:
return buf.getvalue()
[email protected]
+class InitFieldInfo:
+ """A field that participates in the auto-generated ``__init__``."""
+
+ name: str
+ schema: NamedTypeSchema
+ kw_only: bool
+ has_default: bool
+
+
@dataclasses.dataclass
class ObjectInfo:
"""Information of an object type, including its fields and methods."""
@@ -186,6 +203,9 @@ class ObjectInfo:
methods: list[FuncInfo]
type_key: str | None = None
parent_type_key: str | None = None
+ init_fields: list[InitFieldInfo] = dataclasses.field(default_factory=list)
+ has_auto_init: bool = False
+ has_c_init: bool = False
@staticmethod
def from_type_info(type_info: TypeInfo) -> ObjectInfo:
@@ -193,11 +213,45 @@ class ObjectInfo:
parent_type_key: str | None = None
if type_info.parent_type_info is not None:
parent_type_key = type_info.parent_type_info.type_key
+
+ # Detect auto_init / c_init from __ffi_init__ method metadata.
+ has_auto_init = False
+ has_c_init = False
+ for method in type_info.methods:
+ if method.name == "__ffi_init__":
+ has_c_init = True
+ has_auto_init = bool(method.metadata.get("auto_init", False))
+ break
+
+ # Walk parent chain (parent-first) to collect all init-eligible fields.
+ init_fields: list[InitFieldInfo] = []
+ if has_auto_init:
+ ti: TypeInfo | None = type_info
+ chain: list[TypeInfo] = []
+ while ti is not None:
+ chain.append(ti)
+ ti = ti.parent_type_info
+ for ancestor_info in reversed(chain):
+ for field in ancestor_info.fields:
+ if not field.c_init:
+ continue
+ init_fields.append(
+ InitFieldInfo(
+ name=field.name,
+ schema=NamedTypeSchema(
+ name=field.name,
+
schema=_parse_type_schema(field.metadata["type_schema"]),
+ ),
+ kw_only=field.c_kw_only,
+ has_default=field.c_has_default,
+ )
+ )
+
return ObjectInfo(
fields=[
NamedTypeSchema(
name=field.name,
-
schema=TypeSchema.from_json_str(field.metadata["type_schema"]),
+ schema=_parse_type_schema(field.metadata["type_schema"]),
)
for field in type_info.fields
],
@@ -205,7 +259,7 @@ class ObjectInfo:
FuncInfo(
schema=NamedTypeSchema(
name=C.FN_NAME_MAP.get(method.name, method.name),
-
schema=TypeSchema.from_json_str(method.metadata["type_schema"]),
+
schema=_parse_type_schema(method.metadata["type_schema"]),
),
is_member=not method.is_static,
)
@@ -213,6 +267,9 @@ class ObjectInfo:
],
type_key=type_info.type_key,
parent_type_key=parent_type_key,
+ init_fields=init_fields,
+ has_auto_init=has_auto_init,
+ has_c_init=has_c_init,
)
def gen_fields(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
@@ -229,3 +286,59 @@ class ObjectInfo:
ret.append(f"{indent_str}@staticmethod")
ret.append(method.gen(ty_map, indent))
return ret
+
+ def gen_init(self, ty_map: Callable[[str], str], indent: int) -> list[str]:
+ """Generate an ``__init__`` stub from reflection metadata."""
+ if self.has_auto_init:
+ return self._gen_auto_init(ty_map, indent)
+ if self.has_c_init:
+ return self._gen_c_init(ty_map, indent)
+ return []
+
+ def _gen_auto_init(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
+ """Generate ``__init__`` for auto-init types (KWARGS protocol)."""
+ indent_str = " " * indent
+ positional = [f for f in self.init_fields if not f.kw_only]
+ kw_only = [f for f in self.init_fields if f.kw_only]
+
+ pos_required = [f for f in positional if not f.has_default]
+ pos_default = [f for f in positional if f.has_default]
+ kw_required = [f for f in kw_only if not f.has_default]
+ kw_default = [f for f in kw_only if f.has_default]
+
+ parts: list[str] = []
+ for f in pos_required:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+ for f in pos_default:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+ if kw_required or kw_default:
+ parts.append("*")
+ for f in kw_required:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+ for f in kw_default:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+
+ params = ", ".join(parts)
+ if params:
+ return [f"{indent_str}def __init__(self, {params}) -> None: ..."]
+ return [f"{indent_str}def __init__(self) -> None: ..."]
+
+ def _gen_c_init(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
+ """Generate ``__init__`` for non-auto-init types (from
``__c_ffi_init__``)."""
+ indent_str = " " * indent
+ for method in self.methods:
+ func_name = method.schema.name.rsplit(".", 1)[-1]
+ if func_name != "__c_ffi_init__":
+ continue
+ schema = method.schema
+ if schema.origin != "Callable" or not schema.args:
+ break
+ arg_types = schema.args[1:] # skip return type (args[0])
+ parts: list[str] = []
+ for i, arg in enumerate(arg_types):
+ parts.append(f"_{i}: {arg.repr(ty_map)}")
+ params = ", ".join(parts)
+ if params:
+ return [f"{indent_str}def __init__(self, {params}, /) -> None:
..."]
+ return [f"{indent_str}def __init__(self) -> None: ..."]
+ return []
diff --git a/python/tvm_ffi/testing/testing.py
b/python/tvm_ffi/testing/testing.py
index 59f1e50d..3cdc288b 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -48,8 +48,11 @@ class TestObjectBase(Object):
v_f64: float
v_str: str
if TYPE_CHECKING:
+ def __init__(self, v_i64: int = ..., v_f64: float = ..., v_str: str =
...) -> None: ...
def __ffi_shallow_copy__(self, /) -> Object: ...
def add_i64(self, _1: int, /) -> int: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
# fmt: on
# tvm-ffi-stubgen(end)
@@ -65,6 +68,7 @@ class TestIntPair(Object):
a: int
b: int
if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: int, /) -> None: ...
def __ffi_shallow_copy__(self, /) -> Object: ...
@staticmethod
def __c_ffi_init__(_0: int, _1: int, /) -> Object: ...
@@ -82,7 +86,10 @@ class TestObjectDerived(TestObjectBase):
v_map: Mapping[Any, Any]
v_array: Sequence[Any]
if TYPE_CHECKING:
+ def __init__(self, v_map: Mapping[Any, Any], v_array: Sequence[Any],
v_i64: int = ..., v_f64: float = ..., v_str: str = ...) -> None: ...
def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
# fmt: on
# tvm-ffi-stubgen(end)
@@ -91,97 +98,13 @@ class TestObjectDerived(TestObjectBase):
class TestNonCopyable(Object):
"""Test object with deleted copy constructor."""
- value: int
-
-
-@c_class("testing.TestCompare")
-class TestCompare(Object):
- """Test object with Compare(false) on ignored_field."""
-
- __test__ = False
-
- # tvm-ffi-stubgen(begin): object/testing.TestCompare
- # fmt: off
- key: int
- name: str
- ignored_field: int
- if TYPE_CHECKING:
- def __ffi_shallow_copy__(self, /) -> Object: ...
- @staticmethod
- def __c_ffi_init__(_0: int, _1: str, _2: int, /) -> Object: ...
- # fmt: on
- # tvm-ffi-stubgen(end)
-
-
-@c_class("testing.TestCustomCompare")
-class TestCustomCompare(Object):
- """Test object with custom __ffi_eq__/__ffi_compare__ hooks (compares only
key)."""
-
- __test__ = False
-
- # tvm-ffi-stubgen(begin): object/testing.TestCustomCompare
- # fmt: off
- key: int
- label: str
- if TYPE_CHECKING:
- def __ffi_shallow_copy__(self, /) -> Object: ...
- @staticmethod
- def __c_ffi_init__(_0: int, _1: str, /) -> Object: ...
- # fmt: on
- # tvm-ffi-stubgen(end)
-
-
-@c_class("testing.TestEqWithoutHash")
-class TestEqWithoutHash(Object):
- """Test object with __ffi_eq__ but no __ffi_hash__ (exercises hash
guard)."""
-
- __test__ = False
-
- # tvm-ffi-stubgen(begin): object/testing.TestEqWithoutHash
+ # tvm-ffi-stubgen(begin): object/testing.TestNonCopyable
# fmt: off
- key: int
- label: str
- if TYPE_CHECKING:
- def __ffi_shallow_copy__(self, /) -> Object: ...
- @staticmethod
- def __c_ffi_init__(_0: int, _1: str, /) -> Object: ...
- # fmt: on
- # tvm-ffi-stubgen(end)
-
-
-@c_class("testing.TestHash")
-class TestHash(Object):
- """Test object with Hash(false) on hash_ignored."""
-
- __test__ = False
-
- # tvm-ffi-stubgen(begin): object/testing.TestHash
- # fmt: off
- key: int
- name: str
- hash_ignored: int
- if TYPE_CHECKING:
- def __ffi_shallow_copy__(self, /) -> Object: ...
- @staticmethod
- def __c_ffi_init__(_0: int, _1: str, _2: int, /) -> Object: ...
- # fmt: on
- # tvm-ffi-stubgen(end)
-
-
-@c_class("testing.TestCustomHash")
-class TestCustomHash(Object):
- """Test object with custom __ffi_hash__ hook (hashes only key)."""
-
- __test__ = False
-
- # tvm-ffi-stubgen(begin): object/testing.TestCustomHash
- # fmt: off
- key: int
- label: str
+ value: int
if TYPE_CHECKING:
- def __ffi_shallow_copy__(self, /) -> Object: ...
+ def __init__(self, _0: int, /) -> None: ...
@staticmethod
- def __c_ffi_init__(_0: int, _1: str, /) -> Object: ...
+ def __c_ffi_init__(_0: int, /) -> Object: ...
# fmt: on
# tvm-ffi-stubgen(end)
@@ -256,10 +179,117 @@ def add_one(x: int) -> int:
return get_global_func("testing.add_one")(x)
+@c_class("testing.TestCompare")
+class TestCompare(Object):
+ """Test object with Compare(false) on ignored_field."""
+
+ __test__ = False
+
+ # tvm-ffi-stubgen(begin): object/testing.TestCompare
+ # fmt: off
+ key: int
+ name: str
+ ignored_field: int
+ if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: str, _2: int, /) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(_0: int, _1: str, _2: int, /) -> Object: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+
+@c_class("testing.TestHash")
+class TestHash(Object):
+ """Test object with Hash(false) on hash_ignored."""
+
+ __test__ = False
+
+ # tvm-ffi-stubgen(begin): object/testing.TestHash
+ # fmt: off
+ key: int
+ name: str
+ hash_ignored: int
+ if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: str, _2: int, /) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(_0: int, _1: str, _2: int, /) -> Object: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+
+@c_class("testing.TestCustomHash")
+class TestCustomHash(Object):
+ """Test object with custom __ffi_hash__ hook (hashes only key)."""
+
+ __test__ = False
+
+ # tvm-ffi-stubgen(begin): object/testing.TestCustomHash
+ # fmt: off
+ key: int
+ label: str
+ if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: str, /) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(_0: int, _1: str, /) -> Object: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+
+@c_class("testing.TestCustomCompare")
+class TestCustomCompare(Object):
+ """Test object with custom __ffi_eq__/__ffi_compare__ hooks (compares only
key)."""
+
+ __test__ = False
+
+ # tvm-ffi-stubgen(begin): object/testing.TestCustomCompare
+ # fmt: off
+ key: int
+ label: str
+ if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: str, /) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(_0: int, _1: str, /) -> Object: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+
+@c_class("testing.TestEqWithoutHash")
+class TestEqWithoutHash(Object):
+ """Test object with __ffi_eq__ but no __ffi_hash__ (exercises hash
guard)."""
+
+ __test__ = False
+
+ # tvm-ffi-stubgen(begin): object/testing.TestEqWithoutHash
+ # fmt: off
+ key: int
+ label: str
+ if TYPE_CHECKING:
+ def __init__(self, _0: int, _1: str, /) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(_0: int, _1: str, /) -> Object: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+
@c_class("testing.TestCxxClassBase")
class _TestCxxClassBase(Object):
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxClassBase ->
testing._TestCxxClassBase
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxClassBase
+ # fmt: off
v_i64: int
v_i32: int
+ if TYPE_CHECKING:
+ def __init__(self, v_i64: int, v_i32: int) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
not_field_1 = 1
not_field_2: ClassVar[int] = 2
@@ -269,49 +299,69 @@ class _TestCxxClassBase(Object):
@c_class("testing.TestCxxClassDerived", eq=True, order=True, unsafe_hash=True)
class _TestCxxClassDerived(_TestCxxClassBase):
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxClassDerived ->
testing._TestCxxClassDerived
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxClassDerived
+ # fmt: off
v_f64: float
v_f32: float
if TYPE_CHECKING:
-
def __init__(self, v_i64: int, v_i32: int, v_f64: float, v_f32: float
= ...) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxClassDerivedDerived")
class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxClassDerivedDerived ->
testing._TestCxxClassDerivedDerived
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxClassDerivedDerived
+ # fmt: off
v_str: str
v_bool: bool
if TYPE_CHECKING:
-
- def __init__(
- self,
- v_i64: int,
- v_i32: int,
- v_f64: float,
- v_bool: bool,
- v_f32: float = ...,
- v_str: str = ...,
- ) -> None: ...
+ def __init__(self, v_i64: int, v_i32: int, v_f64: float, v_bool: bool,
v_f32: float = ..., v_str: str = ...) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxInitSubset")
class _TestCxxInitSubset(Object):
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxInitSubset ->
testing._TestCxxInitSubset
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxInitSubset
+ # fmt: off
required_field: int
optional_field: int
note: str
if TYPE_CHECKING:
-
def __init__(self, required_field: int) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxKwOnly")
class _TestCxxKwOnly(Object):
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxKwOnly -> testing._TestCxxKwOnly
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxKwOnly
+ # fmt: off
x: int
y: int
z: int
w: int
if TYPE_CHECKING:
-
def __init__(self, *, x: int, y: int, z: int, w: int = ...) -> None:
...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxAutoInit")
@@ -320,13 +370,20 @@ class _TestCxxAutoInit(Object):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxAutoInit ->
testing._TestCxxAutoInit
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxAutoInit
+ # fmt: off
a: int
b: int
c: int
d: int
if TYPE_CHECKING:
-
def __init__(self, a: int, d: int = ..., *, c: int) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxAutoInitSimple")
@@ -335,11 +392,18 @@ class _TestCxxAutoInitSimple(Object):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxAutoInitSimple ->
testing._TestCxxAutoInitSimple
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxAutoInitSimple
+ # fmt: off
x: int
y: int
if TYPE_CHECKING:
-
def __init__(self, x: int, y: int) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxAutoInitAllInitOff")
@@ -348,12 +412,19 @@ class _TestCxxAutoInitAllInitOff(Object):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxAutoInitAllInitOff ->
testing._TestCxxAutoInitAllInitOff
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxAutoInitAllInitOff
+ # fmt: off
x: int
y: int
z: int
if TYPE_CHECKING:
-
def __init__(self) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxAutoInitKwOnlyDefaults")
@@ -362,16 +433,21 @@ class _TestCxxAutoInitKwOnlyDefaults(Object):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxAutoInitKwOnlyDefaults ->
testing._TestCxxAutoInitKwOnlyDefaults
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxAutoInitKwOnlyDefaults
+ # fmt: off
p_required: int
p_default: int
k_required: int
k_default: int
hidden: int
if TYPE_CHECKING:
-
- def __init__(
- self, p_required: int, p_default: int = ..., *, k_required: int,
k_default: int = ...
- ) -> None: ...
+ def __init__(self, p_required: int, p_default: int = ..., *,
k_required: int, k_default: int = ...) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxNoAutoInit", init=False)
@@ -380,8 +456,15 @@ class _TestCxxNoAutoInit(Object):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxNoAutoInit ->
testing._TestCxxNoAutoInit
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxNoAutoInit
+ # fmt: off
x: int
y: int
+ if TYPE_CHECKING:
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxAutoInitParent")
@@ -390,11 +473,18 @@ class _TestCxxAutoInitParent(Object):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxAutoInitParent ->
testing._TestCxxAutoInitParent
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxAutoInitParent
+ # fmt: off
parent_required: int
parent_default: int
if TYPE_CHECKING:
-
def __init__(self, parent_required: int, parent_default: int = ...) ->
None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
@c_class("testing.TestCxxAutoInitChild")
@@ -403,15 +493,15 @@ class _TestCxxAutoInitChild(_TestCxxAutoInitParent):
__test__ = False
+ # tvm-ffi-stubgen(ty-map): testing.TestCxxAutoInitChild ->
testing._TestCxxAutoInitChild
+ # tvm-ffi-stubgen(begin): object/testing.TestCxxAutoInitChild
+ # fmt: off
child_required: int
child_kw_only: int
if TYPE_CHECKING:
-
- def __init__(
- self,
- parent_required: int,
- child_required: int,
- parent_default: int = ...,
- *,
- child_kw_only: int,
- ) -> None: ...
+ def __init__(self, parent_required: int, child_required: int,
parent_default: int = ..., *, child_kw_only: int) -> None: ...
+ def __ffi_shallow_copy__(self, /) -> Object: ...
+ @staticmethod
+ def __c_ffi_init__(*args: Any) -> Any: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
diff --git a/tests/python/test_dataclass_c_class.py
b/tests/python/test_dataclass_c_class.py
index 52bf38bb..9128660c 100644
--- a/tests/python/test_dataclass_c_class.py
+++ b/tests/python/test_dataclass_c_class.py
@@ -123,30 +123,30 @@ def test_c_class_ordering() -> None:
"""c_class installs ordering operators."""
small = _TestCxxClassDerived(0, 0, 0.0, 0.0)
big = _TestCxxClassDerived(100, 100, 100.0, 100.0)
- assert small < big # ty: ignore[unsupported-operator]
- assert small <= big # ty: ignore[unsupported-operator]
- assert big > small # ty: ignore[unsupported-operator]
- assert big >= small # ty: ignore[unsupported-operator]
- assert not (big < small) # ty: ignore[unsupported-operator]
- assert not (small > big) # ty: ignore[unsupported-operator]
+ assert small < big
+ assert small <= big
+ assert big > small
+ assert big >= small
+ assert not (big < small)
+ assert not (small > big)
def test_c_class_ordering_reflexive() -> None:
"""<= and >= are reflexive."""
a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
b = a # alias, same object
- assert a <= b # ty: ignore[unsupported-operator]
- assert a >= b # ty: ignore[unsupported-operator]
+ assert a <= b
+ assert a >= b
def test_c_class_ordering_antisymmetric() -> None:
"""If a < b then not b < a."""
a = _TestCxxClassDerived(0, 0, 0.0, 0.0)
b = _TestCxxClassDerived(100, 100, 100.0, 100.0)
- if a < b: # ty: ignore[unsupported-operator]
- assert not (b < a) # ty: ignore[unsupported-operator]
+ if a < b:
+ assert not (b < a)
else:
- assert not (a < b) # ty: ignore[unsupported-operator]
+ assert not (a < b)
# ---------------------------------------------------------------------------
diff --git a/tests/python/test_dataclass_compare.py
b/tests/python/test_dataclass_compare.py
index a03f1a0a..83234abe 100644
--- a/tests/python/test_dataclass_compare.py
+++ b/tests/python/test_dataclass_compare.py
@@ -448,17 +448,17 @@ def test_equal_dicts_under_ordering() -> None:
def test_reflected_obj_eq() -> None:
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- c = TestIntPair(1, 3) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ b = TestIntPair(1, 2)
+ c = TestIntPair(1, 3)
assert RecursiveEq(a, b)
assert not RecursiveEq(a, c)
def test_reflected_obj_ordering() -> None:
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b = TestIntPair(1, 3) # ty: ignore[too-many-positional-arguments]
- c = TestIntPair(2, 0) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ b = TestIntPair(1, 3)
+ c = TestIntPair(2, 0)
assert RecursiveLt(a, b) # first field equal, second: 2 < 3
assert RecursiveLt(a, c) # first field: 1 < 2
@@ -470,21 +470,21 @@ def test_reflected_obj_ordering() -> None:
def test_compare_off_ignored_field() -> None:
"""ignored_field is excluded from comparison via Compare(false)."""
- a = TestCompare(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(1, "x", 999) # ty: ignore[too-many-positional-arguments]
+ a = TestCompare(1, "x", 100)
+ b = TestCompare(1, "x", 999)
assert RecursiveEq(a, b)
def test_compare_off_key_differs() -> None:
- a = TestCompare(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(2, "x", 100) # ty: ignore[too-many-positional-arguments]
+ a = TestCompare(1, "x", 100)
+ b = TestCompare(2, "x", 100)
assert not RecursiveEq(a, b)
assert RecursiveLt(a, b)
def test_compare_off_name_differs() -> None:
- a = TestCompare(1, "a", 100) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(1, "b", 100) # ty: ignore[too-many-positional-arguments]
+ a = TestCompare(1, "a", 100)
+ b = TestCompare(1, "b", 100)
assert not RecursiveEq(a, b)
assert RecursiveLt(a, b)
@@ -495,7 +495,7 @@ def test_compare_off_name_differs() -> None:
def test_same_pointer() -> None:
- x = TestIntPair(42, 99) # ty: ignore[too-many-positional-arguments]
+ x = TestIntPair(42, 99)
assert RecursiveEq(x, x)
@@ -506,14 +506,14 @@ def test_same_pointer() -> None:
def test_different_obj_types_eq() -> None:
"""RecursiveEq returns False for different object types."""
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(1, "x", 0) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ b = TestCompare(1, "x", 0)
assert not RecursiveEq(a, b)
def test_different_obj_types_ordering_raises() -> None:
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(1, "x", 0) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ b = TestCompare(1, "x", 0)
with pytest.raises(TypeError):
RecursiveLt(a, b)
@@ -524,20 +524,20 @@ def test_different_obj_types_ordering_raises() -> None:
def test_nested_objects_in_array() -> None:
- a1 = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- a2 = TestIntPair(3, 4) # ty: ignore[too-many-positional-arguments]
- b1 = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b2 = TestIntPair(3, 4) # ty: ignore[too-many-positional-arguments]
+ a1 = TestIntPair(1, 2)
+ a2 = TestIntPair(3, 4)
+ b1 = TestIntPair(1, 2)
+ b2 = TestIntPair(3, 4)
arr_a = tvm_ffi.Array([a1, a2])
arr_b = tvm_ffi.Array([b1, b2])
assert RecursiveEq(arr_a, arr_b)
def test_nested_objects_in_array_differ() -> None:
- a1 = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- a2 = TestIntPair(3, 4) # ty: ignore[too-many-positional-arguments]
- b1 = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b2 = TestIntPair(3, 5) # ty: ignore[too-many-positional-arguments]
+ a1 = TestIntPair(1, 2)
+ a2 = TestIntPair(3, 4)
+ b1 = TestIntPair(1, 2)
+ b2 = TestIntPair(3, 5)
arr_a = tvm_ffi.Array([a1, a2])
arr_b = tvm_ffi.Array([b1, b2])
assert not RecursiveEq(arr_a, arr_b)
@@ -736,8 +736,8 @@ def test_object_array_of_objects() -> None:
v_map=tvm_ffi.Map({}),
v_array=tvm_ffi.Array(
[
- TestIntPair(1, 2), # ty: ignore[too-many-positional-arguments]
- TestIntPair(3, 4), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(1, 2),
+ TestIntPair(3, 4),
]
),
)
@@ -749,8 +749,8 @@ def test_object_array_of_objects() -> None:
v_map=tvm_ffi.Map({}),
v_array=tvm_ffi.Array(
[
- TestIntPair(1, 2), # ty: ignore[too-many-positional-arguments]
- TestIntPair(3, 4), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(1, 2),
+ TestIntPair(3, 4),
]
),
)
@@ -766,8 +766,8 @@ def test_object_array_of_objects_differ() -> None:
v_map=tvm_ffi.Map({}),
v_array=tvm_ffi.Array(
[
- TestIntPair(1, 2), # ty: ignore[too-many-positional-arguments]
- TestIntPair(3, 4), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(1, 2),
+ TestIntPair(3, 4),
]
),
)
@@ -779,8 +779,8 @@ def test_object_array_of_objects_differ() -> None:
v_map=tvm_ffi.Map({}),
v_array=tvm_ffi.Array(
[
- TestIntPair(1, 2), # ty: ignore[too-many-positional-arguments]
- TestIntPair(3, 5), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(1, 2),
+ TestIntPair(3, 5),
]
),
)
@@ -796,7 +796,7 @@ def test_object_map_with_object_values() -> None:
v_str="",
v_map=tvm_ffi.Map(
{
- "x": TestIntPair(1, 2), # ty:
ignore[too-many-positional-arguments]
+ "x": TestIntPair(1, 2),
}
),
v_array=tvm_ffi.Array([]),
@@ -808,7 +808,7 @@ def test_object_map_with_object_values() -> None:
v_str="",
v_map=tvm_ffi.Map(
{
- "x": TestIntPair(1, 2), # ty:
ignore[too-many-positional-arguments]
+ "x": TestIntPair(1, 2),
}
),
v_array=tvm_ffi.Array([]),
@@ -859,24 +859,24 @@ def test_deep_object_in_object() -> None:
def test_inherited_fields_eq() -> None:
- a = _TestCxxClassDerived(10, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
- b = _TestCxxClassDerived(10, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
+ a = _TestCxxClassDerived(10, 20, 1.5, 2.5)
+ b = _TestCxxClassDerived(10, 20, 1.5, 2.5)
assert RecursiveEq(a, b)
def test_inherited_fields_differ_in_base() -> None:
- a = _TestCxxClassDerived(10, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
- b = _TestCxxClassDerived(99, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
+ a = _TestCxxClassDerived(10, 20, 1.5, 2.5)
+ b = _TestCxxClassDerived(99, 20, 1.5, 2.5)
assert not RecursiveEq(a, b)
assert RecursiveLt(a, b)
def test_three_level_inheritance_eq_and_differ() -> None:
# Positional order: required (v_i64, v_i32, v_f64, v_bool), then optional
(v_f32, v_str)
- a = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi") # ty:
ignore[too-many-positional-arguments]
- b = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi") # ty:
ignore[too-many-positional-arguments]
+ a = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi")
+ b = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi")
assert RecursiveEq(a, b)
- c = _TestCxxClassDerivedDerived(1, 2, 3.0, False, 4.0, "hi") # ty:
ignore[too-many-positional-arguments]
+ c = _TestCxxClassDerivedDerived(1, 2, 3.0, False, 4.0, "hi")
assert not RecursiveEq(a, c)
@@ -888,14 +888,14 @@ def test_three_level_inheritance_eq_and_differ() -> None:
def test_compare_off_inside_array() -> None:
a = tvm_ffi.Array(
[
- TestCompare(1, "x", 100), # ty:
ignore[too-many-positional-arguments]
- TestCompare(2, "y", 200), # ty:
ignore[too-many-positional-arguments]
+ TestCompare(1, "x", 100),
+ TestCompare(2, "y", 200),
]
)
b = tvm_ffi.Array(
[
- TestCompare(1, "x", 999), # ty:
ignore[too-many-positional-arguments]
- TestCompare(2, "y", 888), # ty:
ignore[too-many-positional-arguments]
+ TestCompare(1, "x", 999),
+ TestCompare(2, "y", 888),
]
)
assert RecursiveEq(a, b)
@@ -911,7 +911,7 @@ def test_compare_off_inside_nested_object() -> None:
v_map=tvm_ffi.Map({}),
v_array=tvm_ffi.Array(
[
- TestCompare(1, "n", 100), # ty:
ignore[too-many-positional-arguments]
+ TestCompare(1, "n", 100),
]
),
)
@@ -923,7 +923,7 @@ def test_compare_off_inside_nested_object() -> None:
v_map=tvm_ffi.Map({}),
v_array=tvm_ffi.Array(
[
- TestCompare(1, "n", 999), # ty:
ignore[too-many-positional-arguments]
+ TestCompare(1, "n", 999),
]
),
)
@@ -982,17 +982,17 @@ def test_map_with_array_values_eq() -> None:
def test_dict_with_object_values_eq() -> None:
a = tvm_ffi.Dict(
{
- "k": TestIntPair(1, 2), # ty:
ignore[too-many-positional-arguments]
+ "k": TestIntPair(1, 2),
}
)
b = tvm_ffi.Dict(
{
- "k": TestIntPair(1, 2), # ty:
ignore[too-many-positional-arguments]
+ "k": TestIntPair(1, 2),
}
)
c = tvm_ffi.Dict(
{
- "k": TestIntPair(1, 3), # ty:
ignore[too-many-positional-arguments]
+ "k": TestIntPair(1, 3),
}
)
assert RecursiveEq(a, b)
@@ -1137,10 +1137,10 @@ def test_cyclic_dict_ordering_raises() -> None:
def test_ordering_laws_on_int_pairs() -> None:
"""Verify ordering laws (trichotomy, antisymmetry, transitivity) on
TestIntPair."""
values = [
- TestIntPair(0, 0), # ty: ignore[too-many-positional-arguments]
- TestIntPair(0, 1), # ty: ignore[too-many-positional-arguments]
- TestIntPair(1, 0), # ty: ignore[too-many-positional-arguments]
- TestIntPair(1, 1), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(0, 0),
+ TestIntPair(0, 1),
+ TestIntPair(1, 0),
+ TestIntPair(1, 1),
]
for a in values:
for b in values:
@@ -1194,21 +1194,21 @@ def test_depth_1000_nested_eq() -> None:
def test_custom_eq_ignores_label() -> None:
"""TestCustomCompare.__ffi_eq__ compares only `key`, ignoring `label`."""
- a = TestCustomCompare(42, "alpha") # ty:
ignore[too-many-positional-arguments]
- b = TestCustomCompare(42, "beta") # ty:
ignore[too-many-positional-arguments]
+ a = TestCustomCompare(42, "alpha")
+ b = TestCustomCompare(42, "beta")
assert RecursiveEq(a, b)
def test_custom_eq_different_key() -> None:
- a = TestCustomCompare(1, "same") # ty:
ignore[too-many-positional-arguments]
- b = TestCustomCompare(2, "same") # ty:
ignore[too-many-positional-arguments]
+ a = TestCustomCompare(1, "same")
+ b = TestCustomCompare(2, "same")
assert not RecursiveEq(a, b)
def test_custom_compare_ordering() -> None:
"""Ordering uses __ffi_compare__ hook (key only)."""
- a = TestCustomCompare(1, "zzz") # ty:
ignore[too-many-positional-arguments]
- b = TestCustomCompare(2, "aaa") # ty:
ignore[too-many-positional-arguments]
+ a = TestCustomCompare(1, "zzz")
+ b = TestCustomCompare(2, "aaa")
assert RecursiveLt(a, b)
assert not RecursiveLt(b, a)
@@ -1217,14 +1217,14 @@ def test_custom_eq_in_container() -> None:
"""Custom-hooked objects inside an Array."""
a = tvm_ffi.Array(
[
- TestCustomCompare(1, "x"), # ty:
ignore[too-many-positional-arguments]
- TestCustomCompare(2, "y"), # ty:
ignore[too-many-positional-arguments]
+ TestCustomCompare(1, "x"),
+ TestCustomCompare(2, "y"),
]
)
b = tvm_ffi.Array(
[
- TestCustomCompare(1, "different"), # ty:
ignore[too-many-positional-arguments]
- TestCustomCompare(2, "labels"), # ty:
ignore[too-many-positional-arguments]
+ TestCustomCompare(1, "different"),
+ TestCustomCompare(2, "labels"),
]
)
assert RecursiveEq(a, b)
@@ -1237,8 +1237,8 @@ def test_custom_eq_in_container() -> None:
def test_eq_only_type_eq_uses_hook() -> None:
"""__ffi_eq__-only type: RecursiveEq uses the hook (compares only key)."""
- a = TestEqWithoutHash(42, "alpha") # ty:
ignore[too-many-positional-arguments]
- b = TestEqWithoutHash(42, "beta") # ty:
ignore[too-many-positional-arguments]
+ a = TestEqWithoutHash(42, "alpha")
+ b = TestEqWithoutHash(42, "beta")
assert RecursiveEq(a, b)
@@ -1249,8 +1249,8 @@ def test_eq_only_type_ordering_uses_reflection() -> None:
though __ffi_eq__ ignores it. This is expected — register __ffi_compare__
for consistent ordering semantics.
"""
- a = TestEqWithoutHash(42, "alpha") # ty:
ignore[too-many-positional-arguments]
- b = TestEqWithoutHash(42, "beta") # ty:
ignore[too-many-positional-arguments]
+ a = TestEqWithoutHash(42, "alpha")
+ b = TestEqWithoutHash(42, "beta")
# Eq says equal (hook), but ordering sees label difference (reflection)
assert RecursiveEq(a, b)
assert RecursiveLt(a, b) # "alpha" < "beta"
@@ -1263,8 +1263,8 @@ def test_eq_only_type_ordering_uses_reflection() -> None:
def test_custom_compare_ordering_consistency() -> None:
"""TestCustomCompare has __ffi_compare__: Eq(a,b) implies not Lt/Gt and
both Le/Ge."""
- a = TestCustomCompare(42, "alpha") # ty:
ignore[too-many-positional-arguments]
- b = TestCustomCompare(42, "beta") # ty:
ignore[too-many-positional-arguments]
+ a = TestCustomCompare(42, "alpha")
+ b = TestCustomCompare(42, "beta")
assert RecursiveEq(a, b)
assert not RecursiveLt(a, b)
assert not RecursiveGt(a, b)
diff --git a/tests/python/test_dataclass_copy.py
b/tests/python/test_dataclass_copy.py
index df9b6367..6507e636 100644
--- a/tests/python/test_dataclass_copy.py
+++ b/tests/python/test_dataclass_copy.py
@@ -33,13 +33,13 @@ class TestShallowCopy:
"""Tests for copy.copy() / __copy__."""
def test_basic_fields(self) -> None:
- pair = tvm_ffi.testing.TestIntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(1, 2)
pair_copy = copy.copy(pair)
assert pair_copy.a == 1
assert pair_copy.b == 2
def test_creates_new_object(self) -> None:
- pair = tvm_ffi.testing.TestIntPair(3, 7) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(3, 7)
pair_copy = copy.copy(pair)
assert not pair.same_as(pair_copy)
@@ -83,10 +83,10 @@ class TestShallowCopy:
obj_copy = copy.copy(obj)
assert obj_copy.v_i64 == 2
assert obj_copy.v_i32 == 4
- assert not obj.same_as(obj_copy) # ty: ignore[unresolved-attribute]
+ assert not obj.same_as(obj_copy)
def test_non_copyable_type_raises(self) -> None:
- obj = tvm_ffi.testing.TestNonCopyable(42) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing.TestNonCopyable(42)
with pytest.raises(TypeError, match="does not support copy"):
copy.copy(obj)
@@ -98,14 +98,14 @@ class TestDeepCopy:
"""Tests for copy.deepcopy() / __deepcopy__."""
def test_basic_fields(self) -> None:
- pair = tvm_ffi.testing.TestIntPair(5, 10) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(5, 10)
pair_deep = copy.deepcopy(pair)
assert pair_deep.a == 5
assert pair_deep.b == 10
assert not pair.same_as(pair_deep)
def test_nested_objects_are_copied(self) -> None:
- inner = tvm_ffi.testing.TestIntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
+ inner = tvm_ffi.testing.TestIntPair(1, 2)
v_array = tvm_ffi.convert([inner])
v_map = tvm_ffi.convert({"x": "y"})
obj = tvm_ffi.testing.create_object(
@@ -125,7 +125,7 @@ class TestDeepCopy:
def test_shared_references_preserved(self) -> None:
"""Two array slots pointing to the same object should still share
after deepcopy."""
- shared = tvm_ffi.testing.TestIntPair(7, 8) # ty:
ignore[too-many-positional-arguments]
+ shared = tvm_ffi.testing.TestIntPair(7, 8)
v_array = tvm_ffi.convert([shared, shared])
v_map = tvm_ffi.convert({"a": "b"})
obj = tvm_ffi.testing.create_object(
@@ -217,7 +217,7 @@ class TestDeepCopy:
def test_array_root(self) -> None:
"""Deepcopy with a bare Array as root should create a new array."""
- inner = tvm_ffi.testing.TestIntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
+ inner = tvm_ffi.testing.TestIntPair(1, 2)
arr = tvm_ffi.convert([inner, "hello", 42])
arr_deep = copy.deepcopy(arr)
assert not arr.same_as(arr_deep)
@@ -230,7 +230,7 @@ class TestDeepCopy:
def test_map_root(self) -> None:
"""Deepcopy with a bare Map as root should create a new map."""
- inner = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ inner = tvm_ffi.testing.TestIntPair(3, 4)
m = tvm_ffi.convert({"key": inner})
m_deep = copy.deepcopy(m)
assert not m.same_as(m_deep)
@@ -240,7 +240,7 @@ class TestDeepCopy:
def test_dict_root(self) -> None:
"""Deepcopy with a bare Dict as root should create a new dict."""
- inner = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ inner = tvm_ffi.testing.TestIntPair(3, 4)
d = tvm_ffi.Dict({"key": inner})
d_deep = copy.deepcopy(d)
assert not d.same_as(d_deep)
@@ -255,10 +255,10 @@ class TestDeepCopy:
obj_deep = copy.deepcopy(obj)
assert obj_deep.v_i64 == 2
assert obj_deep.v_i32 == 4
- assert not obj.same_as(obj_deep) # ty: ignore[unresolved-attribute]
+ assert not obj.same_as(obj_deep)
def test_non_copyable_type_raises(self) -> None:
- obj = tvm_ffi.testing.TestNonCopyable(42) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing.TestNonCopyable(42)
with pytest.raises(TypeError, match="does not support deepcopy"):
copy.deepcopy(obj)
@@ -281,7 +281,7 @@ class TestDeepCopy:
def test_any_field_with_object(self) -> None:
"""Any-typed field containing an object must be recursively copied."""
- inner = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ inner = tvm_ffi.testing.TestIntPair(3, 4)
obj = tvm_ffi.testing.create_object("testing.TestDeepCopyEdges",
v_any=inner, v_obj=inner)
obj_deep = copy.deepcopy(obj)
assert not obj.same_as(obj_deep)
@@ -330,7 +330,7 @@ class TestDeepCopy:
def test_any_field_sharing_preserved(self) -> None:
"""Shared references through Any and ObjectRef fields are preserved."""
- shared = tvm_ffi.testing.TestIntPair(5, 6) # ty:
ignore[too-many-positional-arguments]
+ shared = tvm_ffi.testing.TestIntPair(5, 6)
obj = tvm_ffi.testing.create_object("testing.TestDeepCopyEdges",
v_any=shared, v_obj=shared)
obj_deep = copy.deepcopy(obj)
# Both fields should point to the same copied object
@@ -339,7 +339,7 @@ class TestDeepCopy:
# --------------------------------------------------------------------------- #
-# Deep copy branch coverage (C++ deep_copy.cc)
+# Deep copy branch coverage (C++ dataclass.cc)
# --------------------------------------------------------------------------- #
_deep_copy = tvm_ffi.get_global_func("ffi.DeepCopy")
@@ -401,7 +401,7 @@ class TestDeepCopyBranches:
def test_array_mixed_with_objects_and_containers(self) -> None:
"""Array with int, str, None, object, nested array, nested map."""
- inner_obj = tvm_ffi.testing.TestIntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
+ inner_obj = tvm_ffi.testing.TestIntPair(1, 2)
inner_arr = tvm_ffi.convert([10, 20])
inner_map = tvm_ffi.convert({"k": "v"})
arr = tvm_ffi.convert([42, "hello", None, inner_obj, inner_arr,
inner_map])
@@ -533,7 +533,7 @@ class TestDeepCopyBranches:
def test_shared_object_across_array_and_map(self) -> None:
"""Same object referenced from both v_array and v_map."""
- pair = tvm_ffi.testing.TestIntPair(7, 8) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(7, 8)
v_array = tvm_ffi.convert([pair])
v_map = tvm_ffi.convert({"p": pair})
obj = tvm_ffi.testing.create_object(
@@ -644,13 +644,13 @@ class TestDeepCopyBranches:
assert m_deep["k"].v_i32 == 4
def test_non_copyable_type_in_array(self) -> None:
- obj = tvm_ffi.testing.TestNonCopyable(1) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing.TestNonCopyable(1)
arr = tvm_ffi.convert([obj])
with pytest.raises(RuntimeError, match="not copy-constructible"):
copy.deepcopy(arr)
def test_non_copyable_type_in_map_value(self) -> None:
- obj = tvm_ffi.testing.TestNonCopyable(1) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing.TestNonCopyable(1)
m = tvm_ffi.convert({"k": obj})
with pytest.raises(RuntimeError, match="not copy-constructible"):
copy.deepcopy(m)
@@ -659,7 +659,7 @@ class TestDeepCopyBranches:
def test_deeply_nested_containers(self) -> None:
"""Array > Map > Array > object — all levels resolved."""
- pair = tvm_ffi.testing.TestIntPair(9, 10) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(9, 10)
inner_arr = tvm_ffi.convert([pair])
inner_map = tvm_ffi.convert({"items": inner_arr})
outer = tvm_ffi.convert([inner_map])
@@ -671,7 +671,7 @@ class TestDeepCopyBranches:
def test_object_with_deeply_nested_field(self) -> None:
"""Object whose array field contains a map containing an object."""
- pair = tvm_ffi.testing.TestIntPair(5, 6) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(5, 6)
inner_map = tvm_ffi.convert({"pair": pair})
v_array = tvm_ffi.convert([inner_map])
obj = tvm_ffi.testing.create_object(
@@ -920,7 +920,7 @@ class TestReplace:
assert obj.v_i64 == 5 # ty: ignore[unresolved-attribute]
def test_replace_readonly_field_raises(self) -> None:
- pair = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(3, 4)
with pytest.raises(AttributeError):
pair.__replace__(a=10) # ty: ignore[unresolved-attribute]
@@ -931,9 +931,9 @@ class TestReplace:
obj2 = obj.__replace__(v_i64=99) # ty: ignore[unresolved-attribute]
assert obj2.v_i64 == 99
assert obj2.v_i32 == 4
- assert not obj.same_as(obj2) # ty: ignore[unresolved-attribute]
+ assert not obj.same_as(obj2)
def test_non_copyable_type_raises(self) -> None:
- obj = tvm_ffi.testing.TestNonCopyable(42) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing.TestNonCopyable(42)
with pytest.raises(TypeError, match="does not support replace"):
obj.__replace__() # ty: ignore[unresolved-attribute]
diff --git a/tests/python/test_dataclass_hash.py
b/tests/python/test_dataclass_hash.py
index 7149e06c..ec58c2b7 100644
--- a/tests/python/test_dataclass_hash.py
+++ b/tests/python/test_dataclass_hash.py
@@ -326,14 +326,14 @@ def test_dict_hash_different() -> None:
def test_reflected_obj_hash_equal() -> None:
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ b = TestIntPair(1, 2)
assert RecursiveHash(a) == RecursiveHash(b)
def test_reflected_obj_hash_different() -> None:
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- c = TestIntPair(1, 3) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ c = TestIntPair(1, 3)
assert RecursiveHash(a) != RecursiveHash(c)
@@ -344,20 +344,20 @@ def test_reflected_obj_hash_different() -> None:
def test_hash_off_ignored_field() -> None:
"""hash_ignored is excluded from hashing via Hash(false)."""
- a = TestHash(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestHash(1, "x", 999) # ty: ignore[too-many-positional-arguments]
+ a = TestHash(1, "x", 100)
+ b = TestHash(1, "x", 999)
assert RecursiveHash(a) == RecursiveHash(b)
def test_hash_off_key_differs() -> None:
- a = TestHash(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestHash(2, "x", 100) # ty: ignore[too-many-positional-arguments]
+ a = TestHash(1, "x", 100)
+ b = TestHash(2, "x", 100)
assert RecursiveHash(a) != RecursiveHash(b)
def test_hash_off_name_differs() -> None:
- a = TestHash(1, "a", 100) # ty: ignore[too-many-positional-arguments]
- b = TestHash(1, "b", 100) # ty: ignore[too-many-positional-arguments]
+ a = TestHash(1, "a", 100)
+ b = TestHash(1, "b", 100)
assert RecursiveHash(a) != RecursiveHash(b)
@@ -368,8 +368,8 @@ def test_hash_off_name_differs() -> None:
def test_compare_off_implies_hash_off() -> None:
"""Fields with Compare(false) are also excluded from hashing."""
- a = TestCompare(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(1, "x", 999) # ty: ignore[too-many-positional-arguments]
+ a = TestCompare(1, "x", 100)
+ b = TestCompare(1, "x", 999)
assert RecursiveHash(a) == RecursiveHash(b)
@@ -379,7 +379,7 @@ def test_compare_off_implies_hash_off() -> None:
def test_same_pointer_hash() -> None:
- x = TestIntPair(42, 99) # ty: ignore[too-many-positional-arguments]
+ x = TestIntPair(42, 99)
assert RecursiveHash(x) == RecursiveHash(x)
@@ -437,23 +437,22 @@ def test_map_with_array_values_hash() -> None:
def test_inherited_fields_hash_equal() -> None:
- a = _TestCxxClassDerived(10, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
- b = _TestCxxClassDerived(10, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
+ a = _TestCxxClassDerived(10, 20, 1.5, 2.5)
+ b = _TestCxxClassDerived(10, 20, 1.5, 2.5)
assert RecursiveHash(a) == RecursiveHash(b)
def test_inherited_fields_differ_in_base_hash() -> None:
- a = _TestCxxClassDerived(10, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
- b = _TestCxxClassDerived(99, 20, 1.5, 2.5) # ty:
ignore[too-many-positional-arguments]
+ a = _TestCxxClassDerived(10, 20, 1.5, 2.5)
+ b = _TestCxxClassDerived(99, 20, 1.5, 2.5)
assert RecursiveHash(a) != RecursiveHash(b)
def test_three_level_inheritance_hash() -> None:
- # Positional order: required (v_i64, v_i32, v_f64, v_bool), then optional
(v_f32, v_str)
- a = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi") # ty:
ignore[too-many-positional-arguments]
- b = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi") # ty:
ignore[too-many-positional-arguments]
+ a = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi")
+ b = _TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0, "hi")
assert RecursiveHash(a) == RecursiveHash(b)
- c = _TestCxxClassDerivedDerived(1, 2, 3.0, False, 4.0, "hi") # ty:
ignore[too-many-positional-arguments]
+ c = _TestCxxClassDerivedDerived(1, 2, 3.0, False, 4.0, "hi")
assert RecursiveHash(a) != RecursiveHash(c)
@@ -616,34 +615,34 @@ def test_consistency_containers() -> None:
def test_consistency_reflected_objects() -> None:
"""Verify hash consistency for reflected objects."""
- a = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
- b = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
+ a = TestIntPair(1, 2)
+ b = TestIntPair(1, 2)
assert RecursiveEq(a, b)
assert RecursiveHash(a) == RecursiveHash(b)
def test_consistency_compare_off() -> None:
"""Fields excluded from comparison are also excluded from hash."""
- a = TestCompare(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestCompare(1, "x", 999) # ty: ignore[too-many-positional-arguments]
+ a = TestCompare(1, "x", 100)
+ b = TestCompare(1, "x", 999)
assert RecursiveEq(a, b)
assert RecursiveHash(a) == RecursiveHash(b)
def test_consistency_hash_off() -> None:
"""Fields excluded from hashing produce same hash when they differ."""
- a = TestHash(1, "x", 100) # ty: ignore[too-many-positional-arguments]
- b = TestHash(1, "x", 999) # ty: ignore[too-many-positional-arguments]
+ a = TestHash(1, "x", 100)
+ b = TestHash(1, "x", 999)
assert RecursiveHash(a) == RecursiveHash(b)
def test_consistency_law_on_int_pairs() -> None:
"""Verify: RecursiveEq(a, b) => RecursiveHash(a) == RecursiveHash(b)."""
values = [
- TestIntPair(0, 0), # ty: ignore[too-many-positional-arguments]
- TestIntPair(0, 1), # ty: ignore[too-many-positional-arguments]
- TestIntPair(1, 0), # ty: ignore[too-many-positional-arguments]
- TestIntPair(1, 1), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(0, 0),
+ TestIntPair(0, 1),
+ TestIntPair(1, 0),
+ TestIntPair(1, 1),
]
for a in values:
for b in values:
@@ -666,12 +665,12 @@ def _make_nested_singleton_array(depth: int) -> object:
def test_aliasing_consistency_array_of_reflected_objects() -> None:
- shared = TestIntPair(11, 22) # ty: ignore[too-many-positional-arguments]
+ shared = TestIntPair(11, 22)
aliased = tvm_ffi.Array([shared, shared])
duplicated = tvm_ffi.Array(
[
- TestIntPair(11, 22), # ty: ignore[too-many-positional-arguments]
- TestIntPair(11, 22), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(11, 22),
+ TestIntPair(11, 22),
]
)
assert RecursiveEq(aliased, duplicated)
@@ -679,12 +678,12 @@ def
test_aliasing_consistency_array_of_reflected_objects() -> None:
def test_aliasing_consistency_list_of_reflected_objects() -> None:
- shared = TestIntPair(13, 26) # ty: ignore[too-many-positional-arguments]
+ shared = TestIntPair(13, 26)
aliased = tvm_ffi.List([shared, shared])
duplicated = tvm_ffi.List(
[
- TestIntPair(13, 26), # ty: ignore[too-many-positional-arguments]
- TestIntPair(13, 26), # ty: ignore[too-many-positional-arguments]
+ TestIntPair(13, 26),
+ TestIntPair(13, 26),
]
)
assert RecursiveEq(aliased, duplicated)
@@ -716,12 +715,12 @@ def test_aliasing_consistency_shape_objects() -> None:
def test_aliasing_consistency_map_shared_values() -> None:
- shared = TestIntPair(31, 41) # ty: ignore[too-many-positional-arguments]
+ shared = TestIntPair(31, 41)
aliased = tvm_ffi.Map({"x": shared, "y": shared})
duplicated = tvm_ffi.Map(
{
- "x": TestIntPair(31, 41), # ty:
ignore[too-many-positional-arguments]
- "y": TestIntPair(31, 41), # ty:
ignore[too-many-positional-arguments]
+ "x": TestIntPair(31, 41),
+ "y": TestIntPair(31, 41),
}
)
assert RecursiveEq(aliased, duplicated)
@@ -737,7 +736,7 @@ def test_aliasing_consistency_dict_shared_values() -> None:
def test_aliasing_consistency_reflected_object_fields() -> None:
- shared = TestIntPair(5, 6) # ty: ignore[too-many-positional-arguments]
+ shared = TestIntPair(5, 6)
aliased = create_object(
"testing.TestObjectDerived",
v_i64=1,
@@ -751,8 +750,8 @@ def test_aliasing_consistency_reflected_object_fields() ->
None:
v_i64=1,
v_f64=2.0,
v_str="shared",
- v_map=tvm_ffi.Map({"k": TestIntPair(5, 6)}), # ty:
ignore[too-many-positional-arguments]
- v_array=tvm_ffi.Array([TestIntPair(5, 6)]), # ty:
ignore[too-many-positional-arguments]
+ v_map=tvm_ffi.Map({"k": TestIntPair(5, 6)}),
+ v_array=tvm_ffi.Array([TestIntPair(5, 6)]),
)
assert RecursiveEq(aliased, duplicated)
assert RecursiveHash(aliased) == RecursiveHash(duplicated)
@@ -764,7 +763,7 @@ def test_aliasing_consistency_reflected_object_fields() ->
None:
def test_map_hash_order_independent_with_shared_values() -> None:
- shared = TestIntPair(1, 2) # ty: ignore[too-many-positional-arguments]
+ shared = TestIntPair(1, 2)
a = tvm_ffi.Map({"a": shared, "b": shared, "c": shared})
b = tvm_ffi.Map({"b": shared, "a": shared, "c": shared})
assert RecursiveEq(a, b)
@@ -874,14 +873,14 @@ def test_shared_dag_hash_scaling_not_exponential() ->
None:
def test_custom_hash_ignores_label() -> None:
"""TestCustomHash hashes only `key`, ignoring `label`."""
- a = TestCustomHash(42, "alpha") # ty:
ignore[too-many-positional-arguments]
- b = TestCustomHash(42, "beta") # ty: ignore[too-many-positional-arguments]
+ a = TestCustomHash(42, "alpha")
+ b = TestCustomHash(42, "beta")
assert RecursiveHash(a) == RecursiveHash(b)
def test_custom_hash_different_key() -> None:
- a = TestCustomHash(1, "same") # ty: ignore[too-many-positional-arguments]
- b = TestCustomHash(2, "same") # ty: ignore[too-many-positional-arguments]
+ a = TestCustomHash(1, "same")
+ b = TestCustomHash(2, "same")
assert RecursiveHash(a) != RecursiveHash(b)
@@ -889,14 +888,14 @@ def test_custom_hash_in_container() -> None:
"""Custom-hooked objects inside an Array."""
a = tvm_ffi.Array(
[
- TestCustomHash(1, "x"), # ty:
ignore[too-many-positional-arguments]
- TestCustomHash(2, "y"), # ty:
ignore[too-many-positional-arguments]
+ TestCustomHash(1, "x"),
+ TestCustomHash(2, "y"),
]
)
b = tvm_ffi.Array(
[
- TestCustomHash(1, "different"), # ty:
ignore[too-many-positional-arguments]
- TestCustomHash(2, "labels"), # ty:
ignore[too-many-positional-arguments]
+ TestCustomHash(1, "different"),
+ TestCustomHash(2, "labels"),
]
)
assert RecursiveHash(a) == RecursiveHash(b)
@@ -904,8 +903,8 @@ def test_custom_hash_in_container() -> None:
def test_custom_hash_consistency_with_eq() -> None:
"""RecursiveEq(a,b) => RecursiveHash(a)==RecursiveHash(b) for
TestCustomCompare."""
- a = TestCustomCompare(42, "alpha") # ty:
ignore[too-many-positional-arguments]
- b = TestCustomCompare(42, "beta") # ty:
ignore[too-many-positional-arguments]
+ a = TestCustomCompare(42, "alpha")
+ b = TestCustomCompare(42, "beta")
assert RecursiveEq(a, b)
assert RecursiveHash(a) == RecursiveHash(b)
@@ -920,8 +919,8 @@ def test_custom_hash_consistency_with_eq() -> None:
def test_custom_compare_eq_implies_hash_same_direct(
key: int, lhs_label: str, rhs_label: str
) -> None:
- lhs = TestCustomCompare(key, lhs_label) # ty:
ignore[too-many-positional-arguments]
- rhs = TestCustomCompare(key, rhs_label) # ty:
ignore[too-many-positional-arguments]
+ lhs = TestCustomCompare(key, lhs_label)
+ rhs = TestCustomCompare(key, rhs_label)
assert RecursiveEq(lhs, rhs)
assert RecursiveHash(lhs) == RecursiveHash(rhs)
@@ -943,8 +942,8 @@ def test_custom_compare_eq_implies_hash_same_direct(
def test_custom_compare_eq_implies_hash_same_in_wrappers(
key: int, wrap: Callable[[object], object]
) -> None:
- lhs_obj = TestCustomCompare(key, "left") # ty:
ignore[too-many-positional-arguments]
- rhs_obj = TestCustomCompare(key, "right") # ty:
ignore[too-many-positional-arguments]
+ lhs_obj = TestCustomCompare(key, "left")
+ rhs_obj = TestCustomCompare(key, "right")
lhs = wrap(lhs_obj)
rhs = wrap(rhs_obj)
assert RecursiveEq(lhs, rhs)
@@ -958,14 +957,14 @@ def test_custom_compare_eq_implies_hash_same_in_wrappers(
def test_eq_without_hash_raises() -> None:
"""RecursiveHash rejects types that define __ffi_eq__ but not
__ffi_hash__."""
- obj = TestEqWithoutHash(1, "hello") # ty:
ignore[too-many-positional-arguments]
+ obj = TestEqWithoutHash(1, "hello")
with pytest.raises(ValueError, match="__ffi_eq__ or __ffi_compare__ but
not __ffi_hash__"):
RecursiveHash(obj)
def test_eq_without_hash_inside_container_raises() -> None:
"""The guard also triggers when the object is nested inside a container."""
- obj = TestEqWithoutHash(1, "hello") # ty:
ignore[too-many-positional-arguments]
+ obj = TestEqWithoutHash(1, "hello")
arr = tvm_ffi.Array([obj])
with pytest.raises(ValueError, match="__ffi_eq__ or __ffi_compare__ but
not __ffi_hash__"):
RecursiveHash(arr)
diff --git a/tests/python/test_dataclass_repr.py
b/tests/python/test_dataclass_repr.py
index aea03cae..165ae13b 100644
--- a/tests/python/test_dataclass_repr.py
+++ b/tests/python/test_dataclass_repr.py
@@ -193,7 +193,7 @@ def test_repr_user_object_all_fields() -> None:
def test_repr_user_object_repr_off() -> None:
"""Test repr of object with Repr(false) fields excluded."""
# Positional order: required first (v_i64, v_i32, v_f64), then optional
(v_f32)
- obj = tvm_ffi.testing._TestCxxClassDerived(1, 2, 3.5, 4.5) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing._TestCxxClassDerived(1, 2, 3.5, 4.5)
assert ReprPrint(obj) == "testing.TestCxxClassDerived(v_f64=3.5,
v_f32=4.5)"
@@ -393,7 +393,7 @@ def test_repr_map_with_object_values() -> None:
def test_repr_derived_derived_shows_all_own_fields() -> None:
"""TestCxxClassDerivedDerived should show v_f64, v_f32, v_str, v_bool (not
v_i64, v_i32)."""
# Positional order: required (v_i64, v_i32, v_f64, v_bool), then optional
(v_f32, v_str)
- obj = tvm_ffi.testing._TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0,
"test") # ty: ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing._TestCxxClassDerivedDerived(1, 2, 3.0, True, 4.0,
"test")
assert (
ReprPrint(obj)
== 'testing.TestCxxClassDerivedDerived(v_f64=3, v_f32=4, v_str="test",
v_bool=True)'
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 8f563433..afed61bf 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -35,7 +35,7 @@ def test_make_object() -> None:
def test_make_object_via_init() -> None:
- obj0 = tvm_ffi.testing.TestIntPair(1, 2) # ty:
ignore[too-many-positional-arguments]
+ obj0 = tvm_ffi.testing.TestIntPair(1, 2)
assert obj0.a == 1
assert obj0.b == 2
@@ -49,7 +49,7 @@ def test_method() -> None:
def test_attribute() -> None:
- obj = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ obj = tvm_ffi.testing.TestIntPair(3, 4)
assert obj.a == 3
assert obj.b == 4
assert type(obj).a.__doc__ == "Field `a`"
@@ -137,8 +137,14 @@ def test_opaque_type_error() -> None:
def test_object_init() -> None:
+ # Registered class with auto-generated __ffi_init__ (all fields have
defaults)
+ obj = tvm_ffi.testing.TestObjectBase()
+ assert obj.v_i64 == 10
+ assert obj.v_f64 == 10.0
+ assert obj.v_str == "hello"
+
# Registered class with __c_ffi_init__ should work fine
- pair = tvm_ffi.testing.TestIntPair(3, 4) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(3, 4)
assert pair.a == 3 and pair.b == 4
# FFI-returned objects should work fine
@@ -194,7 +200,7 @@ def test_unregistered_object_fallback() -> None:
),
(
tvm_ffi.testing.TestIntPair,
- lambda: tvm_ffi.testing.TestIntPair(1, 2), # ty:
ignore[too-many-positional-arguments]
+ lambda: tvm_ffi.testing.TestIntPair(1, 2),
),
(
tvm_ffi.testing.TestObjectDerived,
diff --git a/tests/python/test_serialization.py
b/tests/python/test_serialization.py
index 83a9cf96..8d4a8681 100644
--- a/tests/python/test_serialization.py
+++ b/tests/python/test_serialization.py
@@ -406,7 +406,7 @@ class TestObjectSerialization:
def test_int_pair_roundtrip(self) -> None:
"""TestIntPair has refl::init and POD int64 fields."""
- pair = tvm_ffi.testing.TestIntPair(3, 7) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(3, 7)
s = to_json_graph_str(pair)
result = from_json_graph_str(s)
assert result.a == 3
@@ -414,21 +414,21 @@ class TestObjectSerialization:
def test_int_pair_zero_values(self) -> None:
"""TestIntPair with zero values roundtrips correctly."""
- pair = tvm_ffi.testing.TestIntPair(0, 0) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(0, 0)
result = _roundtrip(pair)
assert result.a == 0
assert result.b == 0
def test_int_pair_negative_values(self) -> None:
"""TestIntPair with negative values roundtrips correctly."""
- pair = tvm_ffi.testing.TestIntPair(-100, -200) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(-100, -200)
result = _roundtrip(pair)
assert result.a == -100
assert result.b == -200
def test_int_pair_large_values(self) -> None:
"""TestIntPair with large values roundtrips correctly."""
- pair = tvm_ffi.testing.TestIntPair(10**15, -(10**15)) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(10**15, -(10**15))
result = _roundtrip(pair)
assert result.a == 10**15
assert result.b == -(10**15)
@@ -516,7 +516,7 @@ class TestJSONStructure:
def test_object_pod_fields_are_inlined(self) -> None:
"""POD fields (int, bool, float) are inlined directly via
field_static_type_index."""
- pair = tvm_ffi.testing.TestIntPair(3, 7) # ty:
ignore[too-many-positional-arguments]
+ pair = tvm_ffi.testing.TestIntPair(3, 7)
s = to_json_graph_str(pair)
parsed = json.loads(s)
root = parsed["nodes"][parsed["root_index"]]