This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 25b8a0798e [Hopper TMA] Add intrinsic to create barriers for 
synchronization (#15684)
25b8a0798e is described below

commit 25b8a0798e6308b21191ead0739eac54d376806e
Author: Adam Straw <[email protected]>
AuthorDate: Mon Sep 11 08:09:53 2023 -0700

    [Hopper TMA] Add intrinsic to create barriers for synchronization (#15684)
    
    This PR adds an intrinsic to create barriers that can be used with existing 
barrier intrinsics for synchronization. The prior method of barrier allocation 
was to use alloc_buffer e.g. as follows barrier = T.alloc_buffer([1], "uint64", 
scope="shared") and then pass the pointer and offset to that barrier allocation 
for use in the barrier intrinsics. This was a functional interface, but also 
caused problems with alignment of other non-barrier shared memory allocations. 
See removed workar [...]
    
    * [Hopper TMA] Add intrinsic to create barriers for synchronization
    
    * pad barrier alignment to avoid runtime alignment errors
    
    * CHECK vs ICHECK; barrier alignment = 16 plus comment
    
    * CHECK_EQ compile error
---
 include/tvm/tir/builtin.h                          | 21 ++++--
 python/tvm/script/ir_builder/tir/ir.py             |  2 +
 python/tvm/tir/__init__.py                         |  1 +
 python/tvm/tir/op.py                               | 75 ++++++++++------------
 src/target/source/codegen_cuda.cc                  | 61 ++++++++++++------
 src/target/source/codegen_cuda.h                   |  8 +++
 src/target/source/ptx.cc                           | 30 ++++-----
 src/target/source/ptx.h                            | 36 ++++-------
 src/tir/op/builtin.cc                              |  7 ++
 tests/python/unittest/test_tir_op_types.py         | 20 +++---
 tests/python/unittest/test_tir_ptx_cp_async.py     | 44 ++++---------
 .../test_tir_transform_inject_ptx_async_copy.py    | 14 ++--
 12 files changed, 159 insertions(+), 160 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 0d6d98e255..65012c6c0f 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -663,8 +663,7 @@ TVM_DLL const Op& ptx_cp_async();
  *                   Var global_ptr,
  *                   Expr global_offset,
  *                   size_t bytes,
- *                   Var barrier_ptr,
- *                   Expr barrier_offset);
+ *                   int barrier_id);
  */
 TVM_DLL const Op& ptx_cp_async_bulk();
 
@@ -681,7 +680,7 @@ TVM_DLL const Op& ptx_wait_group();
 /*!
  * \brief tvm intrinsics for ptx async copy barrier using 
cp.async.mbarrier.arrive
  *
- * ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
+ * ptx_cp_async_barrier(int barrier_id)
  *
  */
 TVM_DLL const Op& ptx_cp_async_barrier();
@@ -689,7 +688,7 @@ TVM_DLL const Op& ptx_cp_async_barrier();
 /*!
  * \brief tvm intrinsics for ptx barrier initialization of thread count using 
mbarrier.init
  *
- * ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int 
thread_count)
+ * ptx_init_barrier_thread_count(int barrier_id, int thread_count)
  *
  */
 TVM_DLL const Op& ptx_init_barrier_thread_count();
@@ -697,7 +696,7 @@ TVM_DLL const Op& ptx_init_barrier_thread_count();
 /*!
  * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
  *
- * ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
+ * ptx_arrive_barrier(int barrier_id)
  *
  */
 TVM_DLL const Op& ptx_arrive_barrier();
@@ -705,7 +704,7 @@ TVM_DLL const Op& ptx_arrive_barrier();
 /*!
  * \brief tvm intrinsic for ptx barrier arrival with expect tx using 
mbarrier.arrive.expect_tx
  *
- * ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int 
byte_count)
+ * ptx_arrive_barrier_expect_tx(int barrier_id, int byte_count)
  *
  */
 TVM_DLL const Op& ptx_arrive_barrier_expect_tx();
