This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 171ef61797 [Unity] make LazyTransformParam more general (#16088)
171ef61797 is described below

commit 171ef617977ca3ceb4c6b347b9ac3f5df783dc71
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Nov 9 14:14:08 2023 -0800

    [Unity] make LazyTransformParam more general (#16088)
    
    * lazy transform params
    
    * format
    
    * address comment
    
    * fix
    
    * fix ci
    
    * fix ci
---
 .../tvm/relax/transform/lazy_transform_params.py   | 117 +++++++++++---
 .../relax/test_transform_lazy_transform_params.py  | 174 +++++++++++++++++++++
 2 files changed, 268 insertions(+), 23 deletions(-)

diff --git a/python/tvm/relax/transform/lazy_transform_params.py 
b/python/tvm/relax/transform/lazy_transform_params.py
index 6a8adcb64b..7f734f8a3c 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -14,8 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-argument, missing-function-docstring, 
abstract-method
+# pylint: disable=invalid-name, unused-argument, missing-function-docstring, 
abstract-method, missing-class-docstring
 """Relax LazyTransformParams pass."""
+from typing import Optional
+
 import tvm
 from tvm import IRModule, relax
 from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, 
visitor
@@ -107,8 +109,7 @@ class LivenessAnalysis(PyExprVisitor):
         self.var_liveness_end[binding.var] = self.last_appear_in_var_binding
 
 
-@mutator
-class LazyTransformParamsMutator(PyExprMutator):
+class LazyTransformParamsFuncCreator:
     """
     Transform transform_params functions into a lazy version.
 
@@ -118,16 +119,25 @@ class LazyTransformParamsMutator(PyExprMutator):
         The module to be transformed
     """
 
-    def __init__(self, mod: IRModule = None) -> None:
-        super().__init__(mod)
+    def __init__(
+        self,
+        fget_item,
+        fset_item,
+        extra_get_item_params,
+        extra_set_item_params,
+        mod: IRModule = None,
+    ) -> None:
         self.mod = mod
+        self.fget_item = fget_item
+        self.extra_get_item_params = extra_get_item_params
+        self.fset_item = fset_item
+        self.extra_set_item_params = extra_set_item_params
         # the only input param, which should be a Tuple
         self.input_tuple_param = None
         self.input_params_set = None
         self.out_tuple_map = None
         self.out_tuple_var = None
         self.memory_free_insertion = None
-        self.killed_vars = set()
 
     def transform(self, func: relax.Function) -> relax.Function:
         self.input_tuple_param = func.params[0]
@@ -147,11 +157,18 @@ class LazyTransformParamsMutator(PyExprMutator):
         self.memory_free_insertion = liveness.var_liveness_end
 
         # Step 3. rewrite get item and set item
-        new_body = self.visit_expr(func.body)
+        new_body = func.body
+        if self.fget_item is not None:
+            new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
+
+        if self.fset_item is not None:
+            new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)
 
-        # Step 4. Find all shape parameters that should be retained as
+        # Step 4. Add parameters of get_item and set_item (except index) to 
the function.
+        params = [*self.extra_get_item_params, *self.extra_set_item_params]
+
+        # Step 5. Find all shape parameters that should be retained as
         # parameters.
-        params = []
         symbolic_vars = relax.analysis.defined_symbolic_vars(func)
         if symbolic_vars:
             # direct iterate over the struct info annotation
@@ -167,14 +184,22 @@ class LazyTransformParamsMutator(PyExprMutator):
             is_pure=False,
         ).without_attr("relax.force_pure")
 
+
+@mutator
+class LazyInputMutator(PyExprMutator):
+    def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
+        self.func_creator = func_creator
+        super().__init__(mod)
+
     def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
         # rewrite get item
         tuple_get_item = super().visit_tuple_getitem_(op)
