This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 161049ef85 [Relax][PyTorch] Enhance handling of unbounded upper bound 
constraints (#18489)
161049ef85 is described below

commit 161049ef85b62e1e178c86a53ae0102a86a452e2
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 26 13:14:32 2025 +0800

    [Relax][PyTorch] Enhance handling of unbounded upper bound constraints 
(#18489)
    
    ## Why
    
    PyTorch uses int_oo (IntInfinity) for unbounded constraints, which would
    make our current implemenation crash
    
    ## How
    
    - Update the type hint for `create_input_vars` to allow for optional
    upper bounds.
    - Modify the logic to handle unbounded constraints by setting upper
    bounds to None when applicable.
    - Add a new test case
---
 .../frontend/torch/exported_program_translator.py  | 19 +++++++++++---
 .../relax/test_frontend_from_exported_program.py   | 30 ++++++++++++++++++++++
 2 files changed, 45 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ac79024acf..95b0e05361 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1383,7 +1383,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     def create_input_vars(
         self, exported_program: torch.export.ExportedProgram
-    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, 
Tuple[int, int]]]:
+    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, 
Tuple[int, Optional[int]]]]:
         """Create relax input vars."""
         parameters_buffers_constants = OrderedDict()
         user_inputs = OrderedDict()
@@ -1391,11 +1391,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         range_constraints = {}
 
         if hasattr(exported_program, "range_constraints"):
+            import math
+
             for symbol, value_range in 
exported_program.range_constraints.items():
                 if hasattr(value_range, "lower") and hasattr(value_range, 
"upper"):
                     try:
+                        # PyTorch uses int_oo (IntInfinity) for unbounded 
constraints
                         lower = int(value_range.lower)
-                        upper = int(value_range.upper)
+                        upper = (
+                            None if math.isinf(float(value_range.upper)) else 
int(value_range.upper)
+                        )
 
                         symbol_name, _ = self._process_derived_symbol(
                             symbol, torch_symbol_to_relax_var
@@ -1472,10 +1477,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             func_attrs["tir_var_lower_bound"] = {
                 var_name: lower for var_name, (lower, _) in 
range_constraints.items()
             }
-            func_attrs["tir_var_upper_bound"] = {
-                var_name: upper for var_name, (_, upper) in 
range_constraints.items()
+
+            upper_bounds = {
+                var_name: upper
+                for var_name, (_, upper) in range_constraints.items()
+                if upper is not None
             }
 
+            if upper_bounds:
+                func_attrs["tir_var_upper_bound"] = upper_bounds
+
         nodes: List[fx.Node] = exported_program.graph.nodes
 
         # Find all the missing function types
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 78a8a09a3c..d4c23bfdd5 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7206,6 +7206,7 @@ def test_dynamic_shape():
             lhs: R.Tensor((B, 4), dtype="float32"),
             rhs: R.Tensor((B, 4), dtype="float32"),
         ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")):
+            R.func_attr({"tir_var_lower_bound": {"s0": 0}})
             with R.dataflow():
                 lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs)
                 gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,)
@@ -7909,6 +7910,34 @@ def test_dynamic_shape_with_multiplication_constraints():
     tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)
 
 
+def test_dynamic_shape_with_unbounded_constraints():
+    class DynamicModel(torch.nn.Module):
+        def forward(self, x):
+            return torch.ops.aten.add.Tensor(x, x)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(("s0", 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
+            s0 = T.int64(is_size_var=True)
+            R.func_attr({"tir_var_lower_bound": {"s0": 2}})
+            with R.dataflow():
+                lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x)
+                gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(8, 4),)
+    batch = torch.export.Dim("batch", min=2)
+    dynamic_shapes = {"x": {0: batch}}
+    exported_program = export(DynamicModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
+
+    mod = from_exported_program(exported_program)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_sym_size_int():
     class SymSizeInt(Module):
         def __init__(self, dim):
@@ -7955,6 +7984,7 @@ def test_sym_size_int():
             x: R.Tensor(("s0", 3, 4), dtype="float32")
         ) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")):
             s0 = T.int64(is_size_var=True)
+            R.func_attr({"tir_var_lower_bound": {"s0": 0}})
             with R.dataflow():
                 lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x, 
R.shape([s0, 12]))
                 gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)

Reply via email to