This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 394f668e0d [Relax][Pytorch] Support basic range constraints (#18429)
394f668e0d is described below

commit 394f668e0d568b23930b60d7c8e3e91f0bd2d667
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 12 14:57:37 2025 +0800

    [Relax][Pytorch] Support basic range constraints (#18429)
    
    * Support basic range constraints
    
    * Apply gemini-code-assist suggestions
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
    
    * Apply reviewer comments
    
    * Fix lint error
    
    * Refactor frontend test to use consistent size variable
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 .../frontend/torch/exported_program_translator.py  | 30 +++++++++++++++++++---
 .../relax/test_frontend_from_exported_program.py   | 28 ++++++++++++++++++++
 2 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ddd19f2b58..0dfa4cc6da 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1099,11 +1099,23 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     def create_input_vars(
         self, exported_program: torch.export.ExportedProgram
-    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
+    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, 
Tuple[int, int]]]:
         """Create relax input vars."""
         parameters_buffers_constants = OrderedDict()
         user_inputs = OrderedDict()
         torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
+        range_constraints = {}
+
+        if hasattr(exported_program, "range_constraints"):
+            for symbol, value_range in 
exported_program.range_constraints.items():
+                symbol_name = str(symbol)
+                if hasattr(value_range, "lower") and hasattr(value_range, 
"upper"):
+                    try:
+                        lower = int(value_range.lower)
+                        upper = int(value_range.upper)
+                        range_constraints[symbol_name] = (lower, upper)
+                    except (OverflowError, AttributeError, TypeError):
+                        continue
 
         for spec in exported_program.graph_signature.input_specs:
             name_hint = spec.arg.name
@@ -1121,7 +1133,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 torch_shape = exported_program.state_dict[spec.target].shape
                 torch_dtype = exported_program.state_dict[spec.target].dtype
 
-            # TODO(mshr-h): Support range constraints
             relax_shape = [
                 torch_symbol_to_relax_var.setdefault(str(s), 
tvm.tir.SizeVar(str(s), "int64"))
                 if isinstance(s, torch.SymInt)
@@ -1136,7 +1147,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             else:
                 parameters_buffers_constants[name_hint] = relax_var
 
-        return parameters_buffers_constants, user_inputs
+        return parameters_buffers_constants, user_inputs, range_constraints
 
     def from_exported_program(
         self,
@@ -1149,7 +1160,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         from torch import fx  # type: ignore
 
         # Create input variables.
-        parameter_buffer_constant_vars, user_input_vars = 
self.create_input_vars(exported_program)
+        (
+            parameter_buffer_constant_vars,
+            user_input_vars,
+            range_constraints,
+        ) = self.create_input_vars(exported_program)
         inputs_vars = user_input_vars.copy()
         inputs_vars.update(parameter_buffer_constant_vars)
 
@@ -1157,6 +1172,13 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         self.block_builder = relax.BlockBuilder()
         func_name = "main"
         func_attrs = {"num_input": len(user_input_vars)} if 
keep_params_as_input else None
+        if range_constraints:
+            if func_attrs is None:
+                func_attrs = {}
+            tir_var_upper_bound = {
+                var_name: upper for var_name, (_, upper) in 
range_constraints.items()
+            }
+            func_attrs["tir_var_upper_bound"] = tir_var_upper_bound
 
         nodes: List[fx.Node] = exported_program.graph.nodes
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index fb4f77567e..ba14356e8e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6663,5 +6663,33 @@ def test_gru():
     np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, 
rtol=1e-4, atol=1e-5)
 
 
+def test_dynamic_shape_with_range_constraints():
+    class DynamicModel(torch.nn.Module):
+        def forward(self, x1, x2):
+            return torch.ops.aten.add.Tensor(x1, x2)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
+            s0 = T.int64(is_size_var=True)
+            R.func_attr({"tir_var_upper_bound": {"s0": 64}})
+            with R.dataflow():
+                lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
+                gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(8, 4), torch.randn(8, 4))
+    batch = torch.export.Dim("batch", min=1, max=64)
+    dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
+    exported_program = export(DynamicModel(), args=example_args, 
dynamic_shapes=dynamic_shapes)
+
+    mod = from_exported_program(exported_program, run_ep_decomposition=True)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to