This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0f38ef2d6e [Bugfix][TIR] Fix cache_read update buffer region (#16742)
0f38ef2d6e is described below
commit 0f38ef2d6e6ecb7d1b8e164582f417b15b8f4e9a
Author: albert qing <[email protected]>
AuthorDate: Wed Mar 20 22:52:36 2024 +0800
[Bugfix][TIR] Fix cache_read update buffer region (#16742)
Prior to this commit, cache_read primitive may not update the block
reads buffer region properly when there is a nested buffer access.
This commit fix this bug and add a cache_read unit test.
Co-authored-by: qsqqsqqsq-intellif <[email protected]>
---
src/tir/schedule/primitive/cache_read_write.cc | 7 ++--
.../test_tir_schedule_cache_read_write.py | 41 ++++++++++++++++++++++
2 files changed, 45 insertions(+), 3 deletions(-)
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index a687624bac..eac5500a19 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -958,9 +958,10 @@ class CacheReadRewriter : public StmtExprMutator {
// Otherwise, update read regions and match_buffers
// Only make this change if the block is one of the specified consumers.
if (is_consumer) {
- Array<BufferRegion> reads = update_access_regions(block->reads);
- Array<MatchBufferRegion> match_buffers =
update_match_buffers(block->match_buffers);
- if (!reads.same_as(block->reads) ||
!match_buffers.same_as(block->match_buffers)) {
+ // Use the updated block stmt
+ Array<BufferRegion> reads = update_access_regions(stmt->reads);
+ Array<MatchBufferRegion> match_buffers =
update_match_buffers(stmt->match_buffers);
+ if (!reads.same_as(stmt->reads) ||
!match_buffers.same_as(stmt->match_buffers)) {
ObjectPtr<BlockNode> n =
make_object<BlockNode>(*stmt.as<BlockNode>());
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
index 345c7368ce..1fda0f4321 100644
--- a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
+++ b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
@@ -488,6 +488,19 @@ def cache_read_nested_seq_target(
C[vi, vj] = A_global[vi, vj] * T.float32(2)
[email protected]_func
+def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
+ A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
+ B = T.match_buffer(var_B, T.int64(1), dtype="int32")
+ C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
+ with T.block("C"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[B[v_ax0], v_ax1], B[v_ax0])
+ T.writes(C[v_ax0, v_ax1])
+ C[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
+
+
########## Expected function after cache_read ##########
@@ -831,6 +844,26 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32"))
-> None:
data_io[v0] = data_io_global_1[v0]
[email protected]_func
+def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C:
T.handle):
+ A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
+ B = T.match_buffer(var_B, T.int64(1), dtype="int32")
+ C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
+ B_global = T.alloc_buffer((T.int64(1),), "int32")
+ for ax0 in range(T.int64(1)):
+ with T.block("B_global"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ T.reads(B[v0])
+ T.writes(B_global[v0])
+ B_global[v0] = B[v0]
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
+ with T.block("C"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0])
+ T.writes(C[v_ax0, v_ax1])
+ C[v_ax0, v_ax1] = A[B_global[v_ax0], v_ax1]
+
+
########## Expected function after cache_write ##########
@@ -1358,6 +1391,14 @@ def test_cache_read_non_int32_shape(use_block_name):
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)
+def test_cache_read_nested_buffer_access(use_block_name):
+ sch = tir.Schedule(nested_buffer_access, debug_mask="all")
+ block_c = "C" if use_block_name else sch.get_block("C")
+ sch.cache_read(block_c, 1, "global")
+
assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access,
sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=nested_buffer_access)
+
+
def test_cache_read_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")