Lunderberg commented on code in PR #16785:
URL: https://github.com/apache/tvm/pull/16785#discussion_r1539399718
##########
python/tvm/relax/frontend/nn/exporter.py:
##########
@@ -190,34 +207,64 @@ def _convert_input(arg):
def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []
- def _get_var(shape_var: tir.Var) -> tir.Var:
- name = shape_var.name
- if name in str2var_params:
- return str2var_params[name]
- var = tir.Var(name, "int64")
- str2var_params[name] = var
- return var
+ def _normalize_dim(dim: typing.Union[int, str, tir.Var]) ->
tir.PrimExpr:
+ if isinstance(dim, int):
+ return tir.IntImm("int64", dim)
+ elif isinstance(dim, str):
+ if dim in str2var_params:
+ return str2var_params[dim]
+ else:
+ new_var = tir.Var(dim, "int64")
+ str2var_params[dim] = new_var
+ return new_var
+ elif isinstance(dim, tir.Var):
+ return dim
+ else:
+ raise TypeError(
+ f"Expected dim to be int, str, or tir.Var, "
+ f"but {dim} was of type {type(dim)}."
+ )
for name, param in params:
# Make sure the a symbolic shape is not re-registered (same as
_method_spec_to_inputs)
# e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1`
for `embed_tokens`
- new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in
param.shape]
- var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
+ new_shape = [_normalize_dim(dim) for dim in param._shape]
+ # var_cls = rx.DataflowVar if mode == "packed" else rx.Var
Review Comment:
Whoops, that was a test during dev work. Removing the commented-out
`var_cls` line.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]