@@ -713,11 +712,19 @@ TVM_DLL const Op& ptx_arrive_barrier_expect_tx();
 /*!
  * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
  *
- * ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
+ * ptx_wait_barrier(int barrier_id)
  *
  */
 TVM_DLL const Op& ptx_wait_barrier();
 
+/*!
+ * \brief tvm intrinsics to create N barriers
+ *
+ * ptx_wait_barrier(int barrier_count)
+ *
+ */
+TVM_DLL const Op& create_barriers();
+
 /*!
  * \brief tvm intrinsic for storing the result of PTX MMA into a destination 
pointer.
  *        For example, if each thread in a warp of size 32 has 4 elements from 
the result of
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 337e060895..5471288878 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1849,6 +1849,7 @@ ptx_init_barrier_thread_count = 
_op_wrapper(_tir_op.ptx_init_barrier_thread_coun
 ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
 ptx_arrive_barrier_expect_tx = 
_op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
 ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
+create_barriers = _op_wrapper(_tir_op.create_barriers)
 assume = _op_wrapper(_tir_op.assume)
 undef = _op_wrapper(_tir_op.undef)
 TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
@@ -2125,6 +2126,7 @@ __all__ = [
     "ptx_arrive_barrier",
     "ptx_arrive_barrier_expect_tx",
     "ptx_wait_barrier",
+    "create_barriers",
     "mma_store",
     "mma_fill",
     "vectorlow",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 762fcb599f..f0500290b8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -71,6 +71,7 @@ from .op import (
     ptx_arrive_barrier,
     ptx_arrive_barrier_expect_tx,
     ptx_wait_barrier,
+    create_barriers,
 )
 from .op import vectorlow, vectorhigh, vectorcombine
 from .op import infinity, reinterpret
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index cb9227e8f2..30e2a29487 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1369,7 +1369,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, 
global_ptr, global_offset, by
 
 
 def ptx_cp_async_bulk(
-    dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, 
barrier_ptr, barrier_offset
+    dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, 
barrier_id
 ):
     """TVM intrinsic for ptx async copy from global to shared memory using 
cp.async.bulk
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
@@ -1394,11 +1394,8 @@ def ptx_cp_async_bulk(
     bytes : int
         The data size to copy.
 
-    barrier_ptr : Var
-        The barrier shared memory pointer variable.
-
     barrier_id : int
-        The offset of the barrier shared memory pointer.
+        The ID of the barrier shared memory pointer.
 
     Returns
     -------
@@ -1413,8 +1410,7 @@ def ptx_cp_async_bulk(
         global_ptr,
         global_offset,
         bytes,
-        barrier_ptr,
-        barrier_offset,
+        barrier_id,
     )
 
 
@@ -1447,37 +1443,31 @@ def ptx_wait_group(num):
     return call_intrin("", "tir.ptx_wait_group", num)
 
 
-def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
+def ptx_cp_async_barrier(barrier_id):
     """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
 
     Parameters
     ----------
-    barrier_ptr : Var
-        The barrier shared memory pointer variable.
-
     barrier_id : int
-        The offset of the barrier shared memory pointer.
+        The ID of the barrier shared memory pointer.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, 
barrier_offset)
+    return call_intrin("", "tir.ptx_cp_async_barrier", barrier_id)
 
 
-def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
+def ptx_init_barrier_thread_count(barrier_id, thread_count):
     """TVM intrinsic for ptx barrier initialization of thread count using 
mbarrier.init
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
 
     Parameters
     ----------
-    barrier_ptr : Var
-        The barrier shared memory pointer variable.
-
     barrier_id : int
-        The offset of the barrier shared memory pointer.
+        The ID of the barrier shared memory pointer.
 
     thread_count : int
         Number of threads expected to arrive at the barrier.
@@ -1487,43 +1477,35 @@ def ptx_init_barrier_thread_count(barrier_ptr, 
barrier_offset, thread_count):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin(
-        "", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, 
thread_count
-    )
+    return call_intrin("", "tir.ptx_init_barrier_thread_count", barrier_id, 
thread_count)
 
 
-def ptx_arrive_barrier(barrier_ptr, barrier_offset):
+def ptx_arrive_barrier(barrier_id):
     """TVM intrinsic for ptx barrier arrival using mbarrier.arrive
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
 
     Parameters
     ----------
-    barrier_ptr : Var
-        The barrier shared memory pointer variable.
-
     barrier_id : int
-        The offset of the barrier shared memory pointer.
+        The ID of the barrier shared memory pointer.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, 
barrier_offset)
+    return call_intrin("", "tir.ptx_arrive_barrier", barrier_id)
 
 
-def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
+def ptx_arrive_barrier_expect_tx(barrier_id, byte_count):
     """TVM intrinsic for ptx barrier arrival with expect tx using 
mbarrier.arrive.expect_tx
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
 
     Parameters
     ----------
-    barrier_ptr : Var
-        The barrier shared memory pointer variable.
-
     barrier_id : int
-        The offset of the barrier shared memory pointer.
+        The ID of the barrier shared memory pointer.
 
     byte_count : int
         Increases the tx count of the mbarrier object to track completion of
@@ -1534,29 +1516,40 @@ def ptx_arrive_barrier_expect_tx(barrier_ptr, 
barrier_offset, byte_count):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin(
-        "", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, 
byte_count
-    )
+    return call_intrin("", "tir.ptx_arrive_barrier_expect_tx", barrier_id, 
byte_count)
 
 
-def ptx_wait_barrier(barrier_ptr, barrier_offset):
+def ptx_wait_barrier(barrier_id):
     """TVM intrinsic for ptx barrier wait using mbarrier.try_wait
     
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
 
     Parameters
     ----------
-    barrier_ptr : Var
-        The barrier shared memory pointer variable.
-
     barrier_id : int
-        The offset of the barrier shared memory pointer.
+        The ID of the barrier shared memory pointer.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("", "tir.ptx_wait_barrier", barrier_id)
+
+
+def create_barriers(barrier_count):
+    """TVM intrinsic to create N barriers
+
+    Parameters
+    ----------
+    barrier_count : int
+        The number of barriers to create.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)
+    return call_intrin("", "tir.create_barriers", barrier_count)
 
 
 def vectorlow(dtype, vec):
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index d880b978b5..7639ce6065 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -968,10 +968,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     std::string src = this->PrintExpr(op->args[2]);
     std::string src_offset = this->PrintExpr(op->args[3]);
     std::string size = this->PrintExpr(op->args[4]);
-    std::string barrier_ptr = this->PrintExpr(op->args[5]);
-    std::string barrier_offset = this->PrintExpr(op->args[6]);
-    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, 
size, barrier_ptr,
-                                        barrier_offset);
+    int barrier_id = Downcast<IntImm>(op->args[5])->value;
+    CHECK(barrier_id < barrier_count_);
+    std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + 
"]";
+    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, 
size, barrier);
   } else if (op->op.same_as(builtin::ptx_commit_group())) {
     this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
   } else if (op->op.same_as(builtin::ptx_wait_group())) {
@@ -979,31 +979,50 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << 
";\");\n\n";
   } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barrier_ptr = this->PrintExpr(op->args[0]);
-    std::string barrier_offset = this->PrintExpr(op->args[1]);
-    this->stream << PrintCpAsyncBarrierAsm(barrier_ptr, barrier_offset);
+    int barrier_id = Downcast<IntImm>(op->args[0])->value;
+    CHECK(barrier_id < barrier_count_);
+    std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + 
"]";
+    this->stream << PrintCpAsyncBarrierAsm(barrier);
   } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barrier_ptr = this->PrintExpr(op->args[0]);
-    std::string barrier_offset = this->PrintExpr(op->args[1]);
-    std::string thread_count = this->PrintExpr(op->args[2]);
-    this->stream << PrintInitBarrierThreadCountAsm(barrier_ptr, 
barrier_offset, thread_count);
+    int barrier_id = Downcast<IntImm>(op->args[0])->value;
+    CHECK(barrier_id < barrier_count_);
+    std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + 
"]";
+    std::string thread_count = this->PrintExpr(op->args[1]);
+    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
   } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barrier_ptr = this->PrintExpr(op->args[0]);
-    std::string barrier_offset = this->PrintExpr(op->args[1]);
-    this->stream << PrintArriveBarrierAsm(barrier_ptr, barrier_offset);
+    int barrier_id = Downcast<IntImm>(op->args[0])->value;
+    CHECK(barrier_id < barrier_count_);
+    std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + 
"]";
+    this->stream << PrintArriveBarrierAsm(barrier);
   } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barrier_ptr = this->PrintExpr(op->args[0]);
-    std::string barrier_offset = this->PrintExpr(op->args[1]);
-    std::string byte_count = this->PrintExpr(op->args[2]);
-    this->stream << PrintArriveBarrierExpectTxAsm(barrier_ptr, barrier_offset, 
byte_count);
+    int barrier_id = Downcast<IntImm>(op->args[0])->value;
+    CHECK(barrier_id < barrier_count_);
+    std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + 
"]";
+    std::string byte_count = this->PrintExpr(op->args[1]);
+    this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
   } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barrier_ptr = this->PrintExpr(op->args[0]);
-    std::string barrier_offset = this->PrintExpr(op->args[1]);
-    this->stream << PrintWaitBarrierAsm(barrier_ptr, barrier_offset);
+    int barrier_id = Downcast<IntImm>(op->args[0])->value;
+    CHECK(barrier_id < barrier_count_);
+    std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + 
"]";
+    this->stream << PrintWaitBarrierAsm(barrier);
+  } else if (op->op.same_as(builtin::create_barriers())) {
+    CHECK_EQ(barrier_count_, -1);
+    int barrier_count = Downcast<IntImm>(op->args[0])->value;
+    // pad barrier alignment to avoid runtime alignment errors
+    CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
+    int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
+    if (barrier_count % barrier_alignment_count != 0) {
+      barrier_count = ((barrier_count / barrier_alignment_count) + 1) * 
barrier_alignment_count;
+    }
+    barrier_count_ = barrier_count;
+    this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") 
uint64_t "
+                 << barrier_name_ << "[" << barrier_count << "];\n";
+    this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << 
barrier_name_
+                 << "[i] = 0; }\n";
   } else if (op->op.same_as(builtin::ptx_ldg32())) {
     /*
     asm volatile (
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 797ac99363..bc7b34b500 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -109,6 +109,14 @@ class CodeGenCUDA final : public CodeGenC {
   // Op attribute map
   OpAttrMap<bool> op_need_warp_shuffle_ = 
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
 
+  // The name of the barrier array in shared memory
+  const std::string barrier_name_ = "barrier";
+  // The size of the barrier array in shared memory
+  int barrier_count_ = -1;
+  // The alignment of the barrier array in shared memory
+  // Set to 16 to maintain minimum alignment requirements for async bulk copy
+  const int barrier_alignment_bytes_ = 16;
+
   std::unordered_map<const VarNode*, std::string> fragment_shapes;
   std::unordered_map<const VarNode*, std::string> fragment_layouts;
   friend void PrintConst(const FloatImmNode* op, std::ostream& os, 
CodeGenCUDA* p);
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index dd7c7cb7c4..ed6125e74c 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -713,8 +713,7 @@ std::string PrintCpAsyncBulkAsm(const std::string& 
shared_ptr,
                                 const std::string& shared_elem_offset,
                                 const std::string& global_ptr,
                                 const std::string& global_elem_offset, const 
std::string& bytes,
-                                const std::string& barrier_ptr,
-                                const std::string& barrier_elem_offset) {
+                                const std::string& barrier) {
   std::string asm_code = R"(
   {
     unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr});
@@ -731,13 +730,12 @@ std::string PrintCpAsyncBulkAsm(const std::string& 
shared_ptr,
   replacer.register_rule("{smem_addr}", shared_ptr + " + " + 
shared_elem_offset);
   replacer.register_rule("{global_ptr}", global_ptr + " + " + 
global_elem_offset);
   replacer.register_rule("{bytes}", bytes);
-  replacer.register_rule("{barrier}", barrier_ptr + " + " + 
barrier_elem_offset);
+  replacer.register_rule("{barrier}", "&" + barrier);
   asm_code = replacer.rewrite(asm_code);
   return asm_code;
 }
 
-std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr,
-                                   const std::string& barrier_elem_offset) {
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
   std::string predicated_asm_code = R"(
   {
     unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
@@ -749,13 +747,12 @@ std::string PrintCpAsyncBarrierAsm(const std::string& 
barrier_ptr,
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier_ptr + " + " + 
barrier_elem_offset);
+  replacer.register_rule("{barrier}", "&" + barrier);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr,
-                                           const std::string& 
barrier_elem_offset,
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
                                            const std::string& thread_count) {
   std::string predicated_asm_code = R"(
   {
@@ -769,14 +766,13 @@ std::string PrintInitBarrierThreadCountAsm(const 
std::string& barrier_ptr,
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier_ptr + " + " + 
barrier_elem_offset);
+  replacer.register_rule("{barrier}", "&" + barrier);
   replacer.register_rule("{thread_count}", thread_count);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintArriveBarrierAsm(const std::string& barrier_ptr,
-                                  const std::string& barrier_elem_offset) {
+std::string PrintArriveBarrierAsm(const std::string& barrier) {
   std::string predicated_asm_code = R"(
   {
     unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
@@ -788,13 +784,12 @@ std::string PrintArriveBarrierAsm(const std::string& 
barrier_ptr,
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier_ptr + " + " + 
barrier_elem_offset);
+  replacer.register_rule("{barrier}", "&" + barrier);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr,
-                                          const std::string& 
barrier_elem_offset,
+std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier,
                                           const std::string& byte_count) {
   std::string predicated_asm_code = R"(
   {
@@ -808,14 +803,13 @@ std::string PrintArriveBarrierExpectTxAsm(const 
std::string& barrier_ptr,
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier_ptr + " + " + 
barrier_elem_offset);
+  replacer.register_rule("{barrier}", "&" + barrier);
   replacer.register_rule("{byte_count}", byte_count);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintWaitBarrierAsm(const std::string& barrier_ptr,
-                                const std::string& barrier_elem_offset) {
+std::string PrintWaitBarrierAsm(const std::string& barrier) {
   std::string predicated_asm_code = R"(
   {
     unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
@@ -828,7 +822,7 @@ std::string PrintWaitBarrierAsm(const std::string& 
barrier_ptr,
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier_ptr + " + " + 
barrier_elem_offset);
+  replacer.register_rule("{barrier}", "&" + barrier);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index a73180d40b..13d2f3cefc 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -115,60 +115,48 @@ std::string PrintPredicatedCpAsyncAssembly(const 
std::string& shared_ptr,
  * \param global_ptr: The pointer to the global memory.
  * \param global_elem_offset: The offset into the global memory.
  * \param bytes: The number of bytes to copy.
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
  */
 std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr,
                                 const std::string& shared_elem_offset,
                                 const std::string& global_ptr,
                                 const std::string& global_elem_offset, const 
