This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b9652a2db0 [Hopper TMA] CUDA codegen for async copy with barrier 
synchronization (#15616)
b9652a2db0 is described below

commit b9652a2db0441408b0cffc6f6c0331e213ce6253
Author: Adam Straw <[email protected]>
AuthorDate: Thu Aug 24 19:07:36 2023 -0700

    [Hopper TMA] CUDA codegen for async copy with barrier synchronization 
(#15616)
    
    [Codegen] CUDA async copy with barrier synchronization
---
 include/tvm/tir/builtin.h                          |  32 ++++++
 python/tvm/script/ir_builder/tir/ir.py             |   8 ++
 python/tvm/tir/__init__.py                         |  11 ++-
 python/tvm/tir/op.py                               |  80 +++++++++++++++
 src/target/source/codegen_cuda.cc                  |  39 ++++++++
 src/target/source/codegen_cuda.h                   |   2 +
 src/target/source/ptx.cc                           |  93 ++++++++++++++----
 src/target/source/ptx.h                            |  26 +++++
 src/tir/op/builtin.cc                              |   9 ++
 tests/python/unittest/test_tir_op_types.py         |  20 ++++
 .../test_tir_transform_inject_ptx_async_copy.py    | 108 ++++++++++++++-------
 11 files changed, 372 insertions(+), 56 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index e8bcc028fc..b5c04f760d 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -663,6 +663,38 @@ TVM_DLL const Op& ptx_cp_async();
 TVM_DLL const Op& ptx_commit_group();
 TVM_DLL const Op& ptx_wait_group();
 
+/*!
+ * \brief tvm intrinsics for ptx async copy barrier using 
cp.async.mbarrier.arrive
+ *
+ * ptx_cp_async_barrier(barrier_array, barrier_id)
+ *
+ */
+TVM_DLL const Op& ptx_cp_async_barrier();
+
+/*!
+ * \brief tvm intrinsics for ptx barrier initialization of thread count using 
mbarrier.init
+ *
+ * ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
+ *
+ */
+TVM_DLL const Op& ptx_init_barrier_thread_count();
+
+/*!
+ * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
+ *
+ * ptx_arrive_barrier(barrier_array, barrier_id)
+ *
+ */
+TVM_DLL const Op& ptx_arrive_barrier();
+
+/*!
+ * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
+ *
+ * ptx_wait_barrier(barrier_array, barrier_id)
+ *
+ */
+TVM_DLL const Op& ptx_wait_barrier();
+
 /*!
  * \brief tvm intrinsic for storing the result of PTX MMA into a destination 
pointer.
  *        For example, if each thread in a warp of size 32 has 4 elements from 
the result of
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index efea9f1aea..d7bebbacee 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1844,6 +1844,10 @@ tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
 tvm_warp_activemask = _tir_op.tvm_warp_activemask
 ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
 ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
+ptx_init_barrier_thread_count = 
_op_wrapper(_tir_op.ptx_init_barrier_thread_count)
+ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
+ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
 assume = _op_wrapper(_tir_op.assume)
 undef = _op_wrapper(_tir_op.undef)
 TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
@@ -2113,6 +2117,10 @@ __all__ = [
     "ptx_cp_async",
     "ptx_wait_group",
     "ptx_commit_group",
+    "ptx_cp_async_barrier",
+    "ptx_init_barrier_thread_count",
+    "ptx_arrive_barrier",
+    "ptx_wait_barrier",
     "mma_store",
     "mma_fill",
     "vectorlow",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 5eb1059d27..84c5753337 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -60,7 +60,16 @@ from .op import (
     tvm_fill_fragment,
 )
 from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill
-from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
+from .op import (
+    ptx_ldmatrix,
+    ptx_cp_async,
+    ptx_commit_group,
+    ptx_wait_group,
+    ptx_cp_async_barrier,
+    ptx_init_barrier_thread_count,
+    ptx_arrive_barrier,
+    ptx_wait_barrier,
+)
 from .op import vectorlow, vectorhigh, vectorcombine
 from .op import infinity, reinterpret
 from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 378be84621..7e1c520cc4 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1397,6 +1397,86 @@ def ptx_wait_group(num):
     return call_intrin("", "tir.ptx_wait_group", num)
 
 
+def ptx_cp_async_barrier(barrier_arr, barrier_id):
+    """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
+    
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
+
+    Parameters
+    ----------
+    barrier_arr : string
+        The name of the barrier array in shared memory
+    barrier_id : int
+        Index into the barrier array
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)
+
+
+def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
+    """TVM intrinsic for ptx barrier initialization of thread count using 
mbarrier.init
+    
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
+
+    Parameters
+    ----------
+    barrier_arr : string
+        The name of the barrier array in shared memory
+    barrier_id : int
+        Index into the barrier array
+    thread_count : int
+        Number of threads expected to arrive at the barrier
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, 
thread_count
+    )
+
+
+def ptx_arrive_barrier(barrier_arr, barrier_id):
+    """TVM intrinsic for ptx barrier arrival using mbarrier.arrive
+    
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
+
+    Parameters
+    ----------
+    barrier_arr : string
+        The name of the barrier array in shared memory
+    barrier_id : int
+        Index into the barrier array
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)
+
+
+def ptx_wait_barrier(barrier_arr, barrier_id):
+    """TVM intrinsic for ptx barrier wait using mbarrier.try_wait
+    
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
+
+    Parameters
+    ----------
+    barrier_arr : string
+        The name of the barrier array in shared memory
+    barrier_id : int
+        Index into the barrier array
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)
+
+
 def vectorlow(dtype, vec):
     """Get the low level half of the vector
 
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 6c02348191..edbe8be030 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -141,6 +141,18 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#include <mma.h>\n";
   }
 
+  if (need_cast_smem_ptr_to_int_) {
+    decl_stream << "__forceinline__ __device__ unsigned int\n";
+    decl_stream << "cast_smem_ptr_to_int(const void* const smem_ptr)\n";
+    decl_stream << "{\n";
+    decl_stream << "  unsigned int smem_int;\n";
+    decl_stream << "  asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 
smem_int, %1; "
+                   "cvt.u32.u64 %0, smem_int; }\"\n";
+    decl_stream << "    : \"=r\"(smem_int) : \"l\"(smem_ptr));\n";
+    decl_stream << "  return smem_int;\n";
+    decl_stream << "}\n";
+  }
+
   decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && 
(__CUDACC_VER_MINOR__ >= 4)) || \\\n";
   decl_stream << "     (__CUDACC_VER_MAJOR__ > 11))\n";
   decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
@@ -873,6 +885,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
       os << "}\n";
     } else {
       std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+      need_cast_smem_ptr_to_int_ = true;
       this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, 
local_elem_offset,
                                               smem_ptr, smem_elem_offset);
     }
@@ -941,6 +954,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     std::string src_offset = this->PrintExpr(op->args[3]);
     std::string size = this->PrintExpr(op->args[4]);
     // use size of argument list to indicate whether or not to use predicated 
cp.async
+    need_cast_smem_ptr_to_int_ = true;
     if (op->args.size() == 5) {
       this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, 
size);
     } else {
@@ -952,6 +966,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
   } else if (op->op.same_as(builtin::ptx_wait_group())) {
     int n = Downcast<IntImm>(op->args[0])->value;
     this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << 
";\");\n\n";
+  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
+    need_cast_smem_ptr_to_int_ = true;
+    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+    std::string barrier_id = this->PrintExpr(op->args[1]);
+    std::string barrier = barriers_arr + "[" + barrier_id + "]";
+    this->stream << PrintCpAsyncBarrierAsm(barrier);
+  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
+    need_cast_smem_ptr_to_int_ = true;
+    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+    std::string barrier_id = this->PrintExpr(op->args[1]);
+    std::string barrier = barriers_arr + "[" + barrier_id + "]";
+    std::string thread_count = this->PrintExpr(op->args[2]);
+    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
+  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
+    need_cast_smem_ptr_to_int_ = true;
+    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+    std::string barrier_id = this->PrintExpr(op->args[1]);
+    std::string barrier = barriers_arr + "[" + barrier_id + "]";
+    this->stream << PrintArriveBarrierAsm(barrier);
+  } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
+    need_cast_smem_ptr_to_int_ = true;
+    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+    std::string barrier_id = this->PrintExpr(op->args[1]);
+    std::string barrier = barriers_arr + "[" + barrier_id + "]";
+    this->stream << PrintWaitBarrierAsm(barrier);
   } else if (op->op.same_as(builtin::ptx_ldg32())) {
     /*
     asm volatile (
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 7de6ae05e8..797ac99363 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -104,6 +104,8 @@ class CodeGenCUDA final : public CodeGenC {
   bool need_math_constants_h_{false};
   // whether need mma.h
   bool need_mma_h_{false};
+  // whether need cast_smem_ptr_to_int helper function
+  bool need_cast_smem_ptr_to_int_{false};
   // Op attribute map
   OpAttrMap<bool> op_need_warp_shuffle_ = 
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
 
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index feffc3d304..6ff57f43bd 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -603,12 +603,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, 
const std::string& type
   CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix 
with type .b16.";
   std::string asm_code = R"(
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
-      : "=r"(addr)
-      : "l"((void *)({smem_addr}))
-    );
+    unsigned int addr = cast_smem_ptr_to_int({smem_addr});
     __asm__ __volatile__(
       "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
       "{templates};\n"
@@ -638,12 +633,7 @@ std::string PrintCpAsyncAssembly(const std::string& 
shared_ptr,
                                  const std::string& global_elem_offset, const 
std::string& bytes) {
   std::string asm_code = R"(
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
-      : "=r"(addr)
-      : "l"((void *)({smem_addr}))
-    );
+    unsigned int addr = cast_smem_ptr_to_int({smem_addr});
     __asm__ __volatile__(
       #if TVM_ENABLE_L2_PREFETCH
         "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
@@ -674,12 +664,7 @@ std::string PrintPredicatedCpAsyncAssembly(const 
std::string& shared_ptr,
       << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
   std::string predicated_asm_code = R"(
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
-      : "=r"(addr)
-      : "l"((void *)({smem_addr}))
-    );
+    unsigned int addr = cast_smem_ptr_to_int({smem_addr});
     int pred_guard = (int){pred_guard};
     __asm__ __volatile__(
         "{  .reg .pred p;"
@@ -724,5 +709,77 @@ std::string PrintPredicatedCpAsyncAssembly(const 
std::string& shared_ptr,
   return predicated_asm_code;
 }
 
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
+  std::string predicated_asm_code = R"(
+  {
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    __asm__ __volatile__(
+      "cp.async.mbarrier.arrive.shared.b64 [%0];"
+      :: "r" (barrier_addr_int)
+    );
+  }
+)";
+
+  Replacer replacer;
+  replacer.register_rule("{barrier}", barrier);
+  predicated_asm_code = replacer.rewrite(predicated_asm_code);
+  return predicated_asm_code;
+}
+
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
+                                           const std::string& thread_count) {
+  std::string predicated_asm_code = R"(
+  {
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    int thread_count = {thread_count};
+    __asm__ __volatile__(
+      "mbarrier.init.shared.b64 [%0], %1;"
+      :: "r"(barrier_addr_int), "r"(thread_count)
+    );
+  }
+)";
+
+  Replacer replacer;
+  replacer.register_rule("{barrier}", barrier);
+  replacer.register_rule("{thread_count}", thread_count);
+  predicated_asm_code = replacer.rewrite(predicated_asm_code);
+  return predicated_asm_code;
+}
+
+std::string PrintArriveBarrierAsm(const std::string& barrier) {
+  std::string predicated_asm_code = R"(
+  {
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    __asm__ __volatile__(
+      "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }"
+      :: "r"(barrier_addr_int)
+    );
+  }
+)";
+
+  Replacer replacer;
+  replacer.register_rule("{barrier}", barrier);
+  predicated_asm_code = replacer.rewrite(predicated_asm_code);
+  return predicated_asm_code;
+}
+
+std::string PrintWaitBarrierAsm(const std::string& barrier) {
+  std::string predicated_asm_code = R"(
+  {
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    constexpr int phase_bit = 0;
+    __asm__ __volatile__(
+      "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; 
@P bra.uni DONE; bra.uni WAIT; DONE: }"
+      :: "r"(barrier_addr_int), "r"(phase_bit)
+    );
+  }
+)";
+
+  Replacer replacer;
+  replacer.register_rule("{barrier}", barrier);
+  predicated_asm_code = replacer.rewrite(predicated_asm_code);
+  return predicated_asm_code;
+}
+
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index 1e49b57c17..18519d85f6 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -108,6 +108,32 @@ std::string PrintPredicatedCpAsyncAssembly(const 
std::string& shared_ptr,
                                            const std::string& bytes,
                                            const std::string& predicate_value);
 
+/*!
+ * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
+ * \param barrier: The barrier in shared memory in the form 
barrier_array[barrier_index]
+ */
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier);
+
+/*!
+ * \brief Print ptx barrier initialization of thread count using mbarrier.init
+ * \param barrier: The barrier in shared memory in the form 
barrier_array[barrier_index]
+ * \param thread_count: The number of threads expected to arrive at the barrier
+ */
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
+                                           const std::string& thread_count);
+
+/*!
+ * \brief Print ptx barrier arrival using mbarrier.arrive
+ * \param barrier: The barrier in shared memory in the form 
barrier_array[barrier_index]
+ */
+std::string PrintArriveBarrierAsm(const std::string& barrier);
+
+/*!
+ * \brief Print ptx barrier wait using mbarrier.try_wait
+ * \param barrier: The barrier in shared memory in the form 
barrier_array[barrier_index]
+ */
+std::string PrintWaitBarrierAsm(const std::string& barrier);
+
 }  // namespace codegen
 }  // namespace tvm
 
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index c855904284..0ca61b4099 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -290,6 +290,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
 TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_barrier)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(mma_store)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
