This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 126f1ba7ff [S-TIR][CUDA] Fix legacy predicated cp.async zero fill
(#19741)
126f1ba7ff is described below
commit 126f1ba7ff8359c792980705b8fe71e0a55e6f0f
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Jun 13 07:30:19 2026 -0400
[S-TIR][CUDA] Fix legacy predicated cp.async zero fill (#19741)
This fixes the legacy predicated `ptx.cp_async` codegen path used by
`InjectPTXAsyncCopy` for `if_then_else(..., 0)` stores.
The old inline CUDA emission zero-filled the shared-memory destination
when the predicate was false. The TIRx helper-based legacy codegen only
skipped the `cp.async`, leaving the destination slot unchanged. This
restores the previous behavior by emitting an `@!p st.shared.*` zero
store in the generated legacy predicated helper.
The CUDA source snapshot in
`test_s_tir_transform_inject_ptx_async_copy.py` is updated to reflect
the restored false-predicate zero-fill instruction and the current
generated helper-based CUDA source.
---
.../tvm/tirx/operator/intrinsics/cuda/cp_async.py | 14 +-
src/tirx/transform/common_subexpr_elim.cc | 26 ++-
.../test_s_tir_transform_inject_ptx_async_copy.py | 211 ++++++++-------------
3 files changed, 112 insertions(+), 139 deletions(-)
diff --git a/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
b/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
index 712c4672d4..2eeb0821d6 100644
--- a/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
+++ b/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
@@ -205,7 +205,8 @@ def codegen_ptx_cp_async(*args):
bytes; offsets are pre-scaled by the pass) and the call is
forwarded with default cache / predicate / fill_mode.
* 6 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size,
- predicate)`` — same as 5-arg form with an explicit predicate.
+ predicate)`` — same as 5-arg form with an explicit predicate,
+ zero-filling the destination when the predicate is false.
* 8 args ``(dst_ptr, src_ptr, cp_size, cache_policy, has_cache_hint,
prefetch_size, predicate, fill_mode)`` — the fork-native wrapper
API.
@@ -269,6 +270,14 @@ def codegen_ptx_cp_async(*args):
func_name = (
f"ptx_cp_async_legacy_pred_{ca_or_cg}_{cp_size_v}_{dst_elem_bytes}_{src_elem_bytes}"
)
+ if cp_size_v == 4:
+ zero_fill = ' " @!p st.shared.u32 [%0], {%4};\\n"\n'
+ elif cp_size_v == 8:
+ zero_fill = ' " @!p st.shared.v2.u32 [%0], {%4, %4};\\n"\n'
+ elif cp_size_v == 16:
+ zero_fill = ' " @!p st.shared.v4.u32 [%0], {%4, %4, %4,
%4};\\n"\n'
+ else:
+ raise ValueError(f"unsupported legacy predicated cp.async
size: {cp_size_v}")
body = (
f" uint8_t* dst_p = (uint8_t*)dst + dst_off{dst_scale};\n"
f" uint8_t* src_p = (uint8_t*)src + src_off{src_scale};\n"
@@ -279,8 +288,9 @@ def codegen_ptx_cp_async(*args):
' " setp.eq.u32 p, %3, 1;\\n"\n'
f' " @p cp.async.{ca_or_cg}.shared.global'
' [%0], [%1], %2;\\n"\n'
+ f"{zero_fill}"
' "}\\n"\n'
- f' :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}),
"r"(predicate)\n'
+ f' :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}),
"r"(predicate), "r"(0)\n'
" );"
)
source_code = (
diff --git a/src/tirx/transform/common_subexpr_elim.cc
b/src/tirx/transform/common_subexpr_elim.cc
index 9e7b2b1fb7..2221df9352 100644
--- a/src/tirx/transform/common_subexpr_elim.cc
+++ b/src/tirx/transform/common_subexpr_elim.cc
@@ -83,6 +83,7 @@
#include <utility>
#include <vector>
+#include "../../support/ordered_map.h"
#include "../analysis/check_contains.h"
namespace tvm {
@@ -239,8 +240,16 @@ class CSEPlanner : public StmtExprVisitor {
int consumed{0};
};
- /*! \brief Expression table keyed by structural equality (ExprDeepEqual). */
- using ExprTable = std::unordered_map<PrimExpr, ExprEntry,
ffi::StructuralHash, ExprDeepEqual>;
+ /*!
+ * \brief Expression table keyed by structural equality (ExprDeepEqual).
+ *
+ * An insertion-ordered map so that iteration visits entries in discovery
+ * (program) order. This makes the plan — and hence cse_v numbering —
+ * deterministic. A plain unordered_map iterates in hash order, and
+ * StructuralHash hashes free variables by object identity, which varies
+ * between processes (ASLR).
+ */
+ using ExprTable = support::OrderedMap<PrimExpr, ExprEntry,
ffi::StructuralHash, ExprDeepEqual>;
// ------------------------------------------------------------------
// Eligibility predicates
@@ -592,8 +601,9 @@ class CSEPlanner : public StmtExprVisitor {
* \brief Convert the accumulated expression table into InsertBefore +
ExprRemap tables.
*
* Algorithm (shallower-first with repr propagation):
- * 1. Collect all entries and sort by expr_depth ascending (shallower
first),
- * with structural hash as tie-breaker for determinism.
+ * 1. Collect all entries and sort by expr_depth ascending (shallower
first).
+ * The stable sort over the insertion-ordered table keeps entries of
+ * equal depth in discovery (program) order, so the plan is
deterministic.
* 2. Compute independent occurrence counts from the DAG children.
* For each parent P with count >= 2, its children's consumed counts
* are incremented by `(P.count - 1) * multiplicity` (the Bind value
@@ -610,7 +620,8 @@ class CSEPlanner : public StmtExprVisitor {
* \return A pair of (InsertBeforeTable, ExprRemapTable).
*/
std::pair<InsertBeforeTable, ExprRemapTable> ComputePlan() {
- // Step 1: Sort entries by depth ascending (shallower first), hash for
determinism
+ // Step 1: Sort entries by depth ascending (shallower first). table_
iterates
+ // in discovery order, which the stable sort preserves among equal depths.
std::vector<std::pair<PrimExpr, ExprEntry*>> all_entries;
for (auto& kv : table_) {
all_entries.push_back({kv.first, &kv.second});
@@ -619,10 +630,7 @@ class CSEPlanner : public StmtExprVisitor {
std::stable_sort(
all_entries.begin(), all_entries.end(),
[](const std::pair<PrimExpr, ExprEntry*>& a, const std::pair<PrimExpr,
ExprEntry*>& b) {
- if (a.second->expr_depth != b.second->expr_depth)
- return a.second->expr_depth < b.second->expr_depth;
- ffi::StructuralHash hasher;
- return hasher(a.first) < hasher(b.first);
+ return a.second->expr_depth < b.second->expr_depth;
});
// Step 2: Compute consumed counts in ExprEntry from the DAG.
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
index fde4e91501..428fb9b895 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
@@ -194,19 +194,8 @@ def test_inject_async_copy_shared_dyn():
# `ptx_mbarrier_*` family instead.
-# Note: the expected output contains a dead CSE variable `cse_v1 = (i < 12)`.
-# CSE extracts (i < 12) before inject_ptx_async_copy runs, but the latter
-# replaces the original IfThenElse guards with new cast(int32, i < 12)
-# expressions for predicated async copies, leaving cse_v1 unused.
expected_cuda_script = r"""#include <cuda.h>
-__forceinline__ __device__ unsigned int
-cast_smem_ptr_to_int(const void* const smem_ptr)
-{
- unsigned int smem_int;
- asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1;
cvt.u32.u64 %0, smem_int; }"
- : "=r"(smem_int) : "l"(smem_ptr));
- return smem_int;
-}
+#endif
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
(__CUDACC_VER_MAJOR__ > 11))
@@ -215,132 +204,101 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
#define TVM_ENABLE_L2_PREFETCH 0
#endif
-#ifdef __CUDACC_RTC__
-using int64_t = long long;
-using uint64_t = unsigned long long;
+#ifdef _WIN32
+ using uint = unsigned int;
+ using uchar = unsigned char;
+ using ushort = unsigned short;
+ using int64_t = long long;
+ using uint64_t = unsigned long long;
#else
-#include <cstdint>
+ #define uint unsigned int
+ #define uchar unsigned char
+ #define ushort unsigned short
#endif
-using uint = unsigned int;
-using uchar = unsigned char;
-using ushort = unsigned short;
-
-extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C);
-extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
- __shared__ float A_shared[64];
- __shared__ float B_shared[64];
- A_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
- B_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
-__asm__ __volatile__("cp.async.commit_group;");
-
-
- {
- unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) +
16));
- __asm__ __volatile__(
- #if TVM_ENABLE_L2_PREFETCH
- "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
- #else
- "cp.async.ca.shared.global [%0], [%1], %2;"
- #endif
- :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4)
- );
- }
- {
- unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) +
16));
- __asm__ __volatile__(
- #if TVM_ENABLE_L2_PREFETCH
- "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
- #else
- "cp.async.ca.shared.global [%0], [%1], %2;"
- #endif
- :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4)
- );
- }
-__asm__ __volatile__("cp.async.commit_group;");
-
-
- {
- unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) +
32));
- __asm__ __volatile__(
- #if TVM_ENABLE_L2_PREFETCH
- "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
- #else
- "cp.async.ca.shared.global [%0], [%1], %2;"
- #endif
- :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4)
- );
- }
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_0() {
+ asm volatile("cp.async.wait_group 0;");
+}
- {
- unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) +
32));
- __asm__ __volatile__(
- #if TVM_ENABLE_L2_PREFETCH
- "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
- #else
- "cp.async.ca.shared.global [%0], [%1], %2;"
- #endif
- :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4)
- );
- }
-__asm__ __volatile__("cp.async.commit_group;");
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_1() {
+ asm volatile("cp.async.wait_group 1;");
+}
- for (int i = 0; i < 13; ++i) {
- bool cse_v1 = (i < 12);
-
- {
- unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16)
+ ((int)threadIdx.x)));
- int pred_guard = (int)(i < 12);
- __asm__ __volatile__(
- "{ .reg .pred p;"
- " setp.ne.b32 p, %0, 0;"
- #if TVM_ENABLE_L2_PREFETCH
- " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
- #else
- " @p cp.async.ca.shared.global [%1], [%2], %3;"
- #endif
- " @!p st.shared.u32 [%1], {%4};}"
- :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) *
14) + i) + 2))), "n"(4), "r"(0)
- );
- }
-__asm__ __volatile__("cp.async.commit_group;");
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_2() {
+ asm volatile("cp.async.wait_group 2;");
+}
+
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_5() {
+ asm volatile("cp.async.wait_group 5;");
+}
+
+__forceinline__ __device__ void ptx_cp_async_legacy_pred_ca_4_4_4(void* dst,
int dst_off, void* src, int src_off, int predicate) {
+ uint8_t* dst_p = (uint8_t*)dst + dst_off * 4;
+ uint8_t* src_p = (uint8_t*)src + src_off * 4;
+ unsigned int dst_addr = __cvta_generic_to_shared(dst_p);
+ __asm__ __volatile__(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.eq.u32 p, %3, 1;\n"
+ " @p cp.async.ca.shared.global [%0], [%1], %2;\n"
+ " @!p st.shared.u32 [%0], {%4};\n"
+ "}\n"
+ :: "r"(dst_addr), "l"(src_p), "n"(4), "r"(predicate), "r"(0)
+ );
+}
-__asm__ __volatile__("cp.async.wait_group 5;");
+__forceinline__ __device__ void ptx_cp_async_legacy_ca_4_4_4(void* dst, int
dst_off, void* src, int src_off) {
+ uint8_t* dst_p = (uint8_t*)dst + dst_off * 4;
+ uint8_t* src_p = (uint8_t*)src + src_off * 4;
+ unsigned int dst_addr = __cvta_generic_to_shared(dst_p);
+ asm volatile("cp.async.ca.shared.global [%0], [%1], %2;"
+ :: "r"(dst_addr), "l"(src_p), "n"(4));
+}
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_commit_group() {
+ asm volatile("cp.async.commit_group;");
+}
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A_ptr, float* __restrict__ B_ptr, float* __restrict__ C_ptr);
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A_ptr, float* __restrict__ B_ptr, float* __restrict__ C_ptr) {
+ __shared__ alignas(64) float A_shared_ptr[64];
+ __shared__ alignas(64) float B_shared_ptr[64];
+ A_shared_ptr[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
+ B_shared_ptr[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
+ tvm_builtin_ptx_cp_async_commit_group();
+ int cse_v1 = (((int)threadIdx.x) * 14);
+ int cse_v2 = (((int)threadIdx.x) + 16);
+ ptx_cp_async_legacy_ca_4_4_4(A_shared_ptr, (((int)threadIdx.x) + 16), A_ptr,
(((int)threadIdx.x) * 14));
+ ptx_cp_async_legacy_ca_4_4_4(B_shared_ptr, (((int)threadIdx.x) + 16), B_ptr,
(((int)threadIdx.x) * 14));
+ tvm_builtin_ptx_cp_async_commit_group();
+ int cse_v3 = (((int)threadIdx.x) + 32);
+ int cse_v6 = ((((int)threadIdx.x) * 14) + 1);
+ ptx_cp_async_legacy_ca_4_4_4(A_shared_ptr, (((int)threadIdx.x) + 32), A_ptr,
((((int)threadIdx.x) * 14) + 1));
+ ptx_cp_async_legacy_ca_4_4_4(B_shared_ptr, (((int)threadIdx.x) + 32), B_ptr,
((((int)threadIdx.x) * 14) + 1));
+ tvm_builtin_ptx_cp_async_commit_group();
+ int cse_v4 = (((int)threadIdx.x) * 16);
+ for (int i = 0; i < 13; ++i) {
+ int cse_v7 = (((((int)threadIdx.x) * 14) + i) + 2);
+ int cse_v9 = ((((i + 3) & 3) * 16) + ((int)threadIdx.x));
+ ptx_cp_async_legacy_pred_ca_4_4_4(A_shared_ptr, ((((i + 3) & 3) * 16) +
((int)threadIdx.x)), A_ptr, (((((int)threadIdx.x) * 14) + i) + 2), (i < 12));
+ tvm_builtin_ptx_cp_async_commit_group();
+ tvm_builtin_ptx_cp_async_wait_group_5();
__syncthreads();
- C[((((int)threadIdx.x) * 16) + i)] = (A_shared[(((i & 3) * 16) +
((int)threadIdx.x))] + B_shared[(((i & 3) * 16) + ((int)threadIdx.x))]);
+ int cse_v8 = (((i & 3) * 16) + ((int)threadIdx.x));
+ C_ptr[((((int)threadIdx.x) * 16) + i)] = (A_shared_ptr[(((i & 3) * 16) +
((int)threadIdx.x))] + B_shared_ptr[(((i & 3) * 16) + ((int)threadIdx.x))]);
__syncthreads();
-
- {
- unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16)
+ ((int)threadIdx.x)));
- int pred_guard = (int)(i < 12);
- __asm__ __volatile__(
- "{ .reg .pred p;"
- " setp.ne.b32 p, %0, 0;"
- #if TVM_ENABLE_L2_PREFETCH
- " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
- #else
- " @p cp.async.ca.shared.global [%1], [%2], %3;"
- #endif
- " @!p st.shared.u32 [%1], {%4};}"
- :: "r"(pred_guard), "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) *
14) + i) + 2))), "n"(4), "r"(0)
- );
- }
-__asm__ __volatile__("cp.async.commit_group;");
-
+ ptx_cp_async_legacy_pred_ca_4_4_4(B_shared_ptr, ((((i + 3) & 3) * 16) +
((int)threadIdx.x)), B_ptr, (((((int)threadIdx.x) * 14) + i) + 2), (i < 12));
+ tvm_builtin_ptx_cp_async_commit_group();
}
-__asm__ __volatile__("cp.async.wait_group 2;");
-
+ tvm_builtin_ptx_cp_async_wait_group_2();
__syncthreads();
- C[((((int)threadIdx.x) * 16) + 13)] = (A_shared[(((int)threadIdx.x) + 16)] +
B_shared[(((int)threadIdx.x) + 16)]);
-__asm__ __volatile__("cp.async.wait_group 1;");
-
+ C_ptr[((((int)threadIdx.x) * 16) + 13)] = (A_shared_ptr[(((int)threadIdx.x)
+ 16)] + B_shared_ptr[(((int)threadIdx.x) + 16)]);
+ tvm_builtin_ptx_cp_async_wait_group_1();
__syncthreads();
- C[((((int)threadIdx.x) * 16) + 14)] = (A_shared[(((int)threadIdx.x) + 32)] +
B_shared[(((int)threadIdx.x) + 32)]);
-__asm__ __volatile__("cp.async.wait_group 0;");
-
+ C_ptr[((((int)threadIdx.x) * 16) + 14)] = (A_shared_ptr[(((int)threadIdx.x)
+ 32)] + B_shared_ptr[(((int)threadIdx.x) + 32)]);
+ tvm_builtin_ptx_cp_async_wait_group_0();
__syncthreads();
- C[((((int)threadIdx.x) * 16) + 15)] = (A_shared[(((int)threadIdx.x) + 48)] +
B_shared[(((int)threadIdx.x) + 48)]);
+ int cse_v5 = (((int)threadIdx.x) + 48);
+ C_ptr[((((int)threadIdx.x) * 16) + 15)] = (A_shared_ptr[(((int)threadIdx.x)
+ 48)] + B_shared_ptr[(((int)threadIdx.x) + 48)]);
}
"""
@@ -392,9 +350,6 @@ def postproc_if_missing_async_support():
tvm.register_global_func(func_name, prev_postproc, override=True)
-# TODO(tlopex): fix CSE determinism (change unordered map to ordered map) and
-# remove this xfail; see #19741.
[email protected]
@tvm.testing.requires_cuda
def test_cp_async_in_if_then_else(postproc_if_missing_async_support):
@T.prim_func(s_tir=True)