std::string& bytes,
-                                const std::string& barrier_ptr,
-                                const std::string& barrier_elem_offset);
+                                const std::string& barrier);
 
 /*!
  * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
  */
-std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr,
-                                   const std::string& barrier_elem_offset);
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier);
 
 /*!
  * \brief Print ptx barrier initialization of thread count using mbarrier.init
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
  * \param thread_count: The number of threads expected to arrive at the 
barrier.
  */
-std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr,
-                                           const std::string& 
barrier_elem_offset,
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
                                            const std::string& thread_count);
 
 /*!
  * \brief Print ptx barrier arrival using mbarrier.arrive
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
  */
-std::string PrintArriveBarrierAsm(const std::string& barrier_ptr,
-                                  const std::string& barrier_elem_offset);
+std::string PrintArriveBarrierAsm(const std::string& barrier);
 
 /*!
  * \brief Print ptx barrier arrival with expect tx operation using 
mbarrier.arrive.expect_tx
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
  * \param byte_count: Increases the the tx count of the mbarrier object to 
track completion of
  * addtional async transactions.
  */
-std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr,
-                                          const std::string& 
barrier_elem_offset,
+std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier,
                                           const std::string& byte_count);
 
 /*!
  * \brief Print ptx barrier wait using mbarrier.try_wait
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
  */
