This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 119d6bef5d [Relax] Add support for func attr inheritance in
SplitLayoutRewritePreproc (#17682)
119d6bef5d is described below
commit 119d6bef5def3ff713bc109805071c9c0db288a5
Author: Honglin Zhu <[email protected]>
AuthorDate: Fri Feb 28 15:50:38 2025 +0800
[Relax] Add support for func attr inheritance in SplitLayoutRewritePreproc
(#17682)
* Add support for func attr inheritance in SplitLayoutRewritePreproc
fix bug in test
delete layout_free_buffers
* rebase latest main
* delete layout_free_buffers
---
.../transform/split_layout_rewrite_preproc.cc | 23 +++++-
.../test_transform_split_layout_rewrite_preproc.py | 84 ++++++++++++++++++++++
2 files changed, 105 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc
b/src/relax/transform/split_layout_rewrite_preproc.cc
index 5fee946c26..69b0313397 100644
--- a/src/relax/transform/split_layout_rewrite_preproc.cc
+++ b/src/relax/transform/split_layout_rewrite_preproc.cc
@@ -81,7 +81,16 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator {
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"root", body));
- PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map);
+ Map<String, ObjectRef> dict;
+ for (const auto& [key, original_value] : original_func_->attrs->dict) {
+ if (key == "global_symbol") {
+ dict.Set(key, Downcast<String>(original_value) + "_weight_prepack");
+ } else if (key != "layout_free_buffers") {
+ dict.Set(key, original_value);
+ }
+ }
+ DictAttrs attrs(dict);
+ PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map, attrs);
return RenewDefs(func);
}
@@ -118,7 +127,17 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator {
/*init=*/NullOpt,
/*alloc_buffers=*/alloc_buffers));
- PrimFunc func = PrimFunc(original_func_->params, body, VoidType(),
buffer_map);
+ Map<String, ObjectRef> dict;
+ for (const auto& [key, original_value] : original_func_->attrs->dict) {
+ if (key == "global_symbol") {
+ dict.Set(key, Downcast<String>(original_value) + "_prepacked");
+ } else if (key != "layout_free_buffers") {
+ dict.Set(key, original_value);
+ }
+ }
+ DictAttrs attrs(dict);
+ PrimFunc func = PrimFunc(original_func_->params, body, VoidType(),
buffer_map, attrs);
+
return RenewDefs(func);
}
diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
index e6b4c8ec4e..a5b74283fe 100644
--- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
+++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
@@ -216,5 +216,89 @@ def test_multiple_buffers():
tvm.ir.assert_structural_equal(mod, After)
+def test_attr_inheritance():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def tir_func(
+ X: T.Buffer((224, 224), "float32"),
+ W: T.Buffer((224, 224), "float32"),
+ Out: T.Buffer((224, 224), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [1], "tir.noalias":
T.bool(True)})
+ W_rewrite = T.alloc_buffer((4, 4, 56, 56))
+ for i, j in T.grid(224, 224):
+ with T.block("W_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc":
T.bool(True)})
+ W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+ for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+ with T.block("Out"):
+ vi = T.axis.spatial(224, i0 * 56 + i1)
+ vj = T.axis.spatial(224, j0 * 56 + j1)
+ Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi
% 56, vj % 56]
+
+ @R.function
+ def forward(
+ x: R.Tensor((224, 224), dtype="float32"),
+ w: R.Tensor((224, 224), dtype="float32"),
+ ) -> R.Tensor((224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def tir_func_prepacked(
+ X: T.Buffer((224, 224), "float32"),
+ W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ Out: T.Buffer((224, 224), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+ with T.block("Out"):
+ vi = T.axis.spatial(224, i0 * 56 + i1)
+ vj = T.axis.spatial(224, j0 * 56 + j1)
+ Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi
% 56, vj % 56]
+
+ @T.prim_func(private=True)
+ def tir_func_weight_prepack(
+ W: T.Buffer((224, 224), "float32"),
+ W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i, j in T.grid(224, 224):
+ with T.block("W_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+
+ @R.function
+ def forward(
+ x: R.Tensor((224, 224), dtype="float32"),
+ w: R.Tensor((224, 224), dtype="float32"),
+ ) -> R.Tensor((224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = After
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4,
4, 56, 56), "float32")
+ )
+ lv1 = R.call_tir(
+ cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224,
224), "float32")
+ )
+ gv: R.Tensor((224, 224), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ mod = relax.transform.SplitLayoutRewritePreproc()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
if __name__ == "__main__":
tvm.testing.main()