gemini-code-assist[bot] commented on code in PR #18452:
URL: https://github.com/apache/tvm/pull/18452#discussion_r2531073666
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6920,5 +6920,40 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dynamic_shape_with_derived_range_constraints():
+ class ConcatModel(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.cat([x, y], dim=0)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1",
4), dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
+ s0 = T.int64(is_size_var=True)
+ s0___1 = T.int64(is_size_var=True)
+ R.func_attr(
+ {
+ "tir_var_expr": {"s0 + 1": 1 + s0},
+ "tir_var_lower_bound": {"s0": 1, "s0 + 1": 2},
+ "tir_var_upper_bound": {"s0": 64, "s0 + 1": 65},
+ }
+ )
+ with R.dataflow():
+ lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x,
y), axis=0)
+ gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
Review Comment:

The `Expected` IRModule seems to have some inconsistencies with what the
translator is expected to generate. Given `dynamic_shapes = {"x": {0: batch},
"y": {0: batch + 1}}`, the translator should generate a `SizeVar` named `"s0 +
1"` for the dynamic dimension of `y`. Therefore, the `y` parameter in `main`
should probably have the shape `R.Tensor(("s0 + 1", 4), ...)` instead of
`R.Tensor(("s0___1", 4), ...)`. Consequently, the output tensor shape would be
`R.Tensor(("s0 + s0 + 1", 4), ...)` and the free variable `s0___1` would not be
needed. Could you please double-check the `Expected` module definition?
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1105,14 +1105,49 @@ def create_convert_map(
"_local_scalar_dense.default": self._item,
}
+ def _parse_sympy_to_tir_expr(
+ self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
+ ) -> tvm.tir.PrimExpr:
+ import sympy
+
+ if isinstance(symbol, sympy.Symbol):
+ return None
+
+ # Handle addition
+ if isinstance(symbol, sympy.Add):
+ result = None
+ for arg in symbol.args:
+ if isinstance(arg, sympy.Integer):
+ term = tvm.tir.IntImm("int64", int(arg))
+ elif isinstance(arg, sympy.Symbol):
+ var_name = str(arg)
+ term = torch_symbol_to_relax_var.setdefault(
+ var_name, tvm.tir.SizeVar(var_name, "int64")
+ )
+ else:
+ # Recursively parse nested expressions
+ term = self._parse_sympy_to_tir_expr(arg,
torch_symbol_to_relax_var)
+
+ result = term if result is None else result + term
Review Comment:

The recursive call to `_parse_sympy_to_tir_expr` can return `None` for an
unsupported expression type. In that case, `term` becomes `None`, and the
expression `result + term` on line 1131 will raise a `TypeError`. This should
be handled by checking if `term` is `None` and returning `None` from
`_parse_sympy_to_tir_expr` to indicate that the expression cannot be parsed.
I've also slightly refactored the loop for clarity.
```suggestion
for arg in symbol.args:
if isinstance(arg, sympy.Integer):
term = tvm.tir.IntImm("int64", int(arg))
elif isinstance(arg, sympy.Symbol):
var_name = str(arg)
term = torch_symbol_to_relax_var.setdefault(
var_name, tvm.tir.SizeVar(var_name, "int64")
)
else:
# Recursively parse nested expressions
term = self._parse_sympy_to_tir_expr(arg,
torch_symbol_to_relax_var)
if term is None:
# An unsupported sub-expression was found, so we cannot
parse this expression.
return None
result = term if result is None else result + term
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]