This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ea89f21ec5 [Relax][PyTorch] Support advanced range constraints 
(addition) (#18452)
ea89f21ec5 is described below

commit ea89f21ec53e86ddc7b1799d940b0d8ca569666a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Nov 17 15:27:45 2025 +0800

    [Relax][PyTorch] Support advanced range constraints (addition) (#18452)
    
    ## Related Issue
    
    - https://github.com/apache/tvm/issues/17818
    
    ## Why
    
    - Add support for addition expressions (e.g., s0 + 1) in PyTorch dynamic
    shape constraints
    
    ## How
    
    - Parse `SymPy` addition expressions from PyTorch's range_constraints
---
 .../frontend/torch/exported_program_translator.py  | 46 +++++++++++++++++++---
 .../relax/test_frontend_from_exported_program.py   | 34 ++++++++++++++++
 2 files changed, 75 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c6243c113e..44e967ec0e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -20,7 +20,7 @@
 """PyTorch ExportedProgram of Relax."""
 from collections import ChainMap, OrderedDict
 from functools import partial
-from typing import Callable, Dict, List, Tuple
+from typing import Callable, Dict, List, Optional, Tuple
 
 import torch
 import tvm
@@ -1181,6 +1181,40 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "_local_scalar_dense.default": self._item,
         }
 
+    def _process_derived_symbol(
+        self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
+    ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
+        """Process a sympy symbol to generate a descriptive name and TIR 
expression."""
+        import sympy
+
+        if isinstance(symbol, sympy.Symbol):
+            return str(symbol), None
+
+        if not isinstance(symbol, sympy.Add):
+            return str(symbol), None
+
+        tir_expr = None
+        for arg in symbol.args:
+            if isinstance(arg, sympy.Integer):
+                term = tvm.tir.IntImm("int64", int(arg))
+            elif isinstance(arg, sympy.Symbol):
+                term = torch_symbol_to_relax_var.setdefault(
+                    str(arg), tvm.tir.SizeVar(str(arg), "int64")
+                )
+            else:
+                _, term = self._process_derived_symbol(arg, 
torch_symbol_to_relax_var)
+
+            if term is None:
+                return str(symbol), None
+            tir_expr = term if tir_expr is None else tir_expr + term
+
+        if isinstance(tir_expr, tvm.tir.Add):
+            for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
+                if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
+                    return f"{var.name}___{const.value}", tir_expr
+
+        return str(symbol), tir_expr
+
     def create_input_vars(
         self, exported_program: torch.export.ExportedProgram
     ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, 
Tuple[int, int]]]:
@@ -1192,12 +1226,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         if hasattr(exported_program, "range_constraints"):
             for symbol, value_range in 
exported_program.range_constraints.items():
-                symbol_name = str(symbol)
                 if hasattr(value_range, "lower") and hasattr(value_range, 
"upper"):
                     try:
                         lower = int(value_range.lower)
                         upper = int(value_range.upper)
+
+                        symbol_name, _ = self._process_derived_symbol(
+                            symbol, torch_symbol_to_relax_var
+                        )
                         range_constraints[symbol_name] = (lower, upper)
+
                     except (OverflowError, AttributeError, TypeError):
                         continue
 
@@ -1255,10 +1293,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         # Initialize the block builder with a function and a dataflow block.
         self.block_builder = relax.BlockBuilder()
         func_name = "main"
-        func_attrs = {"num_input": len(user_input_vars)} if 
keep_params_as_input else None
+        func_attrs = {"num_input": len(user_input_vars)} if 
keep_params_as_input else {}
         if range_constraints:
-            if func_attrs is None:
-                func_attrs = {}
             func_attrs["tir_var_lower_bound"] = {
                 var_name: lower for var_name, (lower, _) in 
range_constraints.items()
             }
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 6cf293d96b..ef2736778f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7000,5 +7000,39 @@ def test_dynamic_shape_with_range_constraints():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dynamic_shape_with_derived_range_constraints():
+    class ConcatModel(torch.nn.Module):
+        def forward(self, x, y):
+            return torch.cat([x, y], dim=0)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 
4), dtype="float32")
+        ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
+            s0 = T.int64(is_size_var=True)
+            s0___1 = T.int64(is_size_var=True)
+            R.func_attr(
+                {
+                    "tir_var_lower_bound": {"s0": 1, "s0___1": 2},
+                    "tir_var_upper_bound": {"s0": 64, "s0___1": 65},
+                }
+            )
+            with R.dataflow():
+                lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, 
y), axis=0)
+                gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    batch = torch.export.Dim("batch", min=1, max=64)
+    example_args = (torch.randn(8, 4), torch.randn(9, 4))
+    dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
+    exported_program = export(ConcatModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
+
+    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to