This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 16d0a7edae [TIRX][CUDA] Framework support for FA4, CLC intrinsics, and 
nvfp4 tcgen05 GEMM (#19785)
16d0a7edae is described below

commit 16d0a7edae0d52ccf0b947656310359d965e79d9
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Jun 16 03:52:14 2026 -0700

    [TIRX][CUDA] Framework support for FA4, CLC intrinsics, and nvfp4 tcgen05 
GEMM (#19785)
---
 python/tvm/backend/cuda/lang/pipeline.py           |  11 +-
 python/tvm/backend/cuda/lang/tile_scheduler.py     | 135 ++++++++++++++++-
 python/tvm/backend/cuda/op.py                      |  79 +++++++++-
 .../tvm/backend/cuda/operator/intrinsics/sync.py   | 100 +++++++++++-
 .../tile_primitive/copy_async/tcgen05_ldst.py      |  35 +++--
 .../operator/tile_primitive/elementwise/reg.py     |  67 ++++++++
 python/tvm/backend/cuda/script.py                  |   6 +
 python/tvm/support/nvcc.py                         |  76 ++++++++--
 python/tvm/tirx/script/builder/external_kernel.py  |   2 +-
 src/backend/cuda/op/target_builtin.cc              |   6 +
 src/target/llvm/codegen_llvm.cc                    |  17 +++
 src/target/llvm/codegen_llvm.h                     |   3 +
 src/tirx/ir/layout/tile_slice.cc                   |   6 +-
 tests/python/codegen/test_target_codegen_llvm.py   |  39 +++++
 tests/python/tirx/codegen/test_codegen_cuda.py     |  11 ++
 tests/python/tirx/codegen/test_codegen_nvshmem.py  |   3 +
 tests/python/tirx/codegen/test_cuda_copy.py        |  11 ++
 tests/python/tirx/codegen/test_cuda_cta_reduce.py  |  13 ++
 tests/python/tirx/codegen/test_cuda_warp_reduce.py |  13 ++
 tests/python/tirx/conftest.py                      |  40 +++++
 .../tile_primitive/cuda/copy/test_fallback.py      |   5 +
 .../tile_primitive/cuda/copy/test_gmem_smem.py     |   4 +
 .../operator/tile_primitive/cuda/copy/test_reg.py  |   5 +
 .../tile_primitive/cuda/copy_async/test_ldgsts.py  |   3 +
 .../tile_primitive/cuda/copy_async/test_tmem.py    |   7 +
 .../cuda/copy_async/test_tmem_16xnb.py             | 144 ++++++++++++++++++
 .../tile_primitive/cuda/elementwise/test_binary.py |  13 ++
 .../tile_primitive/cuda/elementwise/test_fma.py    |  15 ++
 .../tile_primitive/cuda/elementwise/test_unary.py  | 168 ++++++++++++++++++++-
 .../cuda/gemm_async/test_gemm_async.py             |  23 +++
 .../cuda/permute_layout/test_permute_layout.py     |   7 +
 .../cuda/reduction/test_reduction.py               |  23 +++
 tests/python/tirx/test_buffer_print.py             |   4 +
 tests/python/tirx/test_control_flow.py             |   8 +
 tests/python/tirx/test_layout.py                   |  35 +++++
 tests/scripts/task_python_unittest.sh              |   1 +
 36 files changed, 1096 insertions(+), 42 deletions(-)

diff --git a/python/tvm/backend/cuda/lang/pipeline.py 
b/python/tvm/backend/cuda/lang/pipeline.py
index ee86090398..40fd40c3fa 100644
--- a/python/tvm/backend/cuda/lang/pipeline.py
+++ b/python/tvm/backend/cuda/lang/pipeline.py
@@ -110,7 +110,7 @@ class MBarrier:
         T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ 
self.phase_offset)
 
     @T.inline
-    def arrive(self, stage, cta_id=None, pred=None):
+    def arrive(self, stage, cta_id=None, pred=None, count=None):
         # Default: local-CTA arrive — emits the simple
         # ``mbarrier.arrive.shared.b64`` form. To arrive on a remote
         # CTA's mbarrier in a cluster kernel, callers must pass
@@ -119,11 +119,18 @@ class MBarrier:
         # the cross-CTA path was both surprising (``bar.arrive(stage)``
         # silently ``mapa`` ed across the cluster) and a per-call cost
         # of ~3 PTX ops on every single-CTA kernel.
+        #
+        # ``count`` (cross-CTA path only) emits the explicit arrival-count
+        # operand, i.e. ``mbarrier.arrive.shared::cluster.b64 _, [addr], 
count``.
+        # When ``None`` the implicit count-of-1 form is emitted. Passing
+        # ``count=1`` is semantically identical but spells the count 
explicitly.
         if cta_id is None:
             T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]))
         else:
             actual_pred = True if pred is None else pred
-            T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, 
pred=actual_pred)
+            T.ptx.mbarrier.arrive(
+                self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred, 
count=count
+            )
 
     def ptr_to(self, idx):
         return self.buf.ptr_to(idx)
diff --git a/python/tvm/backend/cuda/lang/tile_scheduler.py 
b/python/tvm/backend/cuda/lang/tile_scheduler.py
index 3fd27f25ee..c6154f2462 100644
--- a/python/tvm/backend/cuda/lang/tile_scheduler.py
+++ b/python/tvm/backend/cuda/lang/tile_scheduler.py
@@ -20,6 +20,7 @@ These classes emit TIR via @T.inline. Decorate with 
@T.meta_class so that
 instances are automatically treated as meta values inside @T.prim_func.
 """
 
+from tvm.backend.cuda.lang.pipeline import Pipeline, PipelineState
 from tvm.script import tirx as T
 
 
@@ -753,13 +754,20 @@ class FlashAttentionLPTScheduler(BaseTileScheduler):
     """
 
     def __init__(
-        self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: 
int, l2_swizzle: int
+        self,
+        prefix: str,
+        num_batches: int,
+        num_heads: int,
+        num_m_blocks: int,
+        l2_swizzle: int,
+        num_ctas: int | None = None,
     ):
         super().__init__(prefix)
         self._num_batches = num_batches
         self._num_heads = num_heads
         self._num_m_blocks = num_m_blocks
         self._l2_swizzle = l2_swizzle
+        self._num_ctas = num_ctas
         self._total_tasks = num_batches * num_heads * num_m_blocks
 
         # Derived constants for L2 swizzle
@@ -807,10 +815,131 @@ class FlashAttentionLPTScheduler(BaseTileScheduler):
 
     @T.inline
     def next_tile(self):
-        """Advance to next tile by striding by num_ctas."""
-        self.linear_idx = self._total_tasks
+        """Advance to the next tile.
+
+        Single-tile mode (``num_ctas=None``, the default): each CTA owns one
+        task; terminate. Persistent mode (``num_ctas=N``): stride by N, like
+        :class:`FlashAttentionLinearScheduler`, while keeping the LPT + L2
+        swizzle index mapping.
+        """
+        if self._num_ctas is None:
+            self.linear_idx = self._total_tasks
+        else:
+            self.linear_idx = self.linear_idx + self._num_ctas
+            self.update_current_m_n_idx(self.linear_idx)
     # fmt: on
 
     def valid(self):
         """Check if there are more tiles to process."""
         return self.linear_idx < self._total_tasks
+
+
+class _CLCWorker(ClusterPersistentScheduler2D):
+    """Per-role CLC handle: IS-A ClusterPersistentScheduler2D (so m_idx / 
n_idx work as
+    usual) plus the role-local barrier phase and handshake. A coord-free role 
(e.g. an
+    MMA warp consuming whatever a loader staged) arms the loop with reset() 
not init().
+    """
+
+    def __init__(self, clc, prefix):
+        super().__init__(
+            prefix,
+            num_m_tiles=clc._num_m_tiles,
+            num_n_tiles=clc._num_n_tiles,
+            num_clusters=clc._num_m_tiles * clc._num_n_tiles,
+            l2_group_size=clc._l2_group_size,
+        )
+        self._clc = clc
+        self._sa = PipelineState(1, 0)
+        self._done = T.local_scalar("int32")
+        self._nxt = T.local_scalar("uint32")
+
+    @T.inline
+    def reset(self):
+        self._done = 0
+
+    @T.inline
+    def init(self, cluster_id):
+        # Explicit base call: TVMScript's parser has no zero-arg super().
+        ClusterPersistentScheduler2D.init(self, cluster_id)
+        self._done = 0
+
+    def valid(self):
+        return self._done == 0
+
+    @T.inline
+    def consume(self):
+        # Single-elected-thread scope: wait for the handle, decode, release 
the slot.
+        self._clc.sched_arr.full.wait(0, self._sa.phase)
+        self._sa.advance()
+        self._nxt = 
T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0]))
+        self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True)
+
+    @T.inline
+    def consume_wg(self, wg_id, warp_id, lane_id):
+        # Warpgroup scope: all threads decode; one elected lane releases the 
slot.
+        self._clc.sched_arr.full.wait(0, self._sa.phase)
+        self._sa.advance()
+        self._nxt = 
T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0]))
+        T.cuda.warpgroup_sync(wg_id + 1)
+        if (warp_id == 0) & (lane_id == 0):
+            self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True)
+
+    @T.inline
+    def advance_coords(self):
+        if self._nxt != 0xFFFFFFFF:
+            self.update_current_m_n_idx(self._nxt // self._clc._cta_group)
+
+    @T.inline
+    def mark_done_if_drained(self):
+        if self._nxt == 0xFFFFFFFF:
+            self._done = 1
+
+
[email protected]_class
+class ClusterLaunchControlScheduler:
+    """Blackwell Cluster Launch Control (CLC) tile scheduler.
+
+    A scheduler warp runs ``run_scheduler`` (issues ``try_cancel`` to steal 
the next
+    cluster); worker roles each take a ``worker()`` handle and pull the stolen 
tile
+    through the shared smem handshake. Owns the CLC smem: the 16B response 
handle, the
+    arrival barrier (handle ready), and the finished barrier (slot consumed;
+    ``finish_arrivals`` arrivals per round). Tile-coord mapping is delegated to
+    ``ClusterPersistentScheduler2D`` (group-major L2 ordering).
+    """
+
+    def __init__(self, pool, num_m_tiles, num_n_tiles, l2_group_size, 
cta_group, finish_arrivals):
+        self._num_m_tiles = num_m_tiles
+        self._num_n_tiles = num_n_tiles
+        self._l2_group_size = l2_group_size
+        self._cta_group = cta_group
+        self.sched_arr = Pipeline(pool, 1, full="tma", empty="mbar", 
init_empty=1)
+        self.sched_fin = Pipeline(pool, 1, full="mbar", empty="mbar", 
init_empty=finish_arrivals)
+        self.clc_handle = pool.alloc((4,), "uint32", align=16)
+        self._s_done = T.local_scalar("int32")
+        self._s_nxt = T.local_scalar("uint32")
+
+    def worker(self, prefix):
+        return _CLCWorker(self, prefix)
+
+    @T.inline
+    def run_scheduler(self, cbx):
+        # cta0 drives try_cancel; both CTAs expect_bytes + consume the handle 
so the
+        # finished-barrier count is met and the slot can be reissued.
+        if T.ptx.elect_sync():
+            sa = PipelineState(1, 0)
+            sf = PipelineState(1, 1)
+            self._s_done = 0
+            while self._s_done == 0:
+                if cbx == 0:
+                    self.sched_fin.empty.wait(0, sf.phase)
+                    sf.advance()
+                    T.ptx.clc_try_cancel(
+                        T.address_of(self.clc_handle[0]), 
T.address_of(self.sched_arr.full.buf[0])
+                    )
+                self.sched_arr.full.arrive(0, 16)  # expect_bytes for the 16B 
handle
+                self.sched_arr.full.wait(0, sa.phase)
+                sa.advance()
+                self._s_nxt = 
T.ptx.clc_query_cancel(T.address_of(self.clc_handle[0]))
+                self.sched_fin.empty.arrive(0, cta_id=0, pred=True)
+                if self._s_nxt == 0xFFFFFFFF:
+                    self._s_done = 1
diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py
index e76d5fbe24..9570e26662 100644
--- a/python/tvm/backend/cuda/op.py
+++ b/python/tvm/backend/cuda/op.py
@@ -653,12 +653,12 @@ def ptx_mbarrier_init(bar, thread_count):
     return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count)
 
 
