This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eea6268c92 [TIR] Handle DeclBuffer in 
Inline/ComputeAt/ReverseComputeAt (#15038)
eea6268c92 is described below

commit eea6268c928a5d92b0c4b9c864c841edd0740c68
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Jun 10 01:03:18 2023 -0400

    [TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt (#15038)
    
    * [Util] Handle AllocateConst in MergeNest
    
    * [TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt
    
    Part of changes being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
    This commit allows TIR `compute_inline`, `compute_at`, and
    `reverse_compute_at` schedule primitives to preserve `DeclBuffer`
    nodes.
---
 src/tir/schedule/transform.cc                      | 28 +++---
 src/tir/transforms/ir_utils.cc                     |  5 ++
 .../unittest/test_tir_schedule_compute_at.py       | 99 ++++++++++++++++++++++
 3 files changed, 122 insertions(+), 10 deletions(-)

diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index baa7f44bbc..9c209658c3 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+#include "../transforms/ir_utils.h"
 #include "./utils.h"
 
 namespace tvm {
@@ -261,21 +262,28 @@ void LeafBlockRemovalPlan(const ScheduleState& self, 
const StmtSRef& leaf_block_
   if (const auto* block = sref->StmtAs<BlockNode>()) {
     auto body = block->body;
     // Peel off AllocateConst nodes at the beginning of the block body.
-    std::vector<const AllocateConstNode*> allocs;
-    while (const auto* alloc = body.as<AllocateConstNode>()) {
-      allocs.push_back(alloc);
-      body = alloc->body;
+    std::vector<Stmt> allocs;
+    while (true) {
+      if (auto opt = body.as<AllocateConst>()) {
+        auto alloc = opt.value();
+        body = alloc->body;
+        alloc.CopyOnWrite()->body = Evaluate(0);
+        allocs.push_back(alloc);
+      } else if (auto opt = body.as<DeclBuffer>()) {
+        auto decl_buffer = opt.value();
+        body = decl_buffer->body;
+        decl_buffer.CopyOnWrite()->body = Evaluate(0);
+        allocs.push_back(decl_buffer);
+      } else {
+        break;
+      }
     }
+
     if (const auto* seq = body.as<SeqStmtNode>()) {
       ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
       auto new_seq = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), 
GetRef<Stmt>(last_stmt));
       // Re-attach AllocateConst nodes
-      auto new_body = new_seq;
-      for (int i = 0; i < static_cast<int>(allocs.size()); ++i) {
-        auto alloc = allocs[allocs.size() - 1 - i];
-        new_body = AllocateConst(alloc->buffer_var, alloc->dtype, 
alloc->extents, alloc->data,
-                                 new_body, alloc->annotations, alloc->span);
-      }
+      auto new_body = MergeNest(allocs, new_seq);
       n->body = new_body;
       *src_stmt = GetRef<Stmt>(block);
       *tgt_stmt = Stmt(std::move(n));
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 604dbed325..43bf6b983e 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
       ICHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
+    } else if (const auto* alloc = s.as<AllocateConstNode>()) {
+      auto n = make_object<AllocateConstNode>(*alloc);
+      ICHECK(is_no_op(n->body));
+      n->body = body;
+      body = Stmt(n);
     } else if (const auto* decl_buffer = s.as<DeclBufferNode>()) {
       auto n = make_object<DeclBufferNode>(*decl_buffer);
       ICHECK(is_no_op(n->body));
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py 
b/tests/python/unittest/test_tir_schedule_compute_at.py
index 0623fb02f3..7efb4cccc0 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1672,5 +1672,104 @@ def test_reverse_compute_at_layout_trans():
     verify_trace_roundtrip(sch=sch, mod=before)
 
 
[email protected]("use_decl_buffer", [True, False])
[email protected]("use_reverse_compute_at", [True, False])
+def test_compute_at_allocate_const(use_decl_buffer, use_reverse_compute_at):
+    def apply_decl_buffer(*args, **kwargs):
+        if use_decl_buffer:
+            return T.decl_buffer(*args, **kwargs)
+        else:
+            return T.Buffer(*args, **kwargs)
+
+    @T.prim_func
+    def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], 
"float32")):
+        B = T.alloc_buffer([4])
+
+        offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", 
extents=[4])
+        offset = apply_decl_buffer([4], data=offset_ptr)
+        for i in range(4):
+            with T.block("compute_B"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = 10.0 * vi + offset[vi]
+
+        for i, j in T.grid(4, 256):
+            with T.block("compute_C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                C[vi, vj] = B[vi] + 100.0 * vj
+
+    @T.prim_func
+    def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], 
"float32")):
+        B = T.alloc_buffer([4])
+
+        offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", 
extents=[4])
+        offset = apply_decl_buffer([4], data=offset_ptr)
+        for i in range(4):
+            with T.block("compute_B"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = 10.0 * vi + offset[vi]
+
+            for j in range(256):
+                with T.block("compute_C"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    C[vi, vj] = B[vi] + 100.0 * vj
+
+    sch = tir.Schedule(before, debug_mask="all")
+    if use_reverse_compute_at:
+        block = sch.get_block("compute_C")
+        axis = sch.get_loops("compute_B")[0]
+        sch.reverse_compute_at(block, axis)
+    else:
+        block = sch.get_block("compute_B")
+        axis = sch.get_loops("compute_C")[0]
+        sch.compute_at(block, axis)
+
+    after = sch.mod["main"]
+
+    tvm.ir.assert_structural_equal(expected, after)
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
[email protected]("use_decl_buffer", [True, False])
+def test_compute_inline_allocate_const(use_decl_buffer):
+    def apply_decl_buffer(*args, **kwargs):
+        if use_decl_buffer:
+            return T.decl_buffer(*args, **kwargs)
+        else:
+            return T.Buffer(*args, **kwargs)
+
+    @T.prim_func
+    def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], 
"float32")):
+        B = T.alloc_buffer([4])
+
+        offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", 
extents=[4])
+        offset = apply_decl_buffer([4], data=offset_ptr)
+        for i in range(4):
+            with T.block("compute_B"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = 10.0 * vi + offset[vi]
+
+        for i, j in T.grid(4, 256):
+            with T.block("compute_C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                C[vi, vj] = B[vi] + 100.0 * vj
+
+    @T.prim_func
+    def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], 
"float32")):
+        offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", 
extents=[4])
+        offset = apply_decl_buffer([4], data=offset_ptr)
+        for i, j in T.grid(4, 256):
+            with T.block("compute_C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                C[vi, vj] = (10.0 * vi + offset[vi]) + 100.0 * vj
+
+    sch = tir.Schedule(before, debug_mask="all")
+    block = sch.get_block("compute_B")
+    sch.compute_inline(block)
+    after = sch.mod["main"]
+
+    tvm.ir.assert_structural_equal(expected, after)
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to