This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3bcfc293bc [Unity]Lazy transform param now only work on non-dataflow
block (#14864)
3bcfc293bc is described below
commit 3bcfc293bc5634109b22fc7168908846f7d5e9d0
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed May 17 17:30:51 2023 -0400
[Unity]Lazy transform param now only work on non-dataflow block (#14864)
lazy transform params work on non-dataflow-block
---
.../tvm/relax/transform/lazy_transform_params.py | 4 +-
.../relax/test_transform_lazy_transform_params.py | 54 +++++++++-------------
2 files changed, 25 insertions(+), 33 deletions(-)
diff --git a/python/tvm/relax/transform/lazy_transform_params.py
b/python/tvm/relax/transform/lazy_transform_params.py
index d7b5945031..01deee8197 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -85,11 +85,11 @@ class LivenessAnalysis(PyExprVisitor):
self.input_params = input_params
self.var_liveness_end = {}
- def visit_dataflow_block_(self, block: relax.DataflowBlock) -> None:
+ def visit_binding_block_(self, block: relax.BindingBlock) -> None:
for binding in reversed(block.bindings):
self.visit_binding(binding)
- def visit_dataflow_var_(self, op: relax.DataflowVar) -> None:
+ def visit_var_(self, op: relax.Var) -> None:
if op in self.input_params:
self.last_appear_in_var_binding.append(op)
self.input_params.remove(op)
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index bfc1d282ab..0fc08d5ef4 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -45,19 +45,17 @@ def test_lazy_transform_params():
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
):
cls = Before
- with R.dataflow():
- lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
- lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
- lv2 = R.call_tir(
- cls.transform_layout_IOHW_to_OIHW,
- (lv1,),
- out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
- )
- gv: R.Tuple(
- R.Tensor((16, 16, 3, 3), dtype="float32"),
- R.Tensor((16, 3, 3, 3), dtype="float32"),
- ) = (lv, lv2)
- R.output(gv)
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((16, 3, 3, 3), dtype="float32"),
+ ) = (lv, lv2)
return gv
@I.ir_module
@@ -77,24 +75,18 @@ def test_lazy_transform_params():
@R.function
def main_transform_params() -> R.Tuple(R.Object, R.Object):
cls = Expected
- with R.dataflow():
- lv: R.Object = R.call_packed("get_item", R.prim_value(1),
sinfo_args=(R.Object,))
- lv1: R.Object = R.call_packed(
- "set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)
- )
- lv2: R.Tuple = R.vm.kill_object(lv)
- lv1_1: R.Object = R.call_packed("get_item", R.prim_value(0),
sinfo_args=(R.Object,))
- lv3 = R.call_tir(
- cls.transform_layout_IOHW_to_OIHW,
- (lv1_1,),
- out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
- )
- lv4: R.Object = R.call_packed(
- "set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,)
- )
- lv5: R.Tuple = R.vm.kill_object(lv1_1)
- gv: R.Tuple(R.Object, R.Object) = (lv1, lv4)
- R.output(gv)
+ lv: R.Object = R.call_packed("get_item", R.prim_value(1),
sinfo_args=(R.Object,))
+ lv1: R.Object = R.call_packed("set_item", R.prim_value(0), lv,
sinfo_args=(R.Object,))
+ lv2: R.Tuple = R.vm.kill_object(lv)
+ lv1_1: R.Object = R.call_packed("get_item", R.prim_value(0),
sinfo_args=(R.Object,))
+ lv3 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1_1,),
+ out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
+ )
+ lv4: R.Object = R.call_packed("set_item", R.prim_value(1), lv3,
sinfo_args=(R.Object,))
+ lv5: R.Tuple = R.vm.kill_object(lv1_1)
+ gv: R.Tuple(R.Object, R.Object) = (lv1, lv4)
return gv
after = LazyTransformParams()(Before)