This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3bcfc293bc [Unity]Lazy transform param now only work on non-dataflow 
block (#14864)
3bcfc293bc is described below

commit 3bcfc293bc5634109b22fc7168908846f7d5e9d0
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed May 17 17:30:51 2023 -0400

    [Unity]Lazy transform param now only work on non-dataflow block (#14864)
    
    lazy transform params work on non-dataflow-block
---
 .../tvm/relax/transform/lazy_transform_params.py   |  4 +-
 .../relax/test_transform_lazy_transform_params.py  | 54 +++++++++-------------
 2 files changed, 25 insertions(+), 33 deletions(-)

diff --git a/python/tvm/relax/transform/lazy_transform_params.py 
b/python/tvm/relax/transform/lazy_transform_params.py
index d7b5945031..01deee8197 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -85,11 +85,11 @@ class LivenessAnalysis(PyExprVisitor):
         self.input_params = input_params
         self.var_liveness_end = {}
 
-    def visit_dataflow_block_(self, block: relax.DataflowBlock) -> None:
+    def visit_binding_block_(self, block: relax.BindingBlock) -> None:
         for binding in reversed(block.bindings):
             self.visit_binding(binding)
 
-    def visit_dataflow_var_(self, op: relax.DataflowVar) -> None:
+    def visit_var_(self, op: relax.Var) -> None:
         if op in self.input_params:
             self.last_appear_in_var_binding.append(op)
             self.input_params.remove(op)
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index bfc1d282ab..0fc08d5ef4 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -45,19 +45,17 @@ def test_lazy_transform_params():
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
         ):
             cls = Before
-            with R.dataflow():
-                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
-                lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
-                lv2 = R.call_tir(
-                    cls.transform_layout_IOHW_to_OIHW,
-                    (lv1,),
-                    out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
-                )
-                gv: R.Tuple(
-                    R.Tensor((16, 16, 3, 3), dtype="float32"),
-                    R.Tensor((16, 3, 3, 3), dtype="float32"),
-                ) = (lv, lv2)
-                R.output(gv)
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+            lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+            )
+            gv: R.Tuple(
+                R.Tensor((16, 16, 3, 3), dtype="float32"),
+                R.Tensor((16, 3, 3, 3), dtype="float32"),
+            ) = (lv, lv2)
             return gv
 
     @I.ir_module
@@ -77,24 +75,18 @@ def test_lazy_transform_params():
         @R.function
         def main_transform_params() -> R.Tuple(R.Object, R.Object):
             cls = Expected
-            with R.dataflow():
-                lv: R.Object = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=(R.Object,))
-                lv1: R.Object = R.call_packed(
-                    "set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)
-                )
-                lv2: R.Tuple = R.vm.kill_object(lv)
-                lv1_1: R.Object = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=(R.Object,))
-                lv3 = R.call_tir(
-                    cls.transform_layout_IOHW_to_OIHW,
-                    (lv1_1,),
-                    out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
-                )
-                lv4: R.Object = R.call_packed(
-                    "set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,)
-                )
-                lv5: R.Tuple = R.vm.kill_object(lv1_1)
-                gv: R.Tuple(R.Object, R.Object) = (lv1, lv4)
-                R.output(gv)
+            lv: R.Object = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=(R.Object,))
+            lv1: R.Object = R.call_packed("set_item", R.prim_value(0), lv, 
sinfo_args=(R.Object,))
+            lv2: R.Tuple = R.vm.kill_object(lv)
+            lv1_1: R.Object = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=(R.Object,))
+            lv3 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1_1,),
+                out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+            )
+            lv4: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, 
sinfo_args=(R.Object,))
+            lv5: R.Tuple = R.vm.kill_object(lv1_1)
+            gv: R.Tuple(R.Object, R.Object) = (lv1, lv4)
             return gv
 
     after = LazyTransformParams()(Before)

Reply via email to