-def ptx_mbarrier_arrive(bar, cta_id=None, pred=None):
+def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None):
     """TVM intrinsic to call
         mbarrier.arrive.shared::cta.b64
     or
         @p mapa.shared::cluster.u32
-        @p mbarrier.arrive.shared::cluster.b64
+        @p mbarrier.arrive.shared::cluster.b64 [, count]
 
     Parameters
     ----------
@@ -670,11 +670,29 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None):
 
     pred : Optional[PrimExpr]
         The predicate to guard the operation.
+
+    count : Optional[PrimExpr]
+        Explicit arrival count operand for the cross-CTA (cluster) form. When
+        ``None`` the implicit count-of-1 form is emitted; when given, emits
+        ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``.
     """
     if cta_id is None and pred is None:
         return call_intrin("", "tirx.ptx_mbarrier_arrive", bar)
     assert cta_id is not None and pred is not None
-    return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
+    if count is None:
+        return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
+    return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred, 
count)
+
+
+def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
+    """Cross-CTA ``mbarrier.arrive`` on CTA ``cta_id`` with an explicit count.
+
+    Convenience for an already-elected thread: emits
+    ``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64 
_,
+    [addr], count`` with the guard defaulted to 1.
+    """
+    return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, 
count)
+
 
 
 def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
@@ -706,7 +724,11 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, 
cta_id=None, pred=None):
     """
     if cta_id is None and pred is None:
         return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, 
byte_count)
-    assert cta_id is not None and pred is not None
+    assert cta_id is not None
+    # Cross-CTA expect_tx from an already-elected thread: default the guard to 
1
+    # (the caller has elected a single lane), so callers can pass cta_id alone.
+    if pred is None:
+        pred = True
     return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, 
byte_count, cta_id, pred)
 
 
@@ -729,6 +751,23 @@ def ptx_mbarrier_try_wait(bar, phase):
     return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase)
 
 
+def ptx_mbarrier_try_wait_acquire_cluster(bar, phase):
+    """``mbarrier.try_wait.parity.acquire.cluster`` retry loop.
+
+    Cluster-scope acquire wait — used to wait on a barrier that a remote CTA in
+    the cluster arrives on (a group cluster wait).
+
+    Parameters
+    ----------
+    bar : Var
+        The pointer to barrier variable.
+
+    phase : int
+        The phase of the barrier.
+    """
+    return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar, 
phase)
+
+
 def ptx_mbarrier_try_wait_once(bar, phase, ticks):
     """TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``.
 
@@ -1261,6 +1300,38 @@ def ptx_barrier_cluster_wait(acquire=False, 
aligned=True):
     return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned)
 
 
+def ptx_clc_try_cancel(handle, mbar):
+    """TVM intrinsic to call clusterlaunchcontrol.try_cancel.
+
+    Async-requests cancelling the next cluster's launch (work-stealing): 
writes the
+    16B response handle to smem and signals ``mbar`` (complete_tx, multicast 
to both
+    cluster CTAs).
+
+    Parameters
+    ----------
+    handle : PrimExpr
+        Pointer to the 16B (uint4) smem response handle.
+
+    mbar : PrimExpr
+        Pointer to the mbarrier signalled when the handle lands.
+    """
+    return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar)
+
+
+def ptx_clc_query_cancel(handle):
+    """TVM intrinsic to call clusterlaunchcontrol.query_cancel.
+
+    Decodes the response handle written by :func:`ptx_clc_try_cancel`. Returns 
the
+    cancelled cluster's first ``ctaid.x``, or ``0xFFFFFFFF`` when no work was 
stolen.
+
+    Parameters
+    ----------
+    handle : PrimExpr
+        Pointer to the 16B (uint4) smem response handle.
+    """
+    return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle)
+
+
 def ptx_elect_sync():
     """TVM intrinsic to call elect.sync"""
     return call_intrin("uint32", "tirx.ptx_elect_sync")
diff --git a/python/tvm/backend/cuda/operator/intrinsics/sync.py 
b/python/tvm/backend/cuda/operator/intrinsics/sync.py
index 0fcdb31a46..791d9cc981 100644
--- a/python/tvm/backend/cuda/operator/intrinsics/sync.py
+++ b/python/tvm/backend/cuda/operator/intrinsics/sync.py
@@ -168,6 +168,54 @@ device_intrinsic(
 )
 
 
+# =============================================================================
+# clusterlaunchcontrol.try_cancel / query_cancel — Blackwell Cluster Launch
+# Control (CLC) work-stealing, written from the PTX ISA spec (section
+# "clusterlaunchcontrol", PTX ISA 8.6). try_cancel async-requests cancelling 
the
+# next cluster's launch, writing a 16B response to smem + signalling mbar. 
query
+# decodes the response: on success it extracts the cancelled cluster's first
+# ctaid.x (via the get_first_ctaid::x form); a single uint32 is returned, with
+# 0xFFFFFFFF as the "no work stolen" sentinel (a device helper returns one 
scalar).
+# =============================================================================
+device_intrinsic(
+    "ptx_clc_try_cancel",
+    c_signature="(void* handle, void* mbar)",
+    body=(
+        "    unsigned int addr = (unsigned 
int)__cvta_generic_to_shared(handle);\n"
+        "    unsigned int bar = (unsigned 
int)__cvta_generic_to_shared(mbar);\n"
+        "    asm volatile(\n"
+        '        
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes"\n'
+        '        ".multicast::cluster::all.b128 [%0], [%1];\\n"\n'
+        '        :: "r"(addr), "r"(bar) : "memory");'
+    ),
+)
+
+
+device_intrinsic(
+    "ptx_clc_query_cancel",
+    c_signature="(void* handle)",
+    return_type="uint32_t",
+    tvm_return_type="uint32",
+    body=(
+        "    unsigned int addr = (unsigned 
int)__cvta_generic_to_shared(handle);\n"
+        "    unsigned int first_ctaid_x;\n"
+        "    asm volatile(\n"
+        '        "{\\n"\n'
+        '        ".reg .pred canceled;\\n"\n'
+        '        ".reg .b128 response;\\n"\n'
+        '        "ld.shared.b128 response, [%1];\\n"\n'
+        '        "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 
canceled, response;\\n"\n'
+        '        "mov.u32 %0, 0xffffffff;\\n"\n'
+        '        "@canceled 
clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128"\n'
+        '        " %0, response;\\n"\n'
+        '        "}\\n"\n'
+        '        : "=r"(first_ctaid_x) : "r"(addr) : "memory");\n'
+        '    asm volatile("fence.proxy.async.shared::cta;\\n" ::: "memory");\n'
+        "    return first_ctaid_x;"
+    ),
+)
+
+
 # =============================================================================
 # mbarrier.init.shared.b64 [addr], count ; — 1 form.
 # =============================================================================
@@ -208,7 +256,7 @@ device_intrinsic(
         '        "{\\n"\n'
         '        ".reg .pred p;\\n"\n'
         '        ".reg .b32 remAddr32;\\n"\n'
-        '        "setp.eq.u32 p, %2, 1;\\n"\n'
+        '        "setp.ne.s32 p, %2, 0;\\n"\n'
         '        "@p mapa.shared::cluster.u32  remAddr32, %0, %1;\\n"\n'
         '        "@p mbarrier.arrive.shared::cluster.b64  _, 
[remAddr32];\\n"\n'
         '        "}\\n"\n'
@@ -217,15 +265,38 @@ device_intrinsic(
 )
 
 
+# Same cross-CTA arrive, but with an explicit arrival-count operand
+# (``..., [remAddr32], count``). Matches the ``tma::cluster::arrive`` spelling.
+device_intrinsic(
+    "_ptx_mbarrier_arrive_remote_count",
+    helper_name="tvm_builtin_ptx_mbarrier_arrive_remote_count",
+    c_signature="(void* barrier, int cta_id, int pred, int count)",
+    body=(
+        "    unsigned int barrier_addr = __cvta_generic_to_shared(barrier);\n"
+        "    asm volatile(\n"
+        '        "{\\n"\n'
+        '        ".reg .pred p;\\n"\n'
+        '        ".reg .b32 remAddr32;\\n"\n'
+        '        "setp.ne.s32 p, %2, 0;\\n"\n'
+        '        "@p mapa.shared::cluster.u32  remAddr32, %0, %1;\\n"\n'
+        '        "@p mbarrier.arrive.shared::cluster.b64  _, [remAddr32], 
%3;\\n"\n'
+        '        "}\\n"\n'
+        '        :: "r"(barrier_addr), "r"(cta_id), "r"(pred), "r"(count) : 
"memory");'
+    ),
+)
+
+
 @register_codegen("ptx_mbarrier_arrive")
 def _codegen_mbarrier_arrive(*args):
-    """Dispatch by arg count: 1 -> local, 3 -> remote (cluster-mapped)."""
+    """Dispatch by arg count: 1 -> local, 3 -> remote, 4 -> remote+count."""
     if len(args) == 1:
         result = 
CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_local"](list(args))
     elif len(args) == 3:
         result = 
CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote"](list(args))
+    elif len(args) == 4:
+        result = 
CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote_count"](list(args))
     else:
