This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 657880cdce [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite 
schedule primitive (#16660)
657880cdce is described below

commit 657880cdcedd7e41e911c583a8e93b3053a6ad27
Author: Andrei Hutu <[email protected]>
AuthorDate: Thu Mar 7 10:49:03 2024 +0000

    [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule 
primitive (#16660)
    
    * [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule 
primitive
    
    When inserting a `cache_read` / `cache_write` stage, the 
`tir.AllocateConst` statement would be duplicated if its body was not a 
`tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This 
happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always 
re-attached to the statement's body after the `cache_read` / `cache_write` 
stage is inserted in it, but the stage was being appended to the whole 
statement (which already contains the `tir.AllocateConst`)  [...]
    
    This commit also adds a test where the first `cache_read` stage is inserted 
into a statement whose body is a `tir.For`, while the second stage is added to 
a body that is `tir.SeqStmt` to check for regressions.
    
    * Improve PrimFunc readability
    
    * Remove redundant `T.reads()`
---
 src/tir/schedule/primitive/cache_read_write.cc     |  4 +--
 .../test_tir_schedule_cache_read_write.py          | 40 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index 3fbdf856b5..a687624bac 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -483,9 +483,9 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const 
Stmt& stage) {
     seq.insert(seq.begin() + pos, stage);
     body = SeqStmt(seq);
   } else if (pos == 0) {
-    body = SeqStmt({stage, stmt});
+    body = SeqStmt({stage, body});
   } else if (pos == 1) {
-    body = SeqStmt({stmt, stage});
+    body = SeqStmt({body, stage});
   } else {
     LOG(FATAL) << "Cannot insert at position " << pos
                << ".  When inserting adjacent to non-SeqStmt, "
diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py 
b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
index 840a18ae6a..345c7368ce 100644
--- a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
+++ b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
@@ -1379,6 +1379,46 @@ def 
test_cache_read_fail_invalid_storage_scope(use_block_name):
         sch.cache_read(block_b, 0, "test_scope")
 
 
+def test_cache_read_allocate_const():
+    @T.prim_func
+    def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
+        B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
+        B_buf = T.decl_buffer((8), dtype="float32", data=B)
+        for i in range(8):
+            with T.block("C"):
+                vi = T.axis.spatial(8, i)
+                C[vi] = A[vi] + B_buf[vi]
+
+    @T.prim_func
+    def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
+        B_buf_global = T.alloc_buffer((8), dtype="float32")
+        A_global = T.alloc_buffer((8), dtype="float32")
+        B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], 
"float32", [8])
+        B_buf = T.decl_buffer((8), data=B)
+        for ax0 in range(8):
+            with T.block("A_global"):
+                v0 = T.axis.spatial(8, ax0)
+                A_global[v0] = A[v0]
+        for ax0 in range(8):
+            with T.block("B_buf_global"):
+                v0 = T.axis.spatial(8, ax0)
+                B_buf_global[v0] = B_buf[v0]
+        for i in range(8):
+            with T.block("C"):
+                vi = T.axis.spatial(8, i)
+                C[vi] = A_global[vi] + B_buf_global[vi]
+
+    sch = tir.Schedule(before)
+    block_c = sch.get_block("C")
+    sch.cache_read(block_c, 1, "global")
+    sch.cache_read(block_c, 0, "global")
+
+    after = sch.mod["main"]
+
+    assert_structural_equal_ignore_global_symbol(expected, after)
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
 def test_inplace_cache_read():
     sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
     block = sch.get_block("copy_in")

Reply via email to