This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ca61bbf6d [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(4) (#18414)
5ca61bbf6d is described below

commit 5ca61bbf6dee9f938e629802f4c395b078b441ce
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Nov 2 20:27:24 2025 -0500

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(4) (#18414)
---
 .../frontend/torch/exported_program_translator.py  |   4 +
 .../relax/test_frontend_from_exported_program.py   | 154 +++++++++++++--------
 2 files changed, 100 insertions(+), 58 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 48ae002c05..3be255a29a 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1003,6 +1003,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "flip.default": self._flip,
             "gather.default": self._gather,
             "index.Tensor": self._index_tensor,
+            "index_put.default": self._index_put,
             "index_put_.default": self._index_put,
             "meshgrid.indexing": self._meshgrid,
             "meshgrid.default": self._meshgrid,
@@ -1041,6 +1042,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "contiguous.default": lambda node: self.env[node.args[0]],  # no-op
             "clone.default": lambda node: self.env[node.args[0]],
             "bernoulli.p": lambda node: self.env[node.args[0]],  # Dropout: 
just return input
+            "_assert_tensor_metadata.default": lambda node: self.env[
+                node.args[0]
+            ],  # metadata assertion: no-op
             "empty.memory_format": self._empty,
             "empty_permuted.default": self._empty,  # Similar to empty with 
permuted layout
             "empty_like.default": self._empty_like,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 019d649558..9f63743faa 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5222,17 +5222,17 @@ def test_empty_like():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5,), dtype="float32"),
+            data: R.Tensor((5,), dtype="float32"),
         ) -> R.Tuple(R.Tensor((5,), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, 
dtype="void")
+                lv: R.Tensor((5,), dtype="float32") = R.zeros(R.shape([5]), 
dtype="float32")
                 gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(5, dtype=torch.float32),)
 
-    verify_model(EmptyLike(), example_args, {}, Expected)
+    verify_model(EmptyLike(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_one_hot():
@@ -5244,19 +5244,22 @@ def test_one_hot():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5,), dtype="int64"),
+            indices: R.Tensor((5,), dtype="int64"),
         ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
             with R.dataflow():
-                lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
-                    inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
+                lv: R.Tensor((10,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(10), R.prim_value(1), 
dtype="int64"
                 )
-                gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,)
+                lv1: R.Tensor((5, 1), dtype="int64") = R.expand_dims(indices, 
axis=[-1])
+                lv2: R.Tensor((5, 10), dtype="bool") = R.equal(lv1, lv)
+                lv3: R.Tensor((5, 10), dtype="int64") = R.astype(lv2, 
dtype="int64")
+                gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv3,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
 
-    verify_model(OneHot(), example_args, {}, Expected)
+    verify_model(OneHot(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_ones_like():
@@ -5271,14 +5274,16 @@ def test_ones_like():
             input: R.Tensor((128, 128), dtype="float32")
         ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, 
dtype="void")
+                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+                    input, R.const(1, "int32"), dtype="void"
+                )
                 gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
 
-    verify_model(OnesLike(), example_args, {}, Expected)
+    verify_model(OnesLike(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_zero_inplace():
@@ -5291,16 +5296,23 @@ def test_zero_inplace():
         @R.function
         def main(
             input: R.Tensor((128, 128), dtype="float32")
-        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 
128), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((128, 128), dtype="float32") = 
R.zeros_like(input, dtype="void")
-                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
+                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+                    input, R.const(0, "int32"), dtype="void"
+                )
+                gv: R.Tuple(
+                    R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 
128), dtype="float32")
+                ) = (
+                    lv,
+                    lv,
+                )
                 R.output(gv)
             return gv
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
 
-    verify_model(ZeroInplace(), example_args, {}, Expected)
+    verify_model(ZeroInplace(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_zeros():
@@ -5315,14 +5327,16 @@ def test_zeros():
             input: R.Tensor((128, 128), dtype="float32")
         ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 
2]), dtype="float32")
+                lv: R.Tensor((5, 2), dtype="float32") = R.full(
+                    R.shape([5, 2]), R.const(0.0, "float32"), dtype="float32"
+                )
                 gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
 