-        raise ValueError(f"ptx_mbarrier_arrive expects 1 or 3 args, got 
{len(args)}")
+        raise ValueError(f"ptx_mbarrier_arrive expects 1, 3, or 4 args, got 
{len(args)}")
     return result[0] if isinstance(result, tuple) else result
 
 
@@ -252,7 +323,7 @@ device_intrinsic(
         '        "{\\n"\n'
         '        ".reg .pred p;\\n"\n'
         '        ".reg .b32 remAddr32;\\n"\n'
-        '        "setp.eq.u32 p, %2, 1;\\n"\n'
+        '        "setp.ne.s32 p, %2, 0;\\n"\n'
         '        "@p mapa.shared::cluster.u32  remAddr32, %0, %1;\\n"\n'
         '        "@p mbarrier.arrive.expect_tx.shared::cluster.b64  _, 
[remAddr32], %3;\\n"\n'
         '        "}\\n"\n'
@@ -303,6 +374,27 @@ device_intrinsic(
 )
 
 
+# mbarrier.try_wait.parity.acquire.cluster — cluster-scope acquire wait used 
for
+# cross-CTA barrier handshakes (e.g. the tmem-finished handoff).
+device_intrinsic(
+    "ptx_mbarrier_try_wait_acquire_cluster",
+    c_signature="(void* barrier, int phase)",
+    body=(
+        "    unsigned int barrier_addr_int = 
__cvta_generic_to_shared(barrier);\n"
+        "    asm volatile(\n"
+        '        "{\\n"\n'
+        '        ".reg .pred                P1;\\n"\n'
+        '        "LAB_WAIT_AC:\\n"\n'
+        '        "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, 
[%0], %1;\\n"\n'
+        '        "@P1                       bra.uni DONE_AC;\\n"\n'
+        '        "bra.uni                   LAB_WAIT_AC;\\n"\n'
+        '        "DONE_AC:\\n"\n'
+        '        "}\\n"\n'
+        '        :: "r"(barrier_addr_int), "r"(phase) : "memory");'
+    ),
+)
+
+
 # =============================================================================
 # mbarrier.try_wait.parity — ONE-SHOT non-blocking variant. Returns true
 # if the requested parity has already been reached, false otherwise.
diff --git 
a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py 
b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
index ffd5e18a3a..081ea5a772 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
@@ -369,20 +369,24 @@ def _emit_16xnb_path(
     tmem_st, tmem_extent = get_st_extent(tmem_region)
     local_st, local_extent = get_st_extent(local_region)
 
-    # Local slice must be the full (frag_rows, K_cols) view.
+    # Rows must span the full frag. The COLUMN extent may be a sub-multiple of
+    # the atom's full width ``width_elems`` — i.e. a per-chunk column slice of 
a
+    # wider frag (e.g. an epilogue that loads one big (128, MMA_N) frag in
+    # EPI_TILE-wide chunks). The atom layout maps consecutive columns to
+    # consecutive registers within each slab, so a column slice occupies a
+    # contiguous register window; we emit ``num_eff`` (the slice's atom rep) at
+    # the slab base + the column's register offset. When the slice IS the full
+    # atom (the common case), num_eff == num and reg offset == 0 (no change).
     assert analyzer.can_prove_equal(local_st[0], 0)
     assert analyzer.can_prove_equal(local_extent[0], frag_rows)
-    assert analyzer.can_prove_equal(local_extent[1], width_elems)
-
-    # TMEM slice must start at row 0 and span ``frag_rows`` rows. For Layout
-    # F the buffer is already (64, W) so frag_rows=64 covers the full slice;
-    # for Layout D + frag_rows=64 the slice reads the *first* half-slab and
-    # the rest of the buffer's 128 rows is invisible to this atom. For
-    # Layout D + frag_rows=128 the slice covers all 128 physical lanes via
-    # two PTX issues (row=0 + row=16).
     assert analyzer.can_prove_equal(tmem_st[0], 0)
     assert analyzer.can_prove_equal(tmem_extent[0], frag_rows)
-    assert analyzer.can_prove_equal(tmem_extent[1], width_elems)
+    # local and tmem column slices must match and divide the atom's full width.
+    assert analyzer.can_prove_equal(local_extent[1], tmem_extent[1])
+    slice_w = int(local_extent[1])
+    assert width_elems % slice_w == 0, f"slice width {slice_w} must divide 
atom width {width_elems}"
+    num_eff = num * slice_w // width_elems
+    regs_eff = regs_per_thread_per_slab * slice_w // width_elems
     del tmem_rows  # only used for the structural check above
 
     col_off = tmem_st[1]
@@ -410,13 +414,18 @@ def _emit_16xnb_path(
         # for the register-pointer arguments of the PTX builtin.
         local_storage = local_buf.view(per_thread_elems, 
layout=TileLayout(S[per_thread_elems]))
         local_32b = local_storage.view("uint32")
-        local_reg_base = local_col_off_elems // elem_per_32b
+        # Register offset of the column slice within each slab. The old
+        # ``local_col_off // elem_per_32b`` is only correct when the slice IS 
the
+        # full atom; in general consecutive columns advance registers at the 
rate
+        # (regs_per_thread_per_slab / width_elems). For a full-atom load the
+        # offset is 0 either way, so existing callers are unaffected.
+        local_reg_base = local_col_off_elems * regs_per_thread_per_slab // 
width_elems
         for slab in range(n_slabs):
             reg_base = slab * regs_per_thread_per_slab
             op(
                 tmem_buf.allocated_addr[0],
-                *[local_32b[local_reg_base + reg_base + i] for i in 
range(regs_per_thread_per_slab)],  # noqa: E501
-                shape=shape, num=num, row=slab * 16, col=col_off_32b,
+                *[local_32b[local_reg_base + reg_base + i] for i in 
range(regs_eff)],
+                shape=shape, num=num_eff, row=slab * 16, col=col_off_32b,
             )
     # fmt: on
     return impl
diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py 
b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
index eddf9f3d8e..64d77a21cf 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
@@ -45,8 +45,10 @@ from ..common import get_st_extent
 from ..copy._common import _carve_tail, _verify_s_tail_contig
 from ..layout_utils import get_sublayout_from_region, layout_signature
 from ._common import (
+    _TID_AXIS_FOR_SCOPE,
     _all_threads_active,
     _tensor_shape_of,
+    _thread_cnt,
     align_operands_to_anchor,
     buffer_regions,
     compute_dtype_of,
@@ -67,6 +69,68 @@ def _validate_anchor_layout(anchor_br) -> tuple[bool, str | 
None]:
     return True, None
 
 
+def _validate_scope_level_anchor(anchor_br, sctx: DispatchContext) -> 
tuple[bool, str | None]:
+    """For warp/warpgroup/cta scope, require dst to be scope-level: after
+    canonicalizing with the target its thread axes are the scope's intra-thread
+    axis (laneid/tid_in_wg/tx) and, sorted by stride, tile a complete ``T:1``
+    chain over all ``T`` threads of the scope. Rejects thread-local 
``.local()``
+    views; thread scope is exempt.
+    """
+    scope = sctx.scope_kind
+    if scope == "thread":
+        return True, None
+    expected_axis = _TID_AXIS_FOR_SCOPE.get(scope)
+    if expected_axis is None:
+        return True, None
+    expected_cnt = _thread_cnt(sctx)
+
+    # Canonicalize the sliced anchor with the target so warp/lane axes fuse.
+    st, ext = get_st_extent(anchor_br)
+    sliced = get_sublayout_from_region(anchor_br.buffer.layout, 
anchor_br.buffer.shape, st, ext)
+    with sctx.target:
+        canon = sliced.canonicalize() if hasattr(sliced, "canonicalize") else 
sliced
+    shard = getattr(canon, "shard", None)
+    if shard is None:
+        return False, f"{scope}-scope op operand layout is not a TileLayout 
after slicing"
+
+    thread_iters = [it for it in shard if it.axis.is_thread()]
+    if not thread_iters:
+        return (
+            False,
+            f"{scope}-scope op needs a {scope}-level operand whose layout 
carries "
+            f"thread axes ({expected_axis} composing to {expected_cnt}:1); got 
a "
+            f"thread-local view with no thread axes — pass the {scope}-level 
tensor, "
+            f"not its `.local()` (per-thread) view",
+        )
+    bad = sorted({it.axis.name for it in thread_iters if it.axis.name != 
expected_axis})
+    if bad:
+        return (
+            False,
+            f"{scope}-scope op operand carries thread axes {bad}; after "
+            f"canonicalization a {scope}-level layout must use only 
{expected_axis!r}",
+        )
+    # Sorted by stride the thread iters must tile a complete chain 1, e0,
+    # e0*e1, ... up to the scope thread count — i.e. cover all T threads with
+    # no gap or overlap (extents alone would miss gaps/overlaps).
+    running = 1
+    for it in sorted(thread_iters, key=lambda i: int(i.stride)):
+        stride, extent = int(it.stride), int(it.extent)
+        if stride != running:
+            return (
+                False,
+                f"{scope}-scope op operand thread axes do not tile a complete "
+                f"{expected_cnt}:1 (sorted by stride: expected {running}, got 
{stride})",
+            )
+        running *= extent
+    if running != expected_cnt:
+        return (
+            False,
+            f"{scope}-scope op operand thread axes span {running} threads, not 
the "
+            f"full {expected_cnt} of the {scope}",
+        )
+    return True, None
+
+
 def _check_layout_operands_agree(plan) -> tuple[bool, str | None]:
     """Replica sigs must match across non-trivial-layout operands.
 
@@ -133,6 +197,9 @@ def is_reg_ewise(spec):
         ok3, reason3 = _validate_anchor_layout(anchor)
         if not ok3:
             return False, reason3
+        ok_scope, reason_scope = _validate_scope_level_anchor(anchor, sctx)
+        if not ok_scope:
+            return False, reason_scope
         # Shape compat (NumPy-style broadcast): anchor's tensor shape is the
         # result shape; every operand must broadcast TO anchor.
         anchor_tshape = _tensor_shape_of(anchor.region)
diff --git a/python/tvm/backend/cuda/script.py 
b/python/tvm/backend/cuda/script.py
index a1148f9b67..a46aa7e7e4 100644
--- a/python/tvm/backend/cuda/script.py
+++ b/python/tvm/backend/cuda/script.py
@@ -53,6 +53,8 @@ class PTXNamespace:
         self.stmatrix = _op_wrapper(_cuda_op.ptx_stmatrix)
         self.setmaxnreg: Callable[..., Any] = 
_op_wrapper(_cuda_op.ptx_setmaxnreg)
         self.elect_sync: Callable[..., Any] = 
_op_wrapper(_cuda_op.ptx_elect_sync)
+        self.clc_try_cancel = _op_wrapper(_cuda_op.ptx_clc_try_cancel)
+        self.clc_query_cancel = _op_wrapper(_cuda_op.ptx_clc_query_cancel)
         self.fetch_register: Callable[..., Any] = 
_op_wrapper(_cuda_op.ptx_fetch_register)
         self.ld = _op_wrapper(_cuda_op.ptx_ld)
         self.ld_acquire = _op_wrapper(_cuda_op.ptx_ld_acquire)
@@ -276,6 +278,9 @@ class MbarrierNamespace:
         self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init)
         self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait)
         self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once)
+        self.try_wait_acquire_cluster = _op_wrapper(
+            _cuda_op.ptx_mbarrier_try_wait_acquire_cluster
+        )
         self.arrive = MbarrierArriveNamespace()
 
 
@@ -284,6 +289,7 @@ class MbarrierArriveNamespace:
 
     def __init__(self):
         self.expect_tx = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_expect_tx)
+        self.cluster_count = 
_op_wrapper(_cuda_op.ptx_mbarrier_arrive_cluster_count)
 
     def __call__(self, *args, **kwds):
         return _op_wrapper(_cuda_op.ptx_mbarrier_arrive)(*args, **kwds)
diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py
index ea5939fcef..b421042fb3 100644
--- a/python/tvm/support/nvcc.py
+++ b/python/tvm/support/nvcc.py
@@ -32,7 +32,7 @@ from . import utils
 
 
 def compile_cuda(
-    code, target_format=None, arch=None, options=None, path_target=None, 
compiler="nvcc"
+    code, target_format=None, arch=None, options=None, path_target=None, 
compiler="nvrtc"
 ):
     """Compile CUDA code with NVCC or NVRTC.
 
@@ -54,7 +54,7 @@ def compile_cuda(
         Output file.
 
     compiler : str, optional
-        Compiler backend: "nvcc" or "nvrtc".
+        Compiler backend: "nvrtc" (default) or "nvcc".
         This can be set by the TVM_CUDA_COMPILE_MODE environment variable.
 
     Returns
@@ -191,7 +191,7 @@ def _compile_cuda_nvcc(
         "--expt-extended-lambda",
         "--use_fast_math",
         "--ptxas-options=-v",  # printing out number of registers
-        
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
  # printing out number of registers  # noqa: E501
+        
f"--ptxas-options=--verbose,--register-usage-level={os.environ.get('TVM_CUDA_PTXAS_REG_LEVEL',
 '10')},--warn-on-local-memory-usage",  # noqa: E501
     ]
 
     major, _ = 
parse_compute_version(get_target_compute_version(Target.current(allow_none=True)))
@@ -342,14 +342,23 @@ def _compile_cuda_nvrtc(
         line for line in code.splitlines() if line.strip() not in 
headers_to_strip
     )
 
-    # NVRTC compiles device code and does not include the host-side cuda.h.
-    # CUtensorMap is a host-side structure, to reference and use it in device 
code,
-    # we must forward-declare it for NVRTC.
+    # NVRTC compiles device code and does not include the host-side cuda.h
+    # (it is guarded behind ``#ifndef __CUDACC_RTC__`` in generated code and is
+    # stripped above), so the complete ``CUtensorMap_st`` layout that cuda.h
+    # normally provides is missing. TMA kernels take ``CUtensorMap`` by value 
as
+    # ``__grid_constant__`` params, which requires the complete type. Define 
the
+    # ``CUtensorMap_st`` tag with cuda.h's layout (64-byte aligned, 128 bytes)
+    # plus the typedef alias. This is compatible with cccl's 
``<cuda/barrier>``,
+    # which only forward-declares ``struct CUtensorMap_st;`` and re-typedefs 
the
+    # alias (a redundant typedef to the same type is legal in C++); defining 
the
+    # tag rather than ``struct CUtensorMap`` avoids the previous redefinition
+    # clash with that header.
     if "CUtensorMap" in code_filtered:
         code_filtered = (
-            "struct __align__(128) CUtensorMap {\n"
+            "struct alignas(64) CUtensorMap_st {\n"
             "  unsigned long long opaque[16];\n"
-            "};\n\n" + code_filtered
+            "};\n"
+            "typedef struct CUtensorMap_st CUtensorMap;\n\n" + code_filtered
         )
 
     # Add standard type definitions and compatibility macros that NVRTC 
doesn't provide.
@@ -371,6 +380,13 @@ using cuda::std::int64_t;
 #define __volatile__ volatile
 #endif
 
+// NVRTC does not pull in the host <math.h>, so INFINITY is undefined. Provide 
it
+// from libcu++ (same float +inf value nvcc's <math.h> yields).
+#include <cuda/std/limits>
+#ifndef INFINITY
+#define INFINITY (::cuda::std::numeric_limits<float>::infinity())
+#endif
+
 """
     code_filtered = nvrtc_preamble + code_filtered
 
@@ -406,6 +422,9 @@ namespace std {
     compile_opts = [
         f"--gpu-architecture={arch}".encode(),
         b"-default-device",
+        # nvcc enables 128-bit integers by default on Linux; NVRTC requires the
+        # flag to be passed explicitly for kernels that use __int128_t.
+        b"--device-int128",
     ]
 
     if use_nvshmem:
@@ -469,6 +488,21 @@ namespace std {
             ]
         )
 
+    # Define the vector-deprecation silencing macros as no-ops for every NVRTC
+    # compile. These live in vector_types.h, which the fp4/fp6/fp8 headers use
+    # but do not include; depending on the include chain NVRTC pulls in, the
+    # macro can be left undefined and trigger a bogus "declaration has no 
storage
+    # class" error. Defining them empty is harmless (they only gate host-side
+    # deprecation warnings) and matches what the NVSHMEM path already did.
+    compile_opts.extend(
+        [
+            b"-D__NV_SILENCE_DEPRECATION_BEGIN=",
+            b"-D__NV_SILENCE_DEPRECATION_END=",
+            b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=",
+            b"-D__NV_SILENCE_HOST_DEPRECATION_END=",
+        ]
+    )
+
     compile_opts.extend(
         [
             b"-U__CUDA_NO_HALF_OPERATORS__",
@@ -481,6 +515,24 @@ namespace std {
         ]
     )
 
+    # Mirror the nvcc path's ptxas options. register-usage-level drives ptxas
+    # register allocation / instruction scheduling and is perf-relevant (FA4 
was
+    # tuned around it, hence the env-driven default); -v and
+    # --warn-on-local-memory-usage are diagnostic. NVRTC rejects -O3 and
+    # --register-usage-level as top-level flags but forwards them to its 
internal
+    # ptxas via --ptxas-options (ptxas already defaults to -O3). NB: unlike 
nvcc,
+    # NVRTC does not comma-split --ptxas-options, so each ptxas flag must be 
its
+    # own entry. The nvcc-only --expt-relaxed-constexpr / 
--expt-extended-lambda
+    # have no NVRTC equivalent and are intentionally not mirrored.
+    reg_level = os.environ.get("TVM_CUDA_PTXAS_REG_LEVEL", "10")
+    compile_opts.extend(
+        [
+            b"--ptxas-options=-v",
+            f"--ptxas-options=--register-usage-level={reg_level}".encode(),
+            b"--ptxas-options=--warn-on-local-memory-usage",
+        ]
+    )
+
     # Add user-provided options, filtering out nvcc-specific flags that nvrtc 
doesn't support
     if options:
         nvcc_only_prefixes = (
@@ -802,7 +854,7 @@ def tvm_callback_cuda_compile(code):
     Compile CUDA code using the configured backend (nvcc or nvrtc).
 
     This callback is invoked by TVM's C++ backend during CUDA module 
compilation.
-    By default, uses nvcc to generate fatbin.  The current target is fetched
+    By default, uses nvrtc to generate cubin.  The current target is fetched
     inside the callback (via ``tvm.target.Target.current(allow_none=True)``)
     so the caller does not need to push/pop a target scope around the
     invocation.
@@ -810,9 +862,9 @@ def tvm_callback_cuda_compile(code):
     Environment Variables
     ---------------------
     TVM_CUDA_COMPILE_MODE : str
-        Compiler backend: "nvcc" (default) or "nvrtc"
-        - "nvcc": Use nvcc subprocess, generates fatbin
+        Compiler backend: "nvrtc" (default) or "nvcc"
         - "nvrtc": Use NVRTC via cuda-bindings for faster JIT, generates cubin
+        - "nvcc": Use nvcc subprocess, generates fatbin
     TVM_KERNEL_DUMP : str
         If set, dump generated CUDA/intermediate files and append "-lineinfo" 
so profilers can
         correlate SASS back to the dumped source.
@@ -830,7 +882,7 @@ def tvm_callback_cuda_compile(code):
     # The current Target is fetched inside compile_cuda via
     # tvm.target.Target.current(allow_none=True) when arch is unset; the
     # caller no longer needs to push/pop a target scope.
-    compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower()
+    compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc").lower()
 
     if compiler == "nvrtc":
         return compile_cuda(code, target_format="cubin", compiler="nvrtc")
diff --git a/python/tvm/tirx/script/builder/external_kernel.py 
b/python/tvm/tirx/script/builder/external_kernel.py
index c1f5d58716..d56ed9ea03 100644
--- a/python/tvm/tirx/script/builder/external_kernel.py
+++ b/python/tvm/tirx/script/builder/external_kernel.py
@@ -159,7 +159,7 @@ class SourceKernel(BaseKernel):  # pylint: 
disable=too-few-public-methods
             target_format = "cubin" if use_nvshmem else "ptx"
             output_path = f"{temp_dir}/{kernel_name}.{target_format}"
 
-            compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc")
+            compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc")
             nvcc.compile_cuda(
                 source_code,
                 target_format=target_format,
diff --git a/src/backend/cuda/op/target_builtin.cc 
b/src/backend/cuda/op/target_builtin.cc
index 005fe5b322..353c04b501 100644
--- a/src/backend/cuda/op/target_builtin.cc
+++ b/src/backend/cuda/op/target_builtin.cc
@@ -152,6 +152,9 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive_expect_tx)
 TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
 
+TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait_acquire_cluster)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
+
 TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_arrive)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
 
@@ -497,6 +500,8 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = {
     TIRX_DEVICE_INTRIN_ALIAS(ptx_bar_sync, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_arrive, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_wait, ptx, kOpaque),
+    TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_query_cancel, ptx, kOpaque),
+    TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_try_cancel, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_commit_group, ptx, kOpaque),
@@ -540,6 +545,7 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = {
     TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_init, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_test_wait_parity, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait, ptx, kOpaque),
+    TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_acquire_cluster, ptx, 
kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_once, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_mma, ptx, kOpaque),
     TIRX_DEVICE_INTRIN_ALIAS(ptx_mma_legacy, ptx, kOpaque),
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 88a28ebccb..f32dcdde11 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -133,6 +133,8 @@ void CodeGenLLVM::Init(const std::string& module_name, 
LLVMTarget* llvm_target,
   builder_.reset(new IRBuilder(*ctx));
   module_.reset(new llvm::Module(module_name, *ctx));
   md_builder_.reset(new llvm::MDBuilder(*ctx));
+  functions_.clear();
+  function_symbol_owners_.clear();
   // types
   t_void_ = llvm::Type::getVoidTy(*ctx);
   t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), 
GetGlobalAddressSpace());
@@ -260,6 +262,21 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const 
GlobalVar& gvar, cons
       llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false);
 
   auto [symbol_name, linkage_type] = GetLinkage(gvar, func);
+  if (auto it = function_symbol_owners_.find(symbol_name); it != 
function_symbol_owners_.end()) {
+    constexpr const char* kFFISymbolPrefix = "__tvm_ffi_";
+    std::string user_symbol = symbol_name;
+    if (user_symbol.rfind(kFFISymbolPrefix, 0) == 0) {
+      user_symbol = 
user_symbol.substr(std::char_traits<char>::length(kFFISymbolPrefix));
+    }
+    TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" << 
user_symbol
+                                 << "' in LLVM codegen: IRModule keys '" << 
it->second
+                                 << "' and '" << gvar->name_hint
+                                 << "' both lower to the same exported symbol 
'" << symbol_name
+                                 << "'. "
+                                 << "Each exposed PrimFunc in one IRModule 
must have a unique "
+                                    "global_symbol.";
+  }
+  function_symbol_owners_[symbol_name] = gvar->name_hint;
 
   auto function = module_->getFunction(MakeStringRef(symbol_name));
   if (function == nullptr) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 8526b3f642..08396d596d 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -547,6 +547,9 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
   // that function.
   std::unordered_map<const GlobalVarNode*, llvm::Function*> functions_;
 
+  // Map from the generated LLVM function symbol to the GlobalVar that owns it.
+  std::unordered_map<std::string, std::string> function_symbol_owners_;
+
   // Whether current function is restricted
   bool is_restricted_{true};
   // The analyzer information
diff --git a/src/tirx/ir/layout/tile_slice.cc b/src/tirx/ir/layout/tile_slice.cc
index 3f4db48379..ce1809ae99 100644
--- a/src/tirx/ir/layout/tile_slice.cc
+++ b/src/tirx/ir/layout/tile_slice.cc
@@ -144,7 +144,11 @@ ffi::Optional<TileLayout> SlicePerGroup(TileLayout layout, 
PrimExpr begin, PrimE
 ffi::Optional<Layout> TileLayoutNode::Slice(const Array<PrimExpr>& shape,
                                             const Region& region) const {
   arith::Analyzer analyzer;
-  auto [grouped_layout, seps] = Group(ffi::GetRef<TileLayout>(this), shape);
+  // Canonicalize the whole layout first so scope fusion (e.g. wid_in_wg+laneid
+  // -> tid_in_wg) runs globally; otherwise grouping can split sibling thread
+  // axes and SlicePerGroup's per-group fusion leaves an ill-formed mix.
+  TileLayout canon = this->Canonicalize().as<TileLayout>().value();
+  auto [grouped_layout, seps] = Group(canon, shape);
   std::vector<Iter> new_shard;
   ffi::Map<Axis, PrimExpr> new_offset;
   for (size_t i = 0; i < seps.size() - 1; ++i) {
diff --git a/tests/python/codegen/test_target_codegen_llvm.py 
b/tests/python/codegen/test_target_codegen_llvm.py
index 7c093f9be2..624d587b82 100644
--- a/tests/python/codegen/test_target_codegen_llvm.py
+++ b/tests/python/codegen/test_target_codegen_llvm.py
@@ -30,6 +30,45 @@ from tvm.target.codegen import llvm_get_intrinsic_name, 
llvm_lookup_intrinsic_id
 from tvm.testing import env
 
 
[email protected](not env.has_llvm(), reason="need llvm")
+def test_duplicate_primfunc_global_symbol_diagnostic():
+    @I.ir_module(s_tir=True)
+    class Module:
+        @T.prim_func(s_tir=True)
+        def first_unique_key(A: T.Buffer((1,), "float32")):
+            T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True})
+            A[0] = T.float32(1)
+
+        @T.prim_func(s_tir=True)
+        def second_unique_key(A: T.Buffer((1,), "float32")):
+            T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True})
+            A[0] = T.float32(2)
+
+    with pytest.raises(
+        tvm.error.InternalError, match="Duplicate PrimFunc global_symbol 
'dup_symbol'"
+    ) as err:
+        tvm.compile(Module, target="llvm")
+    assert "first_unique_key" in str(err.value)
+    assert "second_unique_key" in str(err.value)
+
+
[email protected](not env.has_llvm(), reason="need llvm")
+def test_unique_primfunc_global_symbols_compile():
+    @I.ir_module(s_tir=True)
+    class Module:
+        @T.prim_func(s_tir=True)
+        def first_unique_key(A: T.Buffer((1,), "float32")):
+            T.func_attr({"global_symbol": "dup_symbol_a", "tirx.noalias": 
True})
+            A[0] = T.float32(1)
+
+        @T.prim_func(s_tir=True)
+        def second_unique_key(A: T.Buffer((1,), "float32")):
+            T.func_attr({"global_symbol": "dup_symbol_b", "tirx.noalias": 
True})
+            A[0] = T.float32(2)
+
+    tvm.compile(Module, target="llvm")
+
+
 @pytest.mark.skipif(not env.has_llvm(), reason="need llvm")
 def test_llvm_intrin():
     @I.ir_module(s_tir=True)
diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py 
b/tests/python/tirx/codegen/test_codegen_cuda.py
index f253d6d375..521a72f6d7 100644
--- a/tests/python/tirx/codegen/test_codegen_cuda.py
+++ b/tests/python/tirx/codegen/test_codegen_cuda.py
@@ -21,6 +21,7 @@ import pytest
 import tvm
 import tvm.testing
 from tvm.script import tirx as T
+from tvm.testing import env
 
 DEV = tvm.device("cuda")
 
@@ -118,6 +119,8 @@ def test_cuda_handle_uint64_reinterpret_codegen():
     assert "*(void* *)" not in src
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cuda_atomic_add():
     @T.prim_func
     def main(A: T.Buffer((1,), "int32"), B: T.Buffer((1,), "float32")):
@@ -442,6 +445,8 @@ def test_cuda_atomic_cas():
     assert "tvm_builtin_cuda_atomic_cas" in src
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cuda_func_call():
     def test_add_one():
         add_one = """
@@ -497,6 +502,8 @@ __device__ void print(int32_t a) {
     test_print()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_warp_shuffle_xor_sync():
     # fmt: off
     @T.prim_func
@@ -532,6 +539,8 @@ def test_warp_shuffle_xor_sync():
     np.testing.assert_allclose(A.numpy(), A_ref)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("cp_size", [4, 8, 16])
 @pytest.mark.parametrize("cache_hint", ["", "evict_last"])
 @pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256])
@@ -575,6 +584,8 @@ def test_ptx_cp_async(cp_size, cache_hint, prefetch_size, 
predicate, fill_mode):
     print(src)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("trans", [False, True])
 @pytest.mark.parametrize("num", [1, 2, 4])
 def test_ptx_ldmatrix(trans, num):
diff --git a/tests/python/tirx/codegen/test_codegen_nvshmem.py 
b/tests/python/tirx/codegen/test_codegen_nvshmem.py
index ff9f17170d..d386907742 100644
--- a/tests/python/tirx/codegen/test_codegen_nvshmem.py
+++ b/tests/python/tirx/codegen/test_codegen_nvshmem.py
@@ -28,6 +28,7 @@ from tvm.runtime import ShapeTuple
 from tvm.runtime import disco as di
 from tvm.script import tirx as T
 from tvm.support.popen_pool import PopenWorker
+from tvm.testing import env
 
 NUM_WORKERS = 4
 
@@ -61,6 +62,8 @@ def create_nvshmem_array(sess, shape, dtype, 
init_data_fn=None, zero_out=True):
     return arr
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.skip(reason="nvshmem doesn't work with pytest")
 def test_codegen_nvshmem():
     def _test_func():
diff --git a/tests/python/tirx/codegen/test_cuda_copy.py 
b/tests/python/tirx/codegen/test_cuda_copy.py
index cb08f42473..047eb1f12c 100644
--- a/tests/python/tirx/codegen/test_cuda_copy.py
+++ b/tests/python/tirx/codegen/test_cuda_copy.py
@@ -21,6 +21,7 @@ import pytest
 
 import tvm
 from tvm.script import tirx as T
+from tvm.testing import env
 
 DEV = tvm.cuda(0)
 TARGET = tvm.target.Target("cuda")
@@ -34,6 +35,8 @@ def _build_and_run(func, *np_args):
     return (*tuple(a.numpy() for a in rt_args), mod)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_copy_128b():
     """copy_128b: copies 16 bytes (4 float32 elements) via uint4 load/store."""
 
@@ -63,6 +66,8 @@ def test_copy_128b():
     assert "tvm_builtin_copy_128b" in mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_copy_64b():
     """copy_64b: copies 8 bytes (2 float32 elements) via uint2 load/store."""
 
@@ -92,6 +97,8 @@ def test_copy_64b():
     assert "tvm_builtin_copy_64b" in mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_copy_32b():
     """copy_32b: copies 4 bytes (1 float32 element) via unsigned int 
load/store."""
 
@@ -121,6 +128,8 @@ def test_copy_32b():
     assert "tvm_builtin_copy_32b" in mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_copy_16b():
     """copy_16b: copies 2 bytes (1 float16 element) via unsigned short 
load/store."""
 
@@ -150,6 +159,8 @@ def test_copy_16b():
     assert "tvm_builtin_copy_16b" in mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_copy_8b():
     """copy_8b: copies 1 byte (1 uint8 element) via unsigned char 
load/store."""
 
diff --git a/tests/python/tirx/codegen/test_cuda_cta_reduce.py 
b/tests/python/tirx/codegen/test_cuda_cta_reduce.py
index 51b8f1099a..bf07da1b67 100644
--- a/tests/python/tirx/codegen/test_cuda_cta_reduce.py
+++ b/tests/python/tirx/codegen/test_cuda_cta_reduce.py
@@ -21,6 +21,7 @@ import pytest
 
 import tvm
 from tvm.script import tirx as T
+from tvm.testing import env
 
 DEV = tvm.cuda(0)
 TARGET = tvm.target.Target("cuda")
@@ -35,6 +36,8 @@ def _build_and_run(func, n):
     return out.numpy(), mod
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cta_sum_4_warps():
     """CTA sum with 4 warps (128 threads): all threads get the same sum."""
     NUM_WARPS = 4
@@ -61,6 +64,8 @@ def test_cta_sum_4_warps():
     assert "cta_reduce_sum_4" in mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cta_sum_8_warps():
     """CTA sum with 8 warps (256 threads)."""
     NUM_WARPS = 8
@@ -86,6 +91,8 @@ def test_cta_sum_8_warps():
     np.testing.assert_allclose(result, np.full(N, expected))
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cta_max_4_warps():
     """CTA max with 4 warps: all threads get the maximum value."""
     NUM_WARPS = 4
@@ -110,6 +117,8 @@ def test_cta_max_4_warps():
     np.testing.assert_allclose(result, np.full(N, float(N)))
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cta_min_4_warps():
     """CTA min with 4 warps: all threads get the minimum value."""
     NUM_WARPS = 4
@@ -134,6 +143,8 @@ def test_cta_min_4_warps():
     np.testing.assert_allclose(result, np.full(N, 1.0))
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_cta_sum_1_warp():
     """CTA sum with 1 warp: degenerates to a pure warp reduce."""
     NUM_WARPS = 1
@@ -159,6 +170,8 @@ def test_cta_sum_1_warp():
     np.testing.assert_allclose(result, np.full(N, expected))
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16])
 def test_cta_sum_all_warp_counts(num_warps):
     """Parametric test: cta_sum with various warp counts."""
diff --git a/tests/python/tirx/codegen/test_cuda_warp_reduce.py 
b/tests/python/tirx/codegen/test_cuda_warp_reduce.py
index df568a95e4..e5167a055c 100644
--- a/tests/python/tirx/codegen/test_cuda_warp_reduce.py
+++ b/tests/python/tirx/codegen/test_cuda_warp_reduce.py
@@ -21,6 +21,7 @@ import pytest
 
 import tvm
 from tvm.script import tirx as T
+from tvm.testing import env
 
 DEV = tvm.cuda(0)
 TARGET = tvm.target.Target("cuda")
@@ -35,6 +36,8 @@ def _build_and_run(func, n=32):
     return out.numpy(), mod
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_warp_sum_full():
     """Full warp sum (width=32): each lane gets the sum of all 32 values."""
 
@@ -57,6 +60,8 @@ def test_warp_sum_full():
     assert "warp_reduce_sum_32" in mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_warp_sum_partial_8():
     """Partial warp sum (width=8): 4 groups of 8 lanes, each group sums 
independently."""
 
@@ -85,6 +90,8 @@ def test_warp_sum_partial_8():
     np.testing.assert_allclose(result, expected)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_warp_max_partial_4():
     """Partial warp max (width=4): 8 groups of 4 lanes."""
 
@@ -109,6 +116,8 @@ def test_warp_max_partial_4():
     np.testing.assert_allclose(result, expected)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_warp_min_full():
     """Full warp min (width=32)."""
 
@@ -129,6 +138,8 @@ def test_warp_min_full():
     np.testing.assert_allclose(result, np.full(32, 1.0))
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_warp_sum_partial_2():
     """Smallest partial warp sum (width=2): 16 pairs of adjacent lanes."""
 
@@ -155,6 +166,8 @@ def test_warp_sum_partial_2():
     np.testing.assert_allclose(result, expected)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("width", [2, 4, 8, 16, 32])
 def test_warp_sum_all_widths(width):
     """Parametric test: warp_sum with every valid width."""
diff --git a/tests/python/tirx/conftest.py b/tests/python/tirx/conftest.py
new file mode 100644
index 0000000000..fb8ba62f4f
--- /dev/null
+++ b/tests/python/tirx/conftest.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Suite-level hardware gate for the tirx tests.
+
+The tirx kernels and codegen paths target Blackwell (sm_100a) — they emit
+PTX/SASS (tcgen05, tmem, cp.async ``.async`` modifiers, fp8 conversions, ...)
+that ptxas/NVRTC reject for older targets, and many tests execute on the
+device. Running the suite on a CPU-only node or a pre-sm_100 GPU therefore
+fails at compile/run time rather than skipping. Gate the whole directory on a
+real sm_100a device so it skips cleanly where the hardware is absent and runs
+in full where it is present.
+"""
+
+import pytest
+
+from tvm.testing import env
+
+
+def pytest_collection_modifyitems(config, items):
+    if env.has_cuda_compute(10):
+        return
+    skip = pytest.mark.skip(
+        reason="tirx suite requires a CUDA compute capability 10.0 (sm_100a) 
device"
+    )
+    for item in items:
+        item.add_marker(skip)
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
index 75faf61366..1824b41eae 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
@@ -32,6 +32,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 
 # Force the fallback dispatch to register before any test compiles a kernel.
 # Without this import, in fresh pytest workers the `copy/fallback` variant
@@ -128,6 +129,8 @@ def _build_round_trip_kernel(scope, n_threads, shape, 
dtype):
     return kernel
 
 
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >= 
9.0")
 @pytest.mark.parametrize(
     "scope,n_threads,shape,why",
     [
@@ -158,6 +161,8 @@ def test_fallback_round_trip(scope, n_threads, shape, why):
     np.testing.assert_array_equal(B.numpy(), A_np)
 
 
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >= 
9.0")
 def test_fallback_thread_scope():
     """``T.thread()`` — single thread, no gate. Either ``gmem_smem`` picks
     it up (n_elements % 1 == 0) or ``fallback`` does — both end up emitting
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
index dc5a46a751..c31ca79db9 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
@@ -103,6 +103,8 @@ TASKS = [
 ]
 
 
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >= 
9.0")
 @pytest.mark.parametrize(
     "scope,n_threads,shape",
     [pytest.param(*t, id=f"{t[0]}-{t[1]}-{'x'.join(map(str, t[2]))}") for t in 
TASKS],
@@ -194,6 +196,8 @@ def test_gmem_smem_roundtrip(scope, n_threads, shape, 
dtype):
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >= 
9.0")
 @pytest.mark.parametrize(
     "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", 
"float32"]
 )
diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
index 4516225303..26c4d5de9b 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
@@ -35,6 +35,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx
 
 
@@ -228,6 +229,8 @@ def _expected(shape, dtype):
     return out
 
 
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >= 
9.0")
 @pytest.mark.parametrize("non_r_scope", ["shared", "global"])
 @pytest.mark.parametrize(
     "scope,n_threads,k",
@@ -287,6 +290,8 @@ def test_reg_roundtrip(scope, n_threads, k, dtype, 
non_r_scope):
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >= 
9.0")
 @pytest.mark.parametrize(
     "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", 
"float32"]
 )
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
index b4d54d2b41..96f9283253 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
@@ -24,6 +24,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import S, TileLayout
 
 
@@ -65,6 +66,8 @@ from tvm.tirx.layout import S, TileLayout
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize(
     "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", 
"float32"]
 )
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
index 0f910a4376..55e32339c7 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
@@ -24,10 +24,13 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import S, TCol, TileLayout, TLane
 from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("dtype", ["float16", "float32"])
 @pytest.mark.parametrize("width_32b", [4, 8, 16, 32])
 def test_copy_tmem2reg_async(dtype, width_32b):
