This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 78ba385fcb [BugTIR]fix error merging shared memory for ptx_cp_async
(#16800)
78ba385fcb is described below
commit 78ba385fcb38aa6181c35cccb5d316543d5f59ac
Author: Wei Tao <[email protected]>
AuthorDate: Sun Mar 31 06:09:10 2024 +0800
[BugTIR]fix error merging shared memory for ptx_cp_async (#16800)
* [BugTIR]fix error merging shared memory for ptx_cp_async
* run black format
* fix get dtype of ptx_cp_async
* get correct offset of ptx_cp_async
* black format
---
.../transforms/merge_shared_memory_allocations.cc | 26 ++++++++++++++++++
...form_merge_dynamic_shared_memory_allocations.py | 31 ++++++++++++++++++++++
2 files changed, 57 insertions(+)
diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc
b/src/tir/transforms/merge_shared_memory_allocations.cc
index c79b9c1f93..bd9ff37151 100644
--- a/src/tir/transforms/merge_shared_memory_allocations.cc
+++ b/src/tir/transforms/merge_shared_memory_allocations.cc
@@ -25,6 +25,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
@@ -170,6 +171,7 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op);
}
}
+
void VisitExpr_(const VarNode* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
@@ -180,6 +182,7 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
}
}
}
+
template <typename T>
void VisitNewScope(const T* op) {
scope_.push_back(StmtEntry());
@@ -200,6 +203,7 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
ICHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
}
+
void VisitStmt_(const AttrStmtNode* op) final {
// Only record the outer most thread extent.
if (op->attr_key == attr::thread_extent && !in_thread_env_) {
@@ -214,6 +218,7 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}
}
+
void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
@@ -392,6 +397,27 @@ class SharedMemoryRewriter : public StmtExprMutator {
PrimExpr extent = this->VisitExpr(op->args[3]);
return Call(op->dtype, op->op,
{op->args[0], merged_buf_var_, extra_offset + offset,
extent, op->args[4]});
+ } else if (op->op.same_as(builtin::ptx_cp_async())) {
+ ICHECK((op->args.size() == 5U) || (op->args.size() == 6U));
+ DataType dtype = op->dtype;
+ Var buffer = Downcast<Var>(op->args[0]);
+ if (!IsAppropriateSharedMemory(buffer)) {
+ return StmtExprMutator::VisitExpr_(op);
+ }
+ PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
+ PrimExpr offset = this->VisitExpr(op->args[1]);
+ // the dst shared memory is a byte buffer generated by merging shared
memory.
+ // we need to multiply the offset index by the byte size of the original
value dtype, to get
+ // the correct offset of merged shared buffer.
+ int index_factor = dtype.bytes();
+ if (op->args.size() == 5)
+ return Call(dtype, op->op,
+ {merged_buf_var_, mul(extra_offset + offset,
PrimExpr(index_factor)),
+ op->args[2], op->args[3], op->args[4]});
+ else
+ return Call(dtype, op->op,
+ {merged_buf_var_, mul(extra_offset + offset,
PrimExpr(index_factor)),
+ op->args[2], op->args[3], op->args[4], op->args[5]});
} else {
return StmtExprMutator::VisitExpr_(op);
}
diff --git
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
index 8661843d39..9bb0aaf6e8 100644
---
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++
b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -513,5 +513,36 @@ class TestSimpleAllocReuse(tvm.testing.CompareBeforeAfter):
return func
+class TestAsyncCopy(tvm.testing.CompareBeforeAfter):
+ """Test async copy in shared memory."""
+
+ transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+
+ def before(self):
+ @T.prim_func
+ def func(A: T.buffer((128)), B: T.buffer((128))):
+ A_sh_data = T.allocate([128], "float32", "shared.dyn")
+ B_sh_data = T.allocate([128], "float32", "shared.dyn")
+ A_sh = T.buffer([128], data=A_sh_data, scope="shared.dyn")
+ B_sh = T.buffer([128], data=B_sh_data, scope="shared.dyn")
+ threadIdx_x = T.launch_thread("threadIdx.x", 128)
+ T.ptx_cp_async("float32", A_sh.data, threadIdx_x, A.data,
threadIdx_x, 512)
+ T.ptx_cp_async("float32", B_sh.data, threadIdx_x, B.data,
threadIdx_x, 512)
+
+ return func
+
+ def expected(self):
+ @T.prim_func
+ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,),
"float32")):
+ threadIdx_x = T.launch_thread("threadIdx.x", 128)
+ buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn")
+ T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x * 4, A.data,
threadIdx_x, 512)
+ T.ptx_cp_async(
+ "float32", buf_dyn_shmem, (128 + threadIdx_x) * 4, B.data,
threadIdx_x, 512
+ )
+
+ return func
+
+
if __name__ == "__main__":
tvm.testing.main()