This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 52fe358fa2 [Relax][PyTorch] Support dynamic shapes in ExportedProgram
frontend (#17817)
52fe358fa2 is described below
commit 52fe358fa240a9239ee168810d7d33f16c4f62a3
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Apr 9 18:50:56 2025 +0900
[Relax][PyTorch] Support dynamic shapes in ExportedProgram frontend (#17817)
support dynamic shape
---
.../frontend/torch/exported_program_translator.py | 18 ++++++++++---
.../relax/test_frontend_from_exported_program.py | 31 ++++++++++++++++++++--
2 files changed, 43 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c05858fd88..5e38d2ff6c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -447,24 +447,34 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
+ torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
+
for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
if spec.kind is
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
- shape = exported_program.tensor_constants[spec.target].shape
+ torch_shape =
exported_program.tensor_constants[spec.target].shape
torch_dtype =
exported_program.tensor_constants[spec.target].dtype
elif spec.kind is
torch.export.graph_signature.InputKind.USER_INPUT:
for node in
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
if node.name == name_hint and "tensor_meta" in node.meta:
- shape = node.meta["tensor_meta"].shape
+ torch_shape = node.meta["tensor_meta"].shape
torch_dtype = node.meta["tensor_meta"].dtype
break
else:
# PARAMETER or BUFFER
- shape = exported_program.state_dict[spec.target].shape
+ torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype
+ # TODO(mshr-h): Support range constraints
+ relax_shape = [
+ torch_symbol_to_relax_var.setdefault(str(s),
tvm.tir.SizeVar(str(s), "int64"))
+ if isinstance(s, torch.SymInt)
+ else s
+ for s in torch_shape
+ ]
dtype = self._convert_data_type(torch_dtype)
- relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape,
dtype))
+
+ relax_var = relax.Var(name_hint,
relax.TensorStructInfo(relax_shape, dtype))
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
user_inputs[name_hint] = relax_var
else:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index dd4ead9e59..a3c939fcb6 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -32,8 +32,8 @@ from packaging import version
torch_version = torch.__version__
-def verify_model(torch_model, example_args, binding, expected):
- exported_program = export(torch_model, args=example_args)
+def verify_model(torch_model, example_args, binding, expected,
dynamic_shapes=None):
+ exported_program = export(torch_model, args=example_args,
dynamic_shapes=dynamic_shapes)
mod = from_exported_program(exported_program)
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
@@ -3961,5 +3961,32 @@ def test_topk():
verify_model(Topk(), example_args, {}, Expected)
+def test_dynamic_shape():
+ class DynamicModel(torch.nn.Module):
+ def forward(self, x1, x2):
+ return torch.ops.aten.add.Tensor(x1, x2)
+
+ B = tvm.tir.SizeVar("BatchSize", dtype="int64")
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ lhs: R.Tensor((B, 4), dtype="float32"),
+ rhs: R.Tensor((B, 4), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs)
+ gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 4), torch.randn(2, 4))
+ batch = torch.export.Dim("batch")
+ dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
+
+ verify_model(DynamicModel(), example_args, {}, Expected,
dynamic_shapes=dynamic_shapes)
+
+
if __name__ == "__main__":
tvm.testing.main()