-std::string PrintWaitBarrierAsm(const std::string& barrier_ptr,
-                                const std::string& barrier_elem_offset);
+std::string PrintWaitBarrierAsm(const std::string& barrier);
 
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index a4116abf13..1b80959b57 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -297,15 +297,22 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
 
 TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_barrier)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier_expect_tx)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(create_barriers)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(mma_store)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
diff --git a/tests/python/unittest/test_tir_op_types.py 
b/tests/python/unittest/test_tir_op_types.py
index e4922e1e0c..7398ee781b 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -237,10 +237,7 @@ def test_op_ptx_cp_async():
 def test_op_ptx_cp_async_bulk():
     buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
     buffer_local = tir.decl_buffer([8], "float16", scope="local")
-    barrier = tir.decl_buffer([1], "uint64", scope="shared")
-    expr = tir.ptx_cp_async_bulk(
-        "float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 
barrier.data, 0
-    )
+    expr = tir.ptx_cp_async_bulk("float16", buffer_shared.data, 0, 
buffer_local.data, 0, 16, 0)
     assert expr.op.name == "tir.ptx_cp_async_bulk"
 
 
@@ -255,30 +252,35 @@ def test_op_ptx_wait_group():
 
 
 def test_op_ptx_cp_async_barrier():
-    expr = tir.ptx_cp_async_barrier("barrier", 0)
+    expr = tir.ptx_cp_async_barrier(0)
     assert expr.op.name == "tir.ptx_cp_async_barrier"
 
 
 def test_op_ptx_init_barrier_thread_count():
