This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 49973d1fc4 [Relax][PyTorch] Support advanced range constraints
(multiplication) (#18463)
49973d1fc4 is described below
commit 49973d1fc4dda847feff7e8f35b30c1db5c68b87
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Nov 18 18:39:22 2025 +0800
[Relax][PyTorch] Support advanced range constraints (multiplication)
(#18463)
## Related Issue
- https://github.com/apache/tvm/issues/17818
## Why
- Add support for multiplication expressions (e.g., s0 * 2) in PyTorch
dynamic shape constraints
## How
- Parse `SymPy` multiplication expressions from PyTorch's
range_constraints
---
.../frontend/torch/exported_program_translator.py | 35 ++++++++---
.../relax/test_frontend_from_exported_program.py | 70 +++++++++++++++++++++-
2 files changed, 96 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3b982b6b46..6aa118ee5c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1191,7 +1191,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if isinstance(symbol, sympy.Symbol):
return str(symbol), None
- if not isinstance(symbol, sympy.Add):
+ if not isinstance(symbol, (sympy.Add, sympy.Mul)):
return str(symbol), None
tir_expr = None
@@ -1207,13 +1207,24 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if term is None:
return str(symbol), None
- tir_expr = term if tir_expr is None else tir_expr + term
+
+ if tir_expr is None:
+ tir_expr = term
+ elif isinstance(symbol, sympy.Mul):
+ tir_expr = tir_expr * term
+ elif isinstance(symbol, sympy.Add):
+ tir_expr = tir_expr + term
if isinstance(tir_expr, tvm.tir.Add):
for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
return f"{var.name}___{const.value}", tir_expr
+ if isinstance(tir_expr, tvm.tir.Mul):
+ for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
+ if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
+ return f"{var.name}_{const.value}", tir_expr
+
return str(symbol), tir_expr
def create_input_vars(
@@ -1256,12 +1267,20 @@ class ExportedProgramImporter(BaseFXGraphImporter):
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype
- relax_shape = [
- torch_symbol_to_relax_var.setdefault(str(s),
tvm.tir.SizeVar(str(s), "int64"))
- if isinstance(s, torch.SymInt)
- else s
- for s in torch_shape
- ]
+ relax_shape = []
+ for s in torch_shape:
+ if isinstance(s, torch.SymInt):
+ sympy_node = s.node.expr if hasattr(s.node, "expr") else
s.node
+ symbol_name, _ = self._process_derived_symbol(
+ sympy_node, torch_symbol_to_relax_var
+ )
+
+ size_var = torch_symbol_to_relax_var.setdefault(
+ symbol_name, tvm.tir.SizeVar(symbol_name, "int64")
+ )
+ relax_shape.append(size_var)
+ else:
+ relax_shape.append(s)
dtype = self._convert_data_type(torch_dtype)
relax_var = relax.Var(name_hint,
relax.TensorStructInfo(relax_shape, dtype))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 87022a2d7d..92140a54b8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7028,7 +7028,7 @@ def test_dynamic_shape_with_range_constraints():
tvm.ir.assert_structural_equal(mod, Expected)
-def test_dynamic_shape_with_derived_range_constraints():
+def test_dynamic_shape_with_addition_constraints():
class ConcatModel(torch.nn.Module):
def forward(self, x, y):
return torch.cat([x, y], dim=0)
@@ -7062,5 +7062,73 @@ def test_dynamic_shape_with_derived_range_constraints():
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+def test_dynamic_shape_with_subtraction_constraints():
+ class ConcatModel(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.cat([x, y], dim=0)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1",
4), dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")):
+ s1___1 = T.int64(is_size_var=True)
+ s1 = T.int64(is_size_var=True)
+ R.func_attr(
+ {
+ "tir_var_lower_bound": {"s1": 0, "s1___1": 1},
+ "tir_var_upper_bound": {"s1": 63, "s1___1": 64},
+ }
+ )
+ with R.dataflow():
+ lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x,
y), axis=0)
+ gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ batch = torch.export.Dim("batch", min=1, max=64)
+ example_args = (torch.randn(8, 4), torch.randn(7, 4))
+ dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
+ exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+
+
+def test_dynamic_shape_with_multiplication_constraints():
+ class ConcatModel(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.cat([x, y], dim=0)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4),
dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")):
+ s0 = T.int64(is_size_var=True)
+ s0_2 = T.int64(is_size_var=True)
+ R.func_attr(
+ {
+ "tir_var_lower_bound": {"s0": 1, "s0_2": 2},
+ "tir_var_upper_bound": {"s0": 64, "s0_2": 128},
+ }
+ )
+ with R.dataflow():
+ lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x,
y), axis=0)
+ gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ batch = torch.export.Dim("batch", min=1, max=64)
+ example_args = (torch.randn(8, 4), torch.randn(16, 4))
+ dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
+ exported_program = export(ConcatModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+
+
if __name__ == "__main__":
tvm.testing.main()