This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e3c5b47eda [Relax][PyTorch] Unify tests using shared verify_model
(#18517)
e3c5b47eda is described below
commit e3c5b47eda74f081957edc4beaa50bb5e0146a60
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Nov 27 17:50:11 2025 +0800
[Relax][PyTorch] Unify tests using shared verify_model (#18517)
## Why
We have the shared verify func in tests and to use it in every tests
could help persist consistency
---
.../relax/test_frontend_from_exported_program.py | 70 ++++++++++++++--------
1 file changed, 46 insertions(+), 24 deletions(-)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 98c6c6d014..93218190fc 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -32,14 +32,29 @@ from tvm.relax.frontend.torch import from_exported_program
def verify_model(
- torch_model, example_args, binding, expected, dynamic_shapes=None,
run_ep_decomposition=True
+ torch_model,
+ example_args,
+ binding,
+ expected,
+ dynamic_shapes=None,
+ run_ep_decomposition=True,
+ keep_params_as_input=False,
+ unwrap_unit_return_tuple=False,
+ no_bind_return_tuple=False,
+ map_free_vars=False,
):
exported_program = export(torch_model, args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program,
run_ep_decomposition=run_ep_decomposition)
+ mod = from_exported_program(
+ exported_program,
+ run_ep_decomposition=run_ep_decomposition,
+ keep_params_as_input=keep_params_as_input,
+ unwrap_unit_return_tuple=unwrap_unit_return_tuple,
+ no_bind_return_tuple=no_bind_return_tuple,
+ )
binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
- tvm.ir.assert_structural_equal(mod, expected)
+ tvm.ir.assert_structural_equal(mod, expected, map_free_vars=map_free_vars)
operator_basic_unary = [
@@ -6282,6 +6297,7 @@ def test_keep_params():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
model = Conv2D1()
+
exported_program = torch.export.export(model, example_args)
mod = from_exported_program(exported_program, keep_params_as_input=True)
mod, params = detach_params(mod)
@@ -6318,9 +6334,7 @@ def test_unwrap_unit_return_tuple():
return gv
example_args = (torch.randn(256, 256, dtype=torch.float32),)
- exported_program = export(Identity(), args=example_args)
- mod = from_exported_program(exported_program,
unwrap_unit_return_tuple=True)
- tvm.ir.assert_structural_equal(mod, Expected)
+ verify_model(Identity(), example_args, {}, Expected,
unwrap_unit_return_tuple=True)
def test_no_bind_return_tuple():
@@ -6348,9 +6362,7 @@ def test_no_bind_return_tuple():
torch.randn(256, 256, dtype=torch.float32),
torch.randn(256, 256, dtype=torch.float32),
)
- exported_program = export(Identity(), args=example_args)
- mod = from_exported_program(exported_program, no_bind_return_tuple=True)
- tvm.ir.assert_structural_equal(mod, Expected)
+ verify_model(Identity(), example_args, {}, Expected,
no_bind_return_tuple=True)
def test_empty_like():
@@ -7839,10 +7851,15 @@ def test_dynamic_shape_with_range_constraints():
example_args = (torch.randn(8, 4), torch.randn(8, 4))
batch = torch.export.Dim("batch", min=1, max=64)
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
- exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program)
- tvm.ir.assert_structural_equal(mod, Expected)
+ verify_model(
+ DynamicModel(),
+ example_args,
+ {},
+ Expected,
+ dynamic_shapes=dynamic_shapes,
+ map_free_vars=True,
+ )
def test_dynamic_shape_with_addition_constraints():
@@ -7873,10 +7890,10 @@ def test_dynamic_shape_with_addition_constraints():
batch = torch.export.Dim("batch", min=1, max=64)
example_args = (torch.randn(8, 4), torch.randn(9, 4))
dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
- exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program)
- tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+ verify_model(
+ ConcatModel(), example_args, {}, Expected,
dynamic_shapes=dynamic_shapes, map_free_vars=True
+ )
def test_dynamic_shape_with_subtraction_constraints():
@@ -7907,10 +7924,10 @@ def test_dynamic_shape_with_subtraction_constraints():
batch = torch.export.Dim("batch", min=1, max=64)
example_args = (torch.randn(8, 4), torch.randn(7, 4))
dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
- exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program)
- tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+ verify_model(
+ ConcatModel(), example_args, {}, Expected,
dynamic_shapes=dynamic_shapes, map_free_vars=True
+ )
def test_dynamic_shape_with_multiplication_constraints():
@@ -7941,10 +7958,10 @@ def
test_dynamic_shape_with_multiplication_constraints():
batch = torch.export.Dim("batch", min=1, max=64)
example_args = (torch.randn(8, 4), torch.randn(16, 4))
dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
- exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program)
- tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+ verify_model(
+ ConcatModel(), example_args, {}, Expected,
dynamic_shapes=dynamic_shapes, map_free_vars=True
+ )
def test_dynamic_shape_with_unbounded_constraints():
@@ -7969,10 +7986,15 @@ def test_dynamic_shape_with_unbounded_constraints():
example_args = (torch.randn(8, 4),)
batch = torch.export.Dim("batch", min=2)
dynamic_shapes = {"x": {0: batch}}
- exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
- mod = from_exported_program(exported_program)
- tvm.ir.assert_structural_equal(mod, Expected)
+ verify_model(
+ DynamicModel(),
+ example_args,
+ {},
+ Expected,
+ dynamic_shapes=dynamic_shapes,
+ map_free_vars=True,
+ )
def test_sym_size_int():