This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ea89f21ec5 [Relax][PyTorch] Support advanced range constraints
(addition) (#18452)
ea89f21ec5 is described below
commit ea89f21ec53e86ddc7b1799d940b0d8ca569666a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Nov 17 15:27:45 2025 +0800
[Relax][PyTorch] Support advanced range constraints (addition) (#18452)
## Related Issue
- https://github.com/apache/tvm/issues/17818
## Why
- Add support for addition expressions (e.g., s0 + 1) in PyTorch dynamic
shape constraints
## How
- Parse `SymPy` addition expressions from PyTorch's range_constraints
---
.../frontend/torch/exported_program_translator.py | 46 +++++++++++++++++++---
.../relax/test_frontend_from_exported_program.py | 34 ++++++++++++++++
2 files changed, 75 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c6243c113e..44e967ec0e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -20,7 +20,7 @@
"""PyTorch ExportedProgram of Relax."""
from collections import ChainMap, OrderedDict
from functools import partial
-from typing import Callable, Dict, List, Tuple
+from typing import Callable, Dict, List, Optional, Tuple
import torch
import tvm
@@ -1181,6 +1181,40 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"_local_scalar_dense.default": self._item,
}
+ def _process_derived_symbol(
+ self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
+ ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
+ """Process a sympy symbol to generate a descriptive name and TIR
expression."""
+ import sympy
+
+ if isinstance(symbol, sympy.Symbol):
+ return str(symbol), None
+
+ if not isinstance(symbol, sympy.Add):
+ return str(symbol), None
+
+ tir_expr = None
+ for arg in symbol.args:
+ if isinstance(arg, sympy.Integer):
+ term = tvm.tir.IntImm("int64", int(arg))
+ elif isinstance(arg, sympy.Symbol):
+ term = torch_symbol_to_relax_var.setdefault(
+ str(arg), tvm.tir.SizeVar(str(arg), "int64")
+ )
+ else:
+ _, term = self._process_derived_symbol(arg,
torch_symbol_to_relax_var)
+
+ if term is None:
+ return str(symbol), None
+ tir_expr = term if tir_expr is None else tir_expr + term
+
+ if isinstance(tir_expr, tvm.tir.Add):
+ for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
+ if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
+ return f"{var.name}___{const.value}", tir_expr
+
+ return str(symbol), tir_expr
+
def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, int]]]:
@@ -1192,12 +1226,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if hasattr(exported_program, "range_constraints"):
for symbol, value_range in
exported_program.range_constraints.items():
- symbol_name = str(symbol)
if hasattr(value_range, "lower") and hasattr(value_range,
"upper"):
try:
lower = int(value_range.lower)
upper = int(value_range.upper)
+
+ symbol_name, _ = self._process_derived_symbol(
+ symbol, torch_symbol_to_relax_var
+ )
range_constraints[symbol_name] = (lower, upper)
+
except (OverflowError, AttributeError, TypeError):
continue
@@ -1255,10 +1293,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
func_name = "main"
- func_attrs = {"num_input": len(user_input_vars)} if
keep_params_as_input else None
+ func_attrs = {"num_input": len(user_input_vars)} if
keep_params_as_input else {}
if range_constraints:
- if func_attrs is None:
- func_attrs = {}
func_attrs["tir_var_lower_bound"] = {
var_name: lower for var_name, (lower, _) in
range_constraints.items()
}
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 6cf293d96b..ef2736778f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7000,5 +7000,39 @@ def test_dynamic_shape_with_range_constraints():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dynamic_shape_with_derived_range_constraints():
+ class ConcatModel(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.cat([x, y], dim=0)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1",
4), dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
+ s0 = T.int64(is_size_var=True)
+ s0___1 = T.int64(is_size_var=True)
+ R.func_attr(
+ {
+ "tir_var_lower_bound": {"s0": 1, "s0___1": 2},
+ "tir_var_upper_bound": {"s0": 64, "s0___1": 65},
+ }
+ )
+ with R.dataflow():
+ lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x,
y), axis=0)
+ gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ batch = torch.export.Dim("batch", min=1, max=64)
+ example_args = (torch.randn(8, 4), torch.randn(9, 4))
+ dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
+ exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+
+
if __name__ == "__main__":
tvm.testing.main()