This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 161049ef85 [Relax][PyTorch] Enhance handling of unbounded upper bound
constraints (#18489)
161049ef85 is described below
commit 161049ef85b62e1e178c86a53ae0102a86a452e2
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 26 13:14:32 2025 +0800
[Relax][PyTorch] Enhance handling of unbounded upper bound constraints
(#18489)
## Why
PyTorch uses int_oo (IntInfinity) for unbounded constraints, which would
make our current implemenation crash
## How
- Update the type hint for `create_input_vars` to allow for optional
upper bounds.
- Modify the logic to handle unbounded constraints by setting upper
bounds to None when applicable.
- Add a new test case
---
.../frontend/torch/exported_program_translator.py | 19 +++++++++++---
.../relax/test_frontend_from_exported_program.py | 30 ++++++++++++++++++++++
2 files changed, 45 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ac79024acf..95b0e05361 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1383,7 +1383,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
def create_input_vars(
self, exported_program: torch.export.ExportedProgram
- ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, int]]]:
+ ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, Optional[int]]]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
@@ -1391,11 +1391,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
range_constraints = {}
if hasattr(exported_program, "range_constraints"):
+ import math
+
for symbol, value_range in
exported_program.range_constraints.items():
if hasattr(value_range, "lower") and hasattr(value_range,
"upper"):
try:
+ # PyTorch uses int_oo (IntInfinity) for unbounded
constraints
lower = int(value_range.lower)
- upper = int(value_range.upper)
+ upper = (
+ None if math.isinf(float(value_range.upper)) else
int(value_range.upper)
+ )
symbol_name, _ = self._process_derived_symbol(
symbol, torch_symbol_to_relax_var
@@ -1472,10 +1477,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
func_attrs["tir_var_lower_bound"] = {
var_name: lower for var_name, (lower, _) in
range_constraints.items()
}
- func_attrs["tir_var_upper_bound"] = {
- var_name: upper for var_name, (_, upper) in
range_constraints.items()
+
+ upper_bounds = {
+ var_name: upper
+ for var_name, (_, upper) in range_constraints.items()
+ if upper is not None
}
+ if upper_bounds:
+ func_attrs["tir_var_upper_bound"] = upper_bounds
+
nodes: List[fx.Node] = exported_program.graph.nodes
# Find all the missing function types
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 78a8a09a3c..d4c23bfdd5 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7206,6 +7206,7 @@ def test_dynamic_shape():
lhs: R.Tensor((B, 4), dtype="float32"),
rhs: R.Tensor((B, 4), dtype="float32"),
) -> R.Tuple(R.Tensor((B, 4), dtype="float32")):
+ R.func_attr({"tir_var_lower_bound": {"s0": 0}})
with R.dataflow():
lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs)
gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,)
@@ -7909,6 +7910,34 @@ def test_dynamic_shape_with_multiplication_constraints():
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
+def test_dynamic_shape_with_unbounded_constraints():
+ class DynamicModel(torch.nn.Module):
+ def forward(self, x):
+ return torch.ops.aten.add.Tensor(x, x)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(("s0", 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
+ s0 = T.int64(is_size_var=True)
+ R.func_attr({"tir_var_lower_bound": {"s0": 2}})
+ with R.dataflow():
+ lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x)
+ gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(8, 4),)
+ batch = torch.export.Dim("batch", min=2)
+ dynamic_shapes = {"x": {0: batch}}
+ exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_sym_size_int():
class SymSizeInt(Module):
def __init__(self, dim):
@@ -7955,6 +7984,7 @@ def test_sym_size_int():
x: R.Tensor(("s0", 3, 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")):
s0 = T.int64(is_size_var=True)
+ R.func_attr({"tir_var_lower_bound": {"s0": 0}})
with R.dataflow():
lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x,
R.shape([s0, 12]))
gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)