guan404ming commented on code in PR #18452:
URL: https://github.com/apache/tvm/pull/18452#discussion_r2531964065
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1105,14 +1105,49 @@ def create_convert_map(
"_local_scalar_dense.default": self._item,
}
+ def _parse_sympy_to_tir_expr(
+ self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
+ ) -> tvm.tir.PrimExpr:
+ import sympy
+
+ if isinstance(symbol, sympy.Symbol):
+ return None
+
+ # Handle addition
+ if isinstance(symbol, sympy.Add):
+ result = None
+ for arg in symbol.args:
+ if isinstance(arg, sympy.Integer):
+ term = tvm.tir.IntImm("int64", int(arg))
+ elif isinstance(arg, sympy.Symbol):
+ var_name = str(arg)
+ term = torch_symbol_to_relax_var.setdefault(
+ var_name, tvm.tir.SizeVar(var_name, "int64")
+ )
+ else:
+ # Recursively parse nested expressions
+ term = self._parse_sympy_to_tir_expr(arg,
torch_symbol_to_relax_var)
+
+ result = term if result is None else result + term
Review Comment:
This makes sense to me, let me apply this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]