This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch tvm-direct-tvm-ffi-structural-apis in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 3c42a3ef0d6de31a198be492a26bb151dfb38715 Author: Tianqi Chen <[email protected]> AuthorDate: Wed Jun 3 14:10:34 2026 +0000 [FFI][REFACTOR] Direct structural APIs to tvm-ffi Python callers should reach the canonical tvm-ffi structural helpers directly instead of going through a TVM-side redirect layer. This makes the public tvm.ir bindings exact aliases of the tvm_ffi APIs and exposes get_first_structural_mismatch from tvm.ir. Main changes: - Import structural_equal, get_first_structural_mismatch, and structural_hash directly from tvm_ffi - Remove the pure wrappers from tvm.ir.base while keeping assert_structural_equal's TVM-specific formatting - Update mismatch tests and add identity coverage for the direct bindings --- python/tvm/ir/__init__.py | 10 +- python/tvm/ir/base.py | 119 +-------------------- tests/python/ir/test_container_structural_equal.py | 8 +- .../tirx-base/test_tir_structural_equal_hash.py | 4 +- 4 files changed, 17 insertions(+), 124 deletions(-) diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 50073a942a..7d89256c92 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -29,8 +29,6 @@ from .base import ( assert_structural_equal, load_json, save_json, - structural_equal, - structural_hash, ) from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelaxExpr from .function import BaseFunc, CallingConv @@ -39,4 +37,10 @@ from .module import IRModule from .op import Op, register_intrin_lowering, register_op_attr from .type import FuncType, PointerType, PrimType, TupleType, Type -from tvm_ffi import Array, Map +from tvm_ffi import ( + Array, + Map, + get_first_structural_mismatch, + structural_equal, + structural_hash, +) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index cff43bb8c1..cfd4857db4 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -160,81 +160,6 @@ def save_json(node) -> str: return _ffi_node_api.SaveJSON(node) -def structural_equal(lhs, rhs, map_free_vars=False): - """Check structural equality of lhs and rhs. - - The structural equality is recursively defined in the DAG of IRNodes. - There are two kinds of nodes: - - - Graph node: a graph node in lhs can only be mapped as equal to - one and only one graph node in rhs. - - Normal node: equality is recursively defined without the restriction - of graph nodes. - - Vars(tirx::Var, relax::Var) are graph nodes. - - A var-type node(e.g. tirx::Var) can be mapped as equal to another var - with the same type if one of the following condition holds: - - - They appear in a same definition point(e.g. function argument). - - They points to the same VarNode via the same_as relation. - - They appear in a same usage point, and map_free_vars is set to be True. - - The rules for var are used to remap variables occurs in function - arguments and let-bindings. - - Parameters - ---------- - lhs : Object - The left operand. - - rhs : Object - The left operand. - - map_free_vars : bool - Whether free variables (i.e. variables without a definition site) should be mapped - as equal to each other. - - Return - ------ - result : bool - The comparison result. - - See Also - -------- - structural_hash - assert_strucural_equal - """ - return tvm_ffi.structural_equal(lhs, rhs, map_free_vars) - - -def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_tensor_content=False): - """Like structural_equal(), but returns the AccessPath pair of the first detected mismatch. - - Parameters - ---------- - lhs : Object - The left operand. - - rhs : Object - The left operand. - - map_free_vars : bool - Whether free variables (i.e. variables without a definition site) should be mapped - as equal to each other. - - skip_tensor_content : bool - Whether to skip the content of ndarray. - - Returns - ------- - mismatch: Optional[Tuple[AccessPath, AccessPath]] - `None` if `lhs` and `rhs` are structurally equal. - Otherwise, a tuple of two AccessPath objects that point to the first detected mismtach. - """ - return tvm_ffi.get_first_structural_mismatch(lhs, rhs, map_free_vars, skip_tensor_content) - - def assert_structural_equal(lhs, rhs, map_free_vars=False): """Assert lhs and rhs are structurally equal to each other. @@ -256,7 +181,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): See Also -------- - structural_equal + tvm.ir.structural_equal """ first_mismatch = tvm_ffi.get_first_structural_mismatch(lhs, rhs, map_free_vars) if first_mismatch is not None: @@ -276,48 +201,6 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): ) -def structural_hash(node, map_free_vars=False): - """Compute structural hash of node - - The structural hash value is recursively defined in the DAG of IRNodes. - There are two kinds of nodes: - - - Normal node: the hash value is defined by its content and type only. - - Graph node: each graph node will be assigned a unique index ordered by the - first occurrence during the visit. The hash value of a graph node is - combined from the hash values of its contents and the index. - - structural_hash is made to be concistent with structural_equal. - If two nodes are structurally equal to each other, - then their structural hash (with the same map_free_vars option) - should be equal to each other as well. - - If the structural hash of two nodes equals to each other, - then it is highly likely(except for rare hash value collison cases) - that the two nodes are structurally equal to each other. - - Parameters - ---------- - node : Object - The input to be hashed. - - map_free_vars : bool - If map_free_vars is set to true, we will hash free variables - by the order of their occurrences. Otherwise, we will hash by - their in-memory pointer address. - - Return - ------ - result : int - The hash result - - See Also - -------- - structrual_equal - """ - return tvm_ffi.structural_hash(node, map_free_vars) - - def deprecated( method_name: str, new_method_name: str, diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 1d9d575af8..9a8a6182ce 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -20,7 +20,7 @@ from tvm_ffi.access_path import AccessPath import tvm import tvm.testing -from tvm.ir.base import get_first_structural_mismatch +from tvm.ir import get_first_structural_mismatch def get_first_mismatch_ensure_symmetry(a, b): @@ -49,6 +49,12 @@ def get_first_mismatch_ensure_symmetry(a, b): return mismatch +def test_structural_api_redirects_to_tvm_ffi(): + assert tvm.ir.structural_equal is tvm_ffi.structural_equal + assert tvm.ir.get_first_structural_mismatch is tvm_ffi.get_first_structural_mismatch + assert tvm.ir.structural_hash is tvm_ffi.structural_hash + + @pytest.mark.parametrize( "a, b, expected_a_path, expected_b_path", [ diff --git a/tests/python/tirx-base/test_tir_structural_equal_hash.py b/tests/python/tirx-base/test_tir_structural_equal_hash.py index 1efef38e3f..6e470de75d 100644 --- a/tests/python/tirx-base/test_tir_structural_equal_hash.py +++ b/tests/python/tirx-base/test_tir_structural_equal_hash.py @@ -45,8 +45,8 @@ def consistent_equal(x, y, map_free_vars=False): def get_sequal_mismatch(x, y, map_free_vars=False): - mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars) - mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars) + mismatch_0 = tvm.ir.get_first_structural_mismatch(x, y, map_free_vars) + mismatch_1 = tvm.ir.get_first_structural_mismatch(y, x, map_free_vars) if mismatch_0 is None and mismatch_1 is None: return None
