This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9249061ad8 [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(3) (#18410)
9249061ad8 is described below

commit 9249061ad80f5ab7d06ff9d259dbdc4b190b7e6c
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Nov 1 10:10:57 2025 -0400

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(3) (#18410)
    
    * finish1
    
    * finish2
    
    * finish3
---
 .../frontend/torch/base_fx_graph_translator.py     |   3 +
 .../frontend/torch/exported_program_translator.py  |   2 +
 .../relax/test_frontend_from_exported_program.py   | 122 ++++++++++++++++-----
 3 files changed, 97 insertions(+), 30 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index b17f62738f..aedef8acf8 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1722,6 +1722,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
     def _squeeze(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        # Support both "dim" and "dims" parameters
+        if dim is None:
+            dim = node.kwargs.get("dims", None)
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
     def _stack(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5bb7a9ea8b..48ae002c05 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1018,6 +1018,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "split_with_sizes.default": self._split,
             "squeeze.default": self._squeeze,
             "squeeze.dim": self._squeeze,
+            "squeeze.dims": self._squeeze,
             "stack.default": self._stack,
             "take.default": self._take,
             "tile.default": self._tile,
@@ -1075,6 +1076,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # other
             "getitem": self._getitem,
             "item.default": self._item,
+            "_local_scalar_dense.default": self._item,
         }
 
     def create_input_vars(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ac36c3fe8f..019d649558 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5823,7 +5823,7 @@ def test_cumprod():
             return gv
 
     example_input = torch.randn(5, 3, dtype=torch.float32)
-    verify_model(Cumprod(), (example_input,), {}, Expected)
+    verify_model(Cumprod(), (example_input,), {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_where():
@@ -5849,7 +5849,7 @@ def test_where():
     x = torch.randn(5, 3, dtype=torch.float32)
     y = torch.randn(5, 3, dtype=torch.float32)
 
-    verify_model(Where(), (condition, x, y), {}, Expected)
+    verify_model(Where(), (condition, x, y), {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_bucketize():
@@ -5874,7 +5874,7 @@ def test_bucketize():
     input_tensor = torch.arange(0, 20)
     boundaries = torch.arange(0, 20, 2)
 
-    verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)
+    verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_argsort():
@@ -5890,12 +5890,18 @@ def test_argsort():
                 lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
                     x, axis=1, descending=True, dtype="int32"
                 )
-                gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,)
+                lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, 
lv, axis=1)
+                lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 
3), dtype="int32")) = (
+                    lv1,
+                    lv,
+                )
+                lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
+                gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Argsort(), example_args, {}, Expected)
+    verify_model(Argsort(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_topk():
@@ -5923,7 +5929,7 @@ def test_topk():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Topk(), example_args, {}, Expected)
+    verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_dynamic_shape():
@@ -5972,7 +5978,7 @@ def test_broadcast_to():
             return gv
 
     example_args = (torch.randn(5, 1, dtype=torch.float32),)
-    verify_model(BroadcastTo(), example_args, {}, Expected)
+    verify_model(BroadcastTo(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_narrow():
@@ -5992,6 +5998,7 @@ def test_narrow():
                     (R.prim_value(1),),
                     (R.prim_value(0),),
                     (R.prim_value(2),),
+                    (R.prim_value(1),),
                     assume_inbound=False,
                 )
                 gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
@@ -6000,7 +6007,7 @@ def test_narrow():
             return gv
 
     example_args = (torch.randn(5, 3, dtype=torch.float32),)
-    verify_model(Narrow(), example_args, {}, Expected)
+    verify_model(Narrow(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_item():
@@ -6019,7 +6026,7 @@ def test_item():
             return gv
 
     example_args = (torch.randn(1, dtype=torch.float32),)
-    verify_model(Item(), example_args, {}, Expected)
+    verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True)
 
 
 def test_norm():
@@ -6131,7 +6138,9 @@ def test_norm():
     example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)
 
     for (p, dim, keepdim), expected in norms:
-        verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, 
expected)
+        verify_model(
+            Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected, 
run_ep_decomposition=True
+        )
 
 
 def test_eye():
@@ -6146,8 +6155,20 @@ def test_eye():
             input: R.Tensor((3, 5), dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, 
dtype="float32")
-                gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
+                lv: R.Tensor((3,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(3), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((5,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(5), R.prim_value(1), 
dtype="int64"
+                )
+                lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv, 
axis=[-1])
+                lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1)
+                lv4: R.Tensor((1,), dtype="float32") = R.full(
+                    R.shape([1]), R.const(1.0, "float32"), dtype="float32"
+                )
+                lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+                lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5)
+                gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,)
                 R.output(gv)
             return gv
 
@@ -6162,16 +6183,28 @@ def test_eye():
             input: R.Tensor((5,), dtype="float32")
         ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, 
dtype="float32")
-                gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
+                lv: R.Tensor((5,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(5), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((5,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(5), R.prim_value(1), 
dtype="int64"
+                )
+                lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv, 
axis=[-1])
+                lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1)
+                lv4: R.Tensor((1,), dtype="float32") = R.full(
+                    R.shape([1]), R.const(1.0, "float32"), dtype="float32"
+                )
+                lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+                lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5)
+                gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,)
                 R.output(gv)
             return gv
 
     example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