diff --git a/tests/python/unittest/test_tir_op_types.py 
b/tests/python/unittest/test_tir_op_types.py
index 58954e7459..30e9ed2dfa 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -244,6 +244,26 @@ def test_op_ptx_wait_group():
     assert expr.op.name == "tir.ptx_wait_group"
 
 
+def test_op_ptx_cp_async_barrier():
+    expr = tir.ptx_cp_async_barrier("barrier", 0)
+    assert expr.op.name == "tir.ptx_cp_async_barrier"
+
+
+def ptx_init_barrier_thread_count():
+    expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32)
+    assert expr.op.name == "tir.ptx_init_barrier_thread_count"
+
+
+def ptx_arrive_barrier():
+    expr = tir.ptx_arrive_barrier("barrier", 0)
+    assert expr.op.name == "tir.ptx_arrive_barrier"
+
+
+def ptx_wait_barrier():
+    expr = tir.ptx_wait_barrier("barrier", 0)
+    assert expr.op.name == "tir.ptx_wait_barrier"
+
+
 def test_tir_op_vectorlow():
     buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
     vec = buffer.vload([0, 0], dtype="int8x16")
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index b39fca72c8..5d866199e7 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -183,7 +183,71 @@ def test_inject_async_copy_shared_dyn():
     tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np)
 
 
