This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b9652a2db0 [Hopper TMA] CUDA codegen for async copy with barrier
synchronization (#15616)
b9652a2db0 is described below
commit b9652a2db0441408b0cffc6f6c0331e213ce6253
Author: Adam Straw <[email protected]>
AuthorDate: Thu Aug 24 19:07:36 2023 -0700
[Hopper TMA] CUDA codegen for async copy with barrier synchronization
(#15616)
[Codegen] CUDA async copy with barrier synchronization
---
include/tvm/tir/builtin.h | 32 ++++++
python/tvm/script/ir_builder/tir/ir.py | 8 ++
python/tvm/tir/__init__.py | 11 ++-
python/tvm/tir/op.py | 80 +++++++++++++++
src/target/source/codegen_cuda.cc | 39 ++++++++
src/target/source/codegen_cuda.h | 2 +
src/target/source/ptx.cc | 93 ++++++++++++++----
src/target/source/ptx.h | 26 +++++
src/tir/op/builtin.cc | 9 ++
tests/python/unittest/test_tir_op_types.py | 20 ++++
.../test_tir_transform_inject_ptx_async_copy.py | 108 ++++++++++++++-------
11 files changed, 372 insertions(+), 56 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index e8bcc028fc..b5c04f760d 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -663,6 +663,38 @@ TVM_DLL const Op& ptx_cp_async();
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();
+/*!
+ * \brief tvm intrinsics for ptx async copy barrier using
cp.async.mbarrier.arrive
+ *
+ * ptx_cp_async_barrier(barrier_array, barrier_id)
+ *
+ */
+TVM_DLL const Op& ptx_cp_async_barrier();
+
+/*!
+ * \brief tvm intrinsics for ptx barrier initialization of thread count using
mbarrier.init
+ *
+ * ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
+ *
+ */
+TVM_DLL const Op& ptx_init_barrier_thread_count();
+
+/*!
+ * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
+ *
+ * ptx_arrive_barrier(barrier_array, barrier_id)
+ *
+ */
+TVM_DLL const Op& ptx_arrive_barrier();
+
+/*!
+ * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
+ *
+ * ptx_wait_barrier(barrier_array, barrier_id)
+ *
+ */
+TVM_DLL const Op& ptx_wait_barrier();
+
/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination
pointer.
* For example, if each thread in a warp of size 32 has 4 elements from
the result of
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index efea9f1aea..d7bebbacee 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1844,6 +1844,10 @@ tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
+ptx_init_barrier_thread_count =
_op_wrapper(_tir_op.ptx_init_barrier_thread_count)
+ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
+ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
@@ -2113,6 +2117,10 @@ __all__ = [
"ptx_cp_async",
"ptx_wait_group",
"ptx_commit_group",
+ "ptx_cp_async_barrier",
+ "ptx_init_barrier_thread_count",
+ "ptx_arrive_barrier",
+ "ptx_wait_barrier",
"mma_store",
"mma_fill",
"vectorlow",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 5eb1059d27..84c5753337 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -60,7 +60,16 @@ from .op import (
tvm_fill_fragment,
)
from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill
-from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
+from .op import (
+ ptx_ldmatrix,
+ ptx_cp_async,
+ ptx_commit_group,
+ ptx_wait_group,
+ ptx_cp_async_barrier,
+ ptx_init_barrier_thread_count,
+ ptx_arrive_barrier,
+ ptx_wait_barrier,
+)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 378be84621..7e1c520cc4 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1397,6 +1397,86 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)
+def ptx_cp_async_barrier(barrier_arr, barrier_id):
+ """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
+
+ Parameters
+ ----------
+ barrier_arr : string
+ The name of the barrier array in shared memory
+ barrier_id : int
+ Index into the barrier array
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)
+
+
+def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
+ """TVM intrinsic for ptx barrier initialization of thread count using
mbarrier.init
+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
+
+ Parameters
+ ----------
+ barrier_arr : string
+ The name of the barrier array in shared memory
+ barrier_id : int
+ Index into the barrier array
+ thread_count : int
+ Number of threads expected to arrive at the barrier
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ "", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id,
thread_count
+ )
+
+
+def ptx_arrive_barrier(barrier_arr, barrier_id):
+ """TVM intrinsic for ptx barrier arrival using mbarrier.arrive
+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
+
+ Parameters
+ ----------
+ barrier_arr : string
+ The name of the barrier array in shared memory
+ barrier_id : int
+ Index into the barrier array
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)
+
+
+def ptx_wait_barrier(barrier_arr, barrier_id):
+ """TVM intrinsic for ptx barrier wait using mbarrier.try_wait
+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
+
+ Parameters
+ ----------
+ barrier_arr : string
+ The name of the barrier array in shared memory
+ barrier_id : int
+ Index into the barrier array
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)
+
+
def vectorlow(dtype, vec):
"""Get the low level half of the vector
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 6c02348191..edbe8be030 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -141,6 +141,18 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}
+ if (need_cast_smem_ptr_to_int_) {
+ decl_stream << "__forceinline__ __device__ unsigned int\n";
+ decl_stream << "cast_smem_ptr_to_int(const void* const smem_ptr)\n";
+ decl_stream << "{\n";
+ decl_stream << " unsigned int smem_int;\n";
+ decl_stream << " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64
smem_int, %1; "
+ "cvt.u32.u64 %0, smem_int; }\"\n";
+ decl_stream << " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n";
+ decl_stream << " return smem_int;\n";
+ decl_stream << "}\n";
+ }
+
decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) &&
(__CUDACC_VER_MINOR__ >= 4)) || \\\n";
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
@@ -873,6 +885,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+ need_cast_smem_ptr_to_int_ = true;
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
local_elem_offset,
smem_ptr, smem_elem_offset);
}
@@ -941,6 +954,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
cp.async
+ need_cast_smem_ptr_to_int_ = true;
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
} else {
@@ -952,6 +966,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n <<
";\");\n\n";
+ } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
+ need_cast_smem_ptr_to_int_ = true;
+ std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+ std::string barrier_id = this->PrintExpr(op->args[1]);
+ std::string barrier = barriers_arr + "[" + barrier_id + "]";
+ this->stream << PrintCpAsyncBarrierAsm(barrier);
+ } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
+ need_cast_smem_ptr_to_int_ = true;
+ std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+ std::string barrier_id = this->PrintExpr(op->args[1]);
+ std::string barrier = barriers_arr + "[" + barrier_id + "]";
+ std::string thread_count = this->PrintExpr(op->args[2]);
+ this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
+ } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
+ need_cast_smem_ptr_to_int_ = true;
+ std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+ std::string barrier_id = this->PrintExpr(op->args[1]);
+ std::string barrier = barriers_arr + "[" + barrier_id + "]";
+ this->stream << PrintArriveBarrierAsm(barrier);
+ } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
+ need_cast_smem_ptr_to_int_ = true;
+ std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
+ std::string barrier_id = this->PrintExpr(op->args[1]);
+ std::string barrier = barriers_arr + "[" + barrier_id + "]";
+ this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 7de6ae05e8..797ac99363 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -104,6 +104,8 @@ class CodeGenCUDA final : public CodeGenC {
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
+ // whether need cast_smem_ptr_to_int helper function
+ bool need_cast_smem_ptr_to_int_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index feffc3d304..6ff57f43bd 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -603,12 +603,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num,
const std::string& type
CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix
with type .b16.";
std::string asm_code = R"(
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr;
}\n"
- : "=r"(addr)
- : "l"((void *)({smem_addr}))
- );
+ unsigned int addr = cast_smem_ptr_to_int({smem_addr});
__asm__ __volatile__(
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
"{templates};\n"
@@ -638,12 +633,7 @@ std::string PrintCpAsyncAssembly(const std::string&
shared_ptr,
const std::string& global_elem_offset, const
std::string& bytes) {
std::string asm_code = R"(
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr;
}\n"
- : "=r"(addr)
- : "l"((void *)({smem_addr}))
- );
+ unsigned int addr = cast_smem_ptr_to_int({smem_addr});
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
@@ -674,12 +664,7 @@ std::string PrintPredicatedCpAsyncAssembly(const
std::string& shared_ptr,
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
std::string predicated_asm_code = R"(
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
- : "=r"(addr)
- : "l"((void *)({smem_addr}))
- );
+ unsigned int addr = cast_smem_ptr_to_int({smem_addr});
int pred_guard = (int){pred_guard};
__asm__ __volatile__(
"{ .reg .pred p;"
@@ -724,5 +709,77 @@ std::string PrintPredicatedCpAsyncAssembly(const
std::string& shared_ptr,
return predicated_asm_code;
}
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
+ std::string predicated_asm_code = R"(
+ {
+ unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+ __asm__ __volatile__(
+ "cp.async.mbarrier.arrive.shared.b64 [%0];"
+ :: "r" (barrier_addr_int)
+ );
+ }
+)";
+
+ Replacer replacer;
+ replacer.register_rule("{barrier}", barrier);
+ predicated_asm_code = replacer.rewrite(predicated_asm_code);
+ return predicated_asm_code;
+}
+
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
+ const std::string& thread_count) {
+ std::string predicated_asm_code = R"(
+ {
+ unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+ int thread_count = {thread_count};
+ __asm__ __volatile__(
+ "mbarrier.init.shared.b64 [%0], %1;"
+ :: "r"(barrier_addr_int), "r"(thread_count)
+ );
+ }
+)";
+
+ Replacer replacer;
+ replacer.register_rule("{barrier}", barrier);
+ replacer.register_rule("{thread_count}", thread_count);
+ predicated_asm_code = replacer.rewrite(predicated_asm_code);
+ return predicated_asm_code;
+}
+
+std::string PrintArriveBarrierAsm(const std::string& barrier) {
+ std::string predicated_asm_code = R"(
+ {
+ unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+ __asm__ __volatile__(
+ "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }"
+ :: "r"(barrier_addr_int)
+ );
+ }
+)";
+
+ Replacer replacer;
+ replacer.register_rule("{barrier}", barrier);
+ predicated_asm_code = replacer.rewrite(predicated_asm_code);
+ return predicated_asm_code;
+}
+
+std::string PrintWaitBarrierAsm(const std::string& barrier) {
+ std::string predicated_asm_code = R"(
+ {
+ unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+ constexpr int phase_bit = 0;
+ __asm__ __volatile__(
+ "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1;
@P bra.uni DONE; bra.uni WAIT; DONE: }"
+ :: "r"(barrier_addr_int), "r"(phase_bit)
+ );
+ }
+)";
+
+ Replacer replacer;
+ replacer.register_rule("{barrier}", barrier);
+ predicated_asm_code = replacer.rewrite(predicated_asm_code);
+ return predicated_asm_code;
+}
+
} // namespace codegen
} // namespace tvm
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index 1e49b57c17..18519d85f6 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -108,6 +108,32 @@ std::string PrintPredicatedCpAsyncAssembly(const
std::string& shared_ptr,
const std::string& bytes,
const std::string& predicate_value);
+/*!
+ * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
+ * \param barrier: The barrier in shared memory in the form
barrier_array[barrier_index]
+ */
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier);
+
+/*!
+ * \brief Print ptx barrier initialization of thread count using mbarrier.init
+ * \param barrier: The barrier in shared memory in the form
barrier_array[barrier_index]
+ * \param thread_count: The number of threads expected to arrive at the barrier
+ */
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
+ const std::string& thread_count);
+
+/*!
+ * \brief Print ptx barrier arrival using mbarrier.arrive
+ * \param barrier: The barrier in shared memory in the form
barrier_array[barrier_index]
+ */
+std::string PrintArriveBarrierAsm(const std::string& barrier);
+
+/*!
+ * \brief Print ptx barrier wait using mbarrier.try_wait
+ * \param barrier: The barrier in shared memory in the form
barrier_array[barrier_index]
+ */
+std::string PrintWaitBarrierAsm(const std::string& barrier);
+
} // namespace codegen
} // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index c855904284..0ca61b4099 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -290,6 +290,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_barrier)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(mma_store)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
diff --git a/tests/python/unittest/test_tir_op_types.py
b/tests/python/unittest/test_tir_op_types.py
index 58954e7459..30e9ed2dfa 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -244,6 +244,26 @@ def test_op_ptx_wait_group():
assert expr.op.name == "tir.ptx_wait_group"
+def test_op_ptx_cp_async_barrier():
+ expr = tir.ptx_cp_async_barrier("barrier", 0)
+ assert expr.op.name == "tir.ptx_cp_async_barrier"
+
+
+def ptx_init_barrier_thread_count():
+ expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32)
+ assert expr.op.name == "tir.ptx_init_barrier_thread_count"
+
+
+def ptx_arrive_barrier():
+ expr = tir.ptx_arrive_barrier("barrier", 0)
+ assert expr.op.name == "tir.ptx_arrive_barrier"
+
+
+def ptx_wait_barrier():
+ expr = tir.ptx_wait_barrier("barrier", 0)
+ assert expr.op.name == "tir.ptx_wait_barrier"
+
+
def test_tir_op_vectorlow():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index b39fca72c8..5d866199e7 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -183,7 +183,71 @@ def test_inject_async_copy_shared_dyn():
tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np)
-expected_cuda_script = r"""
[email protected]_func
+def ptx_global_to_shared_copy_fp32x1_barrier(
+ A: T.Buffer((32, 128), "float32"), B: T.Buffer((32, 128), "float32")
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ bx = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(bx, 1)
+ T.launch_thread(tx, 32)
+ with T.block():
+ barrier = T.alloc_buffer([1], "uint64", scope="shared")
+ A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
+ T.reads(A[0:32, 0:128])
+ T.writes(B[0:32, 0:128], barrier[0:1])
+
+ barrier[0] = 0
+ T.evaluate(T.ptx_init_barrier_thread_count("barrier", 0, 32, dtype=""))
+
+ T.attr("default", "async_scope", 1)
+ for i in T.serial(128):
+ A_shared[tx, i] = A[tx, i]
+
+ T.evaluate(T.ptx_cp_async_barrier("barrier", 0, dtype=""))
+ T.evaluate(T.ptx_arrive_barrier("barrier", 0, dtype=""))
+ T.evaluate(T.ptx_wait_barrier("barrier", 0, dtype=""))
+
+ for i in range(128):
+ B[tx, i] = A_shared[tx, i]
+
+
[email protected]_cuda
+def test_inject_async_copy_barrier():
+ dtype = "float32"
+ vec_size = 1
+ f = ptx_global_to_shared_copy_fp32x1_barrier
+
+ mod = tvm.IRModule.from_expr(f)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+ mod = tvm.tir.transform.FlattenBuffer()(mod)
+ mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+
+ assert count_cp_async(mod["main"].body) == 1
+
+ if tvm.testing.is_ampere_or_newer():
+ with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+ mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
+
+ A_np = np.random.rand(32, 128).astype(dtype)
+ B_np = np.zeros((32, 128)).astype(dtype)
+ dev = tvm.cuda(0)
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ mod(A_nd, B_nd)
+ tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
+expected_cuda_script = r"""__forceinline__ __device__ unsigned int
+cast_smem_ptr_to_int(const void* const smem_ptr)
+{
+ unsigned int smem_int;
+ asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1;
cvt.u32.u64 %0, smem_int; }"
+ : "=r"(smem_int) : "l"(smem_ptr));
+ return smem_int;
+}
+
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
(__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
@@ -214,12 +278,7 @@ __asm__ __volatile__("cp.async.commit_group;");
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr;
}\n"
- : "=r"(addr)
- : "l"((void *)(A_shared + (((int)threadIdx.x) + 16)))
- );
+ unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) +
16));
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -231,12 +290,7 @@ __asm__ __volatile__("cp.async.commit_group;");
}
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr;
}\n"
- : "=r"(addr)
- : "l"((void *)(B_shared + (((int)threadIdx.x) + 16)))
- );
+ unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) +
16));
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -250,12 +304,7 @@ __asm__ __volatile__("cp.async.commit_group;");
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr;
}\n"
- : "=r"(addr)
- : "l"((void *)(A_shared + (((int)threadIdx.x) + 32)))
- );
+ unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) +
32));
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -267,12 +316,7 @@ __asm__ __volatile__("cp.async.commit_group;");
}
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr;
}\n"
- : "=r"(addr)
- : "l"((void *)(B_shared + (((int)threadIdx.x) + 32)))
- );
+ unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) +
32));
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
@@ -288,12 +332,7 @@ __asm__ __volatile__("cp.async.commit_group;");
bool cse_var_1 = (i < 12);
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
- : "=r"(addr)
- : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))))
- );
+ unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16)
+ ((int)threadIdx.x)));
int pred_guard = (int)cse_var_1;
__asm__ __volatile__(
"{ .reg .pred p;"
@@ -316,12 +355,7 @@ __asm__ __volatile__("cp.async.wait_group 5;");
__syncthreads();
{
- unsigned int addr;
- __asm__ __volatile__(
- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
- : "=r"(addr)
- : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))))
- );
+ unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16)
+ ((int)threadIdx.x)));
int pred_guard = (int)cse_var_1;
__asm__ __volatile__(
"{ .reg .pred p;"