This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 119d6bef5d [Relax] Add support for func attr inheritance in 
SplitLayoutRewritePreproc (#17682)
119d6bef5d is described below

commit 119d6bef5def3ff713bc109805071c9c0db288a5
Author: Honglin Zhu <[email protected]>
AuthorDate: Fri Feb 28 15:50:38 2025 +0800

    [Relax] Add support for func attr inheritance in SplitLayoutRewritePreproc 
(#17682)
    
    * Add support for func attr inheritance in SplitLayoutRewritePreproc
    
    fix bug in test
    
    delete layout_free_buffers
    
    * rebase latest main
    
    * delete layout_free_buffers
---
 .../transform/split_layout_rewrite_preproc.cc      | 23 +++++-
 .../test_transform_split_layout_rewrite_preproc.py | 84 ++++++++++++++++++++++
 2 files changed, 105 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc 
b/src/relax/transform/split_layout_rewrite_preproc.cc
index 5fee946c26..69b0313397 100644
--- a/src/relax/transform/split_layout_rewrite_preproc.cc
+++ b/src/relax/transform/split_layout_rewrite_preproc.cc
@@ -81,7 +81,16 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator {
         Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
               /*name_hint=*/"root", body));
 
-    PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map);
+    Map<String, ObjectRef> dict;
+    for (const auto& [key, original_value] : original_func_->attrs->dict) {
+      if (key == "global_symbol") {
+        dict.Set(key, Downcast<String>(original_value) + "_weight_prepack");
+      } else if (key != "layout_free_buffers") {
+        dict.Set(key, original_value);
+      }
+    }
+    DictAttrs attrs(dict);
+    PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map, attrs);
 
     return RenewDefs(func);
   }
@@ -118,7 +127,17 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator {
               /*init=*/NullOpt,
               /*alloc_buffers=*/alloc_buffers));
 
-    PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), 
buffer_map);
+    Map<String, ObjectRef> dict;
+    for (const auto& [key, original_value] : original_func_->attrs->dict) {
+      if (key == "global_symbol") {
+        dict.Set(key, Downcast<String>(original_value) + "_prepacked");
+      } else if (key != "layout_free_buffers") {
+        dict.Set(key, original_value);
+      }
+    }
+    DictAttrs attrs(dict);
+    PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), 
buffer_map, attrs);
+
     return RenewDefs(func);
   }
 
diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py 
b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
index e6b4c8ec4e..a5b74283fe 100644
--- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
+++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
@@ -216,5 +216,89 @@ def test_multiple_buffers():
     tvm.ir.assert_structural_equal(mod, After)
 
 
+def test_attr_inheritance():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def tir_func(
+            X: T.Buffer((224, 224), "float32"),
+            W: T.Buffer((224, 224), "float32"),
+            Out: T.Buffer((224, 224), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [1], "tir.noalias": 
T.bool(True)})
+            W_rewrite = T.alloc_buffer((4, 4, 56, 56))
+            for i, j in T.grid(224, 224):
+                with T.block("W_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.block_attr({"meta_schedule.layout_rewrite_preproc": 
T.bool(True)})
+                    W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+            for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+                with T.block("Out"):
+                    vi = T.axis.spatial(224, i0 * 56 + i1)
+                    vj = T.axis.spatial(224, j0 * 56 + j1)
+                    Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi 
% 56, vj % 56]
+
+        @R.function
+        def forward(
+            x: R.Tensor((224, 224), dtype="float32"),
+            w: R.Tensor((224, 224), dtype="float32"),
+        ) -> R.Tensor((224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class After:
+        @T.prim_func(private=True)
+        def tir_func_prepacked(
+            X: T.Buffer((224, 224), "float32"),
+            W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+            Out: T.Buffer((224, 224), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+                with T.block("Out"):
+                    vi = T.axis.spatial(224, i0 * 56 + i1)
+                    vj = T.axis.spatial(224, j0 * 56 + j1)
+                    Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi 
% 56, vj % 56]
+
+        @T.prim_func(private=True)
+        def tir_func_weight_prepack(
+            W: T.Buffer((224, 224), "float32"),
+            W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i, j in T.grid(224, 224):
+                with T.block("W_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+
+        @R.function
+        def forward(
+            x: R.Tensor((224, 224), dtype="float32"),
+            w: R.Tensor((224, 224), dtype="float32"),
+        ) -> R.Tensor((224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = After
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 
4, 56, 56), "float32")
+                )
+                lv1 = R.call_tir(
+                    cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 
224), "float32")
+                )
+                gv: R.Tensor((224, 224), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    mod = relax.transform.SplitLayoutRewritePreproc()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to