-    verify_model(Zeros(), example_args, {}, Expected)
+    verify_model(Zeros(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_zeros_like():
@@ -5337,13 +5351,15 @@ def test_zeros_like():
             input: R.Tensor((128, 128), dtype="float32")
         ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((128, 128), dtype="float32") = 
R.zeros_like(input, dtype="void")
+                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+                    input, R.const(0, "int32"), dtype="void"
+                )
                 gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.rand(128, 128, dtype=torch.float32),)
-    verify_model(ZerosLike(), example_args, {}, Expected)
+    verify_model(ZerosLike(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_type_as():
@@ -5369,7 +5385,7 @@ def test_type_as():
         torch.rand(128, 128, dtype=torch.float16),
     )
 
-    verify_model(TypeAs(), example_args, {}, Expected)
+    verify_model(TypeAs(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_select():
@@ -5391,7 +5407,7 @@ def test_select():
 
     example_args = (torch.randn(2, 3, dtype=torch.float32),)
 
-    verify_model(Select(), example_args, {}, Expected)
+    verify_model(Select(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_unflatten():
@@ -5417,8 +5433,8 @@ def test_unflatten():
 
     example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)
 
-    verify_model(Unflatten(), example_args, {}, Expected)
-    verify_model(Unflatten1(), example_args, {}, Expected)
+    verify_model(Unflatten(), example_args, {}, Expected, 
run_ep_decomposition=True)
+    verify_model(Unflatten1(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_gather():
@@ -5495,10 +5511,10 @@ def test_gather():
         torch.randint(0, 3, (2, 3), dtype=torch.int64),
     )
 
-    verify_model(Gather0(), example_args, {}, Expected0)
-    verify_model(Gather1(), example_args, {}, Expected1)
-    verify_model(Gather2(), example_args, {}, Expected2)
-    verify_model(Gather3(), example_args, {}, Expected3)
+    verify_model(Gather0(), example_args, {}, Expected0, 
run_ep_decomposition=True)
+    verify_model(Gather1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Gather2(), example_args, {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Gather3(), example_args, {}, Expected3, 
run_ep_decomposition=True)
 
 
 def test_index_put():
@@ -5521,12 +5537,15 @@ def test_index_put():
             data: R.Tensor((64,), dtype="float32"),
             indices_0: R.Tensor((128,), dtype="int64"),
             values: R.Tensor((128,), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((64,), dtype="float32")):
+        ) -> R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,), 
dtype="float32")):
             with R.dataflow():
                 lv: R.Tensor((64,), dtype="float32") = R.index_put(
                     data, R.tuple(indices_0), values, accumulate=False
                 )
-                gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,)
+                gv: R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,), 
dtype="float32")) = (
+                    lv,
+                    lv,
+                )
                 R.output(gv)
             return gv
 
@@ -5551,12 +5570,14 @@ def test_index_put():
             indices_0: R.Tensor((128,), dtype="int64"),
             indices_1: R.Tensor((128,), dtype="int64"),
             values: R.Tensor((128,), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")):
+        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), 
dtype="float32")):
             with R.dataflow():
                 lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
                     data, R.tuple(indices_0, indices_1), values, 
accumulate=False
                 )
-                gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,)
+                gv: R.Tuple(
+                    R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), 
dtype="float32")
+                ) = (lv, lv)
                 R.output(gv)
             return gv
 
@@ -5583,12 +5604,16 @@ def test_index_put():
             indices_1: R.Tensor((128,), dtype="int64"),
             indices_2: R.Tensor((128,), dtype="int64"),
             values: R.Tensor((128,), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")):
+        ) -> R.Tuple(
+            R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), 
dtype="float32")
+        ):
             with R.dataflow():
                 lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
                     data, R.tuple(indices_0, indices_1, indices_2), values, 
accumulate=False
                 )
-                gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,)
+                gv: R.Tuple(
+                    R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 
64), dtype="float32")
+                ) = (lv, lv)
                 R.output(gv)
             return gv
 
@@ -5617,7 +5642,10 @@ def test_index_put():
             indices_2: R.Tensor((128,), dtype="int64"),
             indices_3: R.Tensor((128,), dtype="int64"),
             values: R.Tensor((128,), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")):
+        ) -> R.Tuple(
+            R.Tensor((8, 16, 32, 64), dtype="float32"),
+            R.Tensor((8, 16, 32, 64), dtype="float32"),
+        ):
             with R.dataflow():
                 lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
                     data,