@@ -132,6 +135,8 @@ def test_copy_tmem2reg_async(dtype, width_32b):
 # ----------------------------------------------------------------------------
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"])
 @pytest.mark.parametrize("width_32b", [2, 4, 8, 16, 32, 64, 128])
 @pytest.mark.parametrize("offset_32b", [0, 3, 10])
@@ -224,6 +229,8 @@ def test_copy_tmem2reg(dtype, width_32b, offset_32b):
         np.testing.assert_allclose(B.numpy(), A_np)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("dtype", ["float16", "float32"])
 @pytest.mark.parametrize("width_32b", [4, 8, 16, 32])
 @pytest.mark.parametrize("local_offset_32b", [0, 2, 4])
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
index 4209359460..aac93c0252 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
@@ -43,6 +43,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import (
     S,
     TCol,
@@ -152,6 +153,8 @@ def _expected_reg_value_16b(
 # --------------------------------------------------------------------------
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("shape", list(_SHAPE_REPS))
 @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32])  # subset; full reps 
below
 @pytest.mark.parametrize("dtype", ["float32"])
@@ -162,6 +165,8 @@ def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype):
     _run_load_test(shape, rep, dtype)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize(
     "shape, rep",
     [
@@ -175,6 +180,8 @@ def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep):
     _run_load_test(shape, rep, "float32")
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("shape", list(_SHAPE_REPS))
 @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32])
 @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@@ -201,6 +208,8 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype):
 # We only need to spot-check that the dispatch fires correctly and the per-
 # thread reg ↔ TMEM mapping round-trips bit-exactly — the M=64 sweep above
 # already covers the (lane, reg) decomposition, so a sparse rep set suffices.
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"])
 @pytest.mark.parametrize("rep", [1, 2, 4])
 @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@@ -214,6 +223,8 @@ def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, 