-        if tuple_get_item.tuple_value == self.input_tuple_param:
+        if tuple_get_item.tuple_value == self.func_creator.input_tuple_param:
             get_item_result = self.builder_.emit(
                 relax.Call(
-                    relax.ExternFunc("get_item"),
-                    [relax.PrimValue(tuple_get_item.index)],
+                    relax.ExternFunc(self.func_creator.fget_item),
+                    self.func_creator.extra_get_item_params
+                    + [relax.PrimValue(tuple_get_item.index)],
                     None,
                     [relax.ObjectStructInfo()],
                 )
@@ -183,12 +208,20 @@ class LazyTransformParamsMutator(PyExprMutator):
         else:
             return tuple_get_item
 
+
+@mutator
+class LazyOutputMutator(PyExprMutator):
+    def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
+        self.func_creator = func_creator
+        self.killed_vars = set()
+        super().__init__(mod)
+
     def visit_var_(self, var: relax.Var) -> None:
         assert var not in self.killed_vars
         return super().visit_var_(var)
 
     def visit_var_binding_(self, binding: relax.VarBinding) -> None:
-        if binding.var == self.out_tuple_var:
+        if binding.var == self.func_creator.out_tuple_var:
             # The function after rewriting returns a empty tuple.
             func_output = self.builder_.emit(relax.Tuple([]))
             self.set_var_remap(binding.var.vid, func_output)
@@ -196,23 +229,24 @@ class LazyTransformParamsMutator(PyExprMutator):
 
         super().visit_var_binding_(binding)
 
-        if binding.var in self.memory_free_insertion:
-            for var in self.memory_free_insertion[binding.var]:
-                if var in self.out_tuple_map:
+        if binding.var in self.func_creator.memory_free_insertion:
+            for var in self.func_creator.memory_free_insertion[binding.var]:
+                if var in self.func_creator.out_tuple_map:
                     self.killed_vars.add(var)
-                    for index in self.out_tuple_map[var]:
+                    for index in self.func_creator.out_tuple_map[var]:
                         # rewrite set item
                         self.builder_.emit(
                             relax.Call(
-                                relax.ExternFunc("set_item"),
-                                [index, super().visit_var_(var)],
+                                relax.ExternFunc(self.func_creator.fset_item),
+                                self.func_creator.extra_set_item_params
+                                + [index, super().visit_var_(var)],
                                 None,
                                 [relax.ObjectStructInfo()],
                             ),
                             name_hint="_",
                         )
 
-                if var in self.input_params_set:
+                if var in self.func_creator.input_params_set:
                     self.builder_.emit(
                         relax.op.vm.kill_object(super().visit_var_(var)), 
name_hint="_"
                     )
@@ -225,16 +259,53 @@ class LazyTransformParams:
     (Load the input to memory on demand, and immediately free it after the 
last use.)
 
     Note: ToNonDataflow() and RemovePurityTracking() should be invoked before 
this pass.
+
+    Parameters
+    ----------
+    fget_item: str
+        The name of the get_item function.
+    fset_item: str
+        The name of the set_item function.
+    extra_get_item_params: list of relax.Var
+        The parameters of the get_item function except index.
+        The given parameters will be placed before index.
+        For example, if extra_get_item_params is [param1, param2], then the 
pass will generate
+        call_packed(fget_item, [param1, param2, index])
+    extra_set_item_params: list of relax.Var
+        The parameters of the set_item function except index and value.
+        The given parameters will be placed before index and value.
+        For example, if extra_set_item_params is [param1, param2], then the 
pass will generate
+        call_packed(fset_item, [param1, param2, index, value])
     """
 
+    def __init__(
+        self,
+        fget_item="get_item",
+        fset_item="set_item",
+        extra_get_item_params=None,
+        extra_set_item_params=None,
+    ) -> None:
+        self.fget_item = fget_item
+        self.extra_get_item_params = [] if extra_get_item_params is None else 
extra_get_item_params
+        assert self.fget_item is not None, "transforming set_item only is not 
supported"
+        self.fset_item = fset_item
+        self.extra_set_item_params = [] if extra_set_item_params is None else 
extra_set_item_params
+
     def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) 
-> IRModule:
-        lazy_mutator = LazyTransformParamsMutator(mod)
+        lazy_mutator = LazyTransformParamsFuncCreator(
+            self.fget_item,
+            self.fset_item,
+            self.extra_get_item_params,
+            self.extra_set_item_params,
+            mod,
+        )
+        builder = relax.BlockBuilder(mod)
         for gv, _ in mod.functions_items():
             if gv.name_hint.endswith("transform_params"):
                 func = mod[gv]
                 if not isinstance(func, relax.Function):
                     continue
                 func = lazy_mutator.transform(func)
-                lazy_mutator.builder_.update_func(gv, func)
+                builder.update_func(gv, func)
 
-        return lazy_mutator.builder_.get()
+        return builder.get()
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index af7ed1956b..8f958429c7 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -106,6 +106,180 @@ def test_lazy_transform_params():
     tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
 
 
+def test_get_item_only():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def transform_layout_IOHW_to_OIHW(
+            w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 
3), "float32")
+        ):
+            for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+                with T.block("layout_transform"):
+                    o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w1[i, o, h, w])
+                    T.writes(out[o, i, h, w])
+                    out[o, i, h, w] = w1[i, o, h, w]
+
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(
+                R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
+            )
+        ) -> R.Tuple(
+            R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
+        ):
+            # we expect ToNonDataflow and RemovePurityTracking to be invoked 
first
+            R.func_attr({"relax.force_pure": True})
+            cls = Before
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+            lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+            )
+            lv3 = R.add(lv2, R.const(1, "float32"))
+            gv: R.Tuple(
+                R.Tensor((16, 16, 3, 3), dtype="float32"),
+                R.Tensor((16, 3, 3, 3), dtype="float32"),
+            ) = (lv, lv3)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def transform_layout_IOHW_to_OIHW(
+            w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 
3), "float32")
+        ):
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+                with T.block("layout_transform"):
+                    o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w1[i, o, h, w])
+                    T.writes(out[o, i, h, w])
+                    out[o, i, h, w] = w1[i, o, h, w]
+
+        @R.function(pure=False)
+        def main_transform_params() -> (
+            R.Tuple(
+                R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 
3), dtype="float32")
+            )
+        ):
+            cls = Expected
+            gv: R.Object = R.call_packed("get_item_0", R.prim_value(1), 
sinfo_args=(R.Object,))
+            gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+                gv, R.Tensor((16, 16, 3, 3), dtype="float32")
+            )
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+            gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0), 
sinfo_args=(R.Object,))
+            gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
+                gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
+            )
+            lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+            )
+            lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, 
R.const(1, "float32"))
+            gv_1: R.Tuple(
+                R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 
3), dtype="float32")
+            ) = (lv, lv3)
+            return gv_1
+
+    after = LazyTransformParams(fget_item="get_item_0", fset_item=None)(Before)
+    tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+
+
+def test_extra_params():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def transform_layout_IOHW_to_OIHW(
+            w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 
3), "float32")
+        ):
+            for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+                with T.block("layout_transform"):
+                    o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w1[i, o, h, w])
+                    T.writes(out[o, i, h, w])
+                    out[o, i, h, w] = w1[i, o, h, w]
+
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(
+                R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
+            )
+        ) -> R.Tuple(
+            R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
+        ):
+            # we expect ToNonDataflow and RemovePurityTracking to be invoked 
first
+            R.func_attr({"relax.force_pure": True})
+            cls = Before
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+            lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+            )
+            lv3 = R.add(lv2, R.const(1, "float32"))
+            gv: R.Tuple(
+                R.Tensor((16, 16, 3, 3), dtype="float32"),
+                R.Tensor((16, 3, 3, 3), dtype="float32"),
+            ) = (lv, lv3)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def transform_layout_IOHW_to_OIHW(
+            w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 
3), "float32")
+        ):
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
+                with T.block("layout_transform"):
+                    o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w1[i, o, h, w])
+                    T.writes(out[o, i, h, w])
+                    out[o, i, h, w] = w1[i, o, h, w]
+
+        @R.function(pure=False)
+        def main_transform_params(loader: R.Object) -> R.Tuple:
+            cls = Expected
+            gv: R.Object = R.call_packed(
+                "get_item", loader, R.prim_value(1), sinfo_args=(R.Object,)
+            )
+            gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+                gv, R.Tensor((16, 16, 3, 3), dtype="float32")
+            )
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+            _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, 
sinfo_args=(R.Object,))
+            _1: R.Tuple = R.vm.kill_object(lv)
+            gv2: R.Object = R.call_packed(
+                "get_item", loader, R.prim_value(0), sinfo_args=(R.Object,)
+            )
+            gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
+                gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
+            )
+            lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+            )
+            _2: R.Tuple = R.vm.kill_object(lv1)
+            lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, 
R.const(1, "float32"))
+            _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, 
sinfo_args=(R.Object,))
+            gv_1: R.Tuple = R.tuple()
+            return gv_1
+
+    after = LazyTransformParams(
+        extra_get_item_params=[relax.Var("loader", relax.ObjectStructInfo())]
+    )(Before)
+    tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+
+
 def test_lazy_transform_params_with_symbolic_vars():
     @I.ir_module
     class Before:

Reply via email to