This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f38ef2d6e [Bugfix][TIR] Fix cache_read update buffer region (#16742)
0f38ef2d6e is described below

commit 0f38ef2d6e6ecb7d1b8e164582f417b15b8f4e9a
Author: albert qing <[email protected]>
AuthorDate: Wed Mar 20 22:52:36 2024 +0800

    [Bugfix][TIR] Fix cache_read update buffer region (#16742)
    
    Prior to this commit, cache_read primitive may not update the block
    reads buffer region properly when there is a nested buffer access.
    
    This commit fix this bug and add a cache_read unit test.
    
    Co-authored-by: qsqqsqqsq-intellif <[email protected]>
---
 src/tir/schedule/primitive/cache_read_write.cc     |  7 ++--
 .../test_tir_schedule_cache_read_write.py          | 41 ++++++++++++++++++++++
 2 files changed, 45 insertions(+), 3 deletions(-)

diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index a687624bac..eac5500a19 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -958,9 +958,10 @@ class CacheReadRewriter : public StmtExprMutator {
       // Otherwise, update read regions and match_buffers
       // Only make this change if the block is one of the specified consumers.
       if (is_consumer) {
-        Array<BufferRegion> reads = update_access_regions(block->reads);
-        Array<MatchBufferRegion> match_buffers = 
update_match_buffers(block->match_buffers);
-        if (!reads.same_as(block->reads) || 
!match_buffers.same_as(block->match_buffers)) {
+        // Use the updated block stmt
+        Array<BufferRegion> reads = update_access_regions(stmt->reads);
+        Array<MatchBufferRegion> match_buffers = 
update_match_buffers(stmt->match_buffers);
+        if (!reads.same_as(stmt->reads) || 
!match_buffers.same_as(stmt->match_buffers)) {
           ObjectPtr<BlockNode> n = 
make_object<BlockNode>(*stmt.as<BlockNode>());
           n->reads = std::move(reads);
           n->match_buffers = std::move(match_buffers);
diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py 
b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
index 345c7368ce..1fda0f4321 100644
--- a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
+++ b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
@@ -488,6 +488,19 @@ def cache_read_nested_seq_target(
             C[vi, vj] = A_global[vi, vj] * T.float32(2)
 
 
[email protected]_func
+def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
+    A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
+    B = T.match_buffer(var_B, T.int64(1), dtype="int32")
+    C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
+    for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
+        with T.block("C"):
+            v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+            T.reads(A[B[v_ax0], v_ax1], B[v_ax0])
+            T.writes(C[v_ax0, v_ax1])
+            C[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
+
+
 ########## Expected function after cache_read ##########
 
 
@@ -831,6 +844,26 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) 
-> None:
             data_io[v0] = data_io_global_1[v0]
 
 
[email protected]_func
+def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: 
T.handle):
+    A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
+    B = T.match_buffer(var_B, T.int64(1), dtype="int32")
+    C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
+    B_global = T.alloc_buffer((T.int64(1),), "int32")
+    for ax0 in range(T.int64(1)):
+        with T.block("B_global"):
+            v0 = T.axis.spatial(T.int64(1), ax0)
+            T.reads(B[v0])
+            T.writes(B_global[v0])
+            B_global[v0] = B[v0]
+    for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
+        with T.block("C"):
+            v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+            T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0])
+            T.writes(C[v_ax0, v_ax1])
+            C[v_ax0, v_ax1] = A[B_global[v_ax0], v_ax1]
+
+
 ########## Expected function after cache_write ##########
 
 
@@ -1358,6 +1391,14 @@ def test_cache_read_non_int32_shape(use_block_name):
     verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)
 
 
+def test_cache_read_nested_buffer_access(use_block_name):
+    sch = tir.Schedule(nested_buffer_access, debug_mask="all")
+    block_c = "C" if use_block_name else sch.get_block("C")
+    sch.cache_read(block_c, 1, "global")
+    
assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access, 
sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=nested_buffer_access)
+
+
 def test_cache_read_fail_multi_producer(use_block_name):
     sch = tir.Schedule(func_multi_producer, debug_mask="all")
     block_b = "B" if use_block_name else sch.get_block("B")

Reply via email to