This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf8a907bbf [Relax][PyTorch] Add dynamic shape support to
`torch.ops.aten.sym_size.int` in ExportedProgram frontend (#18485)
bf8a907bbf is described below
commit bf8a907bbffa911946a077a47b779dacc07fa2d8
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Nov 22 02:15:13 2025 +0900
[Relax][PyTorch] Add dynamic shape support to `torch.ops.aten.sym_size.int`
in ExportedProgram frontend (#18485)
As per title.
cc @tlopex
---
.../frontend/torch/base_fx_graph_translator.py | 6 +++-
.../relax/test_frontend_from_exported_program.py | 36 +++++++++++++++++++---
2 files changed, 37 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index ed7811dd71..d2c888cdd1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2361,7 +2361,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
x = self.env[node.args[0]]
shape = self.shape_of(x)
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
- return self.block_builder.emit(relax.const(int(shape[dim]), "int32"))
+
+ shape_dim = shape[dim]
+ if hasattr(shape_dim, "value"):
+ return self.block_builder.emit(relax.const(shape_dim.value,
dtype="int32"))
+ return shape_dim
def _zeros_inplace(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 60a9120445..fcf131965c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7520,7 +7520,7 @@ def test_sym_size_int():
return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim))
@I.ir_module
- class Expected:
+ class Expected1:
@R.function
def main(
x: R.Tensor((1, 3, 4), dtype="float32")
@@ -7534,9 +7534,37 @@ def test_sym_size_int():
R.output(gv)
return gv
- example_args = (torch.randn(1, 3, 4),)
- verify_model(SymSizeInt(dim=1), example_args, {}, Expected)
- verify_model(SymSizeInt(dim=-2), example_args, {}, Expected)
+ example_args_1 = (torch.randn(1, 3, 4),)
+ verify_model(SymSizeInt(dim=1), example_args_1, {}, Expected1)
+ verify_model(SymSizeInt(dim=-2), example_args_1, {}, Expected1)
+
+ class SymSizeIntDynamic(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ shape_dim = torch.ops.aten.sym_size.int(x, self.dim)
+ return x.reshape(shape_dim, -1)
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor(("s0", 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")):
+ s0 = T.int64(is_size_var=True)
+ with R.dataflow():
+ lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x,
R.shape([s0, 12]))
+ gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args_2 = (torch.randn(2, 3, 4),)
+ dynamic_shapes = {"x": {0: torch.export.Dim("dim")}}
+ verify_model(
+ SymSizeIntDynamic(dim=0), example_args_2, {}, Expected2,
dynamic_shapes=dynamic_shapes
+ )
if __name__ == "__main__":