This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 126f1ba7ff [S-TIR][CUDA] Fix legacy predicated cp.async zero fill 
(#19741)
126f1ba7ff is described below

commit 126f1ba7ff8359c792980705b8fe71e0a55e6f0f
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Jun 13 07:30:19 2026 -0400

    [S-TIR][CUDA] Fix legacy predicated cp.async zero fill (#19741)
    
    This fixes the legacy predicated `ptx.cp_async` codegen path used by
    `InjectPTXAsyncCopy` for `if_then_else(..., 0)` stores.
    
    The old inline CUDA emission zero-filled the shared-memory destination
    when the predicate was false. The TIRx helper-based legacy codegen only
    skipped the `cp.async`, leaving the destination slot unchanged. This
    restores the previous behavior by emitting an `@!p st.shared.*` zero
    store in the generated legacy predicated helper.
    
    The CUDA source snapshot in
    `test_s_tir_transform_inject_ptx_async_copy.py` is updated to reflect
    the restored false-predicate zero-fill instruction and the current
    generated helper-based CUDA source.
---
 .../tvm/tirx/operator/intrinsics/cuda/cp_async.py  |  14 +-
 src/tirx/transform/common_subexpr_elim.cc          |  26 ++-
 .../test_s_tir_transform_inject_ptx_async_copy.py  | 211 ++++++++-------------
 3 files changed, 112 insertions(+), 139 deletions(-)

diff --git a/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py 
b/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
index 712c4672d4..2eeb0821d6 100644
--- a/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
+++ b/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py
@@ -205,7 +205,8 @@ def codegen_ptx_cp_async(*args):
       bytes; offsets are pre-scaled by the pass) and the call is
       forwarded with default cache / predicate / fill_mode.
     * 6 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size,
-      predicate)`` — same as 5-arg form with an explicit predicate.
+      predicate)`` — same as 5-arg form with an explicit predicate,
+      zero-filling the destination when the predicate is false.
     * 8 args ``(dst_ptr, src_ptr, cp_size, cache_policy, has_cache_hint,
       prefetch_size, predicate, fill_mode)`` — the fork-native wrapper
       API.
@@ -269,6 +270,14 @@ def codegen_ptx_cp_async(*args):
             func_name = (
                 
f"ptx_cp_async_legacy_pred_{ca_or_cg}_{cp_size_v}_{dst_elem_bytes}_{src_elem_bytes}"
             )
