This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c0b8953a1f [Unity] Dynamic-shape param support in LazyTransformParams
(#15713)
c0b8953a1f is described below
commit c0b8953a1f4b79f8290217a4732365d0baa4c419
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Sep 9 13:46:27 2023 -0400
[Unity] Dynamic-shape param support in LazyTransformParams (#15713)
This PR brings the support for dynamic-shape parameters to pass
LazyTransformParams.
Prior to this PR, the symbolic variables in the dynamic-shape parameters
are not properly popped out. This PR uses MatchCast to make sure the
symbolic variables are always popped out and thereby support the
dynamic-shape parameters.
This PR also fixes a previvously failed test.
---
.../tvm/relax/transform/lazy_transform_params.py | 13 ++-
.../relax/test_transform_lazy_transform_params.py | 120 +++++++++++++++++++--
2 files changed, 118 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/transform/lazy_transform_params.py
b/python/tvm/relax/transform/lazy_transform_params.py
index 69f724067c..90e56c8dbb 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -164,12 +164,15 @@ class LazyTransformParamsMutator(PyExprMutator):
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
if tuple_get_item.tuple_value == self.input_tuple_param:
- return relax.Call(
- relax.ExternFunc("get_item"),
- [relax.PrimValue(tuple_get_item.index)],
- None,
- [relax.ObjectStructInfo()],
+ get_item_result = self.builder_.emit(
+ relax.Call(
+ relax.ExternFunc("get_item"),
+ [relax.PrimValue(tuple_get_item.index)],
+ None,
+ [relax.ObjectStructInfo()],
+ )
)
+ return self.builder_.match_cast(get_item_result, op.struct_info)
else:
return tuple_get_item
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index 478580ff8d..94f2181daf 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -79,15 +79,23 @@ def test_lazy_transform_params():
R.func_attr({"relax.force_pure": True})
cls = Expected
lv: R.Object = R.call_packed("get_item", R.prim_value(1),
sinfo_args=(R.Object,))
- _: R.Object = R.call_packed("set_item", R.prim_value(0), lv,
sinfo_args=(R.Object,))
- _1: R.Tuple = R.vm.kill_object(lv)
+ gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+ lv, R.Tensor((16, 16, 3, 3), dtype="float32")
+ )
+ lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+ _: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m,
sinfo_args=(R.Object,))
+ _1: R.Tuple = R.vm.kill_object(lv_m)
lv1: R.Object = R.call_packed("get_item", R.prim_value(0),
sinfo_args=(R.Object,))
+ gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
+ lv1, R.Tensor((3, 16, 3, 3), dtype="float32")
+ )
+ lv1_m: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
- (lv1,),
+ (lv1_m,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
- _2: R.Tuple = R.vm.kill_object(lv1)
+ _2: R.Tuple = R.vm.kill_object(lv1_m)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2,
sinfo_args=(R.Object,))
gv: R.Tuple = R.tuple()
return gv
@@ -146,13 +154,17 @@ def test_lazy_transform_params_with_symbolic_vars():
slice_index = T.int64()
param = R.call_packed("get_item", R.prim_value(0),
sinfo_args=(R.Object,))
+ gv: R.Tensor((16, 16), dtype="float32") = R.match_cast(
+ param, R.Tensor((16, 16), dtype="float32")
+ )
+ param_m: R.Tensor((16, 16), dtype="float32") = gv
transformed = R.call_tir(
cls.slice_buffer,
- (param,),
+ (param_m,),
tir_vars=[slice_index],
out_sinfo=R.Tensor((16,), dtype="float32"),
)
- unused_1_ = R.vm.kill_object(param)
+ unused_1_ = R.vm.kill_object(param_m)
unused_2_ = R.call_packed(
"set_item", R.prim_value(0), transformed,
sinfo_args=(R.Object,)
)
@@ -175,14 +187,100 @@ def test_lazy_transform_params_with_symbolic_vars():
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
-# TODO(tvm-team): remove once regression get fixed
[email protected]("temp disable, minor regression on read/write region in zero
dim buffer")
+def test_param_shape_symbolic():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
+ ic = T.int32()
+ w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
+ out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
+ for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(
+ R.Tensor((3, "ic", 3, 3), dtype="float32"),
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3,
3), dtype="float32")
+ ):
+ ic = T.int64()
+ # we expect ToNonDataflow and RemovePurityTracking to be invoked
first
+ R.func_attr({"relax.force_pure": True})
+ cls = Before
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = params[0]
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
+ )
+ gv: R.Tuple(
+ R.Tensor((16, 16, 3, 3), dtype="float32"),
+ R.Tensor((ic, 3, 3, 3), dtype="float32"),
+ ) = (lv, lv2)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
+ ic = T.int32()
+ w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
+ out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
+ for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
+ with T.block("layout_transform"):
+ o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(w1[i, o, h, w])
+ T.writes(out[o, i, h, w])
+ out[o, i, h, w] = w1[i, o, h, w]
+
+ @R.function
+ def main_transform_params() -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ ic = T.int64()
+ cls = Expected
+ gv: R.Object = R.call_packed("get_item", R.prim_value(1),
sinfo_args=(R.Object,))
+ gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+ gv, R.Tensor((16, 16, 3, 3), dtype="float32")
+ )
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+ _: R.Object = R.call_packed("set_item", R.prim_value(0), lv,
sinfo_args=(R.Object,))
+ _1: R.Tuple = R.vm.kill_object(lv)
+ gv2: R.Object = R.call_packed("get_item", R.prim_value(0),
sinfo_args=(R.Object,))
+ gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast(
+ gv2, R.Tensor((3, ic, 3, 3), dtype="float32")
+ )
+ lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = gv3
+ lv2 = R.call_tir(
+ cls.transform_layout_IOHW_to_OIHW,
+ (lv1,),
+ out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
+ )
+ _2: R.Tuple = R.vm.kill_object(lv1)
+ _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2,
sinfo_args=(R.Object,))
+ gv4: R.Tuple = R.tuple()
+ return gv4
+
+ after = LazyTransformParams()(Before)
+ tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+
+
def test_output_with_use_site():
@I.ir_module
class Module:
@T.prim_func
def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")):
with T.block("block"):
+ T.reads(x[()])
+ T.writes(y[()])
y[()] = x[()]
@R.function
@@ -212,8 +310,10 @@ def test_output_with_use_site():
R.func_attr({"relax.force_pure": True})
cls = Expected
x: R.Object = R.call_packed("get_item", R.prim_value(0),
sinfo_args=(R.Object,))
- y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((),
dtype="float32"))
- _: R.Tuple = R.vm.kill_object(x)
+ gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((),
dtype="float32"))
+ x_m: R.Tensor((), dtype="float32") = gv
+ y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((),
dtype="float32"))
+ _: R.Tuple = R.vm.kill_object(x_m)
z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((),
dtype="float32"))
_1: R.Object = R.call_packed("set_item", R.prim_value(0), y,
sinfo_args=(R.Object,))
_2: R.Object = R.call_packed("set_item", R.prim_value(1), z,
sinfo_args=(R.Object,))