mshr-h commented on code in PR #17898:
URL: https://github.com/apache/tvm/pull/17898#discussion_r2065386482


##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -519,13 +521,18 @@ def create_input_vars(
                 torch_shape = exported_program.state_dict[spec.target].shape
                 torch_dtype = exported_program.state_dict[spec.target].dtype
 
-            # TODO(mshr-h): Support range constraints
-            relax_shape = [
-                torch_symbol_to_relax_var.setdefault(str(s), 
tvm.tir.SizeVar(str(s), "int64"))
-                if isinstance(s, torch.SymInt)
-                else s
-                for s in torch_shape
-            ]
+            # UPDATED: Create SizeVars and map SymInts (removed original shape 
creation)

Review Comment:
   Why is this update needed? Looks like there's no functional change.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -534,7 +541,47 @@ def create_input_vars(
             else:
                 parameters_buffers_constants[name_hint] = relax_var
 
-        return parameters_buffers_constants, user_inputs
+        # NEW: Process range constraints (basic support for simple SymInt keys)
+        if hasattr(exported_program, "range_constraints"):
+            for torch_sym_expr, value_range in 
exported_program.range_constraints.items():
+                # Basic support: Only handle constraints where the key is a 
simple SymInt
+                if isinstance(torch_sym_expr, torch.SymInt):
+                    s_str = str(torch_sym_expr)
+                    if s_str in torch_symbol_to_relax_var:
+                        relax_tir_var = torch_symbol_to_relax_var[s_str]
+
+                        # Extract bounds, using None for infinity
+                        min_val = int(value_range.lower) if value_range.lower 
!= -sympy.oo else None
+                        max_val = int(value_range.upper) if value_range.upper 
!= sympy.oo else None
+
+                        if relax_tir_var not in relax_range_constraints:
+                            relax_range_constraints[relax_tir_var] = (min_val, 
max_val)
+                        else:
+                            # Refine existing constraints if the new one is 
tighter

Review Comment:
   Please also include a testcase for this condition.



##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4625,6 +4625,62 @@ def main(
     dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
 
     verify_model(DynamicModel(), example_args, {}, Expected, 
dynamic_shapes=dynamic_shapes)
+    
+def test_dynamic_shape_with_constraints():
+        B = torch.export.Dim("B", min=2, max=10)
+        S = torch.export.Dim("S", min=1)
+        # Use a tuple for args
+        example_args = (torch.randn(3, 4, dtype=torch.float32),)
+        # Dynamic shapes dict maps arg index to shape spec {dim_index: Dim obj}
+        dynamic_shapes = {0: {0: B, 1: S}}
+
+        class SimpleDynamic(torch.nn.Module):
+            def forward(self, x):
+                return torch.relu(x)
+
+        # Explicit export and import

Review Comment:
   Any reason to manually check structural equality? If no, please use the 
`verify_model()` utility function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to