This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 657880cdce [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite
schedule primitive (#16660)
657880cdce is described below
commit 657880cdcedd7e41e911c583a8e93b3053a6ad27
Author: Andrei Hutu <[email protected]>
AuthorDate: Thu Mar 7 10:49:03 2024 +0000
[Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule
primitive (#16660)
* [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule
primitive
When inserting a `cache_read` / `cache_write` stage, the
`tir.AllocateConst` statement would be duplicated if its body was not a
`tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This
happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always
re-attached to the statement's body after the `cache_read` / `cache_write`
stage is inserted in it, but the stage was being appended to the whole
statement (which already contains the `tir.AllocateConst`) [...]
This commit also adds a test where the first `cache_read` stage is inserted
into a statement whose body is a `tir.For`, while the second stage is added to
a body that is `tir.SeqStmt` to check for regressions.
* Improve PrimFunc readability
* Remove redundant `T.reads()`
---
src/tir/schedule/primitive/cache_read_write.cc | 4 +--
.../test_tir_schedule_cache_read_write.py | 40 ++++++++++++++++++++++
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index 3fbdf856b5..a687624bac 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -483,9 +483,9 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const
Stmt& stage) {
seq.insert(seq.begin() + pos, stage);
body = SeqStmt(seq);
} else if (pos == 0) {
- body = SeqStmt({stage, stmt});
+ body = SeqStmt({stage, body});
} else if (pos == 1) {
- body = SeqStmt({stmt, stage});
+ body = SeqStmt({body, stage});
} else {
LOG(FATAL) << "Cannot insert at position " << pos
<< ". When inserting adjacent to non-SeqStmt, "
diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
index 840a18ae6a..345c7368ce 100644
--- a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
+++ b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
@@ -1379,6 +1379,46 @@ def
test_cache_read_fail_invalid_storage_scope(use_block_name):
sch.cache_read(block_b, 0, "test_scope")
+def test_cache_read_allocate_const():
+ @T.prim_func
+ def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
+ B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
+ B_buf = T.decl_buffer((8), dtype="float32", data=B)
+ for i in range(8):
+ with T.block("C"):
+ vi = T.axis.spatial(8, i)
+ C[vi] = A[vi] + B_buf[vi]
+
+ @T.prim_func
+ def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
+ B_buf_global = T.alloc_buffer((8), dtype="float32")
+ A_global = T.alloc_buffer((8), dtype="float32")
+ B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
"float32", [8])
+ B_buf = T.decl_buffer((8), data=B)
+ for ax0 in range(8):
+ with T.block("A_global"):
+ v0 = T.axis.spatial(8, ax0)
+ A_global[v0] = A[v0]
+ for ax0 in range(8):
+ with T.block("B_buf_global"):
+ v0 = T.axis.spatial(8, ax0)
+ B_buf_global[v0] = B_buf[v0]
+ for i in range(8):
+ with T.block("C"):
+ vi = T.axis.spatial(8, i)
+ C[vi] = A_global[vi] + B_buf_global[vi]
+
+ sch = tir.Schedule(before)
+ block_c = sch.get_block("C")
+ sch.cache_read(block_c, 1, "global")
+ sch.cache_read(block_c, 0, "global")
+
+ after = sch.mod["main"]
+
+ assert_structural_equal_ignore_global_symbol(expected, after)
+ verify_trace_roundtrip(sch=sch, mod=before)
+
+
def test_inplace_cache_read():
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
block = sch.get_block("copy_in")