-    expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32)
+    expr = tir.ptx_init_barrier_thread_count(0, 32)
     assert expr.op.name == "tir.ptx_init_barrier_thread_count"
 
 
 def test_op_ptx_arrive_barrier():
-    expr = tir.ptx_arrive_barrier("barrier", 0)
+    expr = tir.ptx_arrive_barrier(0)
     assert expr.op.name == "tir.ptx_arrive_barrier"
 
 
 def test_op_ptx_arrive_barrier_expect_tx():
-    expr = tir.ptx_arrive_barrier_expect_tx("barrier", 0, 32)
+    expr = tir.ptx_arrive_barrier_expect_tx(0, 32)
     assert expr.op.name == "tir.ptx_arrive_barrier_expect_tx"
 
 
 def test_op_ptx_wait_barrier():
-    expr = tir.ptx_wait_barrier("barrier", 0)
+    expr = tir.ptx_wait_barrier(0)
     assert expr.op.name == "tir.ptx_wait_barrier"
 
 
+def test_op_create_barriers():
+    expr = tir.create_barriers(16)
+    assert expr.op.name == "tir.create_barriers"
+
+
 def test_tir_op_vectorlow():
     buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
     vec = buffer.vload([0, 0], dtype="int8x16")
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py 
b/tests/python/unittest/test_tir_ptx_cp_async.py
index e6d3942ce5..d760023854 100644
--- a/tests/python/unittest/test_tir_ptx_cp_async.py
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -71,23 +71,13 @@ def ptx_cp_async_barrier(
     T.launch_thread(bx, 1)
     T.launch_thread(tx, 32)
     with T.block():
-        # Shared memory targets for cp.async.bulk must be 16 byte aligned
-        # Problem: CUDA codegen does not support allocation alignment
-        # Workaround: Ensure that `A_shared` occurs before `barrier` in 
program order
-        #             by allocating and initializing `A_shared` before 
`barrier`
-        #             which should result in `A_shared` being 16+ byte aligned
-        #             given it will be the first shared memory allocation
-        # TODO(Straw) Add CUDA codegen support for allocation alignment
         A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
-        A_shared[0, 0] = 0
-
-        barrier = T.alloc_buffer([1], "uint64", scope="shared")
-        barrier[0] = 0
 
         T.reads(A[0:32, 0:128])
         T.writes(B[0:32, 0:128])
 
-        T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, 
dtype=""))
+        T.evaluate(T.create_barriers(1, dtype=""))
+        T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype=""))
 
         for i in range(16):
             T.evaluate(
@@ -96,9 +86,9 @@ def ptx_cp_async_barrier(
                 )
             )
 