+            if cp_size_v == 4:
+                zero_fill = '    " @!p st.shared.u32 [%0], {%4};\\n"\n'
+            elif cp_size_v == 8:
+                zero_fill = '    " @!p st.shared.v2.u32 [%0], {%4, %4};\\n"\n'
+            elif cp_size_v == 16:
+                zero_fill = '    " @!p st.shared.v4.u32 [%0], {%4, %4, %4, 
%4};\\n"\n'
+            else:
+                raise ValueError(f"unsupported legacy predicated cp.async 
size: {cp_size_v}")
             body = (
                 f"  uint8_t* dst_p = (uint8_t*)dst + dst_off{dst_scale};\n"
                 f"  uint8_t* src_p = (uint8_t*)src + src_off{src_scale};\n"
@@ -279,8 +288,9 @@ def codegen_ptx_cp_async(*args):
                 '    " setp.eq.u32 p, %3, 1;\\n"\n'
                 f'    " @p cp.async.{ca_or_cg}.shared.global'
                 ' [%0], [%1], %2;\\n"\n'
+                f"{zero_fill}"
                 '    "}\\n"\n'
-                f'    :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}), 
"r"(predicate)\n'
+                f'    :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}), 
"r"(predicate), "r"(0)\n'
                 "  );"
             )
             source_code = (
diff --git a/src/tirx/transform/common_subexpr_elim.cc 
b/src/tirx/transform/common_subexpr_elim.cc
index 9e7b2b1fb7..2221df9352 100644
--- a/src/tirx/transform/common_subexpr_elim.cc
+++ b/src/tirx/transform/common_subexpr_elim.cc
@@ -83,6 +83,7 @@
 #include <utility>
 #include <vector>
 
+#include "../../support/ordered_map.h"
 #include "../analysis/check_contains.h"
 
 namespace tvm {
@@ -239,8 +240,16 @@ class CSEPlanner : public StmtExprVisitor {
     int consumed{0};
   };
 
-  /*! \brief Expression table keyed by structural equality (ExprDeepEqual). */
-  using ExprTable = std::unordered_map<PrimExpr, ExprEntry, 
ffi::StructuralHash, ExprDeepEqual>;
+  /*!
+   * \brief Expression table keyed by structural equality (ExprDeepEqual).
+   *
+   * An insertion-ordered map so that iteration visits entries in discovery
+   * (program) order. This makes the plan — and hence cse_v numbering —
+   * deterministic. A plain unordered_map iterates in hash order, and
+   * StructuralHash hashes free variables by object identity, which varies
+   * between processes (ASLR).
+   */
+  using ExprTable = support::OrderedMap<PrimExpr, ExprEntry, 
ffi::StructuralHash, ExprDeepEqual>;
 
   // ------------------------------------------------------------------
   // Eligibility predicates
@@ -592,8 +601,9 @@ class CSEPlanner : public StmtExprVisitor {
    * \brief Convert the accumulated expression table into InsertBefore + 
ExprRemap tables.
    *
    * Algorithm (shallower-first with repr propagation):
-   *   1. Collect all entries and sort by expr_depth ascending (shallower 
first),
-   *      with structural hash as tie-breaker for determinism.
+   *   1. Collect all entries and sort by expr_depth ascending (shallower 
first).
+   *      The stable sort over the insertion-ordered table keeps entries of
+   *      equal depth in discovery (program) order, so the plan is 
deterministic.
    *   2. Compute independent occurrence counts from the DAG children.
    *      For each parent P with count >= 2, its children's consumed counts
    *      are incremented by `(P.count - 1) * multiplicity` (the Bind value
@@ -610,7 +620,8 @@ class CSEPlanner : public StmtExprVisitor {
    * \return A pair of (InsertBeforeTable, ExprRemapTable).
    */
   std::pair<InsertBeforeTable, ExprRemapTable> ComputePlan() {
-    // Step 1: Sort entries by depth ascending (shallower first), hash for 
determinism
+    // Step 1: Sort entries by depth ascending (shallower first). table_ 
iterates
+    // in discovery order, which the stable sort preserves among equal depths.
     std::vector<std::pair<PrimExpr, ExprEntry*>> all_entries;
     for (auto& kv : table_) {
       all_entries.push_back({kv.first, &kv.second});
@@ -619,10 +630,7 @@ class CSEPlanner : public StmtExprVisitor {
     std::stable_sort(
         all_entries.begin(), all_entries.end(),
         [](const std::pair<PrimExpr, ExprEntry*>& a, const std::pair<PrimExpr, 
ExprEntry*>& b) {
-          if (a.second->expr_depth != b.second->expr_depth)
-            return a.second->expr_depth < b.second->expr_depth;
-          ffi::StructuralHash hasher;
-          return hasher(a.first) < hasher(b.first);
+          return a.second->expr_depth < b.second->expr_depth;
         });
 
     // Step 2: Compute consumed counts in ExprEntry from the DAG.
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py 
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
index fde4e91501..428fb9b895 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
@@ -194,19 +194,8 @@ def test_inject_async_copy_shared_dyn():
 # `ptx_mbarrier_*` family instead.
 
 
-# Note: the expected output contains a dead CSE variable `cse_v1 = (i < 12)`.
-# CSE extracts (i < 12) before inject_ptx_async_copy runs, but the latter
-# replaces the original IfThenElse guards with new cast(int32, i < 12)
-# expressions for predicated async copies, leaving cse_v1 unused.
 expected_cuda_script = r"""#include <cuda.h>
-__forceinline__ __device__ unsigned int
-cast_smem_ptr_to_int(const void* const smem_ptr)
-{
-  unsigned int smem_int;
-  asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; 
cvt.u32.u64 %0, smem_int; }"
-    : "=r"(smem_int) : "l"(smem_ptr));
-  return smem_int;
-}
+#endif
 
 #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
      (__CUDACC_VER_MAJOR__ > 11))
@@ -215,132 +204,101 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
 #define TVM_ENABLE_L2_PREFETCH 0
 #endif
 
-#ifdef __CUDACC_RTC__
-using int64_t = long long;
-using uint64_t = unsigned long long;
+#ifdef _WIN32
+  using uint = unsigned int;
+  using uchar = unsigned char;
+  using ushort = unsigned short;
+  using int64_t = long long;
+  using uint64_t = unsigned long long;
 #else
-#include <cstdint>
+  #define uint unsigned int
+  #define uchar unsigned char
+  #define ushort unsigned short
 #endif
-using uint = unsigned int;
-using uchar = unsigned char;
-using ushort = unsigned short;
-
-extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C);
-extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
-  __shared__ float A_shared[64];
-  __shared__ float B_shared[64];
-  A_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
-  B_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
-__asm__ __volatile__("cp.async.commit_group;");
-
-
-  {
-    unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) + 
16));
-    __asm__ __volatile__(
-      #if TVM_ENABLE_L2_PREFETCH
-        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
-      #else
-        "cp.async.ca.shared.global [%0], [%1], %2;"
-      #endif
-        :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4)
-    );
-  }
 
-  {
-    unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) + 
16));
-    __asm__ __volatile__(
-      #if TVM_ENABLE_L2_PREFETCH
-        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
-      #else
-        "cp.async.ca.shared.global [%0], [%1], %2;"
-      #endif
-        :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4)
-    );
-  }
-__asm__ __volatile__("cp.async.commit_group;");
-
-
-  {
-    unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) + 
32));
-    __asm__ __volatile__(
-      #if TVM_ENABLE_L2_PREFETCH
-        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
-      #else
-        "cp.async.ca.shared.global [%0], [%1], %2;"
-      #endif
-        :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4)
-    );
-  }
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_0() {
+    asm volatile("cp.async.wait_group 0;");
+}
 
-  {
-    unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) + 
32));
-    __asm__ __volatile__(
-      #if TVM_ENABLE_L2_PREFETCH
-        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
-      #else
-        "cp.async.ca.shared.global [%0], [%1], %2;"
-      #endif
-        :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4)
-    );
-  }
-__asm__ __volatile__("cp.async.commit_group;");
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_1() {
+    asm volatile("cp.async.wait_group 1;");
+}
 
-  for (int i = 0; i < 13; ++i) {
-    bool cse_v1 = (i < 12);
-
-  {
-    unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) 
+ ((int)threadIdx.x)));
-    int pred_guard = (int)(i < 12);
-    __asm__ __volatile__(
-        "{  .reg .pred p;"
-        "  setp.ne.b32 p, %0, 0;"
-      #if TVM_ENABLE_L2_PREFETCH
-        " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
-      #else
-        " @p cp.async.ca.shared.global [%1], [%2], %3;"
-      #endif
-      "  @!p st.shared.u32 [%1], {%4};}"
-        :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 
14) + i) + 2))), "n"(4), "r"(0)
-    );
-  }
-__asm__ __volatile__("cp.async.commit_group;");
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_2() {
+    asm volatile("cp.async.wait_group 2;");
+}
+
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_wait_group_5() {
+    asm volatile("cp.async.wait_group 5;");
+}
+
+__forceinline__ __device__ void ptx_cp_async_legacy_pred_ca_4_4_4(void* dst, 
int dst_off, void* src, int src_off, int predicate) {
+  uint8_t* dst_p = (uint8_t*)dst + dst_off * 4;
+  uint8_t* src_p = (uint8_t*)src + src_off * 4;
+  unsigned int dst_addr = __cvta_generic_to_shared(dst_p);
+  __asm__ __volatile__(
+    "{\n"
+    " .reg .pred p;\n"
+    " setp.eq.u32 p, %3, 1;\n"
+    " @p cp.async.ca.shared.global [%0], [%1], %2;\n"
+    " @!p st.shared.u32 [%0], {%4};\n"
+    "}\n"
+    :: "r"(dst_addr), "l"(src_p), "n"(4), "r"(predicate), "r"(0)
+  );
+}
 
-__asm__ __volatile__("cp.async.wait_group 5;");
+__forceinline__ __device__ void ptx_cp_async_legacy_ca_4_4_4(void* dst, int 
dst_off, void* src, int src_off) {
+  uint8_t* dst_p = (uint8_t*)dst + dst_off * 4;
+  uint8_t* src_p = (uint8_t*)src + src_off * 4;
+  unsigned int dst_addr = __cvta_generic_to_shared(dst_p);
+  asm volatile("cp.async.ca.shared.global [%0], [%1], %2;"
+    :: "r"(dst_addr), "l"(src_p), "n"(4));
+}
 
+__forceinline__ __device__ void tvm_builtin_ptx_cp_async_commit_group() {
+    asm volatile("cp.async.commit_group;");
+}
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A_ptr, float* __restrict__ B_ptr, float* __restrict__ C_ptr);
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A_ptr, float* __restrict__ B_ptr, float* __restrict__ C_ptr) {
+  __shared__ alignas(64) float A_shared_ptr[64];
+  __shared__ alignas(64) float B_shared_ptr[64];
+  A_shared_ptr[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
+  B_shared_ptr[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/;
+  tvm_builtin_ptx_cp_async_commit_group();
+  int cse_v1 = (((int)threadIdx.x) * 14);
+  int cse_v2 = (((int)threadIdx.x) + 16);
+  ptx_cp_async_legacy_ca_4_4_4(A_shared_ptr, (((int)threadIdx.x) + 16), A_ptr, 
(((int)threadIdx.x) * 14));
+  ptx_cp_async_legacy_ca_4_4_4(B_shared_ptr, (((int)threadIdx.x) + 16), B_ptr, 
(((int)threadIdx.x) * 14));
+  tvm_builtin_ptx_cp_async_commit_group();
+  int cse_v3 = (((int)threadIdx.x) + 32);
+  int cse_v6 = ((((int)threadIdx.x) * 14) + 1);
+  ptx_cp_async_legacy_ca_4_4_4(A_shared_ptr, (((int)threadIdx.x) + 32), A_ptr, 
((((int)threadIdx.x) * 14) + 1));
+  ptx_cp_async_legacy_ca_4_4_4(B_shared_ptr, (((int)threadIdx.x) + 32), B_ptr, 
((((int)threadIdx.x) * 14) + 1));
+  tvm_builtin_ptx_cp_async_commit_group();
+  int cse_v4 = (((int)threadIdx.x) * 16);
+  for (int i = 0; i < 13; ++i) {
+    int cse_v7 = (((((int)threadIdx.x) * 14) + i) + 2);
+    int cse_v9 = ((((i + 3) & 3) * 16) + ((int)threadIdx.x));
+    ptx_cp_async_legacy_pred_ca_4_4_4(A_shared_ptr, ((((i + 3) & 3) * 16) + 
((int)threadIdx.x)), A_ptr, (((((int)threadIdx.x) * 14) + i) + 2), (i < 12));
+    tvm_builtin_ptx_cp_async_commit_group();
+    tvm_builtin_ptx_cp_async_wait_group_5();
     __syncthreads();
-    C[((((int)threadIdx.x) * 16) + i)] = (A_shared[(((i & 3) * 16) + 
((int)threadIdx.x))] + B_shared[(((i & 3) * 16) + ((int)threadIdx.x))]);
+    int cse_v8 = (((i & 3) * 16) + ((int)threadIdx.x));
+    C_ptr[((((int)threadIdx.x) * 16) + i)] = (A_shared_ptr[(((i & 3) * 16) + 
((int)threadIdx.x))] + B_shared_ptr[(((i & 3) * 16) + ((int)threadIdx.x))]);
     __syncthreads();
-
-  {
-    unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) 
+ ((int)threadIdx.x)));
-    int pred_guard = (int)(i < 12);
-    __asm__ __volatile__(
-        "{  .reg .pred p;"
-        "  setp.ne.b32 p, %0, 0;"
-      #if TVM_ENABLE_L2_PREFETCH
-        " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
-      #else
-        " @p cp.async.ca.shared.global [%1], [%2], %3;"
-      #endif
-      "  @!p st.shared.u32 [%1], {%4};}"
-        :: "r"(pred_guard), "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 
14) + i) + 2))), "n"(4), "r"(0)
-    );
-  }
-__asm__ __volatile__("cp.async.commit_group;");
-
+    ptx_cp_async_legacy_pred_ca_4_4_4(B_shared_ptr, ((((i + 3) & 3) * 16) + 
((int)threadIdx.x)), B_ptr, (((((int)threadIdx.x) * 14) + i) + 2), (i < 12));
+    tvm_builtin_ptx_cp_async_commit_group();
   }
-__asm__ __volatile__("cp.async.wait_group 2;");
-
+  tvm_builtin_ptx_cp_async_wait_group_2();
   __syncthreads();
-  C[((((int)threadIdx.x) * 16) + 13)] = (A_shared[(((int)threadIdx.x) + 16)] + 
B_shared[(((int)threadIdx.x) + 16)]);
-__asm__ __volatile__("cp.async.wait_group 1;");
-
+  C_ptr[((((int)threadIdx.x) * 16) + 13)] = (A_shared_ptr[(((int)threadIdx.x) 
+ 16)] + B_shared_ptr[(((int)threadIdx.x) + 16)]);
+  tvm_builtin_ptx_cp_async_wait_group_1();
   __syncthreads();
-  C[((((int)threadIdx.x) * 16) + 14)] = (A_shared[(((int)threadIdx.x) + 32)] + 
B_shared[(((int)threadIdx.x) + 32)]);
-__asm__ __volatile__("cp.async.wait_group 0;");
-
+  C_ptr[((((int)threadIdx.x) * 16) + 14)] = (A_shared_ptr[(((int)threadIdx.x) 
+ 32)] + B_shared_ptr[(((int)threadIdx.x) + 32)]);
+  tvm_builtin_ptx_cp_async_wait_group_0();
   __syncthreads();
-  C[((((int)threadIdx.x) * 16) + 15)] = (A_shared[(((int)threadIdx.x) + 48)] + 
B_shared[(((int)threadIdx.x) + 48)]);
+  int cse_v5 = (((int)threadIdx.x) + 48);
+  C_ptr[((((int)threadIdx.x) * 16) + 15)] = (A_shared_ptr[(((int)threadIdx.x) 
+ 48)] + B_shared_ptr[(((int)threadIdx.x) + 48)]);
 }
 
 """
@@ -392,9 +350,6 @@ def postproc_if_missing_async_support():
         tvm.register_global_func(func_name, prev_postproc, override=True)
 
 
-# TODO(tlopex): fix CSE determinism (change unordered map to ordered map) and
-# remove this xfail; see #19741.
[email protected]
 @tvm.testing.requires_cuda
 def test_cp_async_in_if_then_else(postproc_if_missing_async_support):
     @T.prim_func(s_tir=True)

Reply via email to