This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dcf1b1a90e [TIR] Handle DeclBuffer in CacheReadWrite schedule 
primitive (#15037)
dcf1b1a90e is described below

commit dcf1b1a90e3d96d00368b19a99adefe7c0ef5f68
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 18:44:04 2023 -0400

    [TIR] Handle DeclBuffer in CacheReadWrite schedule primitive (#15037)
    
    * [Util] Handle AllocateConst in MergeNest
    
    * [TIR] Handle DeclBuffer in CacheReadWrite schedule primitive
    
    Part of changes being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
    This commit allows TIR `cache_read` and `cache_write` schedule
    primitives to preserve `DeclBuffer` nodes.
---
 src/tir/schedule/primitive/cache_read_write.cc     |  59 ++++++---
 .../unittest/test_tir_schedule_cache_read_write.py | 134 +++++++++++----------
 2 files changed, 113 insertions(+), 80 deletions(-)

diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index 74a960eefb..0a4cf2329e 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -20,6 +20,7 @@
 #include <unordered_set>
 
 #include "../../analysis/var_use_def_analysis.h"
+#include "../../transforms/ir_utils.h"
 #include "../utils.h"
 
 namespace tvm {
@@ -425,21 +426,43 @@ bool CalculateAffineFlag(const ScheduleState& self, const 
StmtSRef& block_sref)
  * \return A SeqStmt, the result after insertion
  */
 Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
-  if (const auto* alloc = stmt.as<AllocateConstNode>()) {
-    auto seq_stmt = InsertCacheStage(alloc->body, pos, stage);
-    return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, 
alloc->data, seq_stmt,
-                         alloc->annotations, alloc->span);
-  }
-  if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
-    ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
-    result->seq.insert(result->seq.begin() + pos, stage);
-    return SeqStmt(result);
+  std::vector<Stmt> nest;
+  Stmt body = stmt;
+  while (true) {
+    if (auto opt = body.as<AllocateConst>()) {
+      auto alloc = opt.value();
+      body = alloc->body;
+      alloc.CopyOnWrite()->body = Evaluate(0);
+      nest.push_back(alloc);
+    } else if (auto opt = body.as<DeclBuffer>()) {
+      auto decl_buffer = opt.value();
+      body = decl_buffer->body;
+      decl_buffer.CopyOnWrite()->body = Evaluate(0);
+      nest.push_back(decl_buffer);
+    } else {
+      break;
+    }
   }
-  if (pos == 0) {
-    return SeqStmt({stage, stmt});
+
+  if (const auto* seq_stmt = body.as<SeqStmtNode>()) {
+    Array<Stmt> seq = seq_stmt->seq;
+    ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << " 
into sequence of length "
+                               << seq.size();
+    seq.insert(seq.begin() + pos, stage);
+    body = SeqStmt(seq);
+  } else if (pos == 0) {
+    body = SeqStmt({stage, stmt});
+  } else if (pos == 1) {
+    body = SeqStmt({stmt, stage});
+  } else {
+    LOG(FATAL) << "Cannot insert at position " << pos
+               << ".  When inserting adjacent to non-SeqStmt, "
+               << "only positions 0 and 1 are valid.";
   }
-  ICHECK_EQ(pos, 1);
-  return SeqStmt({stmt, stage});
+
+  body = MergeNest(nest, body);
+
+  return body;
 }
 
 /*!
@@ -550,8 +573,14 @@ class CacheLocDetector : public StmtVisitor {
 
       auto block_body = scope_sref->StmtAs<BlockNode>()->body;
       // Find the SeqStmtNode within (potentially nested) AllocateConstNodes
-      while (block_body->IsInstance<AllocateConstNode>()) {
-        block_body = block_body.as<AllocateConstNode>()->body;
+      while (true) {
+        if (auto* ptr = block_body.as<AllocateConstNode>()) {
+          block_body = ptr->body;
+        } else if (auto* ptr = block_body.as<DeclBufferNode>()) {
+          block_body = ptr->body;
+        } else {
+          break;
+        }
       }
       const auto* body = block_body.as<SeqStmtNode>();
       info->loc_pos = body == nullptr ? 1 : body->size();
diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py 
b/tests/python/unittest/test_tir_schedule_cache_read_write.py
index 454557a2bd..95955646c6 100644
--- a/tests/python/unittest/test_tir_schedule_cache_read_write.py
+++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py
@@ -1172,67 +1172,6 @@ def block_predicate_cache_write_output_buf() -> None:
 use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, 
"block_name": True})
 
 
[email protected]_func
-def cache_write_allocate_const(
-    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")
-):
-    B = T.alloc_buffer([128, 128], dtype="float32")
-    const = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
-    const_1 = T.Buffer([8], dtype="float32", data=const)
-    const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
-    const_2 = T.Buffer([8], dtype="float32", data=const)
-    for i, j in T.grid(128, 128):
-        for x in range(8):
-            with T.block("B"):
-                vi, vj, vx = T.axis.remap("SSS", [i, j, x])
-                T.reads(A[vi, vj], const_1[vx], const_2[vx])
-                T.writes(B[vi, vj])
-                B[vi, vj] = A[vi, vj] * const_1[vx] + const_2[vx]
-    for i, j in T.grid(128, 128):
-        with T.block("C"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(B[vi, vj])
-            T.writes(C[vi, vj])
-            C[vi, vj] = B[vi, vj] + 1.0
-
-
[email protected]_func
-def cache_write_allocate_const_output(
-    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")
-):
-    B = T.alloc_buffer([128, 128], dtype="float32")
-    A_global = T.alloc_buffer([128, 128], dtype="float32")
-    C_global = T.alloc_buffer([128, 128], dtype="float16")
-    const_2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
-    const_1 = T.Buffer([8], dtype="float32", data=const_2)
-    const_2_1 = T.Buffer([8], dtype="float32", data=const_2)
-    const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
-    for ax0, ax1 in T.grid(128, 128):
-        with T.block("A_global"):
-            v0, v1 = T.axis.remap("SS", [ax0, ax1])
-            T.reads(A[v0, v1])
-            T.writes(A_global[v0, v1])
-            A_global[v0, v1] = A[v0, v1]
-    for i, j, x in T.grid(128, 128, 8):
-        with T.block("B"):
-            vi, vj, vx = T.axis.remap("SSS", [i, j, x])
-            T.reads(A_global[vi, vj], const_1[vx], const_2_1[vx])
-            T.writes(B[vi, vj])
-            B[vi, vj] = A_global[vi, vj] * const_1[vx] + const_2_1[vx]
-    for i, j in T.grid(128, 128):
-        with T.block("C"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(B[vi, vj])
-            T.writes(C_global[vi, vj])
-            C_global[vi, vj] = B[vi, vj] + T.float32(1)
-    for ax0, ax1 in T.grid(128, 128):
-        with T.block("C_global"):
-            v0, v1 = T.axis.remap("SS", [ax0, ax1])
-            T.reads(C_global[v0, v1])
-            T.writes(C[v0, v1])
-            C[v0, v1] = C_global[v0, v1]
-
-
 def test_cache_read_elementwise(use_block_name):
     sch = tir.Schedule(elementwise, debug_mask="all")
     block_b = sch.get_block("B")
@@ -1493,14 +1432,79 @@ def 
test_cache_write_fail_invalid_storage_scope(use_block_name):
         sch.cache_write(block_b, 0, "test_scope")
 
 
-def test_cache_write_allocate_const():
-    sch = tir.Schedule(cache_write_allocate_const)
[email protected]("use_decl_buffer", [True, False])
+def test_cache_write_allocate_const(use_decl_buffer):
+    def apply_decl_buffer(*args, **kwargs):
+        if use_decl_buffer:
+            return T.decl_buffer(*args, **kwargs)
+        else:
+            return T.Buffer(*args, **kwargs)
+
+    @T.prim_func
+    def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), 
"float16")):
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
+        const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
+        const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
+        const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
+        for i, j in T.grid(128, 128):
+            for x in range(8):
+                with T.block("B"):
+                    vi, vj, vx = T.axis.remap("SSS", [i, j, x])
+                    T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx]
+        for i, j in T.grid(128, 128):
+            with T.block("C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(B[vi, vj])
+                T.writes(C[vi, vj])
+                C[vi, vj] = B[vi, vj] + 1.0
+
+    @T.prim_func
+    def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), 
"float16")):
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        A_global = T.alloc_buffer([128, 128], dtype="float32")
+        C_global = T.alloc_buffer([128, 128], dtype="float16")
+        const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
+        const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
+        const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
+        const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
+        for ax0, ax1 in T.grid(128, 128):
+            with T.block("A_global"):
+                v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(A[v0, v1])
+                T.writes(A_global[v0, v1])
+                A_global[v0, v1] = A[v0, v1]
+        for i, j, x in T.grid(128, 128, 8):
+            with T.block("B"):
+                vi, vj, vx = T.axis.remap("SSS", [i, j, x])
+                T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx])
+                T.writes(B[vi, vj])
+                B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx]
+        for i, j in T.grid(128, 128):
+            with T.block("C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(B[vi, vj])
+                T.writes(C_global[vi, vj])
+                C_global[vi, vj] = B[vi, vj] + T.float32(1)
+        for ax0, ax1 in T.grid(128, 128):
+            with T.block("C_global"):
+                v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(C_global[v0, v1])
+                T.writes(C[v0, v1])
+                C[v0, v1] = C_global[v0, v1]
+
+    sch = tir.Schedule(before)
     block_b = sch.get_block("B")
     block_c = sch.get_block("C")
     sch.cache_read(block_b, 0, "global")
     sch.cache_write(block_c, 0, "global")
-    tvm.ir.assert_structural_equal(cache_write_allocate_const_output, 
sch.mod["main"])
-    verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const)
+
+    after = sch.mod["main"]
+
+    tvm.ir.assert_structural_equal(expected, after)
+    verify_trace_roundtrip(sch=sch, mod=before)
 
 
 def test_reindex_cache_read():

Reply via email to