slyubomirsky commented on code in PR #16602:
URL: https://github.com/apache/tvm/pull/16602#discussion_r1496856295
##########
python/tvm/relax/transform/lazy_transform_params.py:
##########
@@ -157,24 +159,60 @@ def transform(self, func: relax.Function) ->
relax.Function:
self.memory_free_insertion = liveness.var_liveness_end
# Step 3. rewrite get item and set item
- new_body = func.body
if self.fget_item is not None:
- new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+ new_func = LazyInputMutator(self, self.mod).visit_expr(func)
+ new_body = new_func.body
if self.fset_item is not None:
+ leaf_outputs = {
+ expr: indices
+ for expr, indices in self.out_tuple_map.items()
+ if not isinstance(expr, relax.Var)
+ }
+ if leaf_outputs:
+ new_bindings = [
+ relax.VarBinding(
+ relax.Var("_", relax.ObjectStructInfo()),
+ relax.Call(
+ relax.ExternFunc(self.fset_item),
+ [*self.extra_set_item_params, index, expr],
+ None,
+ [relax.ObjectStructInfo()],
+ ),
+ )
+ for expr, indices in leaf_outputs.items()
+ for index in indices
Review Comment:
You've gotta love the syntax for nested list comprehensions
:slightly_smiling_face:
##########
python/tvm/relax/transform/lazy_transform_params.py:
##########
@@ -157,24 +159,60 @@ def transform(self, func: relax.Function) ->
relax.Function:
self.memory_free_insertion = liveness.var_liveness_end
# Step 3. rewrite get item and set item
- new_body = func.body
if self.fget_item is not None:
- new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+ new_func = LazyInputMutator(self, self.mod).visit_expr(func)
+ new_body = new_func.body
if self.fset_item is not None:
+ leaf_outputs = {
+ expr: indices
+ for expr, indices in self.out_tuple_map.items()
+ if not isinstance(expr, relax.Var)
+ }
+ if leaf_outputs:
+ new_bindings = [
+ relax.VarBinding(
+ relax.Var("_", relax.ObjectStructInfo()),
+ relax.Call(
+ relax.ExternFunc(self.fset_item),
+ [*self.extra_set_item_params, index, expr],
+ None,
+ [relax.ObjectStructInfo()],
+ ),
+ )
+ for expr, indices in leaf_outputs.items()
+ for index in indices
+ ]
+ new_body = relax.SeqExpr(
+ [*new_body.blocks, relax.BindingBlock(new_bindings)],
new_body.body
+ )
+
Review Comment:
I presume these additions are for handling the non-var case mentioned in the
description?
##########
tests/python/relax/test_transform_lazy_transform_params.py:
##########
@@ -602,5 +602,77 @@ def main_transform_params() -> R.Tuple:
tvm.ir.assert_structural_equal(after, Expected)
+def test_params_without_tuple():
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ @I.ir_module
+ class Expected:
+ @R.function(pure=False)
+ def transform_params():
+ A = R.call_packed("get_item", R.prim_value(0),
sinfo_args=[R.Object])
+ A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+ C = R.multiply(A, R.const(2, "float32"))
+
+ B = R.call_packed("get_item", R.prim_value(1),
sinfo_args=[R.Object])
+ B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ After = LazyTransformParams(fset_item=None)(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_retain_before_num_input():
+ """Only lazily load parameters after num_input"""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(
+ relax_rank: R.Prim(value="rank"),
+ A: R.Tensor([16, 16], "float32"),
+ B: R.Tensor([16, 16], "float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ rank = T.int64()
+ A_sharded = R.strided_slice(
+ A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8],
assume_inbound=True
+ )
+ B_sharded = R.strided_slice(
+ B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8],
assume_inbound=True
+ )
+ return (A_sharded, B_sharded)
+
+ @I.ir_module
+ class Expected:
+ @R.function(pure=False)
+ def transform_params(relax_rank: R.Prim(value="rank")):
+ R.func_attr({"num_input": 1})
+ rank = T.int64()
+
+ A = R.call_packed("get_item", R.prim_value(0),
sinfo_args=[R.Object])
+ A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+ A_sharded = R.strided_slice(
+ A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8],
assume_inbound=True
+ )
+
+ B = R.call_packed("get_item", R.prim_value(1),
sinfo_args=[R.Object])
+ B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+ B_sharded = R.strided_slice(
+ B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8],
assume_inbound=True
+ )
+
+ return (A_sharded, B_sharded)
+
+ After = LazyTransformParams(fset_item=None)(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
Review Comment:
Are there any test cases that make use of extra parameters for get_item and
set_item? If it's not tested, it should be. If there also isn't a case of a
non-var output (I'm not sure exactly what that should look like, as I haven't
used this pass), that would be good to add too.
##########
python/tvm/relax/transform/lazy_transform_params.py:
##########
@@ -157,24 +159,60 @@ def transform(self, func: relax.Function) ->
relax.Function:
self.memory_free_insertion = liveness.var_liveness_end
# Step 3. rewrite get item and set item
- new_body = func.body
if self.fget_item is not None:
- new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+ new_func = LazyInputMutator(self, self.mod).visit_expr(func)
+ new_body = new_func.body
if self.fset_item is not None:
+ leaf_outputs = {
+ expr: indices
+ for expr, indices in self.out_tuple_map.items()
+ if not isinstance(expr, relax.Var)
+ }
+ if leaf_outputs:
+ new_bindings = [
+ relax.VarBinding(
+ relax.Var("_", relax.ObjectStructInfo()),
+ relax.Call(
+ relax.ExternFunc(self.fset_item),
+ [*self.extra_set_item_params, index, expr],
+ None,
+ [relax.ObjectStructInfo()],
+ ),
+ )
+ for expr, indices in leaf_outputs.items()
+ for index in indices
+ ]
+ new_body = relax.SeqExpr(
+ [*new_body.blocks, relax.BindingBlock(new_bindings)],
new_body.body
+ )
+
new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)
# Step 4. Add parameters of get_item and set_item (except index) to
the function.
- params = [*self.extra_get_item_params, *self.extra_set_item_params]
+ params = [
+ *func.params[:num_input],
+ *self.extra_get_item_params,
+ *self.extra_set_item_params,
+ ]
# Step 5. Find all shape parameters that should be retained as
# parameters.
symbolic_vars = relax.analysis.defined_symbolic_vars(func)
if symbolic_vars:
+
+ def unpack_sinfo(sinfo):
+ if isinstance(sinfo, relax.TupleStructInfo):
+ for field in sinfo.fields:
+ yield from unpack_sinfo(field)
Review Comment:
First I'd seen `yield from`, this seems like a good use for it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]