dtype):
 # with the scatter-encoded TileLayout that ``tmem_datapath_layout("F", ...)``
 # produces. ``.16x*b`` M=64 PTX has the matching scatter built in, so the
 # round-trip is bit-exact in the same way as Layout D + M=64.
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"])
 @pytest.mark.parametrize("rep", [1, 2, 4])
 @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@@ -639,6 +650,8 @@ def _run_load_test(shape: str, rep: int, dtype: str):
 # --------------------------------------------------------------------------
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("shape", list(_SHAPE_REPS))
 @pytest.mark.parametrize("rep", [1, 4, 16])
 @pytest.mark.parametrize("dtype", ["float32"])
@@ -853,5 +866,136 @@ def test_alloc_tcgen05_frag_wrapper_compiles(shape, 
frag_rows, K_cols):
     )
 
 
+# --------------------------------------------------------------------------
+# Test 3: column-slice loads of a wider frag
+#
+# An epilogue may allocate one wide ``(128, K)`` frag and load it from TMEM in
+# EPI_TILE-wide column chunks (``frag[:, c:c+w]``) so all loads are in flight
+# before a single ``wait.ld``. The ``.16x*b`` dispatch must emit each slice as
+# its own atom (``num_eff`` derived from the slice width) at the correct
+# per-slab register offset. We verify this is *bit-exact identical* to one
+# full-width load of the same frag — which the sweeps above already validate
+# against the layout-derived expectation. M=128 here exercises the 2-slab path
+# (the slice's two slabs live ``regs_per_thread_per_slab`` apart, not 
adjacent).
+# --------------------------------------------------------------------------
+
+
+def _run_sliced_vs_full_load(shape, full_rep, n_chunks):
+    dtype = "float32"
+    K_cols_fp32 = _COL_FACTOR_FP32[shape] * full_rep
+    assert K_cols_fp32 % n_chunks == 0
+    chunk_elem = K_cols_fp32 // n_chunks  # fp32: elem == fp32 col
+    frag_rows = 128  # M=128 => 2 slabs
+    per_thread_elems = _REGS_FACTOR[shape] * full_rep * 2  # *2 for the second 
slab
+
+    tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32))
+    stage_width_elem = tmem_col_width_32b
+    CHUNK_FP32 = 128
+    n_stage = tmem_col_width_32b // CHUNK_FP32 if tmem_col_width_32b > 
CHUNK_FP32 else 1
+    stage_w = tmem_col_width_32b if n_stage == 1 else CHUNK_FP32
+    VEC_LEN = 4  # 128-bit / fp32
+
+    atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_fp32), dtype)
+    stage_view = TileLayout(S[(128, stage_w) : (1 @ axis_tid_in_wg, 1)])
+
+    @T.prim_func
+    def kernel(A_ptr: T.handle, Bf_ptr: T.handle, Bs_ptr: T.handle) -> None:
+        A = T.match_buffer(A_ptr, (128, stage_width_elem), dtype)
+        Bf = T.match_buffer(Bf_ptr, (128, per_thread_elems), dtype)  # 
full-load dump
+        Bs = T.match_buffer(Bs_ptr, (128, per_thread_elems), dtype)  # 
sliced-load dump
+        A_flat = A.view(-1)
+
+        T.device_entry()
+        warp_id = T.warp_id([4])
+        T.cta_id([2])
+        wg_id = T.warpgroup_id([1])
+        T.warp_id_in_wg([4])
+        T.lane_id([32])
+        tid_in_wg = T.thread_id([128])
+
+        tmem_addr = T.alloc_shared([1], "uint32")
+        if wg_id == 0:
+            if warp_id == 0:
+                T.ptx.tcgen05.alloc(T.address_of(tmem_addr), 
n_cols=tmem_col_width_32b, cta_group=1)
+            T.tvm_storage_sync("shared")
+            tmem = T.decl_buffer(
+                (128, stage_width_elem),
+                dtype,
+                scope="tmem",
+                allocated_addr=tmem_addr[0],
+                layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ 
TCol)]),
+            )
+            # Stage A -> TMEM via the standard .32x32b path.
+            stage_reg = T.alloc_local((stage_w,), dtype)
+            stage_local = stage_reg.view(128, stage_w, layout=stage_view)
+            for ci in range(n_stage):
+                coff = ci * stage_w
+                for i in range(stage_w // VEC_LEN):
+                    g = T.meta_var(tid_in_wg * stage_width_elem + coff + i * 
VEC_LEN)
+                    Tx.copy(stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], 
A_flat[g : g + VEC_LEN])
+                T.cuda.cta_sync()
+                Tx.wg.copy_async(tmem[:, coff : coff + stage_w], 
stage_local[:, :])
+            T.ptx.tcgen05.wait.st()
+            T.cuda.cta_sync()
+
+            # (a) one full-width load
+            ff = T.alloc_local((per_thread_elems,), dtype)
+            ffl = ff.view(frag_rows, K_cols_fp32, layout=atom_view)
+            Tx.wg.copy_async(ffl[:, :], tmem[0:frag_rows, 0:K_cols_fp32])
+            T.ptx.tcgen05.wait.ld()
+            T.cuda.cta_sync()
+            for i in range(per_thread_elems):
+                Bf[tid_in_wg, i] = ff[i]
+
+            # (b) the same frag loaded in n_chunks column slices
+            sf = T.alloc_local((per_thread_elems,), dtype)
+            sfl = sf.view(frag_rows, K_cols_fp32, layout=atom_view)
+            for ck in range(n_chunks):
+                lo = T.meta_var(ck * chunk_elem)
+                Tx.wg.copy_async(
+                    sfl[:, lo : lo + chunk_elem], tmem[0:frag_rows, lo : lo + 
chunk_elem]
+                )
+            T.ptx.tcgen05.wait.ld()
+            T.cuda.cta_sync()
+            for i in range(per_thread_elems):
+                Bs[tid_in_wg, i] = sf[i]
+
+            if warp_id == 0:
+                T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+                T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, 
cta_group=1)
+
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": kernel})
+        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+        A_np = tvm.testing.generate_random_array(dtype, (128, 
stage_width_elem))
+        Bf_np = np.zeros((128, per_thread_elems), dtype=dtype)
+        Bs_np = np.zeros((128, per_thread_elems), dtype=dtype)
+        DEV = tvm.cuda(0)
+        A = tvm.runtime.tensor(A_np, DEV)
+        Bf = tvm.runtime.tensor(Bf_np, DEV)
+        Bs = tvm.runtime.tensor(Bs_np, DEV)
+        mod(A, Bf, Bs)
+        # Sliced load must reproduce the full-width load bit-for-bit.
+        np.testing.assert_array_equal(Bs.numpy().view(np.uint32), 
Bf.numpy().view(np.uint32))
+
+
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
[email protected](
+    "full_rep, n_chunks",
+    [
+        (32, 8),  # 16x256b.x32 (256 fp32 cols) loaded in 8 chunks of 32 cols 
(nvfp4 EPI_TILE=32)
+        (32, 16),  # ...in 16 chunks of 16 cols (nvfp4 EPI_TILE=16)
+        (32, 4),  # ...in 4 chunks of 64 cols
+        (16, 8),  # 16x256b.x16 (128 fp32 cols) in 8 chunks of 16 cols
+        (16, 2),  # ...in 2 chunks of 64 cols
+    ],
+)
+def test_tcgen05_ld_16x256b_sliced_matches_full_M128(full_rep, n_chunks):
+    """Per-chunk column-slice load of a wide M=128 frag == full-width load."""
+    _run_sliced_vs_full_load("16x256b", full_rep, n_chunks)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py 
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
index 1ce0d34ea6..8d39ba3556 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
@@ -23,6 +23,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import S, TileLayout, wg_local_layout
 
 
@@ -67,6 +68,8 @@ from tvm.tirx.layout import S, TileLayout, wg_local_layout
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"])
 @pytest.mark.parametrize("operands_type", ["region_region", "region_const", 
"const_region"])
 @pytest.mark.parametrize("dtype", ["float16"])
