This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new d8bd1890 feat(python): add structural equality/hashing support to
`py_class` and `field` (#507)
d8bd1890 is described below
commit d8bd1890ba407181f18284a6714d1c8cca57dd52
Author: Junru Shao <[email protected]>
AuthorDate: Sun Mar 22 10:59:29 2026 -0700
feat(python): add structural equality/hashing support to `py_class` and
`field` (#507)
## Summary
- Add `structure=` parameter to `@py_class()` and `field()` controlling
structural equality/hashing behavior. Class-level values: `"var"`,
`"tree"`, `"const-tree"`, `"dag"`. Field-level values: `"ignore"`,
`"def"`. Dispatched via Cython-accelerated `_ffi_seq_hash_kind` and
`_ffi_field_info_seq_hash` attributes, mapping to C++
`TVMFFISEqHashKind` and `SEqHashIgnore`/`SEqHashDef` field flags.
- Expand `py_class` test coverage with 349 tests ported from mlc-python
(field parsing, defaults, inheritance, JSON serialization, nested
structures, edge cases).
- Add `docs/concepts/structural_eq_hash.rst` documenting the structural
equality/hashing design.
- Pin `astral-sh/setup-uv` to v7.3.1 and `pypa/cibuildwheel` to v3.3.1
(Apache-approved SHAs), fixing CI permission failures in
https://github.com/apache/tvm-ffi/actions/runs/23389552989.
## Test plan
- [x] `uv run pytest -vvs tests/python/test_dataclass_py_class.py
tests/python/test_structural_py_class.py` — 349 passed
- [x] All pre-commit hooks pass
---
.github/actions/build-wheel-for-publish/action.yml | 4 +-
docs/.rstcheck.cfg | 2 +-
docs/concepts/structural_eq_hash.rst | 715 +++++++++++++++
docs/index.rst | 1 +
python/tvm_ffi/cython/type_info.pxi | 45 +-
python/tvm_ffi/dataclasses/field.py | 41 +-
python/tvm_ffi/dataclasses/py_class.py | 44 +-
src/ffi/extra/serialization.cc | 8 +-
tests/python/test_dataclass_copy.py | 282 +++++-
tests/python/test_dataclass_py_class.py | 978 ++++++++++++++++++++-
tests/python/test_dataclass_repr.py | 62 ++
tests/python/test_structural_py_class.py | 361 ++++++++
12 files changed, 2520 insertions(+), 23 deletions(-)
diff --git a/.github/actions/build-wheel-for-publish/action.yml
b/.github/actions/build-wheel-for-publish/action.yml
index 60047f64..b7e581e8 100644
--- a/.github/actions/build-wheel-for-publish/action.yml
+++ b/.github/actions/build-wheel-for-publish/action.yml
@@ -54,7 +54,7 @@ runs:
python-version: 3.8
- name: Set up uv
- uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 #
v6.7.0
+ uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 #
v7.3.1
- name: Check out source
uses: actions/checkout@v5
@@ -93,7 +93,7 @@ runs:
- name: Build wheels
if: ${{ inputs.build_wheels == 'true' }}
- uses: pypa/[email protected]
+ uses: pypa/cibuildwheel@298ed2fb2c105540f5ed055e8a6ad78d82dd3a7e #
v3.3.1
env:
CIBW_ARCHS_MACOS: ${{ inputs.arch }}
CIBW_ARCHS_LINUX: ${{ inputs.arch }}
diff --git a/docs/.rstcheck.cfg b/docs/.rstcheck.cfg
index 4e532c10..829528df 100644
--- a/docs/.rstcheck.cfg
+++ b/docs/.rstcheck.cfg
@@ -1,5 +1,5 @@
[rstcheck]
report_level = warning
-ignore_directives = automodule, autosummary, currentmodule, toctree, ifconfig,
tab-set, collapse, tabs, dropdown
+ignore_directives = automodule, autosummary, currentmodule, toctree, ifconfig,
tab-set, collapse, tabs, dropdown, mermaid
ignore_roles = ref, cpp:class, cpp:func, py:func, c:macro,
external+data-api:doc, external+scikit_build_core:doc, external+dlpack:doc
ignore_languages = cpp, python
diff --git a/docs/concepts/structural_eq_hash.rst
b/docs/concepts/structural_eq_hash.rst
new file mode 100644
index 00000000..006114b0
--- /dev/null
+++ b/docs/concepts/structural_eq_hash.rst
@@ -0,0 +1,715 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Structural Equality and Hashing
+===============================
+
+TVM FFI provides ``structural_equal`` and ``structural_hash`` for the
+object graph. These compare objects by **content** — recursively walking
+fields — rather than by pointer identity.
+
+The behavior is controlled by two layers of annotation on
+:func:`~tvm_ffi.dataclasses.py_class`:
+
+1. **Type-level** ``structure=`` — what *role* does this type play in the
+ IR graph?
+2. **Field-level** ``structure=`` on :func:`~tvm_ffi.dataclasses.field` —
+ should this field be skipped, or does it introduce new variable bindings?
+
+This document explains what each annotation means, when to use it, and how
+they compose.
+
+
+Type-Level Annotation
+---------------------
+
+The ``structure`` parameter on ``@py_class`` declares how instances of the
+type participate in structural equality and hashing:
+
+.. code-block:: python
+
+ @py_class(structure="tree")
+ class Expr(Object):
+ ...
+
+Quick reference
+~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 18 37 45
+
+ * - ``structure=``
+ - Meaning
+ - Use when...
+ * - ``"tree"``
+ - A regular IR node
+ - Default for most IR nodes
+ * - ``"const-tree"``
+ - An immutable value node (with pointer shortcut)
+ - The type has no transitive ``"var"`` children
+ * - ``"dag"``
+ - A node in a dataflow graph
+ - Pointer sharing is semantically meaningful
+ * - ``"var"``
+ - A bound variable
+ - The type represents a variable binding
+ * - ``"singleton"``
+ - A singleton
+ - Exactly one instance per logical identity (e.g. registry entries)
+ * - ``None``
+ - Not comparable
+ - The type should never be compared structurally
+
+
+``"tree"`` — The Default
+-------------------------
+
+.. code-block:: python
+
+ @py_class(structure="tree")
+ class Add(Object):
+ lhs: Expr
+ rhs: Expr
+
+**Meaning**: "This node is defined by its fields. Two nodes are equal if and
+only if all their fields are recursively equal."
+
+This is the right choice for the vast majority of IR nodes: expressions,
+statements, types, attributes, buffers, etc.
+
+**Example.**
+
+.. code-block:: text
+
+ 1 + 2 vs 1 + 2 → Equal
+ 1 + 2 vs 1 + 3 → Not equal (rhs differs)
+
+Sharing is invisible
+~~~~~~~~~~~~~~~~~~~~
+
+``"tree"`` treats every reference independently. If the same object is
+referenced multiple times, each reference is compared by content separately.
+Sharing is **not** part of the structural identity:
+
+.. code-block:: text
+
+ let s = x + 1
+
+ (s, s) ← same object referenced twice
+ (x + 1, x + 1) ← two independent copies with same content
+
+ These are EQUAL under "tree" — sharing is not detected.
+
+The following diagram illustrates this. Under ``"tree"``, the **DAG** on the
+left and the **tree** on the right are considered structurally equal because
+every node has the same content:
+
+.. mermaid::
+
+ graph TD
+ subgraph "DAG — shared node"
+ T1["(_, _)"]
+ S1["s = x + 1"]
+ T1 -->|".0"| S1
+ T1 -->|".1"| S1
+ end
+
+ subgraph "Tree — independent copies"
+ T2["(_, _)"]
+ A1["x + 1"]
+ A2["x + 1"]
+ T2 -->|".0"| A1
+ T2 -->|".1"| A2
+ end
+
+ style S1 fill:#d4edda
+ style A1 fill:#d4edda
+ style A2 fill:#d4edda
+
+If sharing needs to matter, use ``"dag"`` instead.
+
+
+``"const-tree"`` — Tree with a Fast Path
+-----------------------------------------
+
+.. code-block:: python
+
+ @py_class(structure="const-tree")
+ class DeviceMesh(Object):
+ shape: list[int]
+ device_ids: list[int]
+
+**Meaning**: "Same as ``"tree"``, but if two references point to the same
+object, they are guaranteed equal — skip the field comparison."
+
+This is purely a **performance optimization**. The only behavioral difference
+from ``"tree"`` is that pointer identity short-circuits to ``True``.
+
+When is this safe?
+~~~~~~~~~~~~~~~~~~
+
+When the type satisfies two conditions:
+
+1. **Immutable** — content doesn't change after construction, so same-pointer
+ always implies same-content.
+2. **No transitive** ``"var"`` **children** — skipping field traversal won't
+ cause variable mappings to be missed (see :ref:`var-kind` for why this
+ matters).
+
+Why not use it everywhere?
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Most IR nodes are immutable, but many transitively contain variables
+(e.g., ``x + 1`` contains the ``"var"`` node ``x``). If the pointer
+shortcut fires, the traversal skips ``x``, and a variable mapping that should
+have been established is silently missed.
+
+The following diagram shows the danger. Suppose the ``+`` node were
+incorrectly annotated as ``"const-tree"``. When comparing two trees that
+share a sub-expression, the pointer shortcut fires on the shared node, and
+the ``"var"`` ``x`` inside it is never visited — so no ``x ↔ y`` mapping
+is recorded:
+
+.. mermaid::
+
+ graph TD
+ subgraph "lhs"
+ LT["(_, _)"]
+ LE["x + 1"]
+ LX["x : var"]
+ LT -->|".0"| LE
+ LT -->|".1"| LX
+ LE -->|".lhs"| LX2["x"]
+ end
+
+ subgraph "rhs"
+ RT["(_, _)"]
+ RE["y + 1"]
+ RY["y : var"]
+ RT -->|".0"| RE
+ RT -->|".1"| RY
+ RE -->|".lhs"| RY2["y"]
+ end
+
+ LE -. "const-tree would skip here<br/>(misses x ↔ y mapping)" .-> RE
+ LX -. "Later comparison fails:<br/>x has no recorded mapping" .-> RY
+
+ style LE fill:#fff3cd
+ style RE fill:#fff3cd
+ style LX fill:#f8d7da
+ style RY fill:#f8d7da
+ style LX2 fill:#f8d7da
+ style RY2 fill:#f8d7da
+
+
+``"dag"`` — Sharing-Aware Comparison
+-------------------------------------
+
+.. code-block:: python
+
+ @py_class(structure="dag")
+ class Binding(Object):
+ var: Var
+ value: Expr
+
+**Meaning**: "This node lives in a graph where pointer sharing is
+semantically meaningful. Two graphs are equal only if they have the same
+content **and** the same sharing structure."
+
+Why it exists
+~~~~~~~~~~~~~
+
+In dataflow IR, sharing matters. Consider:
+
+.. code-block:: text
+
+ # Program A: shared — compute once, use twice
+ let s = x + 1 in (s, s)
+
+ # Program B: independent — compute twice
+ (x + 1, x + 1)
+
+Program A computes ``x + 1`` once and references it twice; Program B
+computes it independently twice. Under ``"tree"`` these are equal;
+under ``"dag"`` they are **not**:
+
+.. mermaid::
+
+ graph TD
+ subgraph "Program A — DAG"
+ TA["(_, _)"]
+ SA["s = x + 1"]
+ TA -->|".0"| SA
+ TA -->|".1"| SA
+ end
+
+ subgraph "Program B — Tree"
+ TB["(_, _)"]
+ A1["x + 1"]
+ A2["x + 1"]
+ TB -->|".0"| A1
+ TB -->|".1"| A2
+ end
+
+ SA -. "NOT EQUAL under dag<br/>(sharing structure differs)" .-> A1
+
+ style SA fill:#d4edda
+ style A1 fill:#d4edda
+ style A2 fill:#f8d7da
+
+How ``"dag"`` detects sharing
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``"dag"`` maintains a bijective (one-to-one) mapping between objects that
+have been successfully compared. When the same object appears again, it
+checks whether the *pairing* is consistent:
+
+.. code-block:: text
+
+ Comparing Program A vs Program B:
+
+ .0: s ↔ (x+1)₁ → content equal, record pairing: s ↔ (x+1)₁
+ .1: s ↔ (x+1)₂ → s already paired with (x+1)₁, not (x+1)₂
+ → NOT EQUAL
+
+The mapping is **bijective**: if ``a`` is paired with ``b``, no other object
+can pair with either ``a`` or ``b``. This prevents false positives in both
+directions.
+
+**Example of the reverse direction.**
+
+.. code-block:: text
+
+ lhs: (a, b) rhs: (a, a) where a ≅ b (same content)
+
+ .0: a₁ ↔ a₂ → equal, record a₁ ↔ a₂
+ .1: b₁ ↔ a₂ → b₁ is new, but a₂ already paired with a₁
+ → NOT EQUAL
+
+Without the reverse check, the second comparison would proceed to content
+comparison, find ``b₁ ≅ a₂``, and incorrectly succeed.
+
+Full comparison: ``"tree"`` vs ``"dag"``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 48 13 13
+
+ * - Scenario
+ - ``"tree"``
+ - ``"dag"``
+ * - both trees with same content
+ - Equal
+ - Equal
+ * - both DAGs, same sharing shape
+ - Equal
+ - Equal
+ * - ``let s = e in (s, s)`` vs ``(e, e')`` where ``e ≅ e'``
+ - Equal
+ - **Not equal**
+ * - ``(a, b)`` vs ``(a, a)`` where ``a ≅ b``
+ - Equal
+ - **Not equal**
+
+
+.. _var-kind:
+
+``"var"`` — Bound Variables
+----------------------------
+
+.. code-block:: python
+
+ @py_class(structure="var")
+ class Var(Object):
+ name: str
+
+**Meaning**: "This is a variable. Two variables are equal if they are
+**bound in corresponding positions**, not if they have the same name or
+content."
+
+The problem
+~~~~~~~~~~~
+
+.. code-block:: text
+
+ fun x → x + 1 should equal fun y → y + 1
+
+Variables are not defined by their content (name, type annotation). They
+are defined by **where they are introduced** and **how they are used**.
+``x`` and ``y`` above are interchangeable because they occupy the same
+binding position and are used in the same way.
+
+How it works: definition regions
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``"var"`` works together with ``field(structure="def")`` (see
+:ref:`field-annotations`). A field marked ``structure="def"`` is a
+**definition region** — it's where new variable bindings are introduced.
+
+- **Inside a definition region**: encountering two different variables
+ establishes a correspondence ("treat ``x`` as equivalent to ``y``").
+- **Outside a definition region**: variables are only equal if a prior
+ correspondence already exists, or they are the same pointer.
+
+The following diagram traces the comparison of two alpha-equivalent functions:
+
+.. mermaid::
+
+ sequenceDiagram
+ participant C as Comparator
+ participant L as lhs: fun x → x + 1
+ participant R as rhs: fun y → y + 1
+
+ Note over C: Field "params" has structure="def"
+ C->>L: get params → [x]
+ C->>R: get params → [y]
+ Note over C: Enter definition region
+ C->>C: Compare x ↔ y: both are Vars
+ Note over C: Record mapping: x ↔ y
+ Note over C: Exit definition region
+
+ Note over C: Field "body" — normal region
+ C->>L: get body → x + 1
+ C->>R: get body → y + 1
+ C->>C: Compare + fields...
+ C->>C: x ↔ y: lookup finds x→y ✓
+ C->>C: 1 ↔ 1: equal ✓
+ Note over C: Result: EQUAL ✓
+
+**Without** a definition region, the same variables would **not** be equal:
+
+.. code-block:: text
+
+ # Bare expressions, no enclosing function:
+ x + 1 vs y + 1 → NOT EQUAL (no definition region, different pointers)
+
+Full comparison: with and without definition regions
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 55 22 22
+
+ * - Scenario
+ - With ``"def"``
+ - Without
+ * - ``fun x → x + 1`` vs ``fun y → y + 1``
+ - Equal
+ - n/a
+ * - ``fun x → x + 1`` vs ``fun y → x + 1``
+ - **Not equal** (body uses ``x`` but mapping says ``y``)
+ - n/a
+ * - ``fun (x, y) → x + y`` vs ``fun (a, b) → a + b``
+ - Equal (x↔a, y↔b)
+ - n/a
+ * - ``fun (x, y) → x + y`` vs ``fun (a, b) → b + a``
+ - **Not equal** (x↔a but body uses ``x`` where ``b`` appears)
+ - n/a
+ * - ``x + 1`` vs ``y + 1`` (bare)
+ - n/a
+ - **Not equal**
+ * - ``x + 1`` vs ``x + 1`` (same pointer)
+ - n/a
+ - Equal
+
+Inconsistent variable usage
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The bijective mapping catches inconsistencies. Consider:
+
+.. code-block:: text
+
+ fun (x, y) → x + x vs fun (a, b) → a + b
+
+.. mermaid::
+
+ sequenceDiagram
+ participant C as Comparator
+ participant L as lhs: fun (x, y) → x + x
+ participant R as rhs: fun (a, b) → a + b
+
+ Note over C: Definition region (params)
+ C->>C: x ↔ a → record x↔a ✓
+ C->>C: y ↔ b → record y↔b ✓
+
+ Note over C: Body: x + x vs a + b
+ C->>C: x ↔ a → lookup x→a, matches ✓
+ C->>C: x ↔ b → lookup x→a, but rhs is b ≠ a → FAIL ✗
+ Note over C: Result: NOT EQUAL ✓
+
+The ``map_free_vars`` flag
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``structural_equal(lhs, rhs, map_free_vars=True)`` starts the comparison
+in "definition region" mode. This is useful for comparing standalone
+expressions where you want alpha-equivalence at the top level without an
+enclosing function:
+
+.. code-block:: python
+
+ # With map_free_vars=True:
+ structural_equal(x + 1, y + 1, map_free_vars=True) # → True
+
+ # With map_free_vars=False (default):
+ structural_equal(x + 1, y + 1) # → False
+
+
+``"singleton"`` — Singletons
+------------------------------
+
+.. code-block:: python
+
+ @py_class(structure="singleton")
+ class Op(Object):
+ name: str
+
+**Meaning**: "There is exactly one instance of this object per logical
+identity. Pointer equality is the only valid comparison."
+
+No content comparison is ever performed. Different pointers are always
+unequal; same pointer is always equal.
+
+.. code-block:: python
+
+ op_conv = Op.get("nn.conv2d")
+ op_relu = Op.get("nn.relu")
+
+ structural_equal(op_conv, op_conv) # → True (same pointer)
+ structural_equal(op_conv, op_relu) # → False (different pointers)
+
+
+.. _field-annotations:
+
+Field-Level Annotations
+-----------------------
+
+The ``structure`` parameter on :func:`~tvm_ffi.dataclasses.field` controls
+how structural equality/hashing treats that specific field.
+
+``structure="ignore"`` — Exclude a field
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ @py_class(structure="tree")
+ class MyNode(Object):
+ value: int
+ span: str = field(structure="ignore")
+
+**Meaning**: "This field is not part of the node's structural identity.
+Skip it during comparison and hashing."
+
+Use for:
+
+- **Source locations** (``span``) — where the node came from in source code
+ doesn't affect what it means.
+- **Cached/derived values** — computed from other fields, would be
+ redundant to compare.
+- **Debug annotations** — names, comments, metadata for human consumption.
+
+``structure="def"`` — Definition region
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ @py_class(structure="tree")
+ class Lambda(Object):
+ params: list[Var] = field(structure="def")
+ body: Expr
+
+**Meaning**: "This field introduces new variable bindings. When comparing
+or hashing this field, allow new variable correspondences to be
+established."
+
+This is the counterpart to ``"var"``. A ``"var"`` type says "I am a
+variable"; ``structure="def"`` says "this field is where variables are
+defined." Together they enable alpha-equivalence: comparing functions up
+to consistent variable renaming.
+
+Use for:
+
+- **Function parameter lists**
+- **Let-binding left-hand sides**
+- **Any field that introduces names into scope**
+
+
+Custom Equality and Hashing
+----------------------------
+
+For types where the default field-by-field traversal is insufficient, you
+can register custom callbacks as type attributes:
+
+- **``__s_equal__``** — custom equality logic
+- **``__s_hash__``** — custom hashing logic
+
+These are registered per type via ``EnsureTypeAttrColumn``. When present,
+they replace the default field iteration. The system still manages all
+kind-specific logic (``"dag"`` memoization, ``"var"`` mapping, etc.) — the
+custom callback only controls which sub-values to compare/hash and in what
+order.
+
+
+All Kinds at a Glance
+---------------------
+
+The following diagram visualizes the five comparable kinds, arranged by how
+much structural information they track:
+
+.. mermaid::
+
+ graph LR
+ UI["singleton<br/><i>pointer only</i>"]
+ TN["tree<br/><i>content only</i>"]
+ CTN["const-tree<br/><i>content + pointer shortcut</i>"]
+ DN["dag<br/><i>content + sharing</i>"]
+ FV["var<br/><i>content + binding position</i>"]
+
+ UI --- TN
+ TN --- CTN
+ TN --- DN
+ TN --- FV
+
+ style UI fill:#e2e3e5
+ style TN fill:#d4edda
+ style CTN fill:#d4edda
+ style DN fill:#cce5ff
+ style FV fill:#fff3cd
+
+.. list-table::
+ :header-rows: 1
+ :widths: 18 18 18 18 18
+
+ * -
+ - Content comparison
+ - Pointer shortcut
+ - Tracks sharing
+ - Tracks binding position
+ * - ``"singleton"``
+ - No
+ - Yes (only)
+ - No
+ - No
+ * - ``"tree"``
+ - Yes
+ - No
+ - No
+ - No
+ * - ``"const-tree"``
+ - Yes
+ - Yes (fast path)
+ - No
+ - No
+ * - ``"dag"``
+ - Yes
+ - No
+ - Yes
+ - No
+ * - ``"var"``
+ - Yes
+ - No
+ - No
+ - Yes
+
+
+Decision Guide
+--------------
+
+When defining a new type:
+
+.. mermaid::
+
+ graph TD
+ Start["New @py_class type"] --> Q1{"Singleton?<br/>(one instance
per<br/>logical identity)"}
+ Q1 -->|Yes| UI["structure="singleton""]
+ Q1 -->|No| Q2{"Represents a<br/>variable binding?"}
+ Q2 -->|Yes| FV["structure="var""]
+ Q2 -->|No| Q3{"Pointer sharing<br/>semantically<br/>meaningful?"}
+ Q3 -->|Yes| DN["structure="dag""]
+ Q3 -->|No| Q4{"Immutable AND<br/>no transitive<br/>var children?"}
+ Q4 -->|Yes| CTN["structure="const-tree""]
+ Q4 -->|No| TN["structure="tree""]
+
+ style UI fill:#e2e3e5
+ style FV fill:#fff3cd
+ style DN fill:#cce5ff
+ style CTN fill:#d4edda
+ style TN fill:#d4edda
+
+For fields:
+
+.. mermaid::
+
+ graph TD
+ Start["field() parameter"] --> Q1{"Irrelevant to<br/>structural
identity?<br/>(span, cache, debug)"}
+ Q1 -->|Yes| IGN["structure="ignore""]
+ Q1 -->|No| Q2{"Introduces new<br/>variable bindings?"}
+ Q2 -->|Yes| DEF["structure="def""]
+ Q2 -->|No| NONE["No flag needed"]
+
+ style IGN fill:#f8d7da
+ style DEF fill:#fff3cd
+ style NONE fill:#d4edda
+
+
+Worked Example
+--------------
+
+Putting it all together for a function node with parameters, body, and
+source location:
+
+.. code-block:: python
+
+ @py_class(structure="tree")
+ class Lambda(Object):
+ params: list[Var] = field(structure="def")
+ body: Expr
+ span: str = field(structure="ignore", default="")
+
+ @py_class(structure="var")
+ class Var(Object):
+ name: str
+
+ @py_class(structure="singleton")
+ class Op(Object):
+ name: str
+
+With these annotations, alpha-equivalent functions are structurally equal:
+
+.. code-block:: text
+
+ # These two are structurally equal:
+ fun [x] → x + 1 (span="a.py:1")
+ fun [y] → y + 1 (span="b.py:5")
+
+ # - params has structure="def" → x maps to y
+ # - body uses that mapping → (x + 1) ≅ (y + 1)
+ # - span has structure="ignore" → locations don't matter
+
+And in Python:
+
+.. code-block:: python
+
+ from tvm_ffi import structural_equal, structural_hash
+
+ x, y = Var("x"), Var("y")
+ f1 = Lambda([x], x + 1, span="a.py:1")
+ f2 = Lambda([y], y + 1, span="b.py:5")
+
+ assert structural_equal(f1, f2) # alpha-equivalent
+ assert structural_hash(f1) == structural_hash(f2) # same hash
diff --git a/docs/index.rst b/docs/index.rst
index 5fb646ef..0c070dfd 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -67,6 +67,7 @@ Table of Contents
concepts/tensor.rst
concepts/func_module.rst
concepts/exception_handling.rst
+ concepts/structural_eq_hash.rst
.. toctree::
:maxdepth: 1
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 7a0b8c5e..98c5f8ca 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -704,7 +704,7 @@ class TypeInfo:
end = f_end
return (end + 7) & ~7 # align to 8 bytes
- def _register_fields(self, fields):
+ def _register_fields(self, fields, structure_kind=None):
"""Register Field descriptors and set up __ffi_new__/__ffi_init__.
Delegates to the module-level _register_fields function,
@@ -712,11 +712,19 @@ class TypeInfo:
then reads back methods registered by C++ via _register_methods.
Can only be called once (fields must be None beforehand).
+
+ Parameters
+ ----------
+ fields : list[Field]
+ The Field descriptors to register.
+ structure_kind : int | None
+ The structural equality/hashing kind (``TVMFFISEqHashKind``
integer).
+ ``None`` or ``0`` means unsupported (no metadata registered).
"""
assert self.fields is None, (
f"_register_fields already called for {self.type_key!r}"
)
- self.fields = _register_fields(self, fields)
+ self.fields = _register_fields(self, fields, structure_kind)
self._register_methods()
def _register_methods(self):
@@ -811,6 +819,12 @@ cdef _register_one_field(
flags |= kTVMFFIFieldFlagBitMaskCompareOff
if py_field.kw_only:
flags |= kTVMFFIFieldFlagBitMaskKwOnly
+ # Structural equality/hashing field annotations
+ cdef object field_structure = getattr(py_field, "structure", None)
+ if field_structure == "ignore":
+ flags |= kTVMFFIFieldFlagBitMaskSEqHashIgnore
+ elif field_structure == "def":
+ flags |= kTVMFFIFieldFlagBitMaskSEqHashDef
info.flags = flags
# --- native layout ---
@@ -888,7 +902,7 @@ cdef int _f_type_convert(void* type_converter, const
TVMFFIAny* value, TVMFFIAny
return -1
-def _register_fields(type_info, fields):
+def _register_fields(type_info, fields, structure_kind=None):
"""Register Field descriptors for a Python-defined type and set up
__ffi_new__/__ffi_init__.
For each Field:
@@ -897,8 +911,9 @@ def _register_fields(type_info, fields):
3. Creates a FunctionObj setter with type conversion
4. Registers via TVMFFITypeRegisterField
- After all fields, registers __ffi_new__ (object allocator) and
- __ffi_init__ (auto-generated constructor).
+ After all fields, registers __ffi_new__ (object allocator),
+ __ffi_init__ (auto-generated constructor), and optionally
+ type metadata (structural_eq_hash_kind).
Parameters
----------
@@ -906,6 +921,9 @@ def _register_fields(type_info, fields):
The TypeInfo of the type being defined.
fields : list[Field]
The Field descriptors to register.
+ structure_kind : int | None
+ The structural equality/hashing kind (``TVMFFISEqHashKind`` integer).
+ ``None`` or ``0`` means unsupported (no metadata registered).
Returns
-------
@@ -990,12 +1008,27 @@ def _register_fields(type_info, fields):
# 7. Register __ffi_new__ + deleter
_make_ffi_new(type_index, total_size)
- # 8. Register __ffi_init__ (auto-generated constructor)
+ # 8. Register type metadata (structural_eq_hash_kind) if specified.
+ if structure_kind is not None and structure_kind != 0:
+ _register_type_metadata(type_index, total_size, structure_kind)
+
+ # 9. Register __ffi_init__ (auto-generated constructor)
_register_auto_init(type_index)
return type_fields
+cdef _register_type_metadata(int32_t type_index, int32_t total_size, int
structure_kind):
+ """Register TVMFFITypeMetadata for the given type with structural eq/hash
kind."""
+ cdef TVMFFITypeMetadata metadata
+ metadata.doc.data = NULL
+ metadata.doc.size = 0
+ metadata.creator = NULL
+ metadata.total_size = total_size
+ metadata.structural_eq_hash_kind = <TVMFFISEqHashKind>structure_kind
+ CHECK_CALL(TVMFFITypeRegisterMetadata(type_index, &metadata))
+
+
def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
def wrapper(self: Any, *args: Any) -> Any:
return method_func(self, *args)
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index 9295f13a..97e8864e 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import sys
from collections.abc import Callable
-from typing import Any
+from typing import Any, ClassVar
from ..core import MISSING, TypeSchema
@@ -70,6 +70,17 @@ class Field:
kw_only : bool | None
Whether this field is keyword-only in ``__init__``.
``None`` means "inherit from the decorator-level *kw_only* flag".
+ structure : str | None
+ Structural equality/hashing annotation for this field. Valid
+ values are:
+
+ - ``None`` (default): the field participates normally in
+ structural comparison and hashing.
+ - ``"ignore"``: the field is excluded from structural equality
+ and hashing entirely (e.g. source spans, caches).
+ - ``"def"``: the field is a **definition region** that introduces
+ new variable bindings. Free variables encountered inside this
+ field are mapped by position, enabling alpha-equivalence.
doc : str | None
Optional docstring for the field.
@@ -85,6 +96,7 @@ class Field:
"kw_only",
"name",
"repr",
+ "structure",
"ty",
)
name: str | None
@@ -96,9 +108,13 @@ class Field:
hash: bool | None
compare: bool
kw_only: bool | None
+ structure: str | None
doc: str | None
- def __init__(
+ #: Valid values for the *structure* parameter.
+ _VALID_STRUCTURE_VALUES: ClassVar[frozenset[str | None]] =
frozenset({None, "ignore", "def"})
+
+ def __init__( # noqa: PLR0913
self,
name: str | None = None,
ty: TypeSchema | None = None,
@@ -110,6 +126,7 @@ class Field:
hash: bool | None = True,
compare: bool = False,
kw_only: bool | None = False,
+ structure: str | None = None,
doc: str | None = None,
) -> None:
# MISSING means "parameter not provided".
@@ -122,6 +139,11 @@ class Field:
raise TypeError(
f"default_factory must be a callable, got
{type(default_factory).__name__}"
)
+ if structure not in Field._VALID_STRUCTURE_VALUES:
+ raise ValueError(
+ f"structure must be one of
{sorted(Field._VALID_STRUCTURE_VALUES, key=str)}, "
+ f"got {structure!r}"
+ )
self.name = name
self.ty = ty
self.default = default
@@ -131,6 +153,7 @@ class Field:
self.hash = hash
self.compare = compare
self.kw_only = kw_only
+ self.structure = structure
self.doc = doc
@@ -143,6 +166,7 @@ def field(
hash: bool | None = None,
compare: bool = True,
kw_only: bool | None = None,
+ structure: str | None = None,
doc: str | None = None,
) -> Any:
"""Customize a field in a ``@py_class``-decorated class.
@@ -174,6 +198,11 @@ def field(
kw_only
Whether this field is keyword-only in ``__init__``.
``None`` means "inherit from the decorator-level ``kw_only`` flag".
+ structure
+ Structural equality/hashing annotation. ``None`` (default) means
+ the field participates normally. ``"ignore"`` excludes the field
+ from structural comparison and hashing. ``"def"`` marks the field
+ as a definition region for variable binding.
doc
Optional docstring for the field.
@@ -191,6 +220,13 @@ def field(
x: float
y: float = field(default=0.0, repr=False)
+
+ @py_class(structure="tree")
+ class MyFunc(Object):
+ params: Array = field(structure="def")
+ body: Expr
+ span: Object = field(structure="ignore")
+
"""
return Field(
default=default,
@@ -200,5 +236,6 @@ def field(
hash=hash,
compare=compare,
kw_only=kw_only,
+ structure=structure,
doc=doc,
)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index 9080fae3..103f2c26 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -225,7 +225,9 @@ def _phase2_register_fields(
own_fields = _collect_own_fields(cls, hints, params["kw_only"])
- type_info._register_fields(own_fields)
+ # Register fields and type-level structural eq/hash kind with the C layer.
+ structure_kind = _STRUCTURE_KIND_MAP.get(params.get("structure"))
+ type_info._register_fields(own_fields, structure_kind)
_add_class_attrs(cls, type_info)
# Remove deferred __init__ and restore user-defined __init__ if saved
@@ -338,6 +340,17 @@ def _install_deferred_init(
# ---------------------------------------------------------------------------
+#: Mapping from Python string names to C-level ``TVMFFISEqHashKind`` enum
values.
+_STRUCTURE_KIND_MAP: dict[str | None, int] = {
+ None: 0, # kTVMFFISEqHashKindUnsupported (default; no metadata registered)
+ "tree": 1, # kTVMFFISEqHashKindTreeNode
+ "var": 2, # kTVMFFISEqHashKindFreeVar
+ "dag": 3, # kTVMFFISEqHashKindDAGNode
+ "const-tree": 4, # kTVMFFISEqHashKindConstTreeNode
+ "singleton": 5, # kTVMFFISEqHashKindUniqueInstance
+}
+
+
@dataclass_transform(
eq_default=False,
order_default=False,
@@ -354,6 +367,7 @@ def py_class(
order: bool = False,
unsafe_hash: bool = False,
kw_only: bool = False,
+ structure: str | None = None,
slots: bool = True,
) -> Callable[[_T], _T] | _T:
"""Register a Python-defined FFI class with dataclass-style semantics.
@@ -379,6 +393,12 @@ def py_class(
@py_class("my.Point", eq=True) # both
class Point(Object): ...
+
+ @py_class(structure="tree") # structural eq/hash kind
+ class MyNode(Object):
+ value: int
+ span: Object = field(structure="ignore")
+
Parameters
----------
cls_or_type_key
@@ -400,6 +420,21 @@ def py_class(
If True, generate ``__hash__`` (unsafe for mutable objects).
kw_only
If True, all fields are keyword-only in ``__init__`` by default.
+ structure
+ Structural equality/hashing kind for this type. Controls how
+ instances participate in ``StructuralEqual`` and ``StructuralHash``.
+ Valid values are:
+
+ - ``None`` (default): structural comparison is not supported.
+ - ``"tree"``: content-based comparison, the safe default for
+ most IR nodes.
+ - ``"var"``: compared by binding position, for variable types.
+ - ``"dag"``: content + sharing-aware comparison, for dataflow
+ graph nodes.
+ - ``"const-tree"``: like ``"tree"`` with a pointer-equality
+ fast path (only safe for types with no transitive ``"var"``
+ children).
+ - ``"singleton"``: pointer equality only, for singleton types.
slots
Accepted for ``dataclass_transform`` compatibility. Object
subclasses always use ``__slots__ = ()`` via the metaclass.
@@ -412,6 +447,12 @@ def py_class(
"""
if order and not eq:
raise ValueError("order=True requires eq=True")
+ if structure not in _STRUCTURE_KIND_MAP:
+ raise ValueError(
+ f"structure must be one of "
+ f"{sorted(k for k in _STRUCTURE_KIND_MAP if k is not None)}"
+ f" or None, got {structure!r}"
+ )
effective_type_key = type_key
params: dict[str, Any] = {
@@ -421,6 +462,7 @@ def py_class(
"order": order,
"unsafe_hash": unsafe_hash,
"kw_only": kw_only,
+ "structure": structure,
}
def decorator(cls: _T) -> _T:
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index c1fb6211..2a21239f 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -201,10 +201,10 @@ class ObjectGraphSerializer {
}
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index());
- if (type_info->metadata == nullptr) {
- TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `"
- << String(type_info->type_key)
- << "`, so ToJSONGraph is not supported for this
type";
+ if (!HasCreator(type_info)) {
+ TVM_FFI_THROW(TypeError) << "Type `" << String(type_info->type_key)
+ << "` does not support ToJSONGraph "
+ << "(no native creator or __ffi_new__ type
attr)";
}
const Object* obj = value.cast<const Object*>();
json::Object data;
diff --git a/tests/python/test_dataclass_copy.py
b/tests/python/test_dataclass_copy.py
index 6507e636..128b6b4b 100644
--- a/tests/python/test_dataclass_copy.py
+++ b/tests/python/test_dataclass_copy.py
@@ -14,16 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: D102
+# ruff: noqa: D102, UP006, UP045
"""Tests for __copy__, __deepcopy__, and __replace__ on FFI objects."""
from __future__ import annotations
import copy
+import itertools
+import pickle
+import sys
+from typing import Dict, List, Optional
import pytest
import tvm_ffi
import tvm_ffi.testing
+from tvm_ffi._ffi_api import DeepCopy
+from tvm_ffi.core import Object
+from tvm_ffi.dataclasses import py_class
+
+_needs_310 = pytest.mark.skipif(sys.version_info < (3, 10), reason="X | Y
syntax requires 3.10+")
+
+_counter_pc = itertools.count()
+
+
+def _unique_key_pc(base: str) -> str:
+ return f"testing.copy_pc.{base}_{next(_counter_pc)}"
# --------------------------------------------------------------------------- #
@@ -937,3 +952,268 @@ class TestReplace:
obj = tvm_ffi.testing.TestNonCopyable(42)
with pytest.raises(TypeError, match="does not support replace"):
obj.__replace__() # ty: ignore[unresolved-attribute]
+
+
+# --------------------------------------------------------------------------- #
+# @py_class copy/deepcopy with rich field types
+# --------------------------------------------------------------------------- #
+class TestPyClassCopyRichFields:
+ """copy.copy / copy.deepcopy with container and optional fields on
@py_class."""
+
+ def test_shallow_copy_containers(self) -> None:
+ @py_class(_unique_key_pc("SCCont"))
+ class SCCont(Object):
+ x: int
+ items: List[int]
+ data: Dict[str, int]
+
+ obj = SCCont(x=42, items=[1, 2, 3], data={"a": 1})
+ obj2 = copy.copy(obj)
+ assert obj2.x == 42
+ assert len(obj2.items) == 3
+ assert obj2.data["a"] == 1
+ assert obj.items.same_as(obj2.items) # ty:ignore[unresolved-attribute]
+ assert obj.data.same_as(obj2.data) # ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_containers(self) -> None:
+ @py_class(_unique_key_pc("DCCont"))
+ class DCCont(Object):
+ x: int
+ items: List[int]
+ data: Dict[str, int]
+
+ obj = DCCont(x=42, items=[1, 2, 3], data={"a": 1})
+ obj2 = copy.deepcopy(obj)
+ assert obj2.x == 42
+ assert len(obj2.items) == 3
+ assert obj2.data["a"] == 1
+ assert not obj.items.same_as(obj2.items) #
ty:ignore[unresolved-attribute]
+ assert not obj.data.same_as(obj2.data) #
ty:ignore[unresolved-attribute]
+
+ def test_shallow_copy_nested_containers(self) -> None:
+ @py_class(_unique_key_pc("SCNest"))
+ class SCNest(Object):
+ matrix: List[List[int]]
+
+ obj = SCNest(matrix=[[1, 2], [3, 4]])
+ obj2 = copy.copy(obj)
+ assert obj.matrix.same_as(obj2.matrix) #
ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_nested_containers(self) -> None:
+ @py_class(_unique_key_pc("DCNest"))
+ class DCNest(Object):
+ matrix: List[List[int]]
+
+ obj = DCNest(matrix=[[1, 2], [3, 4]])
+ obj2 = copy.deepcopy(obj)
+ assert obj2.matrix[0][0] == 1
+ assert obj2.matrix[1][1] == 4
+ assert not obj.matrix.same_as(obj2.matrix) #
ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_mutation_independent(self) -> None:
+ @py_class(_unique_key_pc("DCMutInd"))
+ class DCMutInd(Object):
+ x: int
+ items: List[int]
+
+ obj = DCMutInd(x=1, items=[10, 20])
+ obj2 = copy.deepcopy(obj)
+ obj2.x = 99
+ assert obj.x == 1
+ obj2.items[0] = 999
+ assert obj.items[0] == 10
+
+ def test_shallow_copy_optional_fields(self) -> None:
+ @py_class(_unique_key_pc("SCOpt"))
+ class SCOpt(Object):
+ x: Optional[int]
+ items: Optional[List[int]]
+
+ obj = SCOpt(x=42, items=[1, 2])
+ obj2 = copy.copy(obj)
+ assert obj2.x == 42
+ assert len(obj2.items) == 2 # ty:ignore[invalid-argument-type]
+
+ def test_deep_copy_with_none_optional(self) -> None:
+ @py_class(_unique_key_pc("DCOptNone"))
+ class DCOptNone(Object):
+ x: Optional[int]
+ items: Optional[List[int]]
+
+ obj = DCOptNone(x=None, items=None)
+ obj2 = copy.deepcopy(obj)
+ assert obj2.x is None
+ assert obj2.items is None
+
+ def test_replace_with_containers(self) -> None:
+ @py_class(_unique_key_pc("ReplCont"))
+ class ReplCont(Object):
+ x: int
+ items: List[int]
+
+ obj = ReplCont(x=1, items=[1, 2, 3])
+ obj2 = obj.__replace__(x=99) # ty:ignore[unresolved-attribute]
+ assert obj2.x == 99
+ assert tuple(obj2.items) == (1, 2, 3)
+ assert obj.x == 1
+
+
+# --------------------------------------------------------------------------- #
+# DeepCopy FFI with @py_class containers
+# --------------------------------------------------------------------------- #
+class TestPyClassDeepCopyContainers:
+ """DeepCopy FFI function with @py_class container fields."""
+
+ def test_deep_copy_list_field(self) -> None:
+ @py_class(_unique_key_pc("DCList"))
+ class DCList(Object):
+ items: List[int]
+
+ obj = DCList(items=[1, 2, 3])
+ obj2 = DeepCopy(obj)
+ assert tuple(obj2.items) == (1, 2, 3)
+ assert not obj.items.same_as(obj2.items) #
ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_dict_field(self) -> None:
+ @py_class(_unique_key_pc("DCDict"))
+ class DCDict(Object):
+ data: Dict[str, int]
+
+ obj = DCDict(data={"a": 1, "b": 2})
+ obj2 = DeepCopy(obj)
+ assert obj2.data["a"] == 1
+ assert not obj.data.same_as(obj2.data) #
ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_nested(self) -> None:
+ @py_class(_unique_key_pc("DCNested"))
+ class DCNested(Object):
+ matrix: List[List[int]]
+
+ obj = DCNested(matrix=[[1, 2], [3, 4]])
+ obj2 = DeepCopy(obj)
+ assert obj2.matrix[0][0] == 1
+ assert not obj.matrix.same_as(obj2.matrix) #
ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_optional_none(self) -> None:
+ @py_class(_unique_key_pc("DCOptN"))
+ class DCOptN(Object):
+ items: Optional[List[int]]
+
+ obj = DCOptN(items=None)
+ assert DeepCopy(obj).items is None
+
+ def test_deep_copy_optional_value(self) -> None:
+ @py_class(_unique_key_pc("DCOptV"))
+ class DCOptV(Object):
+ items: Optional[List[int]]
+
+ obj = DCOptV(items=[1, 2, 3])
+ obj2 = DeepCopy(obj)
+ assert tuple(obj2.items) == (1, 2, 3)
+ assert not obj.items.same_as(obj2.items) #
ty:ignore[unresolved-attribute]
+
+
+# --------------------------------------------------------------------------- #
+# Copy of @py_class with custom __init__
+# --------------------------------------------------------------------------- #
+class TestPyClassCopyCustomInit:
+ """Copy of @py_class with init=False and custom __init__."""
+
+ def _make_cls(self) -> type:
+ @py_class(_unique_key_pc("CopyCI"), init=False)
+ class CopyCI(Object):
+ a: int
+ b: str
+
+ def __init__(self, *, b: str, a: int) -> None:
+ self.__ffi_init__(a, b)
+
+ return CopyCI
+
+ def test_shallow_copy_custom_init(self) -> None:
+ CopyCI = self._make_cls()
+ src = CopyCI(a=1, b="hello")
+ dst = copy.copy(src)
+ assert not src.same_as(dst)
+ assert dst.a == 1
+ assert dst.b == "hello"
+
+ def test_deep_copy_custom_init(self) -> None:
+ CopyCI = self._make_cls()
+ src = CopyCI(a=1, b="hello")
+ dst = copy.deepcopy(src)
+ assert not src.same_as(dst)
+ assert dst.a == 1
+ assert dst.b == "hello"
+
+
+# --------------------------------------------------------------------------- #
+# Pickle roundtrip for @py_class
+# --------------------------------------------------------------------------- #
+
+# Pickle requires classes to be importable at module level.
+
+
+@py_class(_unique_key_pc("PickleBasic"))
+class _PickleBasic(Object):
+ a: int
+ b: float
+ c: str
+ d: bool
+
+
+@py_class(_unique_key_pc("PickleOptV"))
+class _PickleOptV(Object):
+ a: Optional[int]
+ b: Optional[str]
+
+
+@py_class(_unique_key_pc("PickleCont"))
+class _PickleCont(Object):
+ items: List[int]
+ data: Dict[str, int]
+
+
+@py_class(_unique_key_pc("PickleCI"), init=False)
+class _PickleCI(Object):
+ a: int
+ b: str
+
+ def __init__(self, *, b: str, a: int) -> None:
+ self.__ffi_init__(a, b)
+
+
+class TestPyClassPickleRoundtrip:
+ """Pickle serialization/deserialization for @py_class objects."""
+
+ def test_pickle_basic_fields(self) -> None:
+ obj = _PickleBasic(a=1, b=2.0, c="hello", d=True)
+ obj2 = pickle.loads(pickle.dumps(obj))
+ assert obj2.a == 1
+ assert obj2.b == 2.0
+ assert obj2.c == "hello"
+ assert obj2.d is True
+
+ def test_pickle_optional_with_values(self) -> None:
+ obj = _PickleOptV(a=42, b="world")
+ obj2 = pickle.loads(pickle.dumps(obj))
+ assert obj2.a == 42
+ assert obj2.b == "world"
+
+ def test_pickle_optional_with_none(self) -> None:
+ obj = _PickleOptV(a=None, b=None)
+ obj2 = pickle.loads(pickle.dumps(obj))
+ assert obj2.a is None
+ assert obj2.b is None
+
+ def test_pickle_container_fields(self) -> None:
+ obj = _PickleCont(items=[1, 2, 3], data={"a": 1, "b": 2})
+ obj2 = pickle.loads(pickle.dumps(obj))
+ assert tuple(obj2.items) == (1, 2, 3)
+ assert obj2.data["a"] == 1
+
+ def test_pickle_custom_init(self) -> None:
+ obj = _PickleCI(a=1, b="hello")
+ obj2 = pickle.loads(pickle.dumps(obj))
+ assert obj2.a == 1
+ assert obj2.b == "hello"
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index fb4a0e67..520589ce 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -16,7 +16,7 @@
# under the License.
"""Tests for Python-defined TVM-FFI types: ``@py_class`` decorator and
low-level Field API."""
-# ruff: noqa: D102, PLR0124, PLW1641
+# ruff: noqa: D102, PLR0124, PLW1641, UP006, UP045
from __future__ import annotations
import copy
@@ -25,7 +25,7 @@ import inspect
import itertools
import math
import sys
-from typing import ClassVar
+from typing import Any, ClassVar, Dict, List, Optional
import pytest
import tvm_ffi
@@ -67,7 +67,7 @@ def _unique_key_ff(base: str) -> str:
def _make_type(
name: str,
- fields: list[Field],
+ fields: List[Field],
*,
parent: type = core.Object,
eq: bool = False,
@@ -198,7 +198,7 @@ class TestFieldParsing:
def test_optional_field(self) -> None:
@py_class(_unique_key("OptFld"))
class OptFld(Object):
- x: int | None
+ x: Optional[int]
obj = OptFld(x=42)
assert obj.x == 42
@@ -1135,7 +1135,7 @@ class TestInitReorderingAdversarial:
def test_post_init_sees_reordered_fields(self) -> None:
"""__post_init__ sees correct values even when __init__ reorders
fields."""
- seen: dict[str, int] = {}
+ seen: Dict[str, int] = {}
@py_class(_unique_key("PostReorder"))
class PostReorder(Object):
@@ -2778,7 +2778,7 @@ class TestMutualReferences:
info = core._register_py_class(parent_info, _unique_key_ff(name), cls)
return cls, info
- def _finalize(self, cls: type, info: core.TypeInfo, fields: list[Field])
-> None:
+ def _finalize(self, cls: type, info: core.TypeInfo, fields: List[Field])
-> None:
"""Register fields and install class attrs (phase 2 of two-phase)."""
info._register_fields(fields)
setattr(cls, "__tvm_ffi_type_info__", info)
@@ -3531,3 +3531,969 @@ class TestFFIGlobalFunctions:
def test_make_new_removed(self) -> None:
assert tvm_ffi.get_global_func("ffi.MakeNew", allow_missing=True) is
None
+
+
+# ###########################################################################
+# 22. Container field annotations
+# ###########################################################################
+class TestContainerFieldAnnotations:
+ """Container field annotations: List[T], Dict[K,V], nested."""
+
+ def test_list_int_field(self) -> None:
+ @py_class(_unique_key("ListInt"))
+ class ListInt(Object):
+ items: List[int]
+
+ obj = ListInt(items=[1, 2, 3])
+ assert len(obj.items) == 3
+ assert obj.items[0] == 1
+ assert obj.items[2] == 3
+
+ def test_list_int_from_tuple(self) -> None:
+ @py_class(_unique_key("ListIntTup"))
+ class ListIntTup(Object):
+ items: List[int]
+
+ obj = ListIntTup(items=(10, 20, 30)) #
ty:ignore[invalid-argument-type]
+ assert len(obj.items) == 3
+ assert obj.items[1] == 20
+
+ def test_dict_str_int_field(self) -> None:
+ @py_class(_unique_key("DictStrInt"))
+ class DictStrInt(Object):
+ mapping: Dict[str, int]
+
+ obj = DictStrInt(mapping={"a": 1, "b": 2})
+ assert len(obj.mapping) == 2
+ assert obj.mapping["a"] == 1
+ assert obj.mapping["b"] == 2
+
+ def test_list_list_int_field(self) -> None:
+ @py_class(_unique_key("ListListInt"))
+ class ListListInt(Object):
+ matrix: List[List[int]]
+
+ obj = ListListInt(matrix=[[1, 2, 3], [4, 5, 6]])
+ assert len(obj.matrix) == 2
+ assert len(obj.matrix[0]) == 3
+ assert obj.matrix[0][0] == 1
+ assert obj.matrix[1][2] == 6
+
+ def test_dict_str_list_int_field(self) -> None:
+ @py_class(_unique_key("DictStrListInt"))
+ class DictStrListInt(Object):
+ data: Dict[str, List[int]]
+
+ obj = DictStrListInt(data={"x": [1, 2, 3], "y": [4, 5, 6]})
+ assert len(obj.data) == 2
+ assert tuple(obj.data["x"]) == (1, 2, 3)
+ assert tuple(obj.data["y"]) == (4, 5, 6)
+
+ def test_container_field_set(self) -> None:
+ @py_class(_unique_key("ContSet"))
+ class ContSet(Object):
+ items: List[int]
+
+ obj = ContSet(items=[1, 2])
+ assert tuple(obj.items) == (1, 2)
+ obj.items = [3, 4, 5]
+ assert len(obj.items) == 3
+ assert obj.items[0] == 3
+
+ def test_dict_field_set(self) -> None:
+ @py_class(_unique_key("DictSet"))
+ class DictSet(Object):
+ mapping: Dict[str, int]
+
+ obj = DictSet(mapping={"a": 1})
+ obj.mapping = {"b": 2, "c": 3}
+ assert len(obj.mapping) == 2
+ assert obj.mapping["b"] == 2
+
+ def test_container_shared_reference(self) -> None:
+ @py_class(_unique_key("ContShare"))
+ class ContShare(Object):
+ a: List[int]
+ b: List[int]
+
+ obj = ContShare(a=[1, 2], b=[1, 2])
+ assert tuple(obj.a) == tuple(obj.b)
+
+ def test_untyped_list_field(self) -> None:
+ @py_class(_unique_key("UList"))
+ class UList(Object):
+ items: list
+
+ obj = UList(items=[1, "two", 3.0])
+ assert len(obj.items) == 3
+ assert obj.items[0] == 1
+ assert obj.items[1] == "two"
+
+ def test_untyped_dict_field(self) -> None:
+ @py_class(_unique_key("UDict"))
+ class UDict(Object):
+ data: dict
+
+ obj = UDict(data={"a": 1, "b": "two"})
+ assert len(obj.data) == 2
+
+
+# ###########################################################################
+# 23. Optional container fields
+# ###########################################################################
+class TestOptionalContainerFields:
+ """Optional[List[T]], Optional[Dict[K,V]] via @py_class."""
+
+ @_needs_310
+ def test_optional_list_int(self) -> None:
+ @py_class(_unique_key("OptListInt"))
+ class OptListInt(Object):
+ items: Optional[List[int]]
+
+ obj = OptListInt(items=[1, 2, 3])
+ assert len(obj.items) == 3 # ty:ignore[invalid-argument-type]
+ obj.items = None
+ assert obj.items is None
+ obj.items = [4, 5]
+ assert len(obj.items) == 2
+
+ @_needs_310
+ def test_optional_dict_str_int(self) -> None:
+ @py_class(_unique_key("OptDictStrInt"))
+ class OptDictStrInt(Object):
+ data: Optional[Dict[str, int]]
+
+ obj = OptDictStrInt(data={"a": 1})
+ assert obj.data["a"] == 1 # ty:ignore[not-subscriptable]
+ obj.data = None
+ assert obj.data is None
+ obj.data = {"b": 2}
+ assert obj.data["b"] == 2
+
+ @_needs_310
+ def test_optional_list_list_int(self) -> None:
+ @py_class(_unique_key("OptLLI"))
+ class OptLLI(Object):
+ matrix: Optional[List[List[int]]]
+
+ obj = OptLLI(matrix=[[1, 2], [3, 4]])
+ assert obj.matrix[0][0] == 1 # ty:ignore[not-subscriptable]
+ obj.matrix = None
+ assert obj.matrix is None
+
+ @_needs_310
+ def test_optional_dict_str_list_int(self) -> None:
+ @py_class(_unique_key("OptDSLI"))
+ class OptDSLI(Object):
+ data: Optional[Dict[str, List[int]]]
+
+ obj = OptDSLI(data={"x": [1, 2, 3]})
+ assert tuple(obj.data["x"]) == (1, 2, 3) #
ty:ignore[not-subscriptable]
+ obj.data = None
+ assert obj.data is None
+
+ def test_optional_list_with_typing_optional(self) -> None:
+ @py_class(_unique_key("OptListTyping"))
+ class OptListTyping(Object):
+ items: Optional[List[int]]
+
+ obj = OptListTyping(items=[1, 2, 3])
+ assert len(obj.items) == 3 # ty:ignore[invalid-argument-type]
+ obj.items = None
+ assert obj.items is None
+
+ def test_optional_dict_with_typing_optional(self) -> None:
+ @py_class(_unique_key("OptDictTyping"))
+ class OptDictTyping(Object):
+ data: Optional[Dict[str, int]]
+
+ obj = OptDictTyping(data={"a": 1})
+ assert obj.data["a"] == 1 # ty:ignore[not-subscriptable]
+ obj.data = None
+ assert obj.data is None
+
+
+# ###########################################################################
+# 24. Callable / Function fields
+# ###########################################################################
+class TestFunctionField:
+ """Function/Callable field via @py_class decorator."""
+
+ def test_function_field(self) -> None:
+ @py_class(_unique_key("FuncFld"))
+ class FuncFld(Object):
+ func: tvm_ffi.Function
+
+ fn = tvm_ffi.convert(lambda x: x + 1)
+ obj = FuncFld(func=fn)
+ assert obj.func(1) == 2
+
+ def test_function_field_set(self) -> None:
+ @py_class(_unique_key("FuncSet"))
+ class FuncSet(Object):
+ func: tvm_ffi.Function
+
+ fn1 = tvm_ffi.convert(lambda x: x + 1)
+ fn2 = tvm_ffi.convert(lambda x: x + 2)
+ obj = FuncSet(func=fn1)
+ assert obj.func(1) == 2
+ obj.func = fn2
+ assert obj.func(1) == 3
+
+ @_needs_310
+ def test_optional_function_field(self) -> None:
+ @py_class(_unique_key("OptFunc"))
+ class OptFunc(Object):
+ func: Optional[tvm_ffi.Function]
+
+ obj = OptFunc(func=None)
+ assert obj.func is None
+ obj.func = tvm_ffi.convert(lambda x: x * 2)
+ assert obj.func(3) == 6
+ obj.func = None
+ assert obj.func is None
+
+
+# ###########################################################################
+# 25. Any-typed fields (decorator level)
+# ###########################################################################
+class TestAnyFieldDecorator:
+ """Any-typed field via @py_class decorator."""
+
+ def test_any_holds_int(self) -> None:
+ @py_class(_unique_key("AnyI"))
+ class AnyI(Object):
+ val: Any
+
+ assert AnyI(val=42).val == 42
+
+ def test_any_holds_str(self) -> None:
+ @py_class(_unique_key("AnyS"))
+ class AnyS(Object):
+ val: Any
+
+ assert AnyS(val="hello").val == "hello"
+
+ def test_any_holds_none(self) -> None:
+ @py_class(_unique_key("AnyN"))
+ class AnyN(Object):
+ val: Any = None
+
+ assert AnyN().val is None
+
+ def test_any_holds_list(self) -> None:
+ @py_class(_unique_key("AnyL"))
+ class AnyL(Object):
+ val: Any
+
+ assert len(AnyL(val=[1, 2, 3]).val) == 3
+
+ def test_any_type_change(self) -> None:
+ @py_class(_unique_key("AnyChg"))
+ class AnyChg(Object):
+ val: Any = None
+
+ obj = AnyChg()
+ assert obj.val is None
+ obj.val = 42
+ assert obj.val == 42
+ obj.val = "hello"
+ assert obj.val == "hello"
+ obj.val = tvm_ffi.Array([1, 2])
+ assert len(obj.val) == 2
+ obj.val = None
+ assert obj.val is None
+
+
+# ###########################################################################
+# 26. Post-init field mutation
+# ###########################################################################
+class TestPostInitMutation:
+ """__post_init__ that mutates field values."""
+
+ def test_post_init_mutates_str(self) -> None:
+ @py_class(_unique_key("PostMut"))
+ class PostMut(Object):
+ a: int
+ b: str
+
+ def __post_init__(self) -> None:
+ self.b = self.b.upper()
+
+ obj = PostMut(a=1, b="hello")
+ assert obj.a == 1
+ assert obj.b == "HELLO"
+
+ def test_post_init_computes_derived(self) -> None:
+ @py_class(_unique_key("PostDeriv"))
+ class PostDeriv(Object):
+ x: int
+ doubled: int = 0
+
+ def __post_init__(self) -> None:
+ self.doubled = self.x * 2
+
+ assert PostDeriv(x=5).doubled == 10
+
+
+# ###########################################################################
+# 27. Custom __init__ with init=False
+# ###########################################################################
+class TestCustomInitFalse:
+ """Custom __init__ with init=False and reordered parameters."""
+
+ def test_custom_init_reordered_params(self) -> None:
+ @py_class(_unique_key("CustomOrd"), init=False)
+ class CustomOrd(Object):
+ a: int
+ b: float
+ c: str
+ d: bool
+
+ def __init__(self, b: float, c: str, a: int, d: bool) -> None:
+ self.__ffi_init__(a, b, c, d)
+
+ obj = CustomOrd(b=2.0, c="3", a=1, d=True)
+ assert obj.a == 1
+ assert obj.b == 2.0
+ assert obj.c == "3"
+ assert obj.d is True
+
+ def test_custom_init_keyword_only(self) -> None:
+ @py_class(_unique_key("CustomKW"), init=False)
+ class CustomKW(Object):
+ a: int
+ b: str
+
+ def __init__(self, *, b: str, a: int) -> None:
+ self.__ffi_init__(a, b)
+
+ obj = CustomKW(a=1, b="hello")
+ assert obj.a == 1
+ assert obj.b == "hello"
+
+
+# ###########################################################################
+# 28. Inheritance with defaults and containers
+# ###########################################################################
+class TestInheritanceWithDefaults:
+ """Inheritance with default values and container fields."""
+
+ def test_base_with_default_factory(self) -> None:
+ @py_class(_unique_key("BaseDef"))
+ class BaseDef(Object):
+ a: int
+ b: List[int] = field(default_factory=list)
+
+ obj = BaseDef(a=42)
+ assert obj.a == 42
+ assert len(obj.b) == 0
+
+ def test_derived_adds_optional_fields(self) -> None:
+ @py_class(_unique_key("BaseD"))
+ class BaseD(Object):
+ a: int
+ b: List[int] = field(default_factory=list)
+
+ @py_class(_unique_key("DerivedD"))
+ class DerivedD(BaseD):
+ c: Optional[int] = None
+ d: Optional[str] = "default"
+
+ obj = DerivedD(a=12)
+ assert obj.a == 12
+ assert len(obj.b) == 0
+ assert obj.c is None
+ assert obj.d == "default"
+
+ def test_derived_interleaved_required_optional(self) -> None:
+ @py_class(_unique_key("BaseIL"))
+ class BaseIL(Object):
+ a: int
+ b: List[int] = field(default_factory=list)
+
+ @py_class(_unique_key("DerivedIL"))
+ class DerivedIL(BaseIL):
+ c: int
+ d: Optional[str] = "default"
+
+ obj = DerivedIL(a=1, c=2)
+ assert obj.a == 1
+ assert len(obj.b) == 0
+ assert obj.c == 2
+ assert obj.d == "default"
+
+ def test_three_level_with_defaults(self) -> None:
+ @py_class(_unique_key("L1D"))
+ class L1D(Object):
+ a: int
+
+ @py_class(_unique_key("L2D"))
+ class L2D(L1D):
+ b: Optional[int] = None
+ c: Optional[str] = "hello"
+
+ @py_class(_unique_key("L3D"))
+ class L3D(L2D):
+ d: str
+
+ obj = L3D(a=1, d="world")
+ assert obj.a == 1
+ assert obj.b is None
+ assert obj.c == "hello"
+ assert obj.d == "world"
+
+ def test_derived_with_container_init(self) -> None:
+ @py_class(_unique_key("BaseC"))
+ class BaseC(Object):
+ items: List[int]
+
+ @py_class(_unique_key("DerivedC"))
+ class DerivedC(BaseC):
+ name: str
+
+ obj = DerivedC(items=[1, 2, 3], name="test")
+ assert len(obj.items) == 3
+ assert obj.name == "test"
+
+
+# ###########################################################################
+# 29. Decorator-level type validation
+# ###########################################################################
+class TestFieldTypeValidation:
+ """Type validation on set for @py_class fields."""
+
+ def test_set_int_to_str_raises(self) -> None:
+ @py_class(_unique_key("ValInt"))
+ class ValInt(Object):
+ x: int
+
+ obj = ValInt(x=1)
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.x = "not_an_int" # ty:ignore[invalid-assignment]
+
+ def test_set_str_to_int_raises(self) -> None:
+ @py_class(_unique_key("ValStr"))
+ class ValStr(Object):
+ x: str
+
+ obj = ValStr(x="hello")
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.x = 42 # ty:ignore[invalid-assignment]
+
+ def test_set_bool_to_str_raises(self) -> None:
+ @py_class(_unique_key("ValBool"))
+ class ValBool(Object):
+ x: bool
+
+ obj = ValBool(x=True)
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.x = "not_a_bool" # ty:ignore[invalid-assignment]
+
+ def test_set_list_to_wrong_type_raises(self) -> None:
+ @py_class(_unique_key("ValList"))
+ class ValList(Object):
+ items: List[int]
+
+ obj = ValList(items=[1, 2])
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.items = "not_a_list" # ty:ignore[invalid-assignment]
+
+ def test_set_dict_to_wrong_type_raises(self) -> None:
+ @py_class(_unique_key("ValDict"))
+ class ValDict(Object):
+ data: Dict[str, int]
+
+ obj = ValDict(data={"a": 1})
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "not_a_dict" # ty:ignore[invalid-assignment]
+
+
+# ###########################################################################
+# 30. Three-level inheritance with containers
+# ###########################################################################
+class TestDerivedDerivedContainers:
+ """Three-level inheritance with container fields and init reordering."""
+
+ def test_three_level_with_container_and_defaults(self) -> None:
+ @py_class(_unique_key("DD_L1"))
+ class L1(Object):
+ a: int
+ b: List[int] = field(default_factory=list)
+
+ @py_class(_unique_key("DD_L2"))
+ class L2(L1):
+ c: Optional[int] = None
+ d: Optional[str] = "hello"
+
+ @py_class(_unique_key("DD_L3"))
+ class L3(L2):
+ e: str
+
+ obj = L3(a=1, e="world", b=[1, 2])
+ assert obj.a == 1
+ assert tuple(obj.b) == (1, 2)
+ assert obj.c is None
+ assert obj.d == "hello"
+ assert obj.e == "world"
+
+ def test_three_level_positional_call(self) -> None:
+ @py_class(_unique_key("DD2_L1"))
+ class L1(Object):
+ a: int
+
+ @py_class(_unique_key("DD2_L2"))
+ class L2(L1):
+ b: List[int] = field(default_factory=list)
+
+ @py_class(_unique_key("DD2_L3"))
+ class L3(L2):
+ c: str
+
+ obj = L3(a=1, c="x")
+ assert obj.a == 1
+ assert len(obj.b) == 0
+ assert obj.c == "x"
+
+
+# ###########################################################################
+# 31. Container field mutation and type rejection
+# ###########################################################################
+class TestContainerFieldMutation:
+ """Container field set, mutation, and type rejection."""
+
+ def test_untyped_list_mutation(self) -> None:
+ obj = _make_multi_type_obj()
+ assert len(obj.list_any) == 3
+ assert obj.list_any[0] == 1
+ obj.list_any = [4, 3.0, "two"]
+ assert len(obj.list_any) == 3
+ assert obj.list_any[0] == 4
+ assert obj.list_any[2] == "two"
+
+ def test_untyped_dict_mutation(self) -> None:
+ obj = _make_multi_type_obj()
+ assert len(obj.dict_any) == 2
+ obj.dict_any = {"4": 4, "3": "two", "2": 3.0}
+ assert len(obj.dict_any) == 3
+ assert obj.dict_any["4"] == 4
+ assert obj.dict_any["3"] == "two"
+
+ def test_list_any_type_rejection(self) -> None:
+ @py_class(_unique_key("LAReject"))
+ class LAReject(Object):
+ items: List[Any]
+
+ obj = LAReject(items=[1, 2])
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.items = "wrong" # ty:ignore[invalid-assignment]
+
+ def test_list_list_int_type_rejection(self) -> None:
+ @py_class(_unique_key("LLReject"))
+ class LLReject(Object):
+ matrix: List[List[int]]
+
+ obj = LLReject(matrix=[[1, 2]])
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.matrix = [4, 3, 2, 1] # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.matrix = None # ty:ignore[invalid-assignment]
+ assert len(obj.matrix) == 1
+
+ def test_dict_any_any_type_rejection(self) -> None:
+ @py_class(_unique_key("DAAReject"))
+ class DAAReject(Object):
+ data: Dict[Any, Any]
+
+ obj = DAAReject(data={1: 2})
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = None # ty:ignore[invalid-assignment]
+
+ def test_dict_str_any_type_rejection(self) -> None:
+ @py_class(_unique_key("DSAReject"))
+ class DSAReject(Object):
+ data: Dict[str, Any]
+
+ obj = DSAReject(data={"a": 1})
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = None # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = {4: 4, 3.0: 3} # ty:ignore[invalid-assignment]
+
+ def test_dict_str_list_int_type_rejection(self) -> None:
+ @py_class(_unique_key("DSLReject"))
+ class DSLReject(Object):
+ data: Dict[str, List[int]]
+
+ obj = DSLReject(data={"a": [1, 2]})
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = None # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = {"a": 1, "b": [2]} # ty:ignore[invalid-assignment]
+
+
+# ###########################################################################
+# 32. Optional field set/unset cycles with type rejection
+# ###########################################################################
+class TestOptionalFieldCycles:
+ """Optional field set → None → set-back cycles with type rejection."""
+
+ def test_opt_func_type_rejection(self) -> None:
+ @py_class(_unique_key("OptFuncR"))
+ class OptFuncR(Object):
+ func: Optional[tvm_ffi.Function]
+
+ obj = OptFuncR(func=None)
+ assert obj.func is None
+ obj.func = tvm_ffi.convert(lambda x: x + 2)
+ assert obj.func(1) == 3
+ obj.func = None
+ assert obj.func is None
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.func = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.func = 42 # ty:ignore[invalid-assignment]
+
+ def test_opt_ulist_cycle(self) -> None:
+ @py_class(_unique_key("OptUListC"))
+ class OptUListC(Object):
+ items: Optional[list]
+
+ obj = OptUListC(items=None)
+ assert obj.items is None
+ obj.items = [4, 3.0, "two"]
+ assert len(obj.items) == 3
+ assert obj.items[0] == 4
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.items = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.items = 42 # ty:ignore[invalid-assignment]
+ obj.items = None
+ assert obj.items is None
+
+ def test_opt_udict_cycle(self) -> None:
+ @py_class(_unique_key("OptUDictC"))
+ class OptUDictC(Object):
+ data: Optional[dict]
+
+ obj = OptUDictC(data=None)
+ assert obj.data is None
+ obj.data = {"4": 4, "3": "two"}
+ assert len(obj.data) == 2
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ obj.data = None
+ assert obj.data is None
+
+ def test_opt_list_any_cycle(self) -> None:
+ @py_class(_unique_key("OptLAC"))
+ class OptLAC(Object):
+ items: Optional[List[Any]]
+
+ obj = OptLAC(items=[1, 2.0, "three"])
+ assert len(obj.items) == 3 # ty:ignore[invalid-argument-type]
+ obj.items = [4, 3.0, "two"]
+ assert obj.items[0] == 4
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.items = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.items = 42 # ty:ignore[invalid-assignment]
+ obj.items = None
+ assert obj.items is None
+
+ def test_opt_list_list_int_cycle(self) -> None:
+ @py_class(_unique_key("OptLLIC"))
+ class OptLLIC(Object):
+ matrix: Optional[List[List[int]]]
+
+ obj = OptLLIC(matrix=[[1, 2, 3], [4, 5, 6]])
+ assert tuple(obj.matrix[0]) == (1, 2, 3) #
ty:ignore[not-subscriptable]
+ obj.matrix = [[4, 3, 2]]
+ assert len(obj.matrix) == 1
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.matrix = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.matrix = [1, 2, 3] # ty:ignore[invalid-assignment]
+ obj.matrix = None
+ assert obj.matrix is None
+
+ def test_opt_dict_any_any_cycle(self) -> None:
+ @py_class(_unique_key("OptDAAC"))
+ class OptDAAC(Object):
+ data: Optional[Dict[Any, Any]]
+
+ obj = OptDAAC(data=None)
+ assert obj.data is None
+ obj.data = {4: 4, "three": "two"}
+ assert len(obj.data) == 2
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ obj.data = None
+ assert obj.data is None
+
+ def test_opt_dict_str_any_cycle(self) -> None:
+ @py_class(_unique_key("OptDSAC"))
+ class OptDSAC(Object):
+ data: Optional[Dict[str, Any]]
+
+ obj = OptDSAC(data={"a": 1})
+ assert obj.data["a"] == 1 # ty:ignore[not-subscriptable]
+ obj.data = {}
+ assert len(obj.data) == 0
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ obj.data = None
+ assert obj.data is None
+
+ def test_opt_dict_any_str_cycle(self) -> None:
+ @py_class(_unique_key("OptDASC"))
+ class OptDASC(Object):
+ data: Optional[Dict[Any, str]]
+
+ obj = OptDASC(data={1: "a", "two": "b"})
+ assert obj.data[1] == "a" # ty:ignore[not-subscriptable]
+ obj.data = {4: "4", "three": "two"}
+ assert len(obj.data) == 2
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ obj.data = None
+ assert obj.data is None
+
+ def test_opt_dict_str_list_int_cycle(self) -> None:
+ @py_class(_unique_key("OptDSLIC"))
+ class OptDSLIC(Object):
+ data: Optional[Dict[str, List[int]]]
+
+ obj = OptDSLIC(data={"1": [1, 2, 3], "2": [4, 5, 6]})
+ assert tuple(obj.data["1"]) == (1, 2, 3) #
ty:ignore[not-subscriptable]
+ obj.data = {"a": [7, 8]}
+ assert tuple(obj.data["a"]) == (7, 8)
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = "wrong" # ty:ignore[invalid-assignment]
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.data = 42 # ty:ignore[invalid-assignment]
+ obj.data = None
+ assert obj.data is None
+
+
+# ###########################################################################
+# 33. Multi-type field class
+# ###########################################################################
+@py_class(_unique_key("MultiType"))
+class _PyClassMultiType(Object):
+ """@py_class with many field types for cross-cutting field tests."""
+
+ bool_: bool
+ i64: int
+ f64: float
+ str_: str
+ any_val: Any
+ list_int: List[int]
+ list_any: list
+ dict_str_int: Dict[str, int]
+ dict_any: dict
+ list_list_int: List[List[int]]
+ dict_str_list_int: Dict[str, List[int]]
+ opt_bool: Optional[bool]
+ opt_int: Optional[int]
+ opt_float: Optional[float]
+ opt_str: Optional[str]
+ opt_list_int: Optional[List[int]]
+ opt_dict_str_int: Optional[Dict[str, int]]
+
+
+def _make_multi_type_obj() -> _PyClassMultiType:
+ return _PyClassMultiType(
+ bool_=False,
+ i64=64,
+ f64=2.5,
+ str_="world",
+ any_val="hello",
+ list_int=[1, 2, 3],
+ list_any=[1, "two", 3.0],
+ dict_str_int={"a": 1, "b": 2},
+ dict_any={"x": 1, "y": "two"},
+ list_list_int=[[1, 2, 3], [4, 5, 6]],
+ dict_str_list_int={"p": [1, 2], "q": [3, 4]},
+ opt_bool=True,
+ opt_int=-64,
+ opt_float=None,
+ opt_str=None,
+ opt_list_int=[10, 20],
+ opt_dict_str_int=None,
+ )
+
+
+class TestMultiTypeFieldOps:
+ """Per-field get/set/validation on a many-field @py_class type."""
+
+ def test_bool_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.bool_ is False
+ obj.bool_ = True
+ assert obj.bool_ is True
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.bool_ = "not_a_bool" # ty:ignore[invalid-assignment]
+
+ def test_int_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.i64 == 64
+ obj.i64 = -128
+ assert obj.i64 == -128
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.i64 = "wrong" # ty:ignore[invalid-assignment]
+
+ def test_float_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert abs(obj.f64 - 2.5) < 1e-10
+ obj.f64 = 5.0
+ assert abs(obj.f64 - 5.0) < 1e-10
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.f64 = "wrong" # ty:ignore[invalid-assignment]
+
+ def test_str_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.str_ == "world"
+ obj.str_ = "hello"
+ assert obj.str_ == "hello"
+
+ def test_any_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.any_val == "hello"
+ obj.any_val = 42
+ assert obj.any_val == 42
+ obj.any_val = [1, 2]
+ assert len(obj.any_val) == 2
+
+ def test_list_int_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert tuple(obj.list_int) == (1, 2, 3)
+ obj.list_int = [4, 5]
+ assert len(obj.list_int) == 2
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.list_int = "wrong" # ty:ignore[invalid-assignment]
+
+ def test_untyped_list_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert len(obj.list_any) == 3
+ assert obj.list_any[0] == 1
+ assert obj.list_any[1] == "two"
+ obj.list_any = [4, 3.0, "new"]
+ assert len(obj.list_any) == 3
+
+ def test_dict_str_int_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.dict_str_int["a"] == 1
+ obj.dict_str_int = {"c": 3}
+ assert obj.dict_str_int["c"] == 3
+ with pytest.raises((TypeError, RuntimeError)):
+ obj.dict_str_int = "wrong" # ty:ignore[invalid-assignment]
+
+ def test_untyped_dict_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert len(obj.dict_any) == 2
+ obj.dict_any = {"new": 42}
+ assert obj.dict_any["new"] == 42
+
+ def test_nested_list_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert tuple(obj.list_list_int[0]) == (1, 2, 3)
+ assert tuple(obj.list_list_int[1]) == (4, 5, 6)
+
+ def test_nested_dict_field(self) -> None:
+ obj = _make_multi_type_obj()
+ assert tuple(obj.dict_str_list_int["p"]) == (1, 2)
+ assert tuple(obj.dict_str_list_int["q"]) == (3, 4)
+
+ def test_optional_bool(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.opt_bool is True
+ obj.opt_bool = False
+ assert obj.opt_bool is False
+ obj.opt_bool = None
+ assert obj.opt_bool is None
+
+ def test_optional_int(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.opt_int == -64
+ obj.opt_int = None
+ assert obj.opt_int is None
+ obj.opt_int = 128
+ assert obj.opt_int == 128
+
+ def test_optional_float(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.opt_float is None
+ obj.opt_float = 1.5
+ assert abs(obj.opt_float - 1.5) < 1e-10
+ obj.opt_float = None
+ assert obj.opt_float is None
+
+ def test_optional_str(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.opt_str is None
+ obj.opt_str = "hello"
+ assert obj.opt_str == "hello"
+ obj.opt_str = None
+ assert obj.opt_str is None
+
+ def test_optional_list_int(self) -> None:
+ obj = _make_multi_type_obj()
+ assert tuple(obj.opt_list_int) == (10, 20) #
ty:ignore[invalid-argument-type]
+ obj.opt_list_int = None
+ assert obj.opt_list_int is None
+ obj.opt_list_int = [30]
+ assert len(obj.opt_list_int) == 1
+
+ def test_optional_dict_str_int(self) -> None:
+ obj = _make_multi_type_obj()
+ assert obj.opt_dict_str_int is None
+ obj.opt_dict_str_int = {"z": 99}
+ assert obj.opt_dict_str_int["z"] == 99
+ obj.opt_dict_str_int = None
+ assert obj.opt_dict_str_int is None
+
+
+class TestMultiTypeCopy:
+ """Copy with the comprehensive multi-type class."""
+
+ def test_shallow_copy_comprehensive(self) -> None:
+ obj = _make_multi_type_obj()
+ obj2 = copy.copy(obj)
+ assert obj2.bool_ == obj.bool_
+ assert obj2.i64 == obj.i64
+ assert obj2.f64 == obj.f64
+ assert obj2.str_ == obj.str_
+ assert obj2.any_val == obj.any_val
+ assert obj.list_int.same_as(obj2.list_int) #
ty:ignore[unresolved-attribute]
+ assert obj.dict_str_int.same_as(obj2.dict_str_int) #
ty:ignore[unresolved-attribute]
+
+ def test_deep_copy_comprehensive(self) -> None:
+ obj = _make_multi_type_obj()
+ obj2 = copy.deepcopy(obj)
+ assert obj2.bool_ == obj.bool_
+ assert obj2.i64 == obj.i64
+ assert not obj.list_int.same_as(obj2.list_int) #
ty:ignore[unresolved-attribute]
+ assert not obj.dict_str_int.same_as(obj2.dict_str_int) #
ty:ignore[unresolved-attribute]
+ assert tuple(obj2.list_int) == (1, 2, 3)
+ assert obj2.dict_str_int["a"] == 1
diff --git a/tests/python/test_dataclass_repr.py
b/tests/python/test_dataclass_repr.py
index 165ae13b..62e9aff7 100644
--- a/tests/python/test_dataclass_repr.py
+++ b/tests/python/test_dataclass_repr.py
@@ -16,6 +16,8 @@
# under the License.
"""Tests for __ffi_repr__ / ffi.ReprPrint."""
+from __future__ import annotations
+
import ast
import re
@@ -666,5 +668,65 @@ def
test_repr_unregistered_object_no_duplicate_field_names() -> None:
assert result.count("v1=") == 1
+# --------------------------------------------------------------------------- #
+# @py_class repr
+# --------------------------------------------------------------------------- #
+
+import itertools as _itertools_repr
+from typing import Optional as _Optional_repr
+
+from tvm_ffi.core import Object as _Object_repr
+from tvm_ffi.dataclasses import py_class as _py_class_repr
+
+_counter_repr = _itertools_repr.count()
+
+
+def _unique_key_repr(base: str) -> str:
+ return f"testing.repr_pc.{base}_{next(_counter_repr)}"
+
+
+def test_repr_py_class_base() -> None:
+ """Repr of a simple @py_class contains field names and values."""
+
+ @_py_class_repr(_unique_key_repr("ReprBase"))
+ class ReprBase(_Object_repr):
+ a: int
+ b: str
+
+ r = repr(ReprBase(a=1, b="hello"))
+ assert "a=1" in r or "a: 1" in r
+ assert "hello" in r
+
+
+def test_repr_py_class_derived() -> None:
+ """Repr of a derived @py_class shows all fields including parent."""
+
+ @_py_class_repr(_unique_key_repr("ReprP"))
+ class ReprP(_Object_repr):
+ base_a: int
+ base_b: str
+
+ @_py_class_repr(_unique_key_repr("ReprD"))
+ class ReprD(ReprP):
+ derived_a: float
+ derived_b: _Optional_repr[str] # noqa: UP045
+
+ r = repr(ReprD(base_a=1, base_b="b", derived_a=2.0, derived_b="c"))
+ assert "1" in r
+ assert "2" in r
+
+
+def test_repr_py_class_in_array() -> None:
+ """@py_class objects inside Array have proper repr."""
+
+ @_py_class_repr(_unique_key_repr("ReprInArr"))
+ class ReprInArr(_Object_repr):
+ x: int
+
+ r = repr(tvm_ffi.Array([ReprInArr(x=1), ReprInArr(x=2)]))
+ assert "1" in r
+ assert "2" in r
+
+
if __name__ == "__main__":
pytest.main([__file__, "-v"])
diff --git a/tests/python/test_structural_py_class.py
b/tests/python/test_structural_py_class.py
new file mode 100644
index 00000000..f280980d
--- /dev/null
+++ b/tests/python/test_structural_py_class.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for structural equality/hashing on py_class-defined types.
+
+Mirrors the C++ tests in tests/cpp/extra/test_structural_equal_hash.cc,
+porting the object-level tests (FreeVar, FuncDefAndIgnoreField, etc.)
+to Python using ``@py_class(structure=...)`` and ``field(structure=...)``.
+"""
+
+from __future__ import annotations
+
+import pytest
+import tvm_ffi
+from tvm_ffi import get_first_structural_mismatch, structural_equal,
structural_hash
+from tvm_ffi.dataclasses import field, py_class
+
+# ---------------------------------------------------------------------------
+# Type definitions (mirror testing_object.h)
+# ---------------------------------------------------------------------------
+
+
+@py_class("testing.py.Var", structure="var")
+class TVar(tvm_ffi.Object):
+ """Variable node — compared by binding position, not by name.
+
+ Mirrors C++ TVarObj with:
+ _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar
+ name field has SEqHashIgnore
+ """
+
+ name: str = field(structure="ignore")
+
+
+@py_class("testing.py.Int", structure="tree")
+class TInt(tvm_ffi.Object):
+ """Simple integer literal node."""
+
+ value: int
+
+
+@py_class("testing.py.Func", structure="tree")
+class TFunc(tvm_ffi.Object):
+ """Function node with definition region and ignored comment.
+
+ Mirrors C++ TFuncObj with:
+ params has SEqHashDef
+ comment has SEqHashIgnore
+ """
+
+ params: list = field(structure="def")
+ body: list
+ comment: str = field(structure="ignore", default="")
+
+
+@py_class("testing.py.Expr", structure="tree")
+class TExpr(tvm_ffi.Object):
+ """A simple expression node for tree-comparison tests."""
+
+ value: int
+
+
+@py_class("testing.py.Metadata", structure="const-tree")
+class TMetadata(tvm_ffi.Object):
+ """Immutable metadata node — pointer shortcut is safe (no var children)."""
+
+ tag: str
+ version: int
+
+
+@py_class("testing.py.Binding", structure="dag")
+class TBinding(tvm_ffi.Object):
+ """Binding node — sharing structure is semantically meaningful."""
+
+ name: str
+ value: int
+
+
+# ---------------------------------------------------------------------------
+# Tests: FreeVar (mirrors C++ FreeVar test)
+# ---------------------------------------------------------------------------
+
+
+class TestFreeVar:
+ """Test structure="var" kind (C++ kTVMFFISEqHashKindFreeVar)."""
+
+ def test_free_var_equal_with_mapping(self) -> None:
+ """Two different vars are equal when map_free_vars=True."""
+ a = TVar("a")
+ b = TVar("b")
+ assert structural_equal(a, b, map_free_vars=True)
+
+ def test_free_var_not_equal_without_mapping(self) -> None:
+ """Two different vars are NOT equal by default (no mapping)."""
+ a = TVar("a")
+ b = TVar("b")
+ assert not structural_equal(a, b)
+
+ def test_free_var_hash_differs_without_mapping(self) -> None:
+ """Without mapping, different vars produce different hashes."""
+ a = TVar("a")
+ b = TVar("b")
+ assert structural_hash(a) != structural_hash(b)
+
+ def test_free_var_hash_equal_with_mapping(self) -> None:
+ """With map_free_vars, positional hashing makes them equal."""
+ a = TVar("a")
+ b = TVar("b")
+ assert structural_hash(a, map_free_vars=True) == structural_hash(b,
map_free_vars=True)
+
+ def test_free_var_same_pointer(self) -> None:
+ """Same variable is always equal to itself."""
+ x = TVar("x")
+ assert structural_equal(x, x)
+
+ def test_free_var_name_ignored(self) -> None:
+ """The name field is structure="ignore", so it doesn't affect
comparison."""
+ a = TVar("different_name_a")
+ b = TVar("different_name_b")
+ # Names differ, but with mapping they are still equal
+ assert structural_equal(a, b, map_free_vars=True)
+
+
+# ---------------------------------------------------------------------------
+# Tests: FuncDefAndIgnoreField (mirrors C++ FuncDefAndIgnoreField test)
+# ---------------------------------------------------------------------------
+
+
+class TestFuncDefAndIgnore:
+ """Test structure="def" and structure="ignore" on fields."""
+
+ def test_alpha_equivalent_functions(self) -> None:
+ """fun(x){1, x} with comment_a == fun(y){1, y} with comment_b."""
+ x = TVar("x")
+ y = TVar("y")
+ fa = TFunc(params=[x], body=[TInt(1), x], comment="comment a")
+ fb = TFunc(params=[y], body=[TInt(1), y], comment="comment b")
+ assert structural_equal(fa, fb)
+ assert structural_hash(fa) == structural_hash(fb)
+
+ def test_different_body(self) -> None:
+ """fun(x){1, x} != fun(x){1, 2} — body differs at index 1."""
+ x = TVar("x")
+ fa = TFunc(params=[x], body=[TInt(1), x], comment="comment a")
+ fc = TFunc(params=[x], body=[TInt(1), TInt(2)], comment="comment c")
+ assert not structural_equal(fa, fc)
+
+ def test_mismatch_path(self) -> None:
+ """GetFirstMismatch reports the correct access path."""
+ x = TVar("x")
+ fa = TFunc(params=[x], body=[TInt(1), x])
+ fc = TFunc(params=[x], body=[TInt(1), TInt(2)])
+ mismatch = get_first_structural_mismatch(fa, fc)
+ assert mismatch is not None
+
+ def test_comment_ignored(self) -> None:
+ """Identical structure with different comments are equal."""
+ x = TVar("x")
+ f1 = TFunc(params=[x], body=[TInt(1)], comment="first")
+ f2 = TFunc(params=[x], body=[TInt(1)], comment="second")
+ assert structural_equal(f1, f2)
+ assert structural_hash(f1) == structural_hash(f2)
+
+ def test_inconsistent_var_usage(self) -> None:
+ """fun(x,y){x+x} != fun(a,b){a+b} — inconsistent variable mapping."""
+ x, y = TVar("x"), TVar("y")
+ a, b = TVar("a"), TVar("b")
+ f1 = TFunc(params=[x, y], body=[x, x])
+ f2 = TFunc(params=[a, b], body=[a, b])
+ assert not structural_equal(f1, f2)
+
+ def test_multi_param_alpha_equiv(self) -> None:
+ """fun(x,y){x, y} == fun(a,b){a, b} — consistent variable renaming."""
+ x, y = TVar("x"), TVar("y")
+ a, b = TVar("a"), TVar("b")
+ f1 = TFunc(params=[x, y], body=[x, y])
+ f2 = TFunc(params=[a, b], body=[a, b])
+ assert structural_equal(f1, f2)
+ assert structural_hash(f1) == structural_hash(f2)
+
+ def test_swapped_params_not_equal(self) -> None:
+ """fun(x,y){x, y} != fun(a,b){b, a} — reversed usage."""
+ x, y = TVar("x"), TVar("y")
+ a, b = TVar("a"), TVar("b")
+ f1 = TFunc(params=[x, y], body=[x, y])
+ f2 = TFunc(params=[a, b], body=[b, a])
+ assert not structural_equal(f1, f2)
+
+ def test_nested_functions(self) -> None:
+ """Nested function with inner binding — alpha-equivalence is scoped."""
+ x = TVar("x")
+ y = TVar("y")
+ inner_x = TFunc(params=[x], body=[x])
+ inner_y = TFunc(params=[y], body=[y])
+ outer_a = TFunc(params=[], body=[inner_x])
+ outer_b = TFunc(params=[], body=[inner_y])
+ assert structural_equal(outer_a, outer_b)
+ assert structural_hash(outer_a) == structural_hash(outer_b)
+
+
+# ---------------------------------------------------------------------------
+# Tests: tree kind basics
+# ---------------------------------------------------------------------------
+
+
+class TestTreeNode:
+ """Test structure="tree" kind."""
+
+ def test_equal_content(self) -> None:
+ """Two tree nodes with identical content are structurally equal."""
+ a = TExpr(value=42)
+ b = TExpr(value=42)
+ assert structural_equal(a, b)
+ assert structural_hash(a) == structural_hash(b)
+
+ def test_different_content(self) -> None:
+ """Two tree nodes with different content are not equal."""
+ a = TExpr(value=1)
+ b = TExpr(value=2)
+ assert not structural_equal(a, b)
+ assert structural_hash(a) != structural_hash(b)
+
+ def test_sharing_invisible(self) -> None:
+ """Under "tree", sharing doesn't affect equality."""
+ s = TExpr(value=10)
+ # Two arrays referencing the same object vs two copies
+ shared = tvm_ffi.Array([s, s])
+ copies = tvm_ffi.Array([TExpr(value=10), TExpr(value=10)])
+ assert structural_equal(shared, copies)
+ assert structural_hash(shared) == structural_hash(copies)
+
+
+# ---------------------------------------------------------------------------
+# Tests: const-tree kind
+# ---------------------------------------------------------------------------
+
+
+class TestConstTreeNode:
+ """Test structure="const-tree" kind."""
+
+ def test_equal_content(self) -> None:
+ """Two const-tree nodes with identical content are structurally
equal."""
+ a = TMetadata(tag="v1", version=1)
+ b = TMetadata(tag="v1", version=1)
+ assert structural_equal(a, b)
+ assert structural_hash(a) == structural_hash(b)
+
+ def test_different_content(self) -> None:
+ """Two const-tree nodes with different content are not equal."""
+ a = TMetadata(tag="v1", version=1)
+ b = TMetadata(tag="v1", version=2)
+ assert not structural_equal(a, b)
+
+ def test_same_pointer_shortcircuits(self) -> None:
+ """Same pointer should be equal (the const-tree optimization)."""
+ a = TMetadata(tag="test", version=1)
+ assert structural_equal(a, a)
+
+
+# ---------------------------------------------------------------------------
+# Tests: dag kind
+# ---------------------------------------------------------------------------
+
+
+class TestDAGNode:
+ """Test structure="dag" kind."""
+
+ def test_same_dag_shape(self) -> None:
+ """Two DAGs with the same sharing shape are equal."""
+ s1 = TBinding(name="s", value=1)
+ s2 = TBinding(name="s", value=1)
+ dag1 = tvm_ffi.Array([s1, s1]) # shared
+ dag2 = tvm_ffi.Array([s2, s2]) # shared (same shape)
+ assert structural_equal(dag1, dag2)
+
+ def test_dag_vs_tree_not_equal(self) -> None:
+ """A DAG (shared) vs tree (independent copies) are NOT equal."""
+ shared = TBinding(name="s", value=1)
+ copy1 = TBinding(name="s", value=1)
+ copy2 = TBinding(name="s", value=1)
+ dag = tvm_ffi.Array([shared, shared])
+ tree = tvm_ffi.Array([copy1, copy2])
+ assert not structural_equal(dag, tree)
+
+ def test_dag_vs_tree_hash_differs(self) -> None:
+ """DAG and tree with same content should hash differently."""
+ shared = TBinding(name="s", value=1)
+ copy1 = TBinding(name="s", value=1)
+ copy2 = TBinding(name="s", value=1)
+ dag = tvm_ffi.Array([shared, shared])
+ tree = tvm_ffi.Array([copy1, copy2])
+ assert structural_hash(dag) != structural_hash(tree)
+
+ def test_reverse_bijection(self) -> None:
+ """(a, b) vs (a, a) where a ≅ b — reverse map detects inconsistency."""
+ a = TBinding(name="a", value=1)
+ b = TBinding(name="b", value=1) # same content as a
+ lhs = tvm_ffi.Array([a, b])
+ rhs = tvm_ffi.Array([a, a]) # note: same object twice
+ assert not structural_equal(lhs, rhs)
+
+
+# ---------------------------------------------------------------------------
+# Tests: unsupported kind (default)
+# ---------------------------------------------------------------------------
+
+
+class TestUnsupported:
+ """Test that types without structure= raise on structural comparison."""
+
+ def test_unsupported_raises_on_hash(self) -> None:
+ """Structural hashing raises TypeError for types without structure=."""
+
+ @py_class("testing.py.Plain")
+ class Plain(tvm_ffi.Object):
+ x: int
+
+ with pytest.raises(TypeError):
+ structural_hash(Plain(x=1))
+
+ def test_unsupported_raises_on_equal(self) -> None:
+ """Structural equality raises TypeError for types without
structure=."""
+
+ @py_class("testing.py.Plain2")
+ class Plain2(tvm_ffi.Object):
+ x: int
+
+ with pytest.raises(TypeError):
+ structural_equal(Plain2(x=1), Plain2(x=1))
+
+
+# ---------------------------------------------------------------------------
+# Tests: validation
+# ---------------------------------------------------------------------------
+
+
+class TestValidation:
+ """Test that invalid structure= values are rejected."""
+
+ def test_invalid_type_structure(self) -> None:
+ """Invalid type-level structure= value raises ValueError."""
+ with pytest.raises(ValueError, match="structure"):
+ py_class(structure="invalid")
+
+ def test_invalid_field_structure(self) -> None:
+ """Invalid field-level structure= value raises ValueError."""
+ with pytest.raises(ValueError, match="structure"):
+ field(structure="bad_value")