This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 52fe358fa2 [Relax][PyTorch] Support dynamic shapes in ExportedProgram 
frontend (#17817)
52fe358fa2 is described below

commit 52fe358fa240a9239ee168810d7d33f16c4f62a3
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Apr 9 18:50:56 2025 +0900

    [Relax][PyTorch] Support dynamic shapes in ExportedProgram frontend (#17817)
    
    support dynamic shape
---
 .../frontend/torch/exported_program_translator.py  | 18 ++++++++++---
 .../relax/test_frontend_from_exported_program.py   | 31 ++++++++++++++++++++--
 2 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c05858fd88..5e38d2ff6c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -447,24 +447,34 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         """Create relax input vars."""
         parameters_buffers_constants = OrderedDict()
         user_inputs = OrderedDict()
+        torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
+
         for spec in exported_program.graph_signature.input_specs:
             name_hint = spec.arg.name
             if spec.kind is 
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
-                shape = exported_program.tensor_constants[spec.target].shape
+                torch_shape = 
exported_program.tensor_constants[spec.target].shape
                 torch_dtype = 
exported_program.tensor_constants[spec.target].dtype
             elif spec.kind is 
torch.export.graph_signature.InputKind.USER_INPUT:
                 for node in 
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
                     if node.name == name_hint and "tensor_meta" in node.meta:
-                        shape = node.meta["tensor_meta"].shape
+                        torch_shape = node.meta["tensor_meta"].shape
                         torch_dtype = node.meta["tensor_meta"].dtype
                         break
             else:
                 # PARAMETER or BUFFER
-                shape = exported_program.state_dict[spec.target].shape
+                torch_shape = exported_program.state_dict[spec.target].shape
                 torch_dtype = exported_program.state_dict[spec.target].dtype
 
+            # TODO(mshr-h): Support range constraints
+            relax_shape = [
+                torch_symbol_to_relax_var.setdefault(str(s), 
tvm.tir.SizeVar(str(s), "int64"))
+                if isinstance(s, torch.SymInt)
+                else s
+                for s in torch_shape
+            ]
             dtype = self._convert_data_type(torch_dtype)
-            relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, 
dtype))
+
+            relax_var = relax.Var(name_hint, 
relax.TensorStructInfo(relax_shape, dtype))
             if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
                 user_inputs[name_hint] = relax_var
             else:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index dd4ead9e59..a3c939fcb6 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -32,8 +32,8 @@ from packaging import version
 torch_version = torch.__version__
 
 
-def verify_model(torch_model, example_args, binding, expected):
-    exported_program = export(torch_model, args=example_args)
+def verify_model(torch_model, example_args, binding, expected, 
dynamic_shapes=None):
+    exported_program = export(torch_model, args=example_args, 
dynamic_shapes=dynamic_shapes)
     mod = from_exported_program(exported_program)
 
     binding = {k: tvm.nd.array(v) for k, v in binding.items()}
@@ -3961,5 +3961,32 @@ def test_topk():
     verify_model(Topk(), example_args, {}, Expected)
 
 
+def test_dynamic_shape():
+    class DynamicModel(torch.nn.Module):
+        def forward(self, x1, x2):
+            return torch.ops.aten.add.Tensor(x1, x2)
+
+    B = tvm.tir.SizeVar("BatchSize", dtype="int64")
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((B, 4), dtype="float32"),
+            rhs: R.Tensor((B, 4), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs)
+                gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 4), torch.randn(2, 4))
+    batch = torch.export.Dim("batch")
+    dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
+
+    verify_model(DynamicModel(), example_args, {}, Expected, 
dynamic_shapes=dynamic_shapes)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to