gemini-code-assist[bot] commented on code in PR #18452:
URL: https://github.com/apache/tvm/pull/18452#discussion_r2531073666


##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6920,5 +6920,40 @@ def main(
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dynamic_shape_with_derived_range_constraints():
+    class ConcatModel(torch.nn.Module):
+        def forward(self, x, y):
+            return torch.cat([x, y], dim=0)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 
4), dtype="float32")
+        ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
+            s0 = T.int64(is_size_var=True)
+            s0___1 = T.int64(is_size_var=True)
+            R.func_attr(
+                {
+                    "tir_var_expr": {"s0 + 1": 1 + s0},
+                    "tir_var_lower_bound": {"s0": 1, "s0 + 1": 2},
+                    "tir_var_upper_bound": {"s0": 64, "s0 + 1": 65},
+                }
+            )
+            with R.dataflow():
+                lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, 
y), axis=0)
+                gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `Expected` IRModule seems to have some inconsistencies with what the 
translator is expected to generate. Given `dynamic_shapes = {"x": {0: batch}, 
"y": {0: batch + 1}}`, the translator should generate a `SizeVar` named `"s0 + 
1"` for the dynamic dimension of `y`. Therefore, the `y` parameter in `main` 
should probably have the shape `R.Tensor(("s0 + 1", 4), ...)` instead of 
`R.Tensor(("s0___1", 4), ...)`. Consequently, the output tensor shape would be 
`R.Tensor(("s0 + s0 + 1", 4), ...)` and the free variable `s0___1` would not be 
needed. Could you please double-check the `Expected` module definition?



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1105,14 +1105,49 @@ def create_convert_map(
             "_local_scalar_dense.default": self._item,
         }
 
+    def _parse_sympy_to_tir_expr(
+        self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
+    ) -> tvm.tir.PrimExpr:
+        import sympy
+
+        if isinstance(symbol, sympy.Symbol):
+            return None
+
+        # Handle addition
+        if isinstance(symbol, sympy.Add):
+            result = None
+            for arg in symbol.args:
+                if isinstance(arg, sympy.Integer):
+                    term = tvm.tir.IntImm("int64", int(arg))
+                elif isinstance(arg, sympy.Symbol):
+                    var_name = str(arg)
+                    term = torch_symbol_to_relax_var.setdefault(
+                        var_name, tvm.tir.SizeVar(var_name, "int64")
+                    )
+                else:
+                    # Recursively parse nested expressions
+                    term = self._parse_sympy_to_tir_expr(arg, 
torch_symbol_to_relax_var)
+
+                result = term if result is None else result + term

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The recursive call to `_parse_sympy_to_tir_expr` can return `None` for an 
unsupported expression type. In that case, `term` becomes `None`, and the 
expression `result + term` on line 1131 will raise a `TypeError`. This should 
be handled by checking if `term` is `None` and returning `None` from 
`_parse_sympy_to_tir_expr` to indicate that the expression cannot be parsed. 
I've also slightly refactored the loop for clarity.
   
   ```suggestion
               for arg in symbol.args:
                   if isinstance(arg, sympy.Integer):
                       term = tvm.tir.IntImm("int64", int(arg))
                   elif isinstance(arg, sympy.Symbol):
                       var_name = str(arg)
                       term = torch_symbol_to_relax_var.setdefault(
                           var_name, tvm.tir.SizeVar(var_name, "int64")
                       )
                   else:
                       # Recursively parse nested expressions
                       term = self._parse_sympy_to_tir_expr(arg, 
torch_symbol_to_relax_var)
   
                   if term is None:
                       # An unsupported sub-expression was found, so we cannot 
parse this expression.
                       return None
   
                   result = term if result is None else result + term
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to