@@ -223,6 +226,8 @@ def test_binary_non_commutative_const_lhs_rejected(op_type):
             tvm.compile(mod, target=target, tir_pipeline="tirx")
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"])
 @pytest.mark.parametrize("op_type", ["add", "mul"])
 def test_binary_op_shared_subcta_scope(exec_scope, op_type):
@@ -276,6 +281,8 @@ def test_binary_op_shared_subcta_scope(exec_scope, op_type):
         tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("exec_scope", ["cta", "warpgroup", "warp"])
 @pytest.mark.parametrize("rhs_kind", ["region", "broadcast", "const"])
 @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"])
@@ -392,6 +399,8 @@ def test_binary_op_local_subcta_trivial(exec_scope, 
rhs_kind, op_type):
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("storage_scope", ["shared", "local"])
 @pytest.mark.parametrize("exec_scope", ["cta", "thread"])
 @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"])
@@ -495,6 +504,8 @@ def test_binary_op_vectorized(input, storage_scope, 
exec_scope, op_type, dtype):
         tvm.testing.assert_allclose(A_ref, A.numpy(), atol=atol)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["add", "sub", "mul"])
 def test_binary_op_packed_f32x2_auto_dispatch(op_type):
     target = tvm.target.Target("cuda")
@@ -568,6 +579,8 @@ def test_binary_op_packed_f32x2_auto_dispatch(op_type):
         tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_name", ["add", "sub", "mul"])
 def test_binary_op_warpgroup_wg_local_layout(op_name):
     dtype = "float32"
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py 
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
index aa0f5ced8f..02352638e4 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
@@ -26,6 +26,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import S, TileLayout, wg_local_layout
 
 
@@ -41,6 +42,8 @@ def _get_sm_version():
 # ---------------------------------------------------------------------------
 # FMA op: scalar scale + scalar bias
 # ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_fma_scalar_scalar():
     sm = _get_sm_version()
     if sm < 100:
