This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 171ef61797 [Unity] make LazyTransformParam more general (#16088)
171ef61797 is described below
commit 171ef617977ca3ceb4c6b347b9ac3f5df783dc71
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Nov 9 14:14:08 2023 -0800
[Unity] make LazyTransformParam more general (#16088)
* lazy transform params
* format
* address comment
* fix
* fix ci
* fix ci
---
.../tvm/relax/transform/lazy_transform_params.py | 117 +++++++++++---
.../relax/test_transform_lazy_transform_params.py | 174 +++++++++++++++++++++
2 files changed, 268 insertions(+), 23 deletions(-)
diff --git a/python/tvm/relax/transform/lazy_transform_params.py
b/python/tvm/relax/transform/lazy_transform_params.py
index 6a8adcb64b..7f734f8a3c 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, unused-argument, missing-function-docstring,
abstract-method
+# pylint: disable=invalid-name, unused-argument, missing-function-docstring,
abstract-method, missing-class-docstring
"""Relax LazyTransformParams pass."""
+from typing import Optional
+
import tvm
from tvm import IRModule, relax
from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator,
visitor
@@ -107,8 +109,7 @@ class LivenessAnalysis(PyExprVisitor):
self.var_liveness_end[binding.var] = self.last_appear_in_var_binding
-@mutator
-class LazyTransformParamsMutator(PyExprMutator):
+class LazyTransformParamsFuncCreator:
"""
Transform transform_params functions into a lazy version.
@@ -118,16 +119,25 @@ class LazyTransformParamsMutator(PyExprMutator):
The module to be transformed
"""
- def __init__(self, mod: IRModule = None) -> None:
- super().__init__(mod)
+ def __init__(
+ self,
+ fget_item,
+ fset_item,
+ extra_get_item_params,
+ extra_set_item_params,
+ mod: IRModule = None,
+ ) -> None:
self.mod = mod
+ self.fget_item = fget_item
+ self.extra_get_item_params = extra_get_item_params
+ self.fset_item = fset_item
+ self.extra_set_item_params = extra_set_item_params
# the only input param, which should be a Tuple
self.input_tuple_param = None
self.input_params_set = None
self.out_tuple_map = None
self.out_tuple_var = None
self.memory_free_insertion = None
- self.killed_vars = set()
def transform(self, func: relax.Function) -> relax.Function:
self.input_tuple_param = func.params[0]
@@ -147,11 +157,18 @@ class LazyTransformParamsMutator(PyExprMutator):
self.memory_free_insertion = liveness.var_liveness_end
# Step 3. rewrite get item and set item
- new_body = self.visit_expr(func.body)
+ new_body = func.body
+ if self.fget_item is not None:
+ new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+
+ if self.fset_item is not None:
+ new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)
- # Step 4. Find all shape parameters that should be retained as
+ # Step 4. Add parameters of get_item and set_item (except index) to
the function.
+ params = [*self.extra_get_item_params, *self.extra_set_item_params]
+
+ # Step 5. Find all shape parameters that should be retained as
# parameters.
- params = []
symbolic_vars = relax.analysis.defined_symbolic_vars(func)
if symbolic_vars:
# direct iterate over the struct info annotation
@@ -167,14 +184,22 @@ class LazyTransformParamsMutator(PyExprMutator):
is_pure=False,
).without_attr("relax.force_pure")
+
+@mutator
+class LazyInputMutator(PyExprMutator):
+ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
+ self.func_creator = func_creator
+ super().__init__(mod)
+
def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
- if tuple_get_item.tuple_value == self.input_tuple_param:
+ if tuple_get_item.tuple_value == self.func_creator.input_tuple_param:
get_item_result = self.builder_.emit(
relax.Call(
- relax.ExternFunc("get_item"),
- [relax.PrimValue(tuple_get_item.index)],
+ relax.ExternFunc(self.func_creator.fget_item),
+ self.func_creator.extra_get_item_params
+ + [relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
)
@@ -183,12 +208,20 @@ class LazyTransformParamsMutator(PyExprMutator):
else:
return tuple_get_item
+
+@mutator
+class LazyOutputMutator(PyExprMutator):
+ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
+ self.func_creator = func_creator
+ self.killed_vars = set()
+ super().__init__(mod)
+
def visit_var_(self, var: relax.Var) -> None:
assert var not in self.killed_vars
return super().visit_var_(var)
def visit_var_binding_(self, binding: relax.VarBinding) -> None:
- if binding.var == self.out_tuple_var:
+ if binding.var == self.func_creator.out_tuple_var:
# The function after rewriting returns a empty tuple.
func_output = self.builder_.emit(relax.Tuple([]))
self.set_var_remap(binding.var.vid, func_output)
@@ -196,23 +229,24 @@ class LazyTransformParamsMutator(PyExprMutator):
super().visit_var_binding_(binding)
- if binding.var in self.memory_free_insertion:
- for var in self.memory_free_insertion[binding.var]:
- if var in self.out_tuple_map:
+ if binding.var in self.func_creator.memory_free_insertion:
+ for var in self.func_creator.memory_free_insertion[binding.var]:
+ if var in self.func_creator.out_tuple_map:
self.killed_vars.add(var)
- for index in self.out_tuple_map[var]:
+ for index in self.func_creator.out_tuple_map[var]:
# rewrite set item
self.builder_.emit(
relax.Call(
- relax.ExternFunc("set_item"),
- [index, super().visit_var_(var)],
+ relax.ExternFunc(self.func_creator.fset_item),
+ self.func_creator.extra_set_item_params
+ + [index, super().visit_var_(var)],
None,
[relax.ObjectStructInfo()],
),
name_hint="_",
)
- if var in self.input_params_set:
+ if var in self.func_creator.input_params_set:
self.builder_.emit(
relax.op.vm.kill_object(super().visit_var_(var)),
name_hint="_"
)
@@ -225,16 +259,53 @@ class LazyTransformParams:
(Load the input to memory on demand, and immediately free it after the
last use.)
Note: ToNonDataflow() and RemovePurityTracking() should be invoked before
this pass.
+
+ Parameters
+ ----------
+ fget_item: str
+ The name of the get_item function.
+ fset_item: str
+ The name of the set_item function.
+ extra_get_item_params: list of relax.Var
+ The parameters of the get_item function except index.
+ The given parameters will be placed before index.
+ For example, if extra_get_item_params is [param1, param2], then the
pass will generate
+ call_packed(fget_item, [param1, param2, index])
+ extra_set_item_params: list of relax.Var
+ The parameters of the set_item function except index and value.
+ The given parameters will be placed before index and value.
+ For example, if extra_set_item_params is [param1, param2], then the
pass will generate
+ call_packed(fset_item, [param1, param2, index, value])
"""
+ def __init__(
+ self,
+ fget_item="get_item",
+ fset_item="set_item",
+ extra_get_item_params=None,
+ extra_set_item_params=None,
+ ) -> None:
+ self.fget_item = fget_item
+ self.extra_get_item_params = [] if extra_get_item_params is None else
extra_get_item_params
+ assert self.fget_item is not None, "transforming set_item only is not
supported"
+ self.fset_item = fset_item
+ self.extra_set_item_params = [] if extra_set_item_params is None else
extra_set_item_params
+
def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext)
-> IRModule:
- lazy_mutator = LazyTransformParamsMutator(mod)
+ lazy_mutator = LazyTransformParamsFuncCreator(
+ self.fget_item,
+ self.fset_item,
+ self.extra_get_item_params,
+ self.extra_set_item_params,
+ mod,
+ )
+ builder = relax.BlockBuilder(mod)
for gv, _ in mod.functions_items():
if gv.name_hint.endswith("transform_params"):
func = mod[gv]
if not isinstance(func, relax.Function):
continue
func = lazy_mutator.transform(func)
- lazy_mutator.builder_.update_func(gv, func)
+ builder.update_func(gv, func)
- return lazy_mutator.builder_.get()
+ return builder.get()
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index af7ed1956b..8f958429c7 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -106,6 +106,180 @@ def test_lazy_transform_params():
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+def test_get_item_only():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(
+ w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3,
3), "float32")
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(
+ R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
+ ):
+ # we expect ToNonDataflow and RemovePurityTracking to be invoked
first
+ R.func_attr({"relax.force_pure": True})
+ cls = Before
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ lv3 = R.add(lv2, R.const(1, "float32"))
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 3, 3, 3), dtype="float32"),
+ ) = (lv, lv3)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(
+ w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3,
3), "float32")
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function(pure=False)
+ def main_transform_params() -> (
+ R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3,
3), dtype="float32")
+ )
+ ):
+ cls = Expected
+ gv: R.Object = R.call_packed("get_item_0", R.prim_value(1),
sinfo_args=(R.Object,))
+ gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+ gv, R.Tensor((16, 16, 3, 3), dtype="float32")
+ )
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+ gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0),
sinfo_args=(R.Object,))
+ gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
+ gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
+ )
+ lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2,
R.const(1, "float32"))
+ gv_1: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3,
3), dtype="float32")
+ ) = (lv, lv3)
+ return gv_1
+
+ after = LazyTransformParams(fget_item="get_item_0", fset_item=None)(Before)
+ tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+
+
+def test_extra_params():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(
+ w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3,
3), "float32")
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(
+ R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
+ ):
+ # we expect ToNonDataflow and RemovePurityTracking to be invoked
first
+ R.func_attr({"relax.force_pure": True})
+ cls = Before
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ lv3 = R.add(lv2, R.const(1, "float32"))
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 3, 3, 3), dtype="float32"),
+ ) = (lv, lv3)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(
+ w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3,
3), "float32")
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function(pure=False)
+ def main_transform_params(loader: R.Object) -> R.Tuple:
+ cls = Expected
+ gv: R.Object = R.call_packed(
+ "get_item", loader, R.prim_value(1), sinfo_args=(R.Object,)
+ )
+ gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+ gv, R.Tensor((16, 16, 3, 3), dtype="float32")
+ )
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+ _: R.Object = R.call_packed("set_item", R.prim_value(0), lv,
sinfo_args=(R.Object,))
+ _1: R.Tuple = R.vm.kill_object(lv)
+ gv2: R.Object = R.call_packed(
+ "get_item", loader, R.prim_value(0), sinfo_args=(R.Object,)
+ )
+ gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
+ gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
+ )
+ lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ _2: R.Tuple = R.vm.kill_object(lv1)
+ lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2,
R.const(1, "float32"))
+ _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv3,
sinfo_args=(R.Object,))
+ gv_1: R.Tuple = R.tuple()
+ return gv_1
+
+ after = LazyTransformParams(
+ extra_get_item_params=[relax.Var("loader", relax.ObjectStructInfo())]
+ )(Before)
+ tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+
+
def test_lazy_transform_params_with_symbolic_vars():
@I.ir_module
class Before: