slyubomirsky commented on code in PR #16785:
URL: https://github.com/apache/tvm/pull/16785#discussion_r1538475836
##########
python/tvm/relax/frontend/nn/exporter.py:
##########
@@ -190,34 +207,64 @@ def _convert_input(arg):
def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []
- def _get_var(shape_var: tir.Var) -> tir.Var:
- name = shape_var.name
- if name in str2var_params:
- return str2var_params[name]
- var = tir.Var(name, "int64")
- str2var_params[name] = var
- return var
+ def _normalize_dim(dim: typing.Union[int, str, tir.Var]) ->
tir.PrimExpr:
+ if isinstance(dim, int):
+ return tir.IntImm("int64", dim)
+ elif isinstance(dim, str):
+ if dim in str2var_params:
+ return str2var_params[dim]
+ else:
+ new_var = tir.Var(dim, "int64")
+ str2var_params[dim] = new_var
+ return new_var
+ elif isinstance(dim, tir.Var):
+ return dim
+ else:
+ raise TypeError(
+ f"Expected dim to be int, str, or tir.Var, "
+ f"but {dim} was of type {type(dim)}."
+ )
for name, param in params:
# Make sure the a symbolic shape is not re-registered (same as
_method_spec_to_inputs)
# e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1`
for `embed_tokens`
- new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in
param.shape]
- var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
+ new_shape = [_normalize_dim(dim) for dim in param._shape]
+ # var_cls = rx.DataflowVar if mode == "packed" else rx.Var
Review Comment:
Is this line meant to be used at any point?
##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -676,12 +676,31 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] =
None, name: str = None)
result : Tensor
The transposed result.
"""
+
+ # TODO(Lunderberg): This is a more extensive auto-naming than
+ # intended here. Is this still worth it?
Review Comment:
Do we expect these chains of definitions to be deep? If they can be, this
might be undesirable.
##########
tests/python/relax/test_frontend_nn_packing.py:
##########
@@ -25,7 +25,9 @@ def _iter_binding_names(mod):
"""Helper function to compare the names of relax variables"""
for block in mod["forward"].body.blocks:
for binding in block.bindings:
- yield binding.var.name_hint
+ # Relax variable names may contain '.' even though it
+ # cannot be expressed in TVMScript.
Review Comment:
I wonder if this is something we should just check for and prohibit.
##########
python/tvm/relax/frontend/nn/exporter.py:
##########
@@ -135,9 +136,18 @@ def _effects() -> typing.List[typing.Tuple[str,
core.Effect]]:
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder,
method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
+
+ # TODO(Lunderberg): Make a `ir.transform.ConvertSSA`,
+ # similar to the existing `tir.transform.ConvertSSA`,
+ # that converts an entire module to SSA, including TIR
+ # variable definitions used in either TIR or Relax.
Review Comment:
What would this conversion do on the Relax side? I thought vars already had
exactly one point of definition in Relax.
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -591,7 +609,22 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor,
Sequence[Tensor]]:
The computed result.
"""
if not isinstance(expr, rx.DataflowVar):
- expr = BlockBuilder.current().emit(expr, name)
+ block_builder = BlockBuilder.current()
+ if block_builder is None:
+ # Normalize to make sure we have valid StructInfo, but
+ # wait until we are actually building the function to
+ # flatten nested expressions.
+ #
+ # TODO(Lunderberg): Make this easier to call. Infering
+ # struct info for a nested expression should be doable in
+ # a free function, without requiring an active
+ # BlockBuilder and an active FunctionFrame.
Review Comment:
Yeah... we might be able to decouple some of the functionality and allow for
calling things like that separately.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]