@@ -78,6 +81,8 @@ def test_fma_scalar_scalar():
 # ---------------------------------------------------------------------------
 # FMA op: buffer scale + scalar bias (Horner pattern)
 # ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_fma_buffer_scale_scalar_bias():
     sm = _get_sm_version()
     if sm < 100:
@@ -119,6 +124,8 @@ def test_fma_buffer_scale_scalar_bias():
 # ---------------------------------------------------------------------------
 # Binary op with scalar broadcast (PrimExpr scalar, e.g. BufferLoad)
 # ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_mul_scalar_broadcast():
     sm = _get_sm_version()
     if sm < 100:
@@ -158,6 +165,8 @@ def test_mul_scalar_broadcast():
 # ---------------------------------------------------------------------------
 # Binary add with rounding mode
 # ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_add_rounding_mode():
     sm = _get_sm_version()
     if sm < 100:
@@ -199,6 +208,8 @@ def test_add_rounding_mode():
 # ---------------------------------------------------------------------------
 # FMA op: layout=None local buffer (no TileLayout)
 # ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_fma_no_layout():
     sm = _get_sm_version()
     if sm < 100:
@@ -238,6 +249,8 @@ def test_fma_no_layout():
 # ---------------------------------------------------------------------------
 # Binary sub with rounding mode (buffer-buffer)
 # ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_sub_buffer_buffer_rounding():
     sm = _get_sm_version()
     if sm < 100:
@@ -278,6 +291,8 @@ def test_sub_buffer_buffer_rounding():
         tvm.testing.assert_allclose(expected, A_dev.numpy(), atol=1e-6)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_fma_warpgroup_wg_local_layout():
     rows, cols = 128, 8
     dtype = "float32"
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py 
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
index c20df63beb..fb70b37541 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
@@ -23,6 +23,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.cuda.operator.tile_primitive.layout_utils import (
     cast_layout_supported_for_local as _cast_layout_supported_for_local,
 )
@@ -54,6 +55,8 @@ from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, 
tx, warpid
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["zero", "sqrt"])
 @pytest.mark.parametrize(
     "src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"), 
("float32", "bfloat16")]
@@ -145,6 +148,8 @@ def test_unary_op_shared(input, op_type, src_dtype, 
dst_dtype):
             tvm.testing.assert_allclose(B_ref, B.numpy(), atol=1e-2, rtol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"])
 def test_unary_op_shared_subcta_scope(exec_scope):
     dtype = "float16"
@@ -209,6 +214,8 @@ def test_unary_op_shared_subcta_scope(exec_scope):
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sqrt", "exp"])
 @pytest.mark.parametrize("bias_type", ["const", "region"])
 @pytest.mark.parametrize(
@@ -432,6 +439,8 @@ def test_unary_op_shared_with_bias_scale(input, op_type, 
bias_type, src_dtype, d
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["reciprocal", "exp", "exp2"])
 @pytest.mark.parametrize(
     "src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"), 
("float32", "bfloat16")]
@@ -554,6 +563,8 @@ def test_unary_op_local(input, op_type, src_dtype, 
dst_dtype):
         ),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sqrt", "exp"])
 @pytest.mark.parametrize("bias_type", ["const", "region"])
 @pytest.mark.parametrize(
@@ -682,6 +693,8 @@ def test_unary_op_local_with_bias_scale(input, op_type, 
bias_type, src_dtype, ds
         tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("shape", [(128, 8), (128, 4, 16), (128, 5, 5)])
 @pytest.mark.parametrize("op_type", ["fill"])
 @pytest.mark.parametrize("exec_scope", ["thread", "cta"])
@@ -740,6 +753,8 @@ def test_unary_op_vectorized(shape, op_type, exec_scope, 
storage_scope):
         tvm.testing.assert_allclose(A.numpy(), np.full(shape, value.value), 
atol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["zero", "sqrt", "reciprocal", "exp", 
"silu"])
 @pytest.mark.parametrize("dtype", ["float16"])
 def test_unary_op_local_thread_wise(op_type, dtype):
@@ -791,6 +806,8 @@ def test_unary_op_local_thread_wise(op_type, dtype):
         tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-2, rtol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("shape", [(8,), (16, 16), (5, 5)])
 @pytest.mark.parametrize("A_dtype", ["float16", "float32"])
 @pytest.mark.parametrize("B_dtype", ["float16", "float32"])
@@ -831,6 +848,8 @@ def test_cast_thread_local(shape, A_dtype, B_dtype):
         tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), 
("float32", "bfloat16")])
 def test_cast_warpgroup_local_view(A_dtype, B_dtype):
     """T.cast in warpgroup scope with offset (tid_in_wg + layout offset). 
Covers offset/tid_in_wg/warpgroup scope."""  # noqa: E501
@@ -884,6 +903,8 @@ def test_cast_warpgroup_local_view(A_dtype, B_dtype):
         tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), 
("float32", "bfloat16")])
 def test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype, 
B_dtype):
     """Regression: GEMM-epilogue cast pattern must emit the packed vec2 cuda 
intrinsic.
@@ -944,6 +965,8 @@ def 
test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype, B_dtype)
         tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), 
("float32", "bfloat16")])
 def test_cast_cta_local_view(A_dtype, B_dtype):
     """T.cast with view+layout in CTA scope (128 threads, 
register->register)."""
@@ -988,6 +1011,8 @@ def test_cast_cta_local_view(A_dtype, B_dtype):
         tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), 
("float32", "bfloat16")])
 @pytest.mark.parametrize("slice_start,slice_end", [(0, 4), (2, 6), (4, 8)])
 def test_cast_local_view_sliced(A_dtype, B_dtype, slice_start, slice_end):
@@ -1087,6 +1112,8 @@ def test_cast_layout_partition_and_validation():
             check(part)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("slice_start,slice_end", [(0, 2), (2, 4)])
 def test_cast_mixed_axes_and_subregion(slice_start, slice_end):
     """Test cast with mixed axes and subregion."""
@@ -1095,7 +1122,7 @@ def test_cast_mixed_axes_and_subregion(slice_start, 
slice_end):
     LOCAL_LEN = 4
     full_shape = (8, N_WARPS, 4, LOCAL_LEN)
     g_layout = TileLayout(S[full_shape])
-    cast_layout = TileLayout(S[full_shape : (4 @ laneid, 2 @ warpid, 1 @ 
laneid, 1)])
+    cast_layout = TileLayout(S[full_shape : (4 @ laneid, 1 @ warpid, 1 @ 
laneid, 1)])
 
     A_ref = np.zeros(full_shape, dtype="float32")
     for j in range(full_shape[0]):
@@ -1207,8 +1234,12 @@ def test_cast_validate_extent_mismatch_rejected():
     target = tvm.target.Target("cuda")
     with target:
         mod = tvm.IRModule({"main": kernel})
+        # The mismatched dst also fails the scope-level check (thread axes 
don't
+        # span the full CTA), which fires first — either rejection is fine.
         with pytest.raises(
-            Exception, match="tile_local_valid|layout signature 
mismatch|thread part mismatch"
+            Exception,
+            match="tile_local_valid|layout signature mismatch|thread part 
mismatch"
+            "|do not tile a complete|not the full",
         ):
             tvm.compile(mod, target=target, tir_pipeline="tirx")
 
@@ -1277,5 +1308,138 @@ def test_cast_vec2_packed_dispatch(src_dtype, 
dst_dtype, intrinsic):
     ), f"expected packed vec2 cast {intrinsic}; got:\n{src[:2000]}"
 
 
+# -----------------------------------------------------------------------------
+# Scope-level operand check: a warp/wg/cta reg op needs a scope-level layout
+# (thread axes spanning all the scope's threads), not a thread-local .local().
+# -----------------------------------------------------------------------------
+_SL_ROWS, _SL_COLS = 128, 8
+
+
+def _sl_compile(fn):
+    target = tvm.target.Target("cuda")
+    with target:
+        tvm.compile(tvm.IRModule({"main": fn}), target=target, 
tir_pipeline="tirx")
+
+
+def test_cast_wg_rejects_thread_local_view():
+    """Tx.wg.cast on a .local() (thread-axis-stripped) view is rejected."""
+
+    @T.prim_func
+    def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        T.device_entry()
+        _bx = T.cta_id([1])
+        _wg = T.warpgroup_id([1])
+        tid = T.thread_id_in_wg([_SL_ROWS])
+        src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        src_row = src.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            src_row[i] = A[tid, i]
+        Tx.wg.cast(dst.local(), src.local())
+        dst_row = dst.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            B[tid, i] = dst_row[i]
+
+    with pytest.raises(Exception, match="thread-local view"):
+        _sl_compile(kernel)
+
+
+def test_cast_cta_rejects_thread_local_view():
+    """Tx.cta.cast on a .local() view is rejected (cta -> tx)."""
+
+    @T.prim_func
+    def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        T.device_entry()
+        _bx = T.cta_id([1])
+        tx_var = T.thread_id([_SL_ROWS])
+        src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
+        dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
+        src_row = src.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            src_row[i] = A[tx_var, i]
+        Tx.cta.cast(dst.local(), src.local())
+        dst_row = dst.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            B[tx_var, i] = dst_row[i]
+
+    with pytest.raises(Exception, match="thread-local view"):
+        _sl_compile(kernel)
+
+
+def test_cast_wg_rejects_partial_thread_coverage():
+    """A tid_in_wg layout covering only 64 of the 128 wg threads is 
rejected."""
+    half = 64
+
+    @T.prim_func
+    def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+        A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32", 
layout=TileLayout(S[(half, _SL_COLS)]))
+        B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16", 
layout=TileLayout(S[(half, _SL_COLS)]))
+        T.device_entry()
+        _bx = T.cta_id([1])
+        _wg = T.warpgroup_id([1])
+        tid = T.thread_id_in_wg([_SL_ROWS])
+        src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        src_row = src.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            src_row[i] = A[tid, i]
+        Tx.wg.cast(dst, src)
+        dst_row = dst.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            B[tid, i] = dst_row[i]
+
+    with pytest.raises(Exception, match="not the full 128"):
+        _sl_compile(kernel)
+
+
+def test_cast_wg_accepts_wg_level_layout():
+    """Tx.wg.cast on a wg-level (tid_in_wg-distributed) layout compiles."""
+
+    @T.prim_func
+    def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        T.device_entry()
+        _bx = T.cta_id([1])
+        _wg = T.warpgroup_id([1])
+        tid = T.thread_id_in_wg([_SL_ROWS])
+        src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        src_row = src.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            src_row[i] = A[tid, i]
+        Tx.wg.cast(dst, src)
+        dst_row = dst.local(_SL_COLS)
+        for i in T.serial(_SL_COLS):
+            B[tid, i] = dst_row[i]
+
+    _sl_compile(kernel)
+
+
+def test_cast_thread_accepts_local_view():
+    """thread scope is exempt: a thread-axis-free local tile still compiles."""
+
+    @T.prim_func
+    def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        T.device_entry()
+        _bx = T.cta_id([1])
+        tx_var = T.thread_id([_SL_ROWS])
+        src = T.alloc_buffer((_SL_COLS,), "float32", scope="local", 
layout=TileLayout(S[(_SL_COLS,)]))
+        dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local", 
layout=TileLayout(S[(_SL_COLS,)]))
+        for i in T.serial(_SL_COLS):
+            src[i] = A[tx_var, i]
+        Tx.cast(dst, src)
+        for i in T.serial(_SL_COLS):
+            B[tx_var, i] = dst[i]
+
+    _sl_compile(kernel)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py 
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
index e0a270e709..32ac00e39d 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
@@ -32,6 +32,7 @@ import tvm.testing
 from tvm.ir.type import PointerType, PrimType
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.cuda.operator.tile_primitive.gemm_async import sf_tmem_layout
 from tvm.tirx.cuda.operator.tile_primitive.tma_utils import (
     mma_atom_layout,
@@ -167,6 +168,8 @@ def pack_sf_fp8_uint32(sf_uint8, n_total=128):
     return packed
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize(
     "task",
     [
@@ -293,6 +296,8 @@ def test_gemm_tcgen05_cta_group_1(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 def test_gemm_tcgen05_cta_group_1_layout_f_m64():
     """M=64 MMA with C operand allocated as Layout F (datapath="F").
 
