This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 49973d1fc4 [Relax][PyTorch] Support advanced range constraints 
(multiplication) (#18463)
49973d1fc4 is described below

commit 49973d1fc4dda847feff7e8f35b30c1db5c68b87
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Nov 18 18:39:22 2025 +0800

    [Relax][PyTorch] Support advanced range constraints (multiplication) 
(#18463)
    
    ## Related Issue
    
    - https://github.com/apache/tvm/issues/17818
    
    ## Why
    
    - Add support for multiplication expressions (e.g., s0 * 2) in PyTorch
    dynamic shape constraints
    
    ## How
    
    - Parse `SymPy` multiplication expressions from PyTorch's
    range_constraints
---
 .../frontend/torch/exported_program_translator.py  | 35 ++++++++---
 .../relax/test_frontend_from_exported_program.py   | 70 +++++++++++++++++++++-
 2 files changed, 96 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3b982b6b46..6aa118ee5c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1191,7 +1191,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         if isinstance(symbol, sympy.Symbol):
             return str(symbol), None
 
-        if not isinstance(symbol, sympy.Add):
+        if not isinstance(symbol, (sympy.Add, sympy.Mul)):
             return str(symbol), None
 
         tir_expr = None
@@ -1207,13 +1207,24 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
             if term is None:
                 return str(symbol), None
-            tir_expr = term if tir_expr is None else tir_expr + term
+
+            if tir_expr is None:
+                tir_expr = term
+            elif isinstance(symbol, sympy.Mul):
+                tir_expr = tir_expr * term
+            elif isinstance(symbol, sympy.Add):
+                tir_expr = tir_expr + term
 
         if isinstance(tir_expr, tvm.tir.Add):
             for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
                 if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
                     return f"{var.name}___{const.value}", tir_expr
 
+        if isinstance(tir_expr, tvm.tir.Mul):
+            for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
+                if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
+                    return f"{var.name}_{const.value}", tir_expr
+
         return str(symbol), tir_expr
 
     def create_input_vars(
@@ -1256,12 +1267,20 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 torch_shape = exported_program.state_dict[spec.target].shape
                 torch_dtype = exported_program.state_dict[spec.target].dtype
 
-            relax_shape = [
-                torch_symbol_to_relax_var.setdefault(str(s), 
tvm.tir.SizeVar(str(s), "int64"))
-                if isinstance(s, torch.SymInt)
-                else s
-                for s in torch_shape
-            ]
+            relax_shape = []
+            for s in torch_shape:
+                if isinstance(s, torch.SymInt):
+                    sympy_node = s.node.expr if hasattr(s.node, "expr") else 
s.node
+                    symbol_name, _ = self._process_derived_symbol(
+                        sympy_node, torch_symbol_to_relax_var
+                    )
+
+                    size_var = torch_symbol_to_relax_var.setdefault(
+                        symbol_name, tvm.tir.SizeVar(symbol_name, "int64")
+                    )
+                    relax_shape.append(size_var)
+                else:
+                    relax_shape.append(s)
             dtype = self._convert_data_type(torch_dtype)
 
             relax_var = relax.Var(name_hint, 
relax.TensorStructInfo(relax_shape, dtype))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 87022a2d7d..92140a54b8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7028,7 +7028,7 @@ def test_dynamic_shape_with_range_constraints():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_dynamic_shape_with_derived_range_constraints():
+def test_dynamic_shape_with_addition_constraints():
     class ConcatModel(torch.nn.Module):
         def forward(self, x, y):
             return torch.cat([x, y], dim=0)
@@ -7062,5 +7062,73 @@ def test_dynamic_shape_with_derived_range_constraints():
     tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
 
 
+def test_dynamic_shape_with_subtraction_constraints():
+    class ConcatModel(torch.nn.Module):
+        def forward(self, x, y):
+            return torch.cat([x, y], dim=0)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1", 
4), dtype="float32")
+        ) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")):
+            s1___1 = T.int64(is_size_var=True)
+            s1 = T.int64(is_size_var=True)
+            R.func_attr(
+                {
+                    "tir_var_lower_bound": {"s1": 0, "s1___1": 1},
+                    "tir_var_upper_bound": {"s1": 63, "s1___1": 64},
+                }
+            )
+            with R.dataflow():
+                lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x, 
y), axis=0)
+                gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    batch = torch.export.Dim("batch", min=1, max=64)
+    example_args = (torch.randn(8, 4), torch.randn(7, 4))
+    dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
+    exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
+
+    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+
+
+def test_dynamic_shape_with_multiplication_constraints():
+    class ConcatModel(torch.nn.Module):
+        def forward(self, x, y):
+            return torch.cat([x, y], dim=0)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")):
+            s0 = T.int64(is_size_var=True)
+            s0_2 = T.int64(is_size_var=True)
+            R.func_attr(
+                {
+                    "tir_var_lower_bound": {"s0": 1, "s0_2": 2},
+                    "tir_var_upper_bound": {"s0": 64, "s0_2": 128},
+                }
+            )
+            with R.dataflow():
+                lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x, 
y), axis=0)
+                gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    batch = torch.export.Dim("batch", min=1, max=64)
+    example_args = (torch.randn(8, 4), torch.randn(16, 4))
+    dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
+    exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
+
+    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to