This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 25b8a0798e [Hopper TMA] Add intrinsic to create barriers for
synchronization (#15684)
25b8a0798e is described below
commit 25b8a0798e6308b21191ead0739eac54d376806e
Author: Adam Straw <[email protected]>
AuthorDate: Mon Sep 11 08:09:53 2023 -0700
[Hopper TMA] Add intrinsic to create barriers for synchronization (#15684)
This PR adds an intrinsic to create barriers that can be used with existing
barrier intrinsics for synchronization. The prior method of barrier allocation
was to use alloc_buffer e.g. as follows barrier = T.alloc_buffer([1], "uint64",
scope="shared") and then pass the pointer and offset to that barrier allocation
for use in the barrier intrinsics. This was a functional interface, but also
caused problems with alignment of other non-barrier shared memory allocations.
See removed workar [...]
* [Hopper TMA] Add intrinsic to create barriers for synchronization
* pad barrier alignment to avoid runtime alignment errors
* CHECK vs ICHECK; barrier alignment = 16 plus comment
* CHECK_EQ compile error
---
include/tvm/tir/builtin.h | 21 ++++--
python/tvm/script/ir_builder/tir/ir.py | 2 +
python/tvm/tir/__init__.py | 1 +
python/tvm/tir/op.py | 75 ++++++++++------------
src/target/source/codegen_cuda.cc | 61 ++++++++++++------
src/target/source/codegen_cuda.h | 8 +++
src/target/source/ptx.cc | 30 ++++-----
src/target/source/ptx.h | 36 ++++-------
src/tir/op/builtin.cc | 7 ++
tests/python/unittest/test_tir_op_types.py | 20 +++---
tests/python/unittest/test_tir_ptx_cp_async.py | 44 ++++---------
.../test_tir_transform_inject_ptx_async_copy.py | 14 ++--
12 files changed, 159 insertions(+), 160 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 0d6d98e255..65012c6c0f 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -663,8 +663,7 @@ TVM_DLL const Op& ptx_cp_async();
* Var global_ptr,
* Expr global_offset,
* size_t bytes,
- * Var barrier_ptr,
- * Expr barrier_offset);
+ * int barrier_id);
*/
TVM_DLL const Op& ptx_cp_async_bulk();
@@ -681,7 +680,7 @@ TVM_DLL const Op& ptx_wait_group();
/*!
* \brief tvm intrinsics for ptx async copy barrier using
cp.async.mbarrier.arrive
*
- * ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
+ * ptx_cp_async_barrier(int barrier_id)
*
*/
TVM_DLL const Op& ptx_cp_async_barrier();
@@ -689,7 +688,7 @@ TVM_DLL const Op& ptx_cp_async_barrier();
/*!
* \brief tvm intrinsics for ptx barrier initialization of thread count using
mbarrier.init
*
- * ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int
thread_count)
+ * ptx_init_barrier_thread_count(int barrier_id, int thread_count)
*
*/
TVM_DLL const Op& ptx_init_barrier_thread_count();
@@ -697,7 +696,7 @@ TVM_DLL const Op& ptx_init_barrier_thread_count();
/*!
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
*
- * ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
+ * ptx_arrive_barrier(int barrier_id)
*
*/
TVM_DLL const Op& ptx_arrive_barrier();
@@ -705,7 +704,7 @@ TVM_DLL const Op& ptx_arrive_barrier();
/*!
* \brief tvm intrinsic for ptx barrier arrival with expect tx using
mbarrier.arrive.expect_tx
*
- * ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int
byte_count)
+ * ptx_arrive_barrier_expect_tx(int barrier_id, int byte_count)
*
*/
TVM_DLL const Op& ptx_arrive_barrier_expect_tx();
@@ -713,11 +712,19 @@ TVM_DLL const Op& ptx_arrive_barrier_expect_tx();
/*!
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
*
- * ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
+ * ptx_wait_barrier(int barrier_id)
*
*/
TVM_DLL const Op& ptx_wait_barrier();
+/*!
+ * \brief tvm intrinsics to create N barriers
+ *
+ * ptx_wait_barrier(int barrier_count)
+ *
+ */
+TVM_DLL const Op& create_barriers();
+
/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination
pointer.
* For example, if each thread in a warp of size 32 has 4 elements from
the result of
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 337e060895..5471288878 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1849,6 +1849,7 @@ ptx_init_barrier_thread_count =
_op_wrapper(_tir_op.ptx_init_barrier_thread_coun
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx =
_op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
+create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
@@ -2125,6 +2126,7 @@ __all__ = [
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
+ "create_barriers",
"mma_store",
"mma_fill",
"vectorlow",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 762fcb599f..f0500290b8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -71,6 +71,7 @@ from .op import (
ptx_arrive_barrier,
ptx_arrive_barrier_expect_tx,
ptx_wait_barrier,
+ create_barriers,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index cb9227e8f2..30e2a29487 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1369,7 +1369,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset,
global_ptr, global_offset, by
def ptx_cp_async_bulk(
- dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes,
barrier_ptr, barrier_offset
+ dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes,
barrier_id
):
"""TVM intrinsic for ptx async copy from global to shared memory using
cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
@@ -1394,11 +1394,8 @@ def ptx_cp_async_bulk(
bytes : int
The data size to copy.
- barrier_ptr : Var
- The barrier shared memory pointer variable.
-
barrier_id : int
- The offset of the barrier shared memory pointer.
+ The ID of the barrier shared memory pointer.
Returns
-------
@@ -1413,8 +1410,7 @@ def ptx_cp_async_bulk(
global_ptr,
global_offset,
bytes,
- barrier_ptr,
- barrier_offset,
+ barrier_id,
)
@@ -1447,37 +1443,31 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)
-def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
+def ptx_cp_async_barrier(barrier_id):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
Parameters
----------
- barrier_ptr : Var
- The barrier shared memory pointer variable.
-
barrier_id : int
- The offset of the barrier shared memory pointer.
+ The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
- return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr,
barrier_offset)
+ return call_intrin("", "tir.ptx_cp_async_barrier", barrier_id)
-def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
+def ptx_init_barrier_thread_count(barrier_id, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using
mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
Parameters
----------
- barrier_ptr : Var
- The barrier shared memory pointer variable.
-
barrier_id : int
- The offset of the barrier shared memory pointer.
+ The ID of the barrier shared memory pointer.
thread_count : int
Number of threads expected to arrive at the barrier.
@@ -1487,43 +1477,35 @@ def ptx_init_barrier_thread_count(barrier_ptr,
barrier_offset, thread_count):
call : PrimExpr
The call expression.
"""
- return call_intrin(
- "", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset,
thread_count
- )
+ return call_intrin("", "tir.ptx_init_barrier_thread_count", barrier_id,
thread_count)
-def ptx_arrive_barrier(barrier_ptr, barrier_offset):
+def ptx_arrive_barrier(barrier_id):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Parameters
----------
- barrier_ptr : Var
- The barrier shared memory pointer variable.
-
barrier_id : int
- The offset of the barrier shared memory pointer.
+ The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
- return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr,
barrier_offset)
+ return call_intrin("", "tir.ptx_arrive_barrier", barrier_id)
-def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
+def ptx_arrive_barrier_expect_tx(barrier_id, byte_count):
"""TVM intrinsic for ptx barrier arrival with expect tx using
mbarrier.arrive.expect_tx
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
Parameters
----------
- barrier_ptr : Var
- The barrier shared memory pointer variable.
-
barrier_id : int
- The offset of the barrier shared memory pointer.
+ The ID of the barrier shared memory pointer.
byte_count : int
Increases the tx count of the mbarrier object to track completion of
@@ -1534,29 +1516,40 @@ def ptx_arrive_barrier_expect_tx(barrier_ptr,
barrier_offset, byte_count):
call : PrimExpr
The call expression.
"""
- return call_intrin(
- "", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset,
byte_count
- )
+ return call_intrin("", "tir.ptx_arrive_barrier_expect_tx", barrier_id,
byte_count)
-def ptx_wait_barrier(barrier_ptr, barrier_offset):
+def ptx_wait_barrier(barrier_id):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
Parameters
----------
- barrier_ptr : Var
- The barrier shared memory pointer variable.
-
barrier_id : int
- The offset of the barrier shared memory pointer.
+ The ID of the barrier shared memory pointer.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("", "tir.ptx_wait_barrier", barrier_id)
+
+
+def create_barriers(barrier_count):
+ """TVM intrinsic to create N barriers
+
+ Parameters
+ ----------
+ barrier_count : int
+ The number of barriers to create.
Returns
-------
call : PrimExpr
The call expression.
"""
- return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)
+ return call_intrin("", "tir.create_barriers", barrier_count)
def vectorlow(dtype, vec):
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index d880b978b5..7639ce6065 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -968,10 +968,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
- std::string barrier_ptr = this->PrintExpr(op->args[5]);
- std::string barrier_offset = this->PrintExpr(op->args[6]);
- this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset,
size, barrier_ptr,
- barrier_offset);
+ int barrier_id = Downcast<IntImm>(op->args[5])->value;
+ CHECK(barrier_id < barrier_count_);
+ std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) +
"]";
+ this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset,
size, barrier);
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
@@ -979,31 +979,50 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n <<
";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
- std::string barrier_ptr = this->PrintExpr(op->args[0]);
- std::string barrier_offset = this->PrintExpr(op->args[1]);
- this->stream << PrintCpAsyncBarrierAsm(barrier_ptr, barrier_offset);
+ int barrier_id = Downcast<IntImm>(op->args[0])->value;
+ CHECK(barrier_id < barrier_count_);
+ std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) +
"]";
+ this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
- std::string barrier_ptr = this->PrintExpr(op->args[0]);
- std::string barrier_offset = this->PrintExpr(op->args[1]);
- std::string thread_count = this->PrintExpr(op->args[2]);
- this->stream << PrintInitBarrierThreadCountAsm(barrier_ptr,
barrier_offset, thread_count);
+ int barrier_id = Downcast<IntImm>(op->args[0])->value;
+ CHECK(barrier_id < barrier_count_);
+ std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) +
"]";
+ std::string thread_count = this->PrintExpr(op->args[1]);
+ this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true;
- std::string barrier_ptr = this->PrintExpr(op->args[0]);
- std::string barrier_offset = this->PrintExpr(op->args[1]);
- this->stream << PrintArriveBarrierAsm(barrier_ptr, barrier_offset);
+ int barrier_id = Downcast<IntImm>(op->args[0])->value;
+ CHECK(barrier_id < barrier_count_);
+ std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) +
"]";
+ this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
need_cast_smem_ptr_to_int_ = true;
- std::string barrier_ptr = this->PrintExpr(op->args[0]);
- std::string barrier_offset = this->PrintExpr(op->args[1]);
- std::string byte_count = this->PrintExpr(op->args[2]);
- this->stream << PrintArriveBarrierExpectTxAsm(barrier_ptr, barrier_offset,
byte_count);
+ int barrier_id = Downcast<IntImm>(op->args[0])->value;
+ CHECK(barrier_id < barrier_count_);
+ std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) +
"]";
+ std::string byte_count = this->PrintExpr(op->args[1]);
+ this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true;
- std::string barrier_ptr = this->PrintExpr(op->args[0]);
- std::string barrier_offset = this->PrintExpr(op->args[1]);
- this->stream << PrintWaitBarrierAsm(barrier_ptr, barrier_offset);
+ int barrier_id = Downcast<IntImm>(op->args[0])->value;
+ CHECK(barrier_id < barrier_count_);
+ std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) +
"]";
+ this->stream << PrintWaitBarrierAsm(barrier);
+ } else if (op->op.same_as(builtin::create_barriers())) {
+ CHECK_EQ(barrier_count_, -1);
+ int barrier_count = Downcast<IntImm>(op->args[0])->value;
+ // pad barrier alignment to avoid runtime alignment errors
+ CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
+ int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
+ if (barrier_count % barrier_alignment_count != 0) {
+ barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
barrier_alignment_count;
+ }
+ barrier_count_ = barrier_count;
+ this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ")
uint64_t "
+ << barrier_name_ << "[" << barrier_count << "];\n";
+ this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " <<
barrier_name_
+ << "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 797ac99363..bc7b34b500 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -109,6 +109,14 @@ class CodeGenCUDA final : public CodeGenC {
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
+ // The name of the barrier array in shared memory
+ const std::string barrier_name_ = "barrier";
+ // The size of the barrier array in shared memory
+ int barrier_count_ = -1;
+ // The alignment of the barrier array in shared memory
+ // Set to 16 to maintain minimum alignment requirements for async bulk copy
+ const int barrier_alignment_bytes_ = 16;
+
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os,
CodeGenCUDA* p);
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index dd7c7cb7c4..ed6125e74c 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -713,8 +713,7 @@ std::string PrintCpAsyncBulkAsm(const std::string&
shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const
std::string& bytes,
- const std::string& barrier_ptr,
- const std::string& barrier_elem_offset) {
+ const std::string& barrier) {
std::string asm_code = R"(
{
unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr});
@@ -731,13 +730,12 @@ std::string PrintCpAsyncBulkAsm(const std::string&
shared_ptr,
replacer.register_rule("{smem_addr}", shared_ptr + " + " +
shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " +
global_elem_offset);
replacer.register_rule("{bytes}", bytes);
- replacer.register_rule("{barrier}", barrier_ptr + " + " +
barrier_elem_offset);
+ replacer.register_rule("{barrier}", "&" + barrier);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
-std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr,
- const std::string& barrier_elem_offset) {
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
@@ -749,13 +747,12 @@ std::string PrintCpAsyncBarrierAsm(const std::string&
barrier_ptr,
)";
Replacer replacer;
- replacer.register_rule("{barrier}", barrier_ptr + " + " +
barrier_elem_offset);
+ replacer.register_rule("{barrier}", "&" + barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}
-std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr,
- const std::string&
barrier_elem_offset,
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
const std::string& thread_count) {
std::string predicated_asm_code = R"(
{
@@ -769,14 +766,13 @@ std::string PrintInitBarrierThreadCountAsm(const
std::string& barrier_ptr,
)";
Replacer replacer;
- replacer.register_rule("{barrier}", barrier_ptr + " + " +
barrier_elem_offset);
+ replacer.register_rule("{barrier}", "&" + barrier);
replacer.register_rule("{thread_count}", thread_count);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}
-std::string PrintArriveBarrierAsm(const std::string& barrier_ptr,
- const std::string& barrier_elem_offset) {
+std::string PrintArriveBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
@@ -788,13 +784,12 @@ std::string PrintArriveBarrierAsm(const std::string&
barrier_ptr,
)";
Replacer replacer;
- replacer.register_rule("{barrier}", barrier_ptr + " + " +
barrier_elem_offset);
+ replacer.register_rule("{barrier}", "&" + barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}
-std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr,
- const std::string&
barrier_elem_offset,
+std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier,
const std::string& byte_count) {
std::string predicated_asm_code = R"(
{
@@ -808,14 +803,13 @@ std::string PrintArriveBarrierExpectTxAsm(const
std::string& barrier_ptr,
)";
Replacer replacer;
- replacer.register_rule("{barrier}", barrier_ptr + " + " +
barrier_elem_offset);
+ replacer.register_rule("{barrier}", "&" + barrier);
replacer.register_rule("{byte_count}", byte_count);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}
-std::string PrintWaitBarrierAsm(const std::string& barrier_ptr,
- const std::string& barrier_elem_offset) {
+std::string PrintWaitBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
@@ -828,7 +822,7 @@ std::string PrintWaitBarrierAsm(const std::string&
barrier_ptr,
)";
Replacer replacer;
- replacer.register_rule("{barrier}", barrier_ptr + " + " +
barrier_elem_offset);
+ replacer.register_rule("{barrier}", "&" + barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index a73180d40b..13d2f3cefc 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -115,60 +115,48 @@ std::string PrintPredicatedCpAsyncAssembly(const
std::string& shared_ptr,
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy.
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
*/
std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const
std::string& bytes,
- const std::string& barrier_ptr,
- const std::string& barrier_elem_offset);
+ const std::string& barrier);
/*!
* \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
*/
-std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr,
- const std::string& barrier_elem_offset);
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier);
/*!
* \brief Print ptx barrier initialization of thread count using mbarrier.init
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
* \param thread_count: The number of threads expected to arrive at the
barrier.
*/
-std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr,
- const std::string&
barrier_elem_offset,
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
const std::string& thread_count);
/*!
* \brief Print ptx barrier arrival using mbarrier.arrive
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
*/
-std::string PrintArriveBarrierAsm(const std::string& barrier_ptr,
- const std::string& barrier_elem_offset);
+std::string PrintArriveBarrierAsm(const std::string& barrier);
/*!
* \brief Print ptx barrier arrival with expect tx operation using
mbarrier.arrive.expect_tx
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
* \param byte_count: Increases the the tx count of the mbarrier object to
track completion of
* addtional async transactions.
*/
-std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr,
- const std::string&
barrier_elem_offset,
+std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier,
const std::string& byte_count);
/*!
* \brief Print ptx barrier wait using mbarrier.try_wait
- * \param barrier_ptr: The pointer to the barrier in shared memory.
- * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param barrier: The name of the barrier in shared memory.
*/
-std::string PrintWaitBarrierAsm(const std::string& barrier_ptr,
- const std::string& barrier_elem_offset);
+std::string PrintWaitBarrierAsm(const std::string& barrier);
} // namespace codegen
} // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index a4116abf13..1b80959b57 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -297,15 +297,22 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_barrier)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(create_barriers)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(mma_store)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
diff --git a/tests/python/unittest/test_tir_op_types.py
b/tests/python/unittest/test_tir_op_types.py
index e4922e1e0c..7398ee781b 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -237,10 +237,7 @@ def test_op_ptx_cp_async():
def test_op_ptx_cp_async_bulk():
buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
buffer_local = tir.decl_buffer([8], "float16", scope="local")
- barrier = tir.decl_buffer([1], "uint64", scope="shared")
- expr = tir.ptx_cp_async_bulk(
- "float16", buffer_shared.data, 0, buffer_local.data, 0, 16,
barrier.data, 0
- )
+ expr = tir.ptx_cp_async_bulk("float16", buffer_shared.data, 0,
buffer_local.data, 0, 16, 0)
assert expr.op.name == "tir.ptx_cp_async_bulk"
@@ -255,30 +252,35 @@ def test_op_ptx_wait_group():
def test_op_ptx_cp_async_barrier():
- expr = tir.ptx_cp_async_barrier("barrier", 0)
+ expr = tir.ptx_cp_async_barrier(0)
assert expr.op.name == "tir.ptx_cp_async_barrier"
def test_op_ptx_init_barrier_thread_count():
- expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32)
+ expr = tir.ptx_init_barrier_thread_count(0, 32)
assert expr.op.name == "tir.ptx_init_barrier_thread_count"
def test_op_ptx_arrive_barrier():
- expr = tir.ptx_arrive_barrier("barrier", 0)
+ expr = tir.ptx_arrive_barrier(0)
assert expr.op.name == "tir.ptx_arrive_barrier"
def test_op_ptx_arrive_barrier_expect_tx():
- expr = tir.ptx_arrive_barrier_expect_tx("barrier", 0, 32)
+ expr = tir.ptx_arrive_barrier_expect_tx(0, 32)
assert expr.op.name == "tir.ptx_arrive_barrier_expect_tx"
def test_op_ptx_wait_barrier():
- expr = tir.ptx_wait_barrier("barrier", 0)
+ expr = tir.ptx_wait_barrier(0)
assert expr.op.name == "tir.ptx_wait_barrier"
+def test_op_create_barriers():
+ expr = tir.create_barriers(16)
+ assert expr.op.name == "tir.create_barriers"
+
+
def test_tir_op_vectorlow():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py
b/tests/python/unittest/test_tir_ptx_cp_async.py
index e6d3942ce5..d760023854 100644
--- a/tests/python/unittest/test_tir_ptx_cp_async.py
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -71,23 +71,13 @@ def ptx_cp_async_barrier(
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
- # Shared memory targets for cp.async.bulk must be 16 byte aligned
- # Problem: CUDA codegen does not support allocation alignment
- # Workaround: Ensure that `A_shared` occurs before `barrier` in
program order
- # by allocating and initializing `A_shared` before
`barrier`
- # which should result in `A_shared` being 16+ byte aligned
- # given it will be the first shared memory allocation
- # TODO(Straw) Add CUDA codegen support for allocation alignment
A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
- A_shared[0, 0] = 0
-
- barrier = T.alloc_buffer([1], "uint64", scope="shared")
- barrier[0] = 0
T.reads(A[0:32, 0:128])
T.writes(B[0:32, 0:128])
- T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32,
dtype=""))
+ T.evaluate(T.create_barriers(1, dtype=""))
+ T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype=""))
for i in range(16):
T.evaluate(
@@ -96,9 +86,9 @@ def ptx_cp_async_barrier(
)
)
- T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype=""))
- T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype=""))
- T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+ T.evaluate(T.ptx_cp_async_barrier(0, dtype=""))
+ T.evaluate(T.ptx_arrive_barrier(0, dtype=""))
+ T.evaluate(T.ptx_wait_barrier(0, dtype=""))
for i in range(128):
B[tx, i] = A_shared[tx, i]
@@ -126,32 +116,20 @@ def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"),
B: T.Buffer((32, 128),
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
- # Shared memory targets for cp.async.bulk must be 16 byte aligned
- # Problem: CUDA codegen does not support allocation alignment
- # Workaround: Ensure that `A_shared` occurs before `barrier` in
program order
- # by allocating and initializing `A_shared` before
`barrier`
- # which should result in `A_shared` being 16+ byte aligned
- # given it will be the first shared memory allocation
- # TODO(Straw) Add CUDA codegen support for allocation alignment
- A_shared = T.alloc_buffer([32, 128], "float16", scope="shared",
align=16)
- A_shared[0, 0] = 0
-
- barrier = T.alloc_buffer([1], "uint64", scope="shared")
- barrier[0] = 0
+ A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
T.reads(A[0:32, 0:128])
T.writes(B[0:32, 0:128])
- T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32,
dtype=""))
+ T.evaluate(T.create_barriers(1, dtype=""))
+ T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype=""))
T.evaluate(
- T.ptx_cp_async_bulk(
- A_shared.data, tx * 128, A.data, tx * 128, 256, barrier.data,
0, dtype="float16"
- )
+ T.ptx_cp_async_bulk(A_shared.data, tx * 128, A.data, tx * 128,
256, 0, dtype="float16")
)
- T.evaluate(T.ptx_arrive_barrier_expect_tx(barrier.data, 0, 256,
dtype=""))
- T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+ T.evaluate(T.ptx_arrive_barrier_expect_tx(0, 256, dtype=""))
+ T.evaluate(T.ptx_wait_barrier(0, dtype=""))
for i in range(128):
B[tx, i] = A_shared[tx, i]
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index ff70eeae81..61f0892a9c 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -193,21 +193,21 @@ def ptx_global_to_shared_copy_fp32x1_barrier(
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
- barrier = T.alloc_buffer([1], "uint64", scope="shared")
A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
+
T.reads(A[0:32, 0:128])
- T.writes(B[0:32, 0:128], barrier[0:1])
+ T.writes(B[0:32, 0:128])
- barrier[0] = 0
- T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32,
dtype=""))
+ T.evaluate(T.create_barriers(1, dtype=""))
+ T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype=""))
T.attr("default", "async_scope", 1)
for i in T.serial(128):
A_shared[tx, i] = A[tx, i]
- T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype=""))
- T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype=""))
- T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+ T.evaluate(T.ptx_cp_async_barrier(0, dtype=""))
+ T.evaluate(T.ptx_arrive_barrier(0, dtype=""))
+ T.evaluate(T.ptx_wait_barrier(0, dtype=""))
for i in range(128):
B[tx, i] = A_shared[tx, i]