-        T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype=""))
-        T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype=""))
-        T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_cp_async_barrier(0, dtype=""))
+        T.evaluate(T.ptx_arrive_barrier(0, dtype=""))
+        T.evaluate(T.ptx_wait_barrier(0, dtype=""))
 
         for i in range(128):
             B[tx, i] = A_shared[tx, i]
@@ -126,32 +116,20 @@ def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"), 
B: T.Buffer((32, 128),
     T.launch_thread(bx, 1)
     T.launch_thread(tx, 32)
     with T.block():
-        # Shared memory targets for cp.async.bulk must be 16 byte aligned
-        # Problem: CUDA codegen does not support allocation alignment
-        # Workaround: Ensure that `A_shared` occurs before `barrier` in 
program order
-        #             by allocating and initializing `A_shared` before 
`barrier`
-        #             which should result in `A_shared` being 16+ byte aligned
-        #             given it will be the first shared memory allocation
-        # TODO(Straw) Add CUDA codegen support for allocation alignment
-        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared", 
align=16)
-        A_shared[0, 0] = 0
-
-        barrier = T.alloc_buffer([1], "uint64", scope="shared")
-        barrier[0] = 0
+        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
 
         T.reads(A[0:32, 0:128])
         T.writes(B[0:32, 0:128])
 
