This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4a60bb2532 [Relax][PyTorch] Add support for
`torch.ops.aten.sym_size.int` in ExportedProgram frontend (#18473)
4a60bb2532 is described below
commit 4a60bb253276aec01c3a77cf2cbfbc88b8a9fcb7
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Nov 21 14:46:57 2025 +0900
[Relax][PyTorch] Add support for `torch.ops.aten.sym_size.int` in
ExportedProgram frontend (#18473)
As per title.
cc @tlopex
---
.../frontend/torch/base_fx_graph_translator.py | 6 +++++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 6 -----
.../relax/test_frontend_from_exported_program.py | 31 ++++++++++++++++++++++
4 files changed, 38 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 753b0d7914..ed7811dd71 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2357,6 +2357,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
x = self.env[node.args[0]]
return self.block_builder.emit(relax.op.take(x, relax.const(0,
"int64"), axis=0))
+ def _sym_size_int(self, node: fx.Node) -> relax.Expr:
+ x = self.env[node.args[0]]
+ shape = self.shape_of(x)
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+ return self.block_builder.emit(relax.const(int(shape[dim]), "int32"))
+
def _zeros_inplace(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
output = self.block_builder.emit(relax.op.zeros_like(x))
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a2b9b2afa4..782c14e91c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1189,6 +1189,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# other
"getitem": self._getitem,
"item.default": self._item,
+ "sym_size.int": self._sym_size_int,
"_local_scalar_dense.default": self._item,
}
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a93f788669..6bf164430a 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -730,12 +730,6 @@ class TorchFXImporter(BaseFXGraphImporter):
return self.shape_of(self.env[node.args[0]])
return getattr(self.env[node.args[0]], node.args[1])
- def _sym_size_int(self, node: fx.Node) -> relax.Expr:
- x = self.env[node.args[0]]
- shape = self.shape_of(x)
- idx = node.args[1]
- return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
-
def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) ->
List[relax.Var]:
inputs = list()
for idx, (shape, dtype) in enumerate(input_info):
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 1429dec5e7..60a9120445 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7508,5 +7508,36 @@ def test_dynamic_shape_with_multiplication_constraints():
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+def test_sym_size_int():
+ class SymSizeInt(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ # TODO(@mshr-h): `torch.ops.aten.sym_size.int(x, self.dim)` would
be ideal, but currently
+ # the ep frontend is not able to handle it.
+ return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.take(
+ x, R.const(0, "int64"), axis=0, mode="fast"
+ )
+ lv1: R.Tensor((3, 4), dtype="float32") = R.add(lv,
R.const(3.0, "float32"))
+ gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 4),)
+ verify_model(SymSizeInt(dim=1), example_args, {}, Expected)
+ verify_model(SymSizeInt(dim=-2), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()