This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a60bb2532 [Relax][PyTorch] Add support for 
`torch.ops.aten.sym_size.int` in ExportedProgram frontend (#18473)
4a60bb2532 is described below

commit 4a60bb253276aec01c3a77cf2cbfbc88b8a9fcb7
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Nov 21 14:46:57 2025 +0900

    [Relax][PyTorch] Add support for `torch.ops.aten.sym_size.int` in 
ExportedProgram frontend (#18473)
    
    As per title.
    
    cc @tlopex
---
 .../frontend/torch/base_fx_graph_translator.py     |  6 +++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  6 -----
 .../relax/test_frontend_from_exported_program.py   | 31 ++++++++++++++++++++++
 4 files changed, 38 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 753b0d7914..ed7811dd71 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2357,6 +2357,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         x = self.env[node.args[0]]
         return self.block_builder.emit(relax.op.take(x, relax.const(0, 
"int64"), axis=0))
 
+    def _sym_size_int(self, node: fx.Node) -> relax.Expr:
+        x = self.env[node.args[0]]
+        shape = self.shape_of(x)
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+        return self.block_builder.emit(relax.const(int(shape[dim]), "int32"))
+
     def _zeros_inplace(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         output = self.block_builder.emit(relax.op.zeros_like(x))
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a2b9b2afa4..782c14e91c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1189,6 +1189,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # other
             "getitem": self._getitem,
             "item.default": self._item,
+            "sym_size.int": self._sym_size_int,
             "_local_scalar_dense.default": self._item,
         }
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a93f788669..6bf164430a 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -730,12 +730,6 @@ class TorchFXImporter(BaseFXGraphImporter):
                 return self.shape_of(self.env[node.args[0]])
         return getattr(self.env[node.args[0]], node.args[1])
 
-    def _sym_size_int(self, node: fx.Node) -> relax.Expr:
-        x = self.env[node.args[0]]
-        shape = self.shape_of(x)
-        idx = node.args[1]
-        return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
-
     def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> 
List[relax.Var]:
         inputs = list()
         for idx, (shape, dtype) in enumerate(input_info):
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 1429dec5e7..60a9120445 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7508,5 +7508,36 @@ def test_dynamic_shape_with_multiplication_constraints():
     tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
 
 
+def test_sym_size_int():
+    class SymSizeInt(Module):
+        def __init__(self, dim):
+            super().__init__()
+            self.dim = dim
+
+        def forward(self, x):
+            # TODO(@mshr-h): `torch.ops.aten.sym_size.int(x, self.dim)` would 
be ideal, but currently
+            # the ep frontend is not able to handle it.
+            return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.take(
+                    x, R.const(0, "int64"), axis=0, mode="fast"
+                )
+                lv1: R.Tensor((3, 4), dtype="float32") = R.add(lv, 
R.const(3.0, "float32"))
+                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 4),)
+    verify_model(SymSizeInt(dim=1), example_args, {}, Expected)
+    verify_model(SymSizeInt(dim=-2), example_args, {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to