This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5092c003de [BugTIR] fix thread_sync occurs in letstmt (#16454)
5092c003de is described below
commit 5092c003de6424332d9179ddd405540092b96742
Author: Wei Tao <[email protected]>
AuthorDate: Tue Jan 30 02:05:22 2024 +0800
[BugTIR] fix thread_sync occurs in letstmt (#16454)
* [BugTIR] fix thread_sync occurs in letstmt
* modify visit letstmt
* typo
* remove unecessary clear
---
src/tir/transforms/storage_access.cc | 14 ++++
src/tir/transforms/storage_access.h | 1 +
.../test_tir_transform_thread_sync.py | 77 ++++++++++++++++++++++
3 files changed, 92 insertions(+)
diff --git a/src/tir/transforms/storage_access.cc
b/src/tir/transforms/storage_access.cc
index cbc7f07cae..8c7a7035de 100644
--- a/src/tir/transforms/storage_access.cc
+++ b/src/tir/transforms/storage_access.cc
@@ -94,6 +94,20 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode*
op) {
allow_append_ = false;
}
+void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) {
+ allow_append_ = true;
+ ICHECK_EQ(curr_stmt_.access.size(), 0U);
+ curr_stmt_.stmt = op;
+ this->VisitExpr(op->value);
+ // push to the scope
+ scope_.back().push_back(curr_stmt_);
+ // clear access entry.
+ curr_stmt_.access.clear();
+ allow_append_ = false;
+ // traverse body block
+ this->VisitStmt(op->body);
+}
+
void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::double_buffer_write) {
ICHECK(double_buffer_write_ == nullptr);
diff --git a/src/tir/transforms/storage_access.h
b/src/tir/transforms/storage_access.h
index 119e595f59..a0e03b35cd 100644
--- a/src/tir/transforms/storage_access.h
+++ b/src/tir/transforms/storage_access.h
@@ -84,6 +84,7 @@ class StorageAccessVisitor : public StmtExprVisitor {
void VisitExpr_(const BufferLoadNode* op) final;
void VisitStmt_(const BufferStoreNode* op) final;
void VisitStmt_(const EvaluateNode* op) final;
+ void VisitStmt_(const LetStmtNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const IfThenElseNode* op) final;
diff --git a/tests/python/tir-transform/test_tir_transform_thread_sync.py
b/tests/python/tir-transform/test_tir_transform_thread_sync.py
index 2cfc65aae0..5c43d8d96a 100644
--- a/tests/python/tir-transform/test_tir_transform_thread_sync.py
+++ b/tests/python/tir-transform/test_tir_transform_thread_sync.py
@@ -160,8 +160,85 @@ def test_sync_shared_dyn():
tvm.ir.assert_structural_equal(mod["main"], expected)
[email protected]_cuda
+def test_sync_let_stmt():
+ @T.prim_func(private=True)
+ def func(A: T.Buffer((16 * 512), "float32")):
+ blockIdx_x = T.launch_thread("blockIdx.x", 16)
+ A_shared = T.allocate([512], "float32", "shared")
+ in_thread_A_temp = T.allocate([1], "float32", "local")
+ cross_thread_A_temp = T.allocate([1], "float32", "local")
+ threadIdx_x = T.launch_thread("threadIdx.x", 128)
+ A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared")
+ for ax0 in range(512):
+ A_shared_1[ax0] = A[blockIdx_x * 512 + ax0]
+ in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp,
scope="local")
+ in_thread_A_temp_1[0] = T.float32(0)
+ with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as
A_temp:
+ in_thread_A_temp_1[0] = A_temp
+ with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128])
as A_temp:
+ in_thread_A_temp_1[0] = A_temp
+ with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256])
as A_temp:
+ in_thread_A_temp_1[0] = A_temp
+ with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384])
as A_temp:
+ in_thread_A_temp_1[0] = A_temp
+ cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp,
scope="local")
+ with T.attr(
+ T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ in_thread_A_temp_1[0],
+ T.bool(True),
+ cross_thread_A_temp_1[0],
+ threadIdx_x,
+ )
+
+ @T.prim_func(private=True)
+ def expected(A: T.Buffer((8192,), "float32")):
+ blockIdx_x = T.launch_thread("blockIdx.x", 16)
+ A_shared_1 = T.allocate([512], "float32", "shared")
+ in_thread_A_temp_1 = T.allocate([1], "float32", "local")
+ cross_thread_A_temp_1 = T.allocate([1], "float32", "local")
+ threadIdx_x = T.launch_thread("threadIdx.x", 128)
+ A_shared_1_1 = T.Buffer((512,), data=A_shared_1, scope="shared")
+ for ax0 in range(512):
+ A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0]
+ in_thread_A_temp_1_1 = T.Buffer((1,), data=in_thread_A_temp_1,
scope="local")
+ in_thread_A_temp_1_1[0] = T.float32(0)
+ T.tvm_storage_sync("shared")
+ with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as
A_temp:
+ in_thread_A_temp_1_1[0] = A_temp
+ with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
128]) as A_temp:
+ in_thread_A_temp_1_1[0] = A_temp
+ with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
256]) as A_temp:
+ in_thread_A_temp_1_1[0] = A_temp
+ with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
384]) as A_temp:
+ in_thread_A_temp_1_1[0] = A_temp
+ T.attr(
+ T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ )
+ cross_thread_A_temp_1_1 = T.Buffer((1,), data=cross_thread_A_temp_1,
scope="local")
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ in_thread_A_temp_1_1[0],
+ T.bool(True),
+ cross_thread_A_temp_1_1[0],
+ threadIdx_x,
+ )
+
+ mod = tvm.IRModule({"main": func})
+ mod = tvm.tir.transform.ThreadSync("shared")(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
if __name__ == "__main__":
test_thread_storage_sync()
test_sync_else_branch()
test_sync_read_thread_id_independent_location()
test_sync_shared_dyn()
+ test_sync_let_stmt()