-expected_cuda_script = r"""
[email protected]_func
+def ptx_global_to_shared_copy_fp32x1_barrier(
+    A: T.Buffer((32, 128), "float32"), B: T.Buffer((32, 128), "float32")
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        barrier = T.alloc_buffer([1], "uint64", scope="shared")
+        A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
+        T.reads(A[0:32, 0:128])
+        T.writes(B[0:32, 0:128], barrier[0:1])
+
+        barrier[0] = 0
+        T.evaluate(T.ptx_init_barrier_thread_count("barrier", 0, 32, dtype=""))
+
+        T.attr("default", "async_scope", 1)
+        for i in T.serial(128):
+            A_shared[tx, i] = A[tx, i]
+
+        T.evaluate(T.ptx_cp_async_barrier("barrier", 0, dtype=""))
+        T.evaluate(T.ptx_arrive_barrier("barrier", 0, dtype=""))
+        T.evaluate(T.ptx_wait_barrier("barrier", 0, dtype=""))
+
+        for i in range(128):
+            B[tx, i] = A_shared[tx, i]
+
+
[email protected]_cuda
+def test_inject_async_copy_barrier():
+    dtype = "float32"
+    vec_size = 1
+    f = ptx_global_to_shared_copy_fp32x1_barrier
+
+    mod = tvm.IRModule.from_expr(f)
+    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+    mod = tvm.tir.transform.FlattenBuffer()(mod)
+    mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+
+    assert count_cp_async(mod["main"].body) == 1
+
+    if tvm.testing.is_ampere_or_newer():
+        with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+            mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
+
+        A_np = np.random.rand(32, 128).astype(dtype)
+        B_np = np.zeros((32, 128)).astype(dtype)
+        dev = tvm.cuda(0)
+        A_nd = tvm.nd.array(A_np, device=dev)
+        B_nd = tvm.nd.array(B_np, device=dev)
+        mod(A_nd, B_nd)
+        tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
+expected_cuda_script = r"""__forceinline__ __device__ unsigned int
+cast_smem_ptr_to_int(const void* const smem_ptr)
+{
+  unsigned int smem_int;
+  asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; 
cvt.u32.u64 %0, smem_int; }"
+    : "=r"(smem_int) : "l"(smem_ptr));
+  return smem_int;
+}
+
 #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
      (__CUDACC_VER_MAJOR__ > 11))
 #define TVM_ENABLE_L2_PREFETCH 1