-    verify_model(Eye1(), example_args1, {}, Expected1)
+    verify_model(Eye1(), example_args1, {}, Expected1, 
run_ep_decomposition=True)
 
     example_args2 = (torch.randn(5, dtype=torch.float32),)
-    verify_model(Eye2(), example_args2, {}, Expected2)
+    verify_model(Eye2(), example_args2, {}, Expected2, 
run_ep_decomposition=True)
 
 
 def test_cross_entropy():
@@ -6187,21 +6220,39 @@ def test_cross_entropy():
     @tvm.script.ir_module
     class Expected1:
         @R.function
-        def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32")):
+        def main(x: R.Tensor((4, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((4,), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, 
axis=-1)
-                lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
-                    lv,
-                    targets=R.const([0, 1, 2, 1], dtype="int64"),
-                    reduction="mean",
-                    ignore_index=-100,
+                lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, 
dtype="float32")
+                lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, 
axis=1)
+                lv2: R.Tensor((4,), dtype="bool") = R.not_equal(
+                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
+                )
+                lv3: R.Tensor((), dtype="int64") = R.const(0, "int64")
+                lv4: R.Tensor((4,), dtype="int64") = R.where(
+                    lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3
                 )
-                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+                lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4, 
axis=[1])
+                lv6: R.Tensor((4, 1), dtype="float32") = 
R.gather_elements(lv1, lv5, axis=1)
+                lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1])
+                lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7)
+                lv9: R.Tensor((4,), dtype="bool") = R.not_equal(
+                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
+                )
+                lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+                lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10)
+                lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
+                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, 
"int64")
+                )
+                lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], 
keepdims=False)
+                lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, 
dtype="float32")
+                lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], 
keepdims=False)
+                lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
+                gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
                 R.output(gv)
             return gv
 
     example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
-    verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
+    verify_model(CrossEntropyModule(), example_args1, {}, Expected1, 
run_ep_decomposition=True)
 
 
 def test_linspace():
@@ -6216,13 +6267,24 @@ def test_linspace():
             input: R.Tensor((9, 9), dtype="float32")
         ) -> R.Tuple(R.Tensor((9,), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 
0.125, dtype="float32")
-                gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
+                lv: R.Tensor((9,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(9), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, 
"int64"))
+                lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, 
dtype="float32")
+                lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, 
R.const(0.125, "float32"))
+                lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, 
"float32"))
+                lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, 
"int64"), lv)
+                lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, 
dtype="float32")
+                lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, 
R.const(0.125, "float32"))
+                lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, 
"float32"), lv7)
+                lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
+                gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(9, 9, dtype=torch.float32),)
-    verify_model(Linspace(), example_args, {}, Expected)
+    verify_model(Linspace(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 @pytest.mark.parametrize(
@@ -6259,7 +6321,7 @@ def test_dtypes(torch_dtype, relax_dtype):
                 R.output(gv)
             return gv
 
-    verify_model(Model(), example_args, {}, Expected)
+    verify_model(Model(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_mm():
@@ -6285,7 +6347,7 @@ def test_mm():
                 R.output(gv)
             return gv
 
-    verify_model(MatrixMultiply(), example_args, {}, Expected)
+    verify_model(MatrixMultiply(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
 
 def test_lstm():

Reply via email to