@@ -405,6 +410,8 @@ def test_gemm_tcgen05_cta_group_1_layout_f_m64():
     np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-2, rtol=1e-2)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize(
     "task",
     [
@@ -545,6 +552,8 @@ def test_gemm_tcgen05_cta_group_2(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 def test_gemm_tcgen05_cta_group_2_layout_b():
     """Test cta_group=2 with Layout B (2x2 datapath, M=128 total, 64 per CTA).
 
@@ -675,6 +684,8 @@ def test_gemm_tcgen05_cta_group_2_layout_b():
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
 @pytest.mark.parametrize(
     "task",
@@ -864,6 +875,8 @@ def test_gemm_block_scaled_fp8_cta_group_1(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
 @pytest.mark.parametrize(
     "task",
@@ -1089,6 +1102,8 @@ def test_gemm_block_scaled_fp8_cta_group_2(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
 def test_gemm_block_scaled_nvfp4_cta_group_1():
     """Test block-scaled nvfp4 GEMM with cta_group=1.
@@ -1258,6 +1273,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_1():
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
 def test_gemm_block_scaled_nvfp4_cta_group_2():
     """Test block-scaled nvfp4 GEMM with cta_group=2.
@@ -1462,6 +1479,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_2():
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
 def test_gemm_block_scaled_fp8_sf_id():
     """Test sf_id auto-derivation from layout for fp8 block-scaled MMA.
@@ -1681,6 +1700,8 @@ def test_gemm_block_scaled_fp8_sf_id():
             )
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize(
     "task",
     [
@@ -1960,6 +1981,8 @@ def test_gemm_tcgen05_arbitrary_tiles(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >= 
10.0")
 @pytest.mark.parametrize("k_lo,k_hi", [(0, 16), (0, 32), (16, 32), (16, 48), 
(32, 64)])
 def test_gemm_tcgen05_contiguous_kslice_partial_k(k_lo, k_hi):
     """A slice on the *contiguous* (K) axis of a swizzled gemm_async operand 
must
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
 
b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
index 67cc1e0bd6..0402719ba1 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
@@ -43,6 +43,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 
 # Helpers exposed by the dispatcher module for direct algorithm tests.
 from tvm.tirx.cuda.operator.tile_primitive.permute_layout.warp_xor_swizzle 
import (
@@ -167,6 +168,8 @@ def _compile_and_run(prim_func, np_inputs):
     return [t.numpy() for t in tensors], mod.mod.imports[0].inspect_source()
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @needs_cuda
 @pytest.mark.parametrize(
     "name, pipe, blk, dtype",
@@ -231,6 +234,8 @@ def test_sf_blockwise_transpose(name, pipe, blk, dtype):
         np.testing.assert_array_equal(B_flat, ref, err_msg=f"{name} stage {s}")
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @needs_cuda
 def test_identity_passes_through_as_copy():
     """L_src == L_dst should still compile and produce a correct (identity) 
copy."""
@@ -255,6 +260,8 @@ def test_identity_passes_through_as_copy():
     np.testing.assert_array_equal(B_out, A_np)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @needs_cuda
 @pytest.mark.parametrize("dtype", ["uint32", "int32", "float32"])
 @pytest.mark.parametrize(
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py 
b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
index 0474ad2dc4..9031aa4f48 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
@@ -21,6 +21,7 @@ import tvm
 import tvm.testing
 from tvm.script import tirx as T
 from tvm.script.tirx import tile as Tx
+from tvm.testing import env
 from tvm.tirx.layout import R, S, TileLayout, laneid, wg_local_layout
 
 
@@ -41,6 +42,8 @@ from tvm.tirx.layout import R, S, TileLayout, laneid, 
wg_local_layout
         ((32, 32), (32,), (-1,), (1, 1), (2,), (5, 8), (5,)),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sum", "max", "min"])
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
 @pytest.mark.parametrize("accum", [False, True])
@@ -129,6 +132,8 @@ def test_reduction_shared(
         tvm.testing.assert_allclose(ref, B.numpy()[tuple(reduce_slice_dst)], 
atol=atol)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup", "thread"])
 @pytest.mark.parametrize("op_type", ["sum", "max", "min"])
 @pytest.mark.parametrize("accum", [False, True])
@@ -264,6 +269,8 @@ def test_reduction_shared_subscope(exec_scope, op_type, 
accum):
         ((2, 3, 4), (3, 4), (0,)),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sum", "max", "min"])
 @pytest.mark.parametrize("accum", [False, True])
 def test_reduction_local_thread_wise(src_shape, dst_shape, axes, op_type, 
accum):
@@ -367,6 +374,8 @@ def test_reduction_local_thread_wise(src_shape, dst_shape, 
axes, op_type, accum)
         ((4, 8), (1, 8), (1,), False, None),
     ],
 )
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sum", "max", "min"])
 def test_reduction_local_view_basic(inner_dims, dst_dims, axes, accum, 
slice_end, op_type):
     """Test view-based local reduction with simple purely-local layouts."""
@@ -484,6 +493,8 @@ def test_reduction_local_view_basic(inner_dims, dst_dims, 
axes, accum, slice_end
         tvm.testing.assert_allclose(ref, B.numpy(), atol=1e-5)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("n_groups, n_warps", [(1, 1), (1, 4), (2, 8)])
 @pytest.mark.parametrize("op_type", ["sum", "max", "min"])
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
@@ -616,6 +627,8 @@ def test_reduction_local_view_complex(n_groups, n_warps, 
op_type, dtype, shuffle
         tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 7, 10, 15, 
100])
 @pytest.mark.parametrize("op_type", ["max", "min"])
 @pytest.mark.parametrize("accum", [False, True])
@@ -685,6 +698,8 @@ def 
test_reduction_local_optimized_3input_maxmin(reduction_len, op_type, accum):
         tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-5)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 9, 17, 63, 65, 
100])
 @pytest.mark.parametrize("accum", [False, True])
 def test_reduction_local_optimized_packed_add_sum(reduction_len, accum):
@@ -746,6 +761,8 @@ def 
test_reduction_local_optimized_packed_add_sum(reduction_len, accum):
         tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-4)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sum", "max"])
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
 def test_reduction_op_warp_shuffle(op_type, dtype):
@@ -807,6 +824,8 @@ def test_reduction_op_warp_shuffle(op_type, dtype):
         tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_type", ["sum", "max"])
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
 def test_reduction_op_warp_shuffle_multi_elem(op_type, dtype):
@@ -875,6 +894,8 @@ def test_reduction_op_warp_shuffle_multi_elem(op_type, 
dtype):
         tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_reduction_warp_shuffle_multi_warp_loop():
     """Test intra-warp + cross-warp reduction via T.sum in a for loop with 
multiple warps.
 
@@ -951,6 +972,8 @@ def test_reduction_warp_shuffle_multi_warp_loop():
         tvm.testing.assert_allclose(B_ref, B_dev.numpy(), atol=1e-3)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 @pytest.mark.parametrize("op_name", ["sum", "max"])
 def test_reduction_warpgroup_wg_local_layout(op_name):
     rows, cols = 128, 16
diff --git a/tests/python/tirx/test_buffer_print.py 
b/tests/python/tirx/test_buffer_print.py
index 211f4d3903..dbd0da8f84 100644
--- a/tests/python/tirx/test_buffer_print.py
+++ b/tests/python/tirx/test_buffer_print.py
@@ -18,10 +18,12 @@
 import re
 
 import numpy as np
+import pytest
 
 import tvm
 import tvm.testing
 from tvm.script import tirx as T
+from tvm.testing import env
 
 
 def generate_random_data(shape, dtype):
@@ -181,6 +183,8 @@ def verify_cuda_code_string(func, expected_var_name, 
expected_string_literal):
     )
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_print():
     DEV = tvm.cuda()
     target = tvm.target.Target("cuda")
diff --git a/tests/python/tirx/test_control_flow.py 
b/tests/python/tirx/test_control_flow.py
index 1f905bd03c..9085c2b021 100644
--- a/tests/python/tirx/test_control_flow.py
+++ b/tests/python/tirx/test_control_flow.py
@@ -15,9 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import pytest
 
 import tvm
 from tvm.script import tirx as T
+from tvm.testing import env
 
 
 def run_test_break_continue(func, shape, expected):
@@ -32,6 +34,8 @@ def run_test_break_continue(func, shape, expected):
     np.testing.assert_allclose(arr.numpy(), expected)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_break_continue1():
     # fmt: off
     @T.prim_func
@@ -53,6 +57,8 @@ def test_break_continue1():
     run_test_break_continue(func, (10,), expected)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_break_continue2():
     # fmt: off
     @T.prim_func
@@ -79,6 +85,8 @@ def test_break_continue2():
     run_test_break_continue(func, (9,), expected)
 
 
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
 def test_break_continue3():
     # fmt: off
     @T.prim_func
diff --git a/tests/python/tirx/test_layout.py b/tests/python/tirx/test_layout.py
index e3711cb00c..0dcf212ce2 100644
--- a/tests/python/tirx/test_layout.py
+++ b/tests/python/tirx/test_layout.py
@@ -1733,5 +1733,40 @@ def test_slice_single_shard_skips_defensive_floormod():
     # we just assert offset is non-empty and structurally sane (not None).
 
 
+def test_slice_tcgen05_frag_layout_scope_consistent():
+    """Slicing a wid_in_wg+laneid frag layout (tcgen05 16x256b) must stay
+    scope-consistent: the sliced result canonicalizes to a single tid_in_wg
+    chain over the full 128 threads (regression for the per-group-fusion bug).
+    """
+    frag = TileLayout(
+        S[(4, 2, 2, 8, 4, 4, 2) : (1 @ wid_in_wg, 16, 2, 4 @ laneid, 4, 1 @ 
laneid, 1)]
+    )
+
+    def thread_chain(layout):
+        canon = layout.canonicalize()
+        names = {it.axis.name for it in canon.shard if it.axis.is_thread()}
+        titers = sorted(
+            ((int(it.stride), int(it.extent)) for it in canon.shard if 
it.axis.is_thread()),
+        )
+        running = 1
+        for stride, extent in titers:
+            assert stride == running, f"non-contiguous thread chain: {titers}"
+            running *= extent
+        return names, running
+
+    with tvm.target.Target("cuda"):
+        # Full-region slice and a column sub-slice must both canonicalize to a
+        # single tid_in_wg chain covering all 128 warpgroup threads.
+        full = frag.slice([128, 32], [(0, 128), (0, 32)])
+        names, total = thread_chain(full)
+        assert names == {"tid_in_wg"}, names
+        assert total == 128, total
+
+        col = frag.slice([128, 32], [(0, 128), (16, 32)])
+        names_c, total_c = thread_chain(col)
+        assert names_c == {"tid_in_wg"}, names_c
+        assert total_c == 128, total_c
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/scripts/task_python_unittest.sh 
b/tests/scripts/task_python_unittest.sh
index ec052281ad..15bb51bdf7 100755
--- a/tests/scripts/task_python_unittest.sh
+++ b/tests/scripts/task_python_unittest.sh
@@ -55,6 +55,7 @@ TEST_FILES=(
   "tirx-analysis"
   "tirx-base"
   "tirx-transform"
+  "tirx"
   "tvmscript"
   "relax"
 )

Reply via email to