@@ -214,12 +278,7 @@ __asm__ __volatile__("cp.async.commit_group;");
 
 
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
-      : "=r"(addr)
-      : "l"((void *)(A_shared + (((int)threadIdx.x) + 16)))
-    );
+    unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) + 
16));
     __asm__ __volatile__(
       #if TVM_ENABLE_L2_PREFETCH
         "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -231,12 +290,7 @@ __asm__ __volatile__("cp.async.commit_group;");
   }
 
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
-      : "=r"(addr)
-      : "l"((void *)(B_shared + (((int)threadIdx.x) + 16)))
-    );
+    unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) + 
16));
     __asm__ __volatile__(
       #if TVM_ENABLE_L2_PREFETCH
         "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -250,12 +304,7 @@ __asm__ __volatile__("cp.async.commit_group;");
 
 
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
-      : "=r"(addr)
-      : "l"((void *)(A_shared + (((int)threadIdx.x) + 32)))
-    );
+    unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) + 
32));
     __asm__ __volatile__(
       #if TVM_ENABLE_L2_PREFETCH
         "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -267,12 +316,7 @@ __asm__ __volatile__("cp.async.commit_group;");
   }
 
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
-      : "=r"(addr)
-      : "l"((void *)(B_shared + (((int)threadIdx.x) + 32)))
-    );
+    unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) + 
32));
     __asm__ __volatile__(
       #if TVM_ENABLE_L2_PREFETCH
         "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -288,12 +332,7 @@ __asm__ __volatile__("cp.async.commit_group;");
     bool cse_var_1 = (i < 12);
 
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
-      : "=r"(addr)
-      : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))))
-    );
+    unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) 
+ ((int)threadIdx.x)));
     int pred_guard = (int)cse_var_1;
     __asm__ __volatile__(
         "{  .reg .pred p;"
@@ -316,12 +355,7 @@ __asm__ __volatile__("cp.async.wait_group 5;");
     __syncthreads();
 
   {
-    unsigned int addr;
-    __asm__ __volatile__(
-      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
-      : "=r"(addr)
-      : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))))
-    );
+    unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) 
+ ((int)threadIdx.x)));
     int pred_guard = (int)cse_var_1;
     __asm__ __volatile__(
         "{  .reg .pred p;"

Reply via email to