@@ -5625,7 +5653,10 @@ def test_index_put():
                     values,
                     accumulate=False,
                 )
-                gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,)
+                gv: R.Tuple(
+                    R.Tensor((8, 16, 32, 64), dtype="float32"),
+                    R.Tensor((8, 16, 32, 64), dtype="float32"),
+                ) = (lv, lv)
                 R.output(gv)
             return gv
 
@@ -5656,7 +5687,10 @@ def test_index_put():
             indices_3: R.Tensor((128,), dtype="int64"),
             indices_4: R.Tensor((128,), dtype="int64"),
             values: R.Tensor((128,), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")):
+        ) -> R.Tuple(
+            R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+            R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+        ):
             with R.dataflow():
                 lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = 
R.index_put(
                     data,
@@ -5664,16 +5698,19 @@ def test_index_put():
                     values,
                     accumulate=False,
                 )
-                gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = 
(lv,)
+                gv: R.Tuple(
+                    R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+                    R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+                ) = (lv, lv)
                 R.output(gv)
             return gv
 
     # Run verification for each case
-    verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
-    verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
-    verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
-    verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
-    verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
+    verify_model(IndexPut1D(), example_args_1d, {}, Expected1D, 
run_ep_decomposition=True)
+    verify_model(IndexPut2D(), example_args_2d, {}, Expected2D, 
run_ep_decomposition=True)
+    verify_model(IndexPut3D(), example_args_3d, {}, Expected3D, 
run_ep_decomposition=True)
+    verify_model(IndexPut4D(), example_args_4d, {}, Expected4D, 
run_ep_decomposition=True)
+    verify_model(IndexPut5D(), example_args_5d, {}, Expected5D, 
run_ep_decomposition=True)
 
 
 def test_flip():
@@ -5711,8 +5748,8 @@ def test_flip():
 
     example_args = (torch.randn(2, 2, dtype=torch.float32),)
 
-    verify_model(Flip0(), example_args, {}, Expected0)
-    verify_model(Flip1(), example_args, {}, Expected1)
+    verify_model(Flip0(), example_args, {}, Expected0, 
run_ep_decomposition=True)
+    verify_model(Flip1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
 
 
 def test_take():
@@ -5724,12 +5761,12 @@ def test_take():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5,), dtype="float32"),
-            inp_1: R.Tensor((3,), dtype="int64"),
+            data: R.Tensor((5,), dtype="float32"),
+            indices: R.Tensor((3,), dtype="int64"),
         ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, 
dtype="int32")
-                lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, 
axis=None)
+                lv: R.Tensor((5,), dtype="float32") = R.reshape(data, 
R.shape([5]))
+                lv1: R.Tensor((3,), dtype="float32") = R.index_tensor(lv, 
(indices,))
                 gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
@@ -5739,7 +5776,7 @@ def test_take():
         torch.randint(0, 5, (3,), dtype=torch.int64),
     )
 
-    verify_model(Take(), example_args, {}, Expected)
+    verify_model(Take(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_std():
@@ -5751,16 +5788,17 @@ def test_std():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5, 3), dtype="float32"),
+            x: R.Tensor((5, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, 
keepdims=False)
-                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, 
keepdims=False)
+                lv1: R.Tensor((), dtype="float32") = R.sqrt(lv)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Std(), example_args, {}, Expected)
+    verify_model(Std(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_var():
@@ -5772,16 +5810,16 @@ def test_var():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5, 3), dtype="float32"),
+            x: R.Tensor((5, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((), dtype="float32") = R.variance(inp_0, 
axis=None, keepdims=False)
+                lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, 
keepdims=False)
                 gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Var(), example_args, {}, Expected)
+    verify_model(Var(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_prod():
@@ -5793,16 +5831,16 @@ def test_prod():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5, 3), dtype="float32"),
+            x: R.Tensor((5, 3), dtype="float32"),
         ) -> R.Tuple(R.Tensor((), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, 
keepdims=False)
+                lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None, 
keepdims=False)
                 gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Prod(), example_args, {}, Expected)
+    verify_model(Prod(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_cumprod():

Reply via email to