-        T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, 
dtype=""))
+        T.evaluate(T.create_barriers(1, dtype=""))
+        T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype=""))
 
         T.evaluate(
-            T.ptx_cp_async_bulk(
-                A_shared.data, tx * 128, A.data, tx * 128, 256, barrier.data, 
0, dtype="float16"
-            )
+            T.ptx_cp_async_bulk(A_shared.data, tx * 128, A.data, tx * 128, 
256, 0, dtype="float16")
         )
 
-        T.evaluate(T.ptx_arrive_barrier_expect_tx(barrier.data, 0, 256, 
dtype=""))
-        T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_arrive_barrier_expect_tx(0, 256, dtype=""))
+        T.evaluate(T.ptx_wait_barrier(0, dtype=""))
 
         for i in range(128):
             B[tx, i] = A_shared[tx, i]
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index ff70eeae81..61f0892a9c 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -193,21 +193,21 @@ def ptx_global_to_shared_copy_fp32x1_barrier(
     T.launch_thread(bx, 1)
     T.launch_thread(tx, 32)
     with T.block():
-        barrier = T.alloc_buffer([1], "uint64", scope="shared")
         A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
+
         T.reads(A[0:32, 0:128])
-        T.writes(B[0:32, 0:128], barrier[0:1])
+        T.writes(B[0:32, 0:128])
 
-        barrier[0] = 0
-        T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, 
dtype=""))
+        T.evaluate(T.create_barriers(1, dtype=""))
+        T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype=""))
 
         T.attr("default", "async_scope", 1)
         for i in T.serial(128):
             A_shared[tx, i] = A[tx, i]
 
-        T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype=""))
-        T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype=""))
-        T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_cp_async_barrier(0, dtype=""))
+        T.evaluate(T.ptx_arrive_barrier(0, dtype=""))
+        T.evaluate(T.ptx_wait_barrier(0, dtype=""))
 
         for i in range(128):
             B[tx, i] = A_shared[tx, i]


Reply via email to