This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 03d55dfe1a [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(5) (#18417)
03d55dfe1a is described below

commit 03d55dfe1a29a100a320e27d2d93d7001d9d260e
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Nov 3 23:11:58 2025 -0500

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(5) (#18417)
    
    * f1
    
    * f2
    
    * f3
    
    * f5
    
    * f7
---
 .../relax/test_frontend_from_exported_program.py   | 120 +++++++++++----------
 1 file changed, 62 insertions(+), 58 deletions(-)

diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 9f63743faa..8a9fe66a0f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4580,12 +4580,13 @@ def test_stack():
     class Expected0:
         @R.function
         def main(
-            inp_0: R.Tensor((2, 3), dtype="float32"),
-            inp_1: R.Tensor((2, 3), dtype="float32"),
+            x: R.Tensor((2, 3), dtype="float32"),
+            y: R.Tensor((2, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=0)
-                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), 
axis=0)
+                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, 
R.shape([2, 2, 3]))
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
@@ -4593,12 +4594,13 @@ def test_stack():
     class Expected1:
         @R.function
         def main(
-            inp_0: R.Tensor((2, 3), dtype="float32"),
-            inp_1: R.Tensor((2, 3), dtype="float32"),
+            x: R.Tensor((2, 3), dtype="float32"),
+            y: R.Tensor((2, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=1)
-                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y), 
axis=1)
+                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, 
R.shape([2, 2, 3]))
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
@@ -4606,21 +4608,23 @@ def test_stack():
     class Expected3:
         @R.function
         def main(
-            inp_0: R.Tensor((2, 3), dtype="float32"),
-            inp_1: R.Tensor((2, 3), dtype="float32"),
+            x: R.Tensor((2, 3), dtype="float32"),
+            y: R.Tensor((2, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0, 
inp_1), axis=-1)
-                gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,)
+                lv: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(x, 
axis=[2])
+                lv1: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(y, 
axis=[2])
+                lv2: R.Tensor((2, 3, 2), dtype="float32") = R.concat((lv, 
lv1), axis=-1)
+                gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
 
-    verify_model(Stack0(), example_args, {}, Expected0)
-    verify_model(Stack1(), example_args, {}, Expected1)
-    verify_model(Stack2(), example_args, {}, Expected1)
-    verify_model(Stack3(), example_args, {}, Expected3)
+    verify_model(Stack0(), example_args, {}, Expected0, 
run_ep_decomposition=True)
+    verify_model(Stack1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Stack2(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Stack3(), example_args, {}, Expected3, 
run_ep_decomposition=True)
 
 
 def test_tile():
@@ -4644,7 +4648,7 @@ def test_tile():
         ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2])
+                lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, repeats=[1, 
2])
                 gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
@@ -4657,15 +4661,15 @@ def test_tile():
         ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
+                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, repeats=[4, 
2])
                 gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, dtype=torch.float32),)
-    verify_model(Tile1(), example_args, {}, expected1)
-    verify_model(Tile2(), example_args, {}, expected2)
-    verify_model(Tile3(), example_args, {}, expected2)
+    verify_model(Tile1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Tile2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(Tile3(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_transpose():
@@ -4687,7 +4691,7 @@ def test_transpose():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Transpose(), example_args, {}, expected1)
+    verify_model(Transpose(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_unsqueeze():
@@ -4727,8 +4731,8 @@ def test_unsqueeze():
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
 
-    verify_model(Unsqueeze1(), example_args, {}, expected1)
-    verify_model(Unsqueeze2(), example_args, {}, expected2)
+    verify_model(Unsqueeze1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Unsqueeze2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_view():
@@ -4750,7 +4754,7 @@ def test_view():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(View(), example_args, {}, expected1)
+    verify_model(View(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_arange():
@@ -4771,7 +4775,7 @@ def test_arange():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Arange(), example_args, {}, Expected)
+    verify_model(Arange(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_hamming_window():
@@ -4798,7 +4802,7 @@ def test_hamming_window():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(HammingWindow(), example_args, {}, Expected)
+    verify_model(HammingWindow(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_contiguous():
@@ -4818,7 +4822,7 @@ def test_contiguous():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Contiguous(), example_args, {}, Expected)
+    verify_model(Contiguous(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_clone():
@@ -4838,7 +4842,7 @@ def test_clone():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Clone(), example_args, {}, Expected)
+    verify_model(Clone(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_empty():
@@ -4850,7 +4854,7 @@ def test_empty():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((10, 10), dtype="float32")
+            input: R.Tensor((10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
             with R.dataflow():
                 lv: R.Tensor((10, 10), dtype="float32") = R.zeros(
@@ -4861,7 +4865,7 @@ def test_empty():
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Empty(), example_args, {}, Expected)
+    verify_model(Empty(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_fill():
@@ -4873,18 +4877,18 @@ def test_fill():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((10, 10), dtype="float32")
+            input: R.Tensor((10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.full(
-                    R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32"
+                lv: R.Tensor((10, 10), dtype="float32") = R.full_like(
+                    input, R.const(1.5, "float32"), dtype="void"
                 )
                 gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(10, 10, dtype=torch.float32),)
-    verify_model(Fill(), example_args, {}, Expected)
+    verify_model(Fill(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_fill_inplace():
@@ -4897,18 +4901,20 @@ def test_fill_inplace():
     class Expected:
         @R.function
         def main(
-            x: R.Tensor((2, 3), dtype="float32")
-        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            input: R.Tensor((2, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), 
dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((2, 3), dtype="float32") = R.full(
-                    R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32"
+                lv: R.Tensor((2, 3), dtype="float32") = R.full_like(
+                    input, R.const(42.0, "float32"), dtype="void"
                 )
-                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
+                gv: R.Tuple(
+                    R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), 
dtype="float32")
+                ) = (lv, lv)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(2, 3, dtype=torch.float32),)
-    verify_model(FillInplace(), example_args, {}, Expected)
+    verify_model(FillInplace(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_masked_fill():
@@ -4923,16 +4929,14 @@ def test_masked_fill():
             input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 
128), dtype="bool")
         ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
-                    input, R.const(0, "int32"), dtype="void"
-                )
+                lv: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                 lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, 
input)
                 gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(128, 128, dtype=torch.float32), 
torch.rand(128, 128) < 0.5)
-    verify_model(Masked_Fill(), example_args, {}, Expected)
+    verify_model(Masked_Fill(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_masked_fill_inplace():
@@ -4945,18 +4949,18 @@ def test_masked_fill_inplace():
         @R.function
         def main(
             input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 
128), dtype="bool")
-        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 
128), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
-                    input, R.const(1.5, "float32"), dtype="void"
-                )
+                lv: R.Tensor((), dtype="float32") = R.const(1.5, "float32")
                 lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, 
input)
-                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
+                gv: R.Tuple(
+                    R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 
128), dtype="float32")
+                ) = (lv1, lv1)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(128, 128, dtype=torch.float32), 
torch.rand(128, 128) < 0.5)
-    verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
+    verify_model(Masked_Fill_Inplace(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_new_ones():
@@ -4980,7 +4984,7 @@ def test_new_ones():
             return gv
 
     example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
-    verify_model(NewOnes(), example_args, {}, expected1)
+    verify_model(NewOnes(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_new_zeros():
@@ -5003,7 +5007,7 @@ def test_new_zeros():
             return gv
 
     example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
-    verify_model(NewZeros(), example_args, {}, expected1)
+    verify_model(NewZeros(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_to_copy():
@@ -5094,11 +5098,11 @@ def test_to_copy():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(ToFloat(), example_args, {}, expected_float)
-    verify_model(ToHalf(), example_args, {}, expected_half)
-    verify_model(Type(), example_args, {}, expected_type)
-    verify_model(To1(), example_args, {}, expected_to1)
-    verify_model(To2(), example_args, {}, expected_to2)
+    verify_model(ToFloat(), example_args, {}, expected_float, 
run_ep_decomposition=True)
+    verify_model(ToHalf(), example_args, {}, expected_half, 
run_ep_decomposition=True)
+    verify_model(Type(), example_args, {}, expected_type, 
run_ep_decomposition=True)
+    verify_model(To1(), example_args, {}, expected_to1, 
run_ep_decomposition=True)
+    verify_model(To2(), example_args, {}, expected_to2, 
run_ep_decomposition=True)
 
 
 def test_keep_params():

Reply via email to