This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5092c003de [BugTIR] fix thread_sync occurs in letstmt (#16454)
5092c003de is described below

commit 5092c003de6424332d9179ddd405540092b96742
Author: Wei Tao <[email protected]>
AuthorDate: Tue Jan 30 02:05:22 2024 +0800

    [BugTIR] fix thread_sync occurs in letstmt (#16454)
    
    * [BugTIR] fix thread_sync occurs in letstmt
    
    * modify visit letstmt
    
    * typo
    
    * remove unecessary clear
---
 src/tir/transforms/storage_access.cc               | 14 ++++
 src/tir/transforms/storage_access.h                |  1 +
 .../test_tir_transform_thread_sync.py              | 77 ++++++++++++++++++++++
 3 files changed, 92 insertions(+)

diff --git a/src/tir/transforms/storage_access.cc 
b/src/tir/transforms/storage_access.cc
index cbc7f07cae..8c7a7035de 100644
--- a/src/tir/transforms/storage_access.cc
+++ b/src/tir/transforms/storage_access.cc
@@ -94,6 +94,20 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* 
op) {
   allow_append_ = false;
 }
 
+void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) {
+  allow_append_ = true;
+  ICHECK_EQ(curr_stmt_.access.size(), 0U);
+  curr_stmt_.stmt = op;
+  this->VisitExpr(op->value);
+  // push to the scope
+  scope_.back().push_back(curr_stmt_);
+  // clear access entry.
+  curr_stmt_.access.clear();
+  allow_append_ = false;
+  // traverse body block
+  this->VisitStmt(op->body);
+}
+
 void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::double_buffer_write) {
     ICHECK(double_buffer_write_ == nullptr);
diff --git a/src/tir/transforms/storage_access.h 
b/src/tir/transforms/storage_access.h
index 119e595f59..a0e03b35cd 100644
--- a/src/tir/transforms/storage_access.h
+++ b/src/tir/transforms/storage_access.h
@@ -84,6 +84,7 @@ class StorageAccessVisitor : public StmtExprVisitor {
   void VisitExpr_(const BufferLoadNode* op) final;
   void VisitStmt_(const BufferStoreNode* op) final;
   void VisitStmt_(const EvaluateNode* op) final;
+  void VisitStmt_(const LetStmtNode* op) final;
   void VisitStmt_(const AttrStmtNode* op) final;
   void VisitStmt_(const ForNode* op) final;
   void VisitStmt_(const IfThenElseNode* op) final;
diff --git a/tests/python/tir-transform/test_tir_transform_thread_sync.py 
b/tests/python/tir-transform/test_tir_transform_thread_sync.py
index 2cfc65aae0..5c43d8d96a 100644
--- a/tests/python/tir-transform/test_tir_transform_thread_sync.py
+++ b/tests/python/tir-transform/test_tir_transform_thread_sync.py
@@ -160,8 +160,85 @@ def test_sync_shared_dyn():
     tvm.ir.assert_structural_equal(mod["main"], expected)
 
 
[email protected]_cuda
+def test_sync_let_stmt():
+    @T.prim_func(private=True)
+    def func(A: T.Buffer((16 * 512), "float32")):
+        blockIdx_x = T.launch_thread("blockIdx.x", 16)
+        A_shared = T.allocate([512], "float32", "shared")
+        in_thread_A_temp = T.allocate([1], "float32", "local")
+        cross_thread_A_temp = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 128)
+        A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared")
+        for ax0 in range(512):
+            A_shared_1[ax0] = A[blockIdx_x * 512 + ax0]
+        in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, 
scope="local")
+        in_thread_A_temp_1[0] = T.float32(0)
+        with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as 
A_temp:
+            in_thread_A_temp_1[0] = A_temp
+        with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) 
as A_temp:
+            in_thread_A_temp_1[0] = A_temp
+        with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) 
as A_temp:
+            in_thread_A_temp_1[0] = A_temp
+        with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) 
as A_temp:
+            in_thread_A_temp_1[0] = A_temp
+        cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, 
scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            T.tvm_thread_allreduce(
+                T.uint32(1),
+                in_thread_A_temp_1[0],
+                T.bool(True),
+                cross_thread_A_temp_1[0],
+                threadIdx_x,
+            )
+
+    @T.prim_func(private=True)
+    def expected(A: T.Buffer((8192,), "float32")):
+        blockIdx_x = T.launch_thread("blockIdx.x", 16)
+        A_shared_1 = T.allocate([512], "float32", "shared")
+        in_thread_A_temp_1 = T.allocate([1], "float32", "local")
+        cross_thread_A_temp_1 = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 128)
+        A_shared_1_1 = T.Buffer((512,), data=A_shared_1, scope="shared")
+        for ax0 in range(512):
+            A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0]
+        in_thread_A_temp_1_1 = T.Buffer((1,), data=in_thread_A_temp_1, 
scope="local")
+        in_thread_A_temp_1_1[0] = T.float32(0)
+        T.tvm_storage_sync("shared")
+        with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as 
A_temp:
+            in_thread_A_temp_1_1[0] = A_temp
+        with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 
128]) as A_temp:
+            in_thread_A_temp_1_1[0] = A_temp
+        with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 
256]) as A_temp:
+            in_thread_A_temp_1_1[0] = A_temp
+        with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 
384]) as A_temp:
+            in_thread_A_temp_1_1[0] = A_temp
+        T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        )
+        cross_thread_A_temp_1_1 = T.Buffer((1,), data=cross_thread_A_temp_1, 
scope="local")
+        T.tvm_thread_allreduce(
+            T.uint32(1),
+            in_thread_A_temp_1_1[0],
+            T.bool(True),
+            cross_thread_A_temp_1_1[0],
+            threadIdx_x,
+        )
+
+    mod = tvm.IRModule({"main": func})
+    mod = tvm.tir.transform.ThreadSync("shared")(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
 if __name__ == "__main__":
     test_thread_storage_sync()
     test_sync_else_branch()
     test_sync_read_thread_id_independent_location()
     test_sync_shared_dyn()
+    test_sync_let_stmt()

Reply via email to