This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 78ba385fcb [BugTIR]fix error merging shared memory for ptx_cp_async 
(#16800)
78ba385fcb is described below

commit 78ba385fcb38aa6181c35cccb5d316543d5f59ac
Author: Wei Tao <[email protected]>
AuthorDate: Sun Mar 31 06:09:10 2024 +0800

    [BugTIR]fix error merging shared memory for ptx_cp_async (#16800)
    
    * [BugTIR]fix error merging shared memory for ptx_cp_async
    
    * run black format
    
    * fix get dtype of ptx_cp_async
    
    * get correct offset of ptx_cp_async
    
    * black format
---
 .../transforms/merge_shared_memory_allocations.cc  | 26 ++++++++++++++++++
 ...form_merge_dynamic_shared_memory_allocations.py | 31 ++++++++++++++++++++++
 2 files changed, 57 insertions(+)

diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc 
b/src/tir/transforms/merge_shared_memory_allocations.cc
index c79b9c1f93..bd9ff37151 100644
--- a/src/tir/transforms/merge_shared_memory_allocations.cc
+++ b/src/tir/transforms/merge_shared_memory_allocations.cc
@@ -25,6 +25,7 @@
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
@@ -170,6 +171,7 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
       StmtExprVisitor::VisitExpr_(op);
     }
   }
+
   void VisitExpr_(const VarNode* buf) final {
     // Directly reference to the variable count as a read.
     auto it = alloc_info_.find(buf);
@@ -180,6 +182,7 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
       }
     }
   }
+
   template <typename T>
   void VisitNewScope(const T* op) {
     scope_.push_back(StmtEntry());
@@ -200,6 +203,7 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
     ICHECK_NE(end_index, 0U);
     linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
   }
+
   void VisitStmt_(const AttrStmtNode* op) final {
     // Only record the outer most thread extent.
     if (op->attr_key == attr::thread_extent && !in_thread_env_) {
@@ -214,6 +218,7 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
       StmtExprVisitor::VisitStmt_(op);
     }
   }
+
   void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
 
   void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
@@ -392,6 +397,27 @@ class SharedMemoryRewriter : public StmtExprMutator {
       PrimExpr extent = this->VisitExpr(op->args[3]);
       return Call(op->dtype, op->op,
                   {op->args[0], merged_buf_var_, extra_offset + offset, 
extent, op->args[4]});
+    } else if (op->op.same_as(builtin::ptx_cp_async())) {
+      ICHECK((op->args.size() == 5U) || (op->args.size() == 6U));
+      DataType dtype = op->dtype;
+      Var buffer = Downcast<Var>(op->args[0]);
+      if (!IsAppropriateSharedMemory(buffer)) {
+        return StmtExprMutator::VisitExpr_(op);
+      }
+      PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
+      PrimExpr offset = this->VisitExpr(op->args[1]);
+      // the dst shared memory is a byte buffer generated by merging shared 
memory.
+      // we need to multiply the offset index by the byte size of the original 
value dtype, to get
+      // the correct offset of merged shared buffer.
+      int index_factor = dtype.bytes();
+      if (op->args.size() == 5)
+        return Call(dtype, op->op,
+                    {merged_buf_var_, mul(extra_offset + offset, 
PrimExpr(index_factor)),
+                     op->args[2], op->args[3], op->args[4]});
+      else
+        return Call(dtype, op->op,
+                    {merged_buf_var_, mul(extra_offset + offset, 
PrimExpr(index_factor)),
+                     op->args[2], op->args[3], op->args[4], op->args[5]});
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
diff --git 
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
 
b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
index 8661843d39..9bb0aaf6e8 100644
--- 
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++ 
b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -513,5 +513,36 @@ class TestSimpleAllocReuse(tvm.testing.CompareBeforeAfter):
         return func
 
 
+class TestAsyncCopy(tvm.testing.CompareBeforeAfter):
+    """Test async copy in shared memory."""
+
+    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+
+    def before(self):
+        @T.prim_func
+        def func(A: T.buffer((128)), B: T.buffer((128))):
+            A_sh_data = T.allocate([128], "float32", "shared.dyn")
+            B_sh_data = T.allocate([128], "float32", "shared.dyn")
+            A_sh = T.buffer([128], data=A_sh_data, scope="shared.dyn")
+            B_sh = T.buffer([128], data=B_sh_data, scope="shared.dyn")
+            threadIdx_x = T.launch_thread("threadIdx.x", 128)
+            T.ptx_cp_async("float32", A_sh.data, threadIdx_x, A.data, 
threadIdx_x, 512)
+            T.ptx_cp_async("float32", B_sh.data, threadIdx_x, B.data, 
threadIdx_x, 512)
+
+        return func
+
+    def expected(self):
+        @T.prim_func
+        def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), 
"float32")):
+            threadIdx_x = T.launch_thread("threadIdx.x", 128)
+            buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn")
+            T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x * 4, A.data, 
threadIdx_x, 512)
+            T.ptx_cp_async(
+                "float32", buf_dyn_shmem, (128 + threadIdx_x) * 4, B.data, 
threadIdx_x, 512
+            )
+
+        return func
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to