This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e3c5b47eda [Relax][PyTorch] Unify tests using shared verify_model 
(#18517)
e3c5b47eda is described below

commit e3c5b47eda74f081957edc4beaa50bb5e0146a60
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Nov 27 17:50:11 2025 +0800

    [Relax][PyTorch] Unify tests using shared verify_model (#18517)
    
    ## Why
    
    We have the shared verify func in tests and to use it in every tests
    could help persist consistency
---
 .../relax/test_frontend_from_exported_program.py   | 70 ++++++++++++++--------
 1 file changed, 46 insertions(+), 24 deletions(-)

diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 98c6c6d014..93218190fc 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -32,14 +32,29 @@ from tvm.relax.frontend.torch import from_exported_program
 
 
 def verify_model(
-    torch_model, example_args, binding, expected, dynamic_shapes=None, 
run_ep_decomposition=True
+    torch_model,
+    example_args,
+    binding,
+    expected,
+    dynamic_shapes=None,
+    run_ep_decomposition=True,
+    keep_params_as_input=False,
+    unwrap_unit_return_tuple=False,
+    no_bind_return_tuple=False,
+    map_free_vars=False,
 ):
     exported_program = export(torch_model, args=example_args, 
dynamic_shapes=dynamic_shapes)
-    mod = from_exported_program(exported_program, 
run_ep_decomposition=run_ep_decomposition)
+    mod = from_exported_program(
+        exported_program,
+        run_ep_decomposition=run_ep_decomposition,
+        keep_params_as_input=keep_params_as_input,
+        unwrap_unit_return_tuple=unwrap_unit_return_tuple,
+        no_bind_return_tuple=no_bind_return_tuple,
+    )
 
     binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
     expected = relax.transform.BindParams("main", binding)(expected)
-    tvm.ir.assert_structural_equal(mod, expected)
+    tvm.ir.assert_structural_equal(mod, expected, map_free_vars=map_free_vars)
 
 
 operator_basic_unary = [
@@ -6282,6 +6297,7 @@ def test_keep_params():
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
     model = Conv2D1()
+
     exported_program = torch.export.export(model, example_args)
     mod = from_exported_program(exported_program, keep_params_as_input=True)
     mod, params = detach_params(mod)
@@ -6318,9 +6334,7 @@ def test_unwrap_unit_return_tuple():
             return gv
 
     example_args = (torch.randn(256, 256, dtype=torch.float32),)
-    exported_program = export(Identity(), args=example_args)
-    mod = from_exported_program(exported_program, 
unwrap_unit_return_tuple=True)
-    tvm.ir.assert_structural_equal(mod, Expected)
+    verify_model(Identity(), example_args, {}, Expected, 
unwrap_unit_return_tuple=True)
 
 
 def test_no_bind_return_tuple():
@@ -6348,9 +6362,7 @@ def test_no_bind_return_tuple():
         torch.randn(256, 256, dtype=torch.float32),
         torch.randn(256, 256, dtype=torch.float32),
     )
-    exported_program = export(Identity(), args=example_args)
-    mod = from_exported_program(exported_program, no_bind_return_tuple=True)
-    tvm.ir.assert_structural_equal(mod, Expected)
+    verify_model(Identity(), example_args, {}, Expected, 
no_bind_return_tuple=True)
 
 
 def test_empty_like():
@@ -7839,10 +7851,15 @@ def test_dynamic_shape_with_range_constraints():
     example_args = (torch.randn(8, 4), torch.randn(8, 4))
     batch = torch.export.Dim("batch", min=1, max=64)
     dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
-    exported_program = export(DynamicModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program)
-    tvm.ir.assert_structural_equal(mod, Expected)
+    verify_model(
+        DynamicModel(),
+        example_args,
+        {},
+        Expected,
+        dynamic_shapes=dynamic_shapes,
+        map_free_vars=True,
+    )
 
 
 def test_dynamic_shape_with_addition_constraints():
@@ -7873,10 +7890,10 @@ def test_dynamic_shape_with_addition_constraints():
     batch = torch.export.Dim("batch", min=1, max=64)
     example_args = (torch.randn(8, 4), torch.randn(9, 4))
     dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
-    exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program)
-    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+    verify_model(
+        ConcatModel(), example_args, {}, Expected, 
dynamic_shapes=dynamic_shapes, map_free_vars=True
+    )
 
 
 def test_dynamic_shape_with_subtraction_constraints():
@@ -7907,10 +7924,10 @@ def test_dynamic_shape_with_subtraction_constraints():
     batch = torch.export.Dim("batch", min=1, max=64)
     example_args = (torch.randn(8, 4), torch.randn(7, 4))
     dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
-    exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program)
-    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+    verify_model(
+        ConcatModel(), example_args, {}, Expected, 
dynamic_shapes=dynamic_shapes, map_free_vars=True
+    )
 
 
 def test_dynamic_shape_with_multiplication_constraints():
@@ -7941,10 +7958,10 @@ def 
test_dynamic_shape_with_multiplication_constraints():
     batch = torch.export.Dim("batch", min=1, max=64)
     example_args = (torch.randn(8, 4), torch.randn(16, 4))
     dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
-    exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program)
-    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+    verify_model(
+        ConcatModel(), example_args, {}, Expected, 
dynamic_shapes=dynamic_shapes, map_free_vars=True
+    )
 
 
 def test_dynamic_shape_with_unbounded_constraints():
@@ -7969,10 +7986,15 @@ def test_dynamic_shape_with_unbounded_constraints():
     example_args = (torch.randn(8, 4),)
     batch = torch.export.Dim("batch", min=2)
     dynamic_shapes = {"x": {0: batch}}
-    exported_program = export(DynamicModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
 
-    mod = from_exported_program(exported_program)
-    tvm.ir.assert_structural_equal(mod, Expected)
+    verify_model(
+        DynamicModel(),
+        example_args,
+        {},
+        Expected,
+        dynamic_shapes=dynamic_shapes,
+        map_free_vars=True,
+    )
 
 
 def test_sym_size_int():

Reply via email to