This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 6053f82fcc3efef8ee355d1b0f794c06db6decb0
Author: Bohan Hou <[email protected]>
AuthorDate: Mon May 18 17:48:03 2026 -0400

    feat(op-dispatch): add warp ldmatrix/stmatrix dispatch for Tx.copy (#630)
    
    * feat(op-dispatch): add warp ldmatrix/stmatrix dispatch for Tx.copy
    
    Adds warp-cooperative variants for the Tx.copy primitive that lower to a
    single PTX ldmatrix.sync.aligned.m8n8.x{1,2,4}[.trans].b16 or stmatrix.
    The dispatcher picks num and (row, col) axes from the SMEM region's
    buffer-layout strides so it is invariant under consistent permutation of
    the SMEM tensor's shape/strides, and supports horizontal, vertical, and
    2x2-grid arrangements. SwizzleLayout XOR is honored automatically via
    buf.ptr_to(...). 24 unit tests cover all (num × trans × direction)
    configs, swizzle equivalence, and permutation invariance (byte-for-byte
    identical addresses across permuted layouts).
    
    * feat(op-dispatch): add wg-scope warp_stmatrix/ldmatrix + active-set check
    
    Extends the warp-cooperative ldmatrix/stmatrix dispatch to warpgroup scope:
    each Tx.copy at warpgroup scope emits 4 stmatrix instructions (one per warp)
    covering 4 per-warp tiles inside the SMEM region. The per-warp distribution
    is read from the LOCAL fragment's TileLayout shard (its ``wid_in_wg`` iter's
    position + extents of subsequent iters in the same dim), so it's invariant
    under the layout's internal structure choice.
    
    Also adds an active-thread-set check at both warp and wg scope: PTX
    ld/stmatrix require every lane of the participating warp(s) to be active,
    so an enclosing ``if Tx.filter(...)`` that narrows ``sctx.intra['laneid']``
    or ``sctx.intra['wid_in_wg']`` is rejected (laneid must be (32, 0); for wg
    scope wid_in_wg must also be (4, 0)).
    
    Tests: 28 cases — original warp coverage, new wg dispatch, layout-required
    rejection, and active-set narrowing rejection for both warp and wg.
    
    * docs(test): trim wg layout helper docstring
    
    * refactor(op-dispatch): require warp/wg-wide local view with matching 
extents
    
    Tx.copy semantics demand LHS/RHS region extents to match. Previously the
    warp/wg dispatcher accepted a per-thread local fragment (e.g. ``regs[0:8]``)
    against a warp-wide SMEM tile (e.g. ``D[0:8, 0:32]``) — the byte addresses
    happened to come out right but the extents did not match, so the call was
    not a well-formed Tx.copy.
    
    The dispatcher now requires:
      * Warp scope: local is a warp-wide VIEW whose shape equals the SMEM
        region extents (modulo unit dims), with laneid iters in its layout
        shard whose extents multiply to 32 (full warp).
      * Warpgroup scope: same but ALSO with a wid_in_wg iter (extent 4).
    
    Layouts are written bottom-up with ``.tile().tile()``:
      pure_m → tile laneid → (tile wid_in_wg)
    Helpers in the test file cover x4 horizontal / vertical / 2×2 arrangements.
    
    The impl decomposes the local view via ``local_buf.local()`` inside
    ``with Tx.warp():`` to get a per-thread fragment for the PTX intrinsics.
    
    Coverage: 18 tests for x4 (warp/wg, st/ld, swizzle, permutation
    invariance, all 3 arrangements, rejection paths). x1/x2 dropped pending
    their more complex lane→matrix mapping; PTX intrinsics still available.
    
    * docs(test): inline x4 layout helpers as direct TileLayout shards
---
 .../operator/tile_primitive/cuda/copy/__init__.py  |   1 +
 .../tile_primitive/cuda/copy/warp_matrix.py        | 717 +++++++++++++++++++++
 .../tile_primitive/cuda/test_ldstmatrix.py         | 564 ++++++++++++++++
 3 files changed, 1282 insertions(+)

diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py
index b1b1cc4591..0c236f3a0c 100644
--- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py
@@ -25,3 +25,4 @@ from .utils import (
     copy_default_impl,
 )
 from .vectorized import *
+from .warp_matrix import *
diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/warp_matrix.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/copy/warp_matrix.py
new file mode 100644
index 0000000000..aace360aca
--- /dev/null
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/warp_matrix.py
@@ -0,0 +1,717 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CUDA copy dispatch: warp-cooperative ldmatrix / stmatrix (PTX m8n8 b16).
+
+Registered ops: copy (variant=warp_ldmatrix, warp_stmatrix; priority=25).
+
+Each ``Tx.copy(SMEM[region], LOCAL[region])`` (or the reverse) emits a single
+``ldmatrix.sync.aligned.m8n8.x{1,2,4}[.trans].b16`` / ``stmatrix`` PTX
+instruction — the user does the outer iteration; one ``Tx.copy(...)`` call
+lowers to one PTX instruction.
+
+Dispatcher contract (called once per ``Tx.copy``):
+
+  * direction: ``local`` ↔ ``shared::cta`` determines st vs ld.
+  * ``num`` (1/2/4) is inferred from the SMEM region's last two non-unit dims:
+    ``(8, 8)`` → x1, ``(8, 16)`` → x2, ``(8, 32)`` → x4 (horizontal stack only
+    for now; vertical / 2×2 stacks are future work).
+  * ``trans`` is read from the kwarg ``trans=True/False`` on ``Tx.copy``.
+    (Auto-inferring trans from the local fragment's TileLayout when the user
+    encodes it as a thread-distributed layout is future work — for now the
+    user passes the flag explicitly, matching the existing fp8 callsite.)
+  * The local fragment is per-thread (each lane holds ``num`` b32 register
+    slots = ``num*32`` bits). The dtype can be any width that evenly divides
+    32 (bfloat16/float16 → ``num*2`` elements/lane; uint32 → ``num``).
+  * Per-lane SMEM address: lane k provides row ``k % 8`` of matrix
+    ``k // 8 if num>1 else 0``. The address is computed via
+    ``smem_buf.ptr_to([...])`` so any ``SwizzleLayout`` / ``ComposeLayout`` on
+    the SMEM buffer is honored automatically — exact equivalence with the
+    hand-written XOR-swizzled form in fp8_blockwise_gemm depends on this.
+
+Bottom-line invariants:
+
+  * On the fp8 ``stmatrix.x4.trans`` callsite, the generated PTX is identical
+    to the hand-written form (same per-lane addr expression, same swizzle XOR
+    via the SMEM buffer's layout, same num/trans).
+  * Any input that the dispatcher cannot prove correctly maps to one of the
+    PTX-defined fragment shapes is rejected — falls through to scalar/vec
+    variants, never silently mis-emitted.
+"""
+
+from __future__ import annotations
+
+import math
+import re
+from dataclasses import dataclass
+from typing import Optional
+
+from tvm.runtime import DataType
+from tvm.script import tirx as Tx
+from tvm.tirx import Buffer, BufferRegion, IntImm, PrimFunc
+from tvm.tirx.layout import Axis, ComposeLayout, TileLayout
+from tvm.tirx.operator.tile_primitive import DispatchContext, fail, predicate, 
register_dispatch
+from tvm.tirx.stmt import TilePrimitiveCall
+
+from ..exec_scope_utils import exec_scope_ok
+from .utils import _scope_allowed
+
+
+# ---------- helpers ---------------------------------------------------------
+
+
+def _as_int(x):
+    """Best-effort coercion to a Python int; returns None for non-static."""
+    if isinstance(x, int):
+        return x
+    if isinstance(x, IntImm):
+        return int(x.value)
+    if hasattr(x, "value") and isinstance(x.value, int):
+        return int(x.value)
+    return None
+
+
+def _region_st_ext(region: BufferRegion):
+    return [r.min for r in region.region], [r.extent for r in region.region]
+
+
+def _direction(op_call: TilePrimitiveCall) -> Optional[str]:
+    """Return ``'st'`` (local→shared), ``'ld'`` (shared→local), or None."""
+    op_call = TilePrimitiveCall.downcast(op_call)
+    s = op_call.src.buffer.scope()
+    d = op_call.dst.buffer.scope()
+    if s == "local" and d.startswith("shared"):
+        return "st"
+    if s.startswith("shared") and d == "local":
+        return "ld"
+    return None
+
+
+def _buffer_per_dim_stride(buf: Buffer) -> Optional[list[int]]:
+    """For each buffer dim, return the per-unit element stride.
+
+    Walks ``buf.layout`` (TileLayout or ComposeLayout(Swizzle, TileLayout)),
+    groups its shard by the buffer shape, and takes the **smallest** shard
+    stride within each group — that's the stride incurred by incrementing
+    that buffer dim by 1. Returns None if any stride is non-static or the
+    layout isn't supported.
+
+    The SwizzleLayout XORs bits of the resulting linear offset; it does NOT
+    change which buffer dim has the bigger macro-stride. So using the inner
+    TileLayout's strides is sound for our row-vs-col identification.
+    """
+    layout = buf.layout
+    if layout is None:
+        return None
+    tile_layout = layout
+    if isinstance(layout, ComposeLayout):
+        tile_layout = layout.tile_layout
+    if not isinstance(tile_layout, TileLayout):
+        return None
+    shard = getattr(tile_layout, "shard", None)
+    if not shard:
+        return None
+    try:
+        grouped, seps = tile_layout.group(list(buf.shape))
+    except Exception:  # noqa: BLE001
+        return None
+    strides: list[int] = []
+    for d in range(len(buf.shape)):
+        start, end = int(seps[d]), int(seps[d + 1])
+        if end == start:
+            strides.append(0)
+            continue
+        group_strides = []
+        for i in range(start, end):
+            s = _as_int(grouped.shard[i].stride)
+            if s is None:
+                return None
+            group_strides.append(s)
+        # Smallest stride in the group = stride incurred by buffer-dim+=1
+        # (we walk the inner-most shard first).
+        strides.append(min(group_strides))
+    return strides
+
+
+# Arrangement constants: how the ``num`` 8×8 matrices are laid out in the
+# 2D SMEM region (per-warp tile).
+_HORIZONTAL = "horizontal"  # 8 × (num*8): matrices side-by-side along col_dim
+_VERTICAL = "vertical"      # (num*8) × 8: matrices stacked along row_dim
+_GRID_2X2 = "2x2"           # 16 × 16: 4 matrices in a 2×2 grid (num=4 only)
+
+
+def _has_full_laneid_iters(local_buf: Buffer) -> bool:
+    """Check the local layout's shard has laneid iters whose extents multiply
+    to 32 (full warp coverage)."""
+    if local_buf.layout is None:
+        return False
+    shard = getattr(local_buf.layout, "shard", None)
+    if not shard:
+        return False
+    product = 1
+    for it in shard:
+        if it.axis.name == "laneid":
+            e = _as_int(it.extent)
+            if e is None:
+                return False
+            product *= e
+    return product == 32
+
+
+def _wg_distribution_from_layout(
+    local_buf: Buffer, smem_ext_i: list[int]
+) -> Optional[tuple[int, int, list[int]]]:
+    """Read the warp distribution from the local fragment's layout.
+
+    The local view must be wg-wide (its shape matches the SMEM region
+    extents) and carry a ``wid_in_wg`` iter in the shard. Returns
+    ``(wg_axis_dim, wg_step, per_warp_extents)``:
+
+      * ``wg_axis_dim`` — the local/SMEM shape dim in which the ``wid_in_wg``
+        iter lives (via ``layout.group(buf.shape)``).
+      * ``wg_step`` — the **shape-coord step** corresponding to ``warp_id +=
+        1`` along that dim. In TileLayout shard, iters within a dim are
+        ordered slowest-to-fastest in mixed-radix; the step is the product
+        of subsequent (faster) iters' extents in the same dim.
+      * ``per_warp_extents`` — ``smem_ext_i`` with ``wg_axis_dim`` reduced
+        from ``ext`` to ``ext // 4`` (the per-warp tile).
+
+    Note: this is shape-coord units, not linear stride. The dispatcher adds
+    ``warp_id * wg_step`` to ``smem_idx[wg_axis_dim]`` directly.
+    """
+    if local_buf.layout is None:
+        return None
+    shard = getattr(local_buf.layout, "shard", None)
+    if not shard:
+        return None
+    wid_pos = None
+    wid_iter = None
+    for i, it in enumerate(shard):
+        if it.axis.name == "wid_in_wg":
+            if wid_iter is not None:
+                return None  # multiple wid_in_wg iters → too complex for now
+            wid_pos = i
+            wid_iter = it
+    if wid_iter is None:
+        return None
+    if _as_int(wid_iter.extent) != 4:
+        return None
+
+    # Find which local-buffer dim the wid iter belongs to.
+    try:
+        grouped, seps = local_buf.layout.group(list(local_buf.shape))
+    except Exception:  # noqa: BLE001
+        return None
+    wid_local_dim = None
+    for d in range(len(local_buf.shape)):
+        if int(seps[d]) <= wid_pos < int(seps[d + 1]):
+            wid_local_dim = d
+            break
+    if wid_local_dim is None:
+        return None
+
+    # Per-warp shape step = product of extents of subsequent iters in the
+    # same dim (faster-changing axes), all in the SAME shape-dim segment.
+    dim_end = int(seps[wid_local_dim + 1])
+    wg_step = 1
+    for i in range(wid_pos + 1, dim_end):
+        e = _as_int(grouped.shard[i].extent)
+        if e is None:
+            return None
+        wg_step *= e
+
+    # Map local dim → SMEM dim by aligning non-unit dims one-to-one. The
+    # local view's shape must equal the SMEM region's non-unit extents.
+    local_shape_i = [_as_int(s) for s in local_buf.shape]
+    if None in local_shape_i:
+        return None
+    smem_non_unit = [(i, e) for i, e in enumerate(smem_ext_i) if e != 1]
+    if [e for _, e in smem_non_unit] != local_shape_i:
+        return None
+    wid_smem_dim = smem_non_unit[wid_local_dim][0]
+
+    # 4 warps × wg_step must fit the SMEM dim's extent.
+    if smem_ext_i[wid_smem_dim] != wg_step * 4:
+        return None
+
+    per_warp = list(smem_ext_i)
+    per_warp[wid_smem_dim] = wg_step
+    return wid_smem_dim, wg_step, per_warp
+
+
+def _infer_arrangement(
+    smem_ext_i: list[int], smem_strides: list[int]
+) -> Optional[tuple[int, int, int, str]]:
+    """Identify the m8n8.x{1,2,4} arrangement via the buffer's per-dim strides.
+
+    Among the slice's non-unit dims, the dim with the LARGER buffer stride is
+    the "row direction" (where matrix rows live), the SMALLER is the "col
+    direction". Returns ``(num, row_dim, col_dim, arrangement)`` where
+    ``arrangement`` is one of:
+
+      * ``"horizontal"`` — row=8, col=num*8 ∈ {8,16,32}: matrices side-by-side
+        along col_dim. Lane k: row_dim += k%8; col_dim += matrix_id*8.
+      * ``"vertical"`` — row=num*8 ∈ {16,32}, col=8: matrices stacked along
+        row_dim. Lane k: row_dim += matrix_id*8 + k%8; col_dim += 0.
+      * ``"2x2"`` — row=16, col=16 (num=4 only): four matrices in a 2×2 grid.
+        Lane k: row_dim += (matrix_id//2)*8 + k%8; col_dim += (matrix_id%2)*8.
+
+    Returns None if no pattern matches or if both stride and extent ties make
+    the row/col choice genuinely ambiguous (degenerate square with equal
+    strides — pathological, can be addressed if a use case appears).
+    """
+    non_unit_idxs = [i for i, e in enumerate(smem_ext_i) if e != 1]
+    if len(non_unit_idxs) != 2:
+        return None
+    i0, i1 = non_unit_idxs
+    s0, s1 = smem_strides[i0], smem_strides[i1]
+    e0, e1 = smem_ext_i[i0], smem_ext_i[i1]
+    if s0 > s1:
+        row_dim, col_dim = i0, i1
+    elif s1 > s0:
+        row_dim, col_dim = i1, i0
+    elif e0 != e1:
+        # Strides tied but extents differ — the dim with the smaller extent
+        # is conventionally the "row" (8 rows per matrix in PTX m8n8).
+        if e0 < e1:
+            row_dim, col_dim = i0, i1
+        else:
+            row_dim, col_dim = i1, i0
+    else:
+        # Both strides AND extents equal — genuinely ambiguous (degenerate
+        # square). Caller can resolve by reshaping or by choosing a
+        # non-square slice.
+        return None
+    e_row = smem_ext_i[row_dim]
+    e_col = smem_ext_i[col_dim]
+
+    if e_row == 8 and e_col in (8, 16, 32):
+        return e_col // 8, row_dim, col_dim, _HORIZONTAL
+    if e_row in (16, 32) and e_col == 8:
+        return e_row // 8, row_dim, col_dim, _VERTICAL
+    if e_row == 16 and e_col == 16:
+        return 4, row_dim, col_dim, _GRID_2X2
+    return None
+
+
+@dataclass
+class _Bound:
+    num: int
+    trans: bool
+    direction: str  # "st" or "ld"
+    smem_region: BufferRegion
+    local_region: BufferRegion
+    row_dim: int  # SMEM buffer dim with LARGER stride (or smaller extent on 
tie)
+    col_dim: int  # SMEM buffer dim with SMALLER stride
+    arrangement: str  # one of _HORIZONTAL / _VERTICAL / _GRID_2X2
+    local_elements_per_b32: int
+    # Warpgroup-scope fields. ``wg`` is False for warp-scope binds.
+    wg: bool = False
+    # The SMEM dim along which the 4 warps walk (each warp adds
+    # ``warp_id * wg_step`` to this dim on top of the per-stamp offset).
+    wg_axis_dim: int = -1
+    wg_step: int = 0
+
+
+def _try_bind(op_call: TilePrimitiveCall, sctx: DispatchContext, 
want_direction: str):
+    """Validate and bind dispatcher state for **warp** scope. Returns
+    ``_Bound`` on success or a short error string on rejection."""
+    if not sctx.is_warp:
+        return f"exec_scope is {sctx.scope_kind!r}, not 'warp'"
+    err = _check_full_active_set(sctx, is_wg=False)
+    if err is not None:
+        return err
+    return _bind_common(op_call, want_direction, is_wg=False)
+
+
+def _try_bind_wg(op_call: TilePrimitiveCall, sctx: DispatchContext, 
want_direction: str):
+    """Validate and bind for **warpgroup** scope. Returns ``_Bound`` (with
+    ``wg=True`` and warp-walk fields populated) or an error string."""
+    if not sctx.is_warpgroup:
+        return f"exec_scope is {sctx.scope_kind!r}, not 'warpgroup'"
+    err = _check_full_active_set(sctx, is_wg=True)
+    if err is not None:
+        return err
+    return _bind_common(op_call, want_direction, is_wg=True)
+
+
+def _check_full_active_set(sctx: DispatchContext, *, is_wg: bool) -> 
Optional[str]:
+    """Verify the active thread set is the FULL warp/warpgroup.
+
+    PTX ldmatrix/stmatrix requires every lane of the participating warp to be
+    active (32-lane sync). If an enclosing ``if Tx.filter(...)`` narrowed the
+    active set, ``sctx.intra`` reports the reduced extent — we reject those
+    cases here.
+
+    For warp scope: laneid must be (32, 0).
+    For warpgroup scope: laneid (32, 0) AND wid_in_wg (4, 0).
+    """
+    required = {"laneid": 32}
+    if is_wg:
+        required["wid_in_wg"] = 4
+    for axis_name, expected in required.items():
+        if axis_name not in sctx.intra:
+            return f"sctx.intra missing {axis_name!r} 
(scope_kind={sctx.scope_kind!r})"
+        extent_raw, offset_raw = sctx.intra[axis_name]
+        extent = _as_int(extent_raw)
+        offset = _as_int(offset_raw)
+        if extent is None or offset is None:
+            return f"non-static active range for {axis_name}: ({extent_raw}, 
{offset_raw})"
+        if extent != expected or offset != 0:
+            return (
+                f"active {axis_name} range is [{offset}, {offset + extent}); "
+                f"ldmatrix/stmatrix needs the full [0, {expected}) — an 
enclosing "
+                "if/filter has narrowed the warp"
+            )
+    return None
+
+
+def _bind_common(op_call: TilePrimitiveCall, want_direction: str, *, is_wg: 
bool):
+    direction = _direction(op_call)
+    if direction != want_direction:
+        return f"direction {direction} != {want_direction}"
+
+    op_call = TilePrimitiveCall.downcast(op_call)
+    smem_region = op_call.dst if direction == "st" else op_call.src
+    local_region = op_call.src if direction == "st" else op_call.dst
+
+    smem_buf: Buffer = smem_region.buffer
+    local_buf: Buffer = local_region.buffer
+
+    # B1: SMEM dtype 16-bit (PTX .b16). Local dtype any width that divides 32.
+    smem_bits = DataType(smem_buf.dtype).bits
+    if smem_bits != 16:
+        return f"SMEM dtype must be 16-bit (b16), got {smem_buf.dtype}"
+    local_bits = DataType(local_buf.dtype).bits
+    if 32 % local_bits != 0:
+        return f"local dtype bits {local_bits} must evenly divide 32 (b32 reg 
unit)"
+    elements_per_b32 = 32 // local_bits
+
+    # B2: SMEM region extents + buffer strides.
+    _, smem_ext = _region_st_ext(smem_region)
+    smem_ext_i = [_as_int(e) for e in smem_ext]
+    if None in smem_ext_i:
+        return f"SMEM extents must be compile-time integers, got {smem_ext}"
+    smem_strides = _buffer_per_dim_stride(smem_buf)
+    if smem_strides is None:
+        return f"could not determine static per-dim strides from SMEM layout 
{smem_buf.layout}"
+
+    if is_wg:
+        # WG: local is a wg-wide view; layout carries laneid (extent 32 full)
+        # AND a wid_in_wg iter. Per-warp SMEM dim/step come from the
+        # wid_in_wg iter's position in the shard.
+        wg_info = _wg_distribution_from_layout(local_buf, smem_ext_i)
+        if wg_info is None:
+            return (
+                f"warpgroup local fragment must be a wg-wide view (shape 
matching "
+                f"SMEM region) with a wid_in_wg iter in its layout shard; "
+                f"got shape={list(local_buf.shape)} layout={local_buf.layout}"
+            )
+        wg_axis_dim, wg_step, per_warp_ext = wg_info
+        inferred = _infer_arrangement(per_warp_ext, smem_strides)
+    else:
+        # Warp: local is a warp-wide view; layout carries laneid iters whose
+        # extents multiply to 32 (full warp). Whole SMEM region is the per-warp
+        # tile (no warp-walk).
+        if not _has_full_laneid_iters(local_buf):
+            return (
+                f"warp local fragment must be a warp-wide view (shape matching 
"
+                f"SMEM region) with laneid iters totaling extent 32 in its "
+                f"layout shard; got shape={list(local_buf.shape)} 
layout={local_buf.layout}"
+            )
+        per_warp_ext = smem_ext_i
+        wg_axis_dim = -1
+        wg_step = 0
+        inferred = _infer_arrangement(smem_ext_i, smem_strides)
+
+    if inferred is None:
+        return (
+            f"per-warp tile {per_warp_ext} (strides {smem_strides}) doesn't 
match any "
+            "m8n8.x{1,2,4} arrangement"
+        )
+    num, row_dim, col_dim, arrangement = inferred
+
+    # B3: local fragment is a warp- or wg-wide VIEW. Its logical extents must
+    # equal the SMEM region extents (matching ``Tx.copy`` semantics — both
+    # sides describe the same region size).
+    _, local_ext = _region_st_ext(local_region)
+    local_ext_i = [_as_int(e) for e in local_ext]
+    if None in local_ext_i:
+        return f"local extents must be compile-time integers, got {local_ext}"
+    smem_non_unit = sorted([e for e in smem_ext_i if e != 1])
+    local_non_unit = sorted([e for e in local_ext_i if e != 1])
+    if smem_non_unit != local_non_unit:
+        return (
+            f"local region {local_ext_i} non-unit extents must match SMEM "
+            f"region {smem_ext_i} (got {local_non_unit} vs {smem_non_unit})"
+        )
+
+    cfg = op_call.config or {}
+    trans = bool(cfg.get("trans", False))
+
+    return _Bound(
+        num=num,
+        trans=trans,
+        direction=direction,
+        smem_region=smem_region,
+        local_region=local_region,
+        row_dim=row_dim,
+        col_dim=col_dim,
+        arrangement=arrangement,
+        local_elements_per_b32=elements_per_b32,
+        wg=is_wg,
+        wg_axis_dim=wg_axis_dim,
+        wg_step=wg_step,
+    )
+
+
+def _sm_version(sctx: DispatchContext) -> int:
+    arch = getattr(sctx.target, "arch", "") or ""
+    m = re.match(r"sm_(\d+)", arch)
+    return int(m.group(1)) if m else 0
+
+
+def _make_predicate(want_direction: str, min_sm: int, *, wg: bool = False):
+    bind = _try_bind_wg if wg else _try_bind
+
+    def _pred(op_call, sctx):
+        res = bind(op_call, sctx, want_direction)
+        if isinstance(res, str):
+            return False, res
+        sm = _sm_version(sctx)
+        if sm < min_sm:
+            name = "stmatrix" if want_direction == "st" else "ldmatrix"
+            return False, f"{name} requires sm_{min_sm}+, got sm_{sm}"
+        return True, None
+    return _pred
+
+
+# ---------- impl ------------------------------------------------------------
+
+
+def _impl(op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: 
str) -> PrimFunc:
+    return _emit(op_call, sctx, want_direction, is_wg=False)
+
+
+def _impl_wg(op_call: TilePrimitiveCall, sctx: DispatchContext, 
want_direction: str) -> PrimFunc:
+    return _emit(op_call, sctx, want_direction, is_wg=True)
+
+
+def _emit(
+    op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: str, *, 
is_wg: bool
+) -> PrimFunc:
+    res = (_try_bind_wg if is_wg else _try_bind)(op_call, sctx, want_direction)
+    if isinstance(res, str):
+        fail(res)
+    b: _Bound = res
+
+    smem_buf = b.smem_region.buffer
+    local_buf = b.local_region.buffer
+    smem_st, _ = _region_st_ext(b.smem_region)
+    local_st, _ = _region_st_ext(b.local_region)
+
+    tid_x = sctx.launch_params["threadIdx.x"]
+    num = b.num
+    trans = b.trans
+    row_dim = b.row_dim
+    col_dim = b.col_dim
+    arrangement = b.arrangement
+    wg_axis_dim = b.wg_axis_dim
+    wg_step = b.wg_step
+
+    # Python-level closures: build PrimExpr index lists at parse time. Index
+    # mutation must live outside the prim_func body — TVM Script treats
+    # ``list[i] = ...`` inside a func as a BufferStore.
+    #
+    # Per-arrangement lane → SMEM (row_dim, col_dim) offsets on the
+    # PER-WARP tile:
+    #   horizontal (8 × num*8): row += k%8; col += matrix_id*8.
+    #   vertical   (num*8 × 8): row += matrix_id*8 + k%8; col += 0.
+    #   2x2       (16 × 16):    row += (matrix_id//2)*8 + k%8;
+    #                           col += (matrix_id%2)*8.
+    #
+    # At warpgroup scope (``is_wg=True``), an additional warp-walk offset
+    # ``warp_id * wg_step`` is layered onto ``wg_axis_dim`` (the dim along
+    # which the 4 warps line up their per-warp tiles).
+    def _make_smem_idx(row_in_matrix, matrix_id, warp_id):
+        idx = list(smem_st)
+        if arrangement == _HORIZONTAL:
+            idx[row_dim] = smem_st[row_dim] + row_in_matrix
+            idx[col_dim] = smem_st[col_dim] + matrix_id * 8
+        elif arrangement == _VERTICAL:
+            idx[row_dim] = smem_st[row_dim] + matrix_id * 8 + row_in_matrix
+        else:  # _GRID_2X2 (num=4 only)
+            idx[row_dim] = smem_st[row_dim] + (matrix_id // 2) * 8 + 
row_in_matrix
+            idx[col_dim] = smem_st[col_dim] + (matrix_id % 2) * 8
+        if is_wg:
+            idx[wg_axis_dim] = idx[wg_axis_dim] + warp_id * wg_step
+        return idx
+
+    elements_per_b32 = b.local_elements_per_b32
+
+    def _make_ld_handles(local_per_thread):
+        return tuple(
+            local_per_thread.ptr_to([r * elements_per_b32]) for r in range(num)
+        )
+
+    if b.direction == "st":
+        if is_wg:
+            # fmt: off
+            @Tx.prim_func(check_well_formed=False)
+            def impl():
+                with Tx.warp():
+                    warp_id = Tx.meta_var((tid_x // 32) % 4)
+                    lane_id = Tx.meta_var(tid_x % 32)
+                    row_in_matrix = Tx.meta_var(lane_id % 8)
+                    matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0)
+                    local_per_thread = local_buf.local()
+                    Tx.ptx.stmatrix(
+                        smem_buf.ptr_to(_make_smem_idx(row_in_matrix, 
matrix_id, warp_id)),
+                        local_per_thread.ptr_to([0]),
+                        num=num,
+                        trans=trans,
+                    )
+            # fmt: on
+            return impl
+
+        # fmt: off
+        @Tx.prim_func(check_well_formed=False)
+        def impl():
+            lane_id = Tx.meta_var(tid_x % 32)
+            row_in_matrix = Tx.meta_var(lane_id % 8)
+            matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0)
+            local_per_thread = local_buf.local()
+            Tx.ptx.stmatrix(
+                smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, 0)),
+                local_per_thread.ptr_to([0]),
+                num=num,
+                trans=trans,
+            )
+        # fmt: on
+        return impl
+
+    if is_wg:
+        # fmt: off
+        @Tx.prim_func(check_well_formed=False)
+        def impl():
+            with Tx.warp():
+                warp_id = Tx.meta_var((tid_x // 32) % 4)
+                lane_id = Tx.meta_var(tid_x % 32)
+                row_in_matrix = Tx.meta_var(lane_id % 8)
+                matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0)
+                local_per_thread = local_buf.local()
+                Tx.ptx.ldmatrix(
+                    trans, num, "b16",
+                    smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, 
warp_id)),
+                    *_make_ld_handles(local_per_thread),
+                )
+        # fmt: on
+        return impl
+
+    # fmt: off
+    @Tx.prim_func(check_well_formed=False)
+    def impl():
+        lane_id = Tx.meta_var(tid_x % 32)
+        row_in_matrix = Tx.meta_var(lane_id % 8)
+        matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0)
+        local_per_thread = local_buf.local()
+        Tx.ptx.ldmatrix(
+            trans, num, "b16",
+            smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, 0)),
+            *_make_ld_handles(local_per_thread),
+        )
+    # fmt: on
+    return impl
+
+
+# ---------- registration ----------------------------------------------------
+
+
+_STMATRIX_PAIRS = [("local", "shared*"), ("local", "shared::cta")]
+_LDMATRIX_PAIRS = [("shared*", "local"), ("shared::cta", "local")]
+
+
+@register_dispatch(
+    "copy",
+    "cuda",
+    variant="warp_stmatrix",
+    priority=25,
+    when=[
+        predicate("storage_scope", _scope_allowed, 
allowed_pairs=_STMATRIX_PAIRS),
+        predicate("exec_scope", exec_scope_ok, expected_scopes=["warp"]),
+        predicate("stmatrix_compat", _make_predicate("st", min_sm=90)),
+    ],
+)
+def copy_schedule_warp_stmatrix(op_call: TilePrimitiveCall, sctx: 
DispatchContext) -> PrimFunc:
+    return _impl(op_call, sctx, want_direction="st")
+
+
+@register_dispatch(
+    "copy",
+    "cuda",
+    variant="warp_ldmatrix",
+    priority=25,
+    when=[
+        predicate("storage_scope", _scope_allowed, 
allowed_pairs=_LDMATRIX_PAIRS),
+        predicate("exec_scope", exec_scope_ok, expected_scopes=["warp"]),
+        predicate("ldmatrix_compat", _make_predicate("ld", min_sm=75)),
+    ],
+)
+def copy_schedule_warp_ldmatrix(op_call: TilePrimitiveCall, sctx: 
DispatchContext) -> PrimFunc:
+    return _impl(op_call, sctx, want_direction="ld")
+
+
+@register_dispatch(
+    "copy",
+    "cuda",
+    variant="warpgroup_stmatrix",
+    priority=25,
+    when=[
+        predicate("storage_scope", _scope_allowed, 
allowed_pairs=_STMATRIX_PAIRS),
+        predicate("exec_scope", exec_scope_ok, expected_scopes=["warpgroup"]),
+        predicate("wg_stmatrix_compat", _make_predicate("st", min_sm=90, 
wg=True)),
+    ],
+)
+def copy_schedule_warpgroup_stmatrix(
+    op_call: TilePrimitiveCall, sctx: DispatchContext
+) -> PrimFunc:
+    return _impl_wg(op_call, sctx, want_direction="st")
+
+
+@register_dispatch(
+    "copy",
+    "cuda",
+    variant="warpgroup_ldmatrix",
+    priority=25,
+    when=[
+        predicate("storage_scope", _scope_allowed, 
allowed_pairs=_LDMATRIX_PAIRS),
+        predicate("exec_scope", exec_scope_ok, expected_scopes=["warpgroup"]),
+        predicate("wg_ldmatrix_compat", _make_predicate("ld", min_sm=75, 
wg=True)),
+    ],
+)
+def copy_schedule_warpgroup_ldmatrix(
+    op_call: TilePrimitiveCall, sctx: DispatchContext
+) -> PrimFunc:
+    return _impl_wg(op_call, sctx, want_direction="ld")
+
+
+__all__ = [
+    "copy_schedule_warp_ldmatrix",
+    "copy_schedule_warp_stmatrix",
+    "copy_schedule_warpgroup_ldmatrix",
+    "copy_schedule_warpgroup_stmatrix",
+]
diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py 
b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py
new file mode 100644
index 0000000000..1d0816bf9e
--- /dev/null
+++ b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py
@@ -0,0 +1,564 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring
+
+"""Tests for the warp/warpgroup ldmatrix/stmatrix dispatcher under
+``Tx.copy``. Covers x4 in horizontal / vertical / 2×2 arrangements at warp
+and warpgroup scope. Each Tx.copy must have matching LHS/RHS region
+extents — the local fragment is a warp- or wg-wide VIEW with thread
+distribution encoded in its TileLayout (laneid / wid_in_wg iters).
+x1/x2 variants are TODO (their lane→matrix mapping is more involved).
+"""
+
+from __future__ import annotations
+
+import re
+
+import pytest
+
+import tvm
+from tvm.script import tirx as Tx
+from tvm.tirx.layout import Axis, Iter, S, TileLayout
+from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode, 
mma_shared_layout
+
+
+# ---------------------------------------------------------------------------
+# Layout helpers: pure-m → tile laneid (→ tile wid_in_wg)
+# ---------------------------------------------------------------------------
+
+
+def _x4_h_warp_layout():
+    """x4 horizontal warp-wide view (8, 32):
+      row = laneid % 8;  col = (laneid // 8) * 8 + per-thread 0..7"""
+    return TileLayout(S[(8, 4, 8) : (
+        1 @ Axis.laneid,
+        8 @ Axis.laneid,
+        1,
+    )])
+
+
+def _x4_h_wg_layout():
+    """x4 horizontal wg-wide view (8, 128):
+      row = laneid % 8;  col = wid_in_wg*32 + (laneid//8)*8 + per-thread 
0..7"""
+    return TileLayout(S[(8, 4, 4, 8) : (
+        1 @ Axis.laneid,
+        1 @ Axis.wid_in_wg,
+        8 @ Axis.laneid,
+        1,
+    )])
+
+
+def _x4_v_warp_layout():
+    """x4 vertical warp-wide view (32, 8): row = laneid; col = per-thread 
0..7"""
+    return TileLayout(S[(32, 8) : (1 @ Axis.laneid, 1)])
+
+
+def _x4_2x2_warp_layout():
+    """x4 2×2 grid warp-wide view (16, 16):
+      row = (laneid//16)*8 + laneid%8;  col = ((laneid//8)%2)*8 + per-thread 
0..7"""
+    return TileLayout.from_iters([
+        Iter(8, 1, Axis.laneid),    # lane_low → row stride 1
+        Iter(2, 16, Axis.laneid),   # row-block (lane bit 4) → row stride 8
+        Iter(2, 8, Axis.laneid),    # col-block (lane bit 3) → col stride 8
+        Iter(8, 1, Axis.m),         # per-thread → col stride 1
+    ])
+
+
+_SM100A = tvm.target.Target({"kind": "cuda", "arch": "sm_100a"})
+
+
+def _compile_get_cuda(prim_func) -> str:
+    with _SM100A:
+        mod = tvm.compile(
+            tvm.IRModule({"main": prim_func}), target=_SM100A, 
tir_pipeline="tirx"
+        )
+    return mod.mod.imports[0].inspect_source()
+
+
+# ---------------------------------------------------------------------------
+# warp scope x4: stmatrix / ldmatrix × {non-trans, trans}
+# ---------------------------------------------------------------------------
+
+
[email protected]("trans", [False, True])
+def test_warp_stmatrix_x4(trans):
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (32, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 32, layout=layout)
+                    Tx.copy(D[0:8, 0:32], regs_warp[0:8, 0:32], trans=trans)
+
+    src = _compile_get_cuda(f)
+    expected = f"stmatrix.sync.aligned.m8n8.x4{'.trans' if trans else 
''}.shared.b16"
+    assert expected in src
+    assert "& 7" in src and ">> 3" in src
+
+
[email protected]("trans", [False, True])
+def test_warp_ldmatrix_x4(trans):
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (32, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 32, layout=layout)
+                    Tx.copy(regs_warp[0:8, 0:32], D[0:8, 0:32], trans=trans)
+
+    src = _compile_get_cuda(f)
+    expected = f"ldmatrix.sync.aligned.m8n8.x4{'.trans' if trans else 
''}.shared.b16"
+    assert expected in src
+    assert "& 7" in src and ">> 3" in src
+
+
+# ---------------------------------------------------------------------------
+# Swizzle: 128B SwizzleLayout XOR honored
+# ---------------------------------------------------------------------------
+
+
+def test_warp_stmatrix_swizzle_128b():
+    shape = (1, 8, 128)
+    sw_layout = mma_shared_layout("bfloat16", SwizzleMode.SWIZZLE_128B_ATOM, 
shape)
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(shape, "bfloat16", scope="shared", 
layout=sw_layout)
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 32, layout=layout)
+                    Tx.copy(D[0, 0:8, 0:32], regs_warp[0:8, 0:32], trans=True)
+
+    src = _compile_get_cuda(f)
+    assert "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" in src
+    assert bool(re.search(r"\^.*threadIdx|threadIdx.*\^", src))
+
+
+def test_warp_ldmatrix_swizzle_128b():
+    shape = (1, 8, 128)
+    sw_layout = mma_shared_layout("bfloat16", SwizzleMode.SWIZZLE_128B_ATOM, 
shape)
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(shape, "bfloat16", scope="shared", 
layout=sw_layout)
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 32, layout=layout)
+                    Tx.copy(regs_warp[0:8, 0:32], D[0, 0:8, 0:32], trans=False)
+
+    src = _compile_get_cuda(f)
+    assert "ldmatrix.sync.aligned.m8n8.x4.shared.b16" in src
+    assert bool(re.search(r"\^.*threadIdx|threadIdx.*\^", src))
+
+
+# ---------------------------------------------------------------------------
+# Permutation invariance: rebuilding the SMEM with permuted shape/strides
+# gives identical per-lane addresses (3 arrangements × byte-equal check)
+# ---------------------------------------------------------------------------
+
+
+def _stmatrix_line(src):
+    for line in src.split("\n"):
+        if "ptx_stmatrix_m8n8" in line and "D_ptr[" in line:
+            return line.strip()
+    return None
+
+
+def _assert_permute_same_addr(f_ref, f_perm, expected_inst):
+    src_ref = _compile_get_cuda(f_ref)
+    src_perm = _compile_get_cuda(f_perm)
+    assert expected_inst in src_ref
+    assert expected_inst in src_perm
+    a_ref = _stmatrix_line(src_ref)
+    a_perm = _stmatrix_line(src_perm)
+    assert a_ref is not None and a_perm is not None
+    assert a_ref == a_perm, f"\n  ref:  {a_ref}\n  perm: {a_perm}"
+
+
+def test_permutation_invariance_horizontal():
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f_ref():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (2, 8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(2, 8, 32) : (256, 32, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 32, layout=layout)
+                    Tx.copy(D[0, 0:8, 0:32], regs_warp[0:8, 0:32], trans=True)
+
+    @Tx.prim_func
+    def f_perm():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (2, 32, 8), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(2, 32, 8) : (256, 1, 32)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 32, layout=layout)
+                    Tx.copy(D[0, 0:32, 0:8], regs_warp[0:8, 0:32], trans=True)
+
+    _assert_permute_same_addr(
+        f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+    )
+
+
+def test_permutation_invariance_vertical():
+    layout = _x4_v_warp_layout()
+
+    @Tx.prim_func
+    def f_ref():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (32, 8), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(32, 8) : (8, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(32, 8, layout=layout)
+                    Tx.copy(D[0:32, 0:8], regs_warp[0:32, 0:8], trans=False)
+
+    @Tx.prim_func
+    def f_perm():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (1, 8)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(32, 8, layout=layout)
+                    Tx.copy(D[0:8, 0:32], regs_warp[0:32, 0:8], trans=False)
+
+    _assert_permute_same_addr(
+        f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.shared.b16"
+    )
+
+
+def test_permutation_invariance_2x2():
+    layout = _x4_2x2_warp_layout()
+
+    @Tx.prim_func
+    def f_ref():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (16, 16), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(16, 16) : (16, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(16, 16, layout=layout)
+                    Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False)
+
+    @Tx.prim_func
+    def f_perm():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (16, 16), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(16, 16) : (1, 16)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(16, 16, layout=layout)
+                    Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False)
+
+    _assert_permute_same_addr(
+        f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.shared.b16"
+    )
+
+
+# ---------------------------------------------------------------------------
+# Arrangement coverage (vertical / 2×2 dispatch reaches PTX emit)
+# ---------------------------------------------------------------------------
+
+
+def test_warp_vertical():
+    layout = _x4_v_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (32, 8), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(32, 8) : (8, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(32, 8, layout=layout)
+                    Tx.copy(D[0:32, 0:8], regs_warp[0:32, 0:8], trans=False)
+
+    src = _compile_get_cuda(f)
+    assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in src
+    assert "threadIdx.x" in src
+
+
+def test_warp_2x2():
+    layout = _x4_2x2_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (16, 16), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(16, 16) : (16, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    regs_warp = regs.view(16, 16, layout=layout)
+                    Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False)
+
+    src = _compile_get_cuda(f)
+    assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in src
+    assert "threadIdx.x" in src
+
+
+# ---------------------------------------------------------------------------
+# Warpgroup-scope x4
+# ---------------------------------------------------------------------------
+
+
+def test_wg_stmatrix_x4_trans():
+    wg_layout = _x4_h_wg_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([128])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 128), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 128) : (128, 1)]),
+                )
+                with Tx.warpgroup():
+                    regs = Tx.alloc_buffer((4,), "uint32", scope="local")
+                    regs_wg = regs.view("bfloat16").view(8, 128, 
layout=wg_layout)
+                    Tx.copy(D[0:8, 0:128], regs_wg[0:8, 0:128], trans=True)
+
+    src = _compile_get_cuda(f)
+    assert "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" in src
+
+
+# ---------------------------------------------------------------------------
+# Rejection cases
+# ---------------------------------------------------------------------------
+
+
+def test_reject_extent_mismatch():
+    """Local region extents don't match SMEM region — Tx.copy semantically
+    invalid, dispatcher rejects."""
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (32, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                    # Raw per-thread fragment, no warp-wide view.
+                    Tx.copy(D[0:8, 0:32], regs[0:8], trans=True)
+
+    with pytest.raises(Exception) as excinfo:
+        _compile_get_cuda(f)
+    assert "warp_stmatrix" in str(excinfo.value)
+
+
+def test_reject_non_b16_smem():
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "float32", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (32, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((8,), "float32", scope="local")
+                    Tx.copy(D[0:8, 0:32], regs[0:8], trans=True)
+
+    with pytest.raises(Exception) as excinfo:
+        _compile_get_cuda(f)
+    s = str(excinfo.value)
+    assert "warp_stmatrix" in s and "b16" in s
+
+
+def test_reject_wrong_smem_shape():
+    """8×40 doesn't decompose into any m8n8.x{1,2,4} arrangement."""
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 40), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 40) : (40, 1)]),
+                )
+                with Tx.warp():
+                    regs = Tx.alloc_buffer((10,), "bfloat16", scope="local")
+                    regs_warp = regs.view(8, 40, layout=TileLayout(S[(8, 40) : 
(40, 1)]))
+                    Tx.copy(D[0:8, 0:40], regs_warp[0:8, 0:40], trans=True)
+
+    with pytest.raises(Exception) as excinfo:
+        _compile_get_cuda(f)
+    assert "warp_stmatrix" in str(excinfo.value)
+
+
+def test_reject_warp_filtered_lanes():
+    """``if Tx.filter(lane, 0, 16)`` narrows the active set — stmatrix
+    requires all 32 lanes."""
+    layout = _x4_h_warp_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (32, 1)]),
+                )
+                with Tx.warp():
+                    lane_id = Tx.lane_id([32])
+                    if Tx.filter(lane_id, 0, 16):
+                        regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                        regs_warp = regs.view(8, 32, layout=layout)
+                        Tx.copy(D[0:8, 0:32], regs_warp[0:8, 0:32], trans=True)
+
+    with pytest.raises(Exception) as excinfo:
+        _compile_get_cuda(f)
+    s = str(excinfo.value)
+    assert "warp_stmatrix" in s and "laneid" in s and "narrow" in s
+
+
+def test_reject_wg_filtered_warps():
+    """``if Tx.filter(warp_id, 0, 2)`` at wg scope narrows to 2 warps —
+    stmatrix wg dispatcher needs all 4."""
+    wg_layout = _x4_h_wg_layout()
+
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([128])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 128), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 128) : (128, 1)]),
+                )
+                with Tx.warpgroup():
+                    warp_id = Tx.warp_id_in_wg([4])
+                    if Tx.filter(warp_id, 0, 2):
+                        regs = Tx.alloc_buffer((4,), "uint32", scope="local")
+                        regs_wg = regs.view("bfloat16").view(8, 128, 
layout=wg_layout)
+                        Tx.copy(D[0:8, 0:128], regs_wg[0:8, 0:128], trans=True)
+
+    with pytest.raises(Exception) as excinfo:
+        _compile_get_cuda(f)
+    s = str(excinfo.value)
+    assert "warpgroup_stmatrix" in s and "wid_in_wg" in s and "narrow" in s
+
+
+def test_reject_non_warp_scope():
+    """Tx.copy at cta scope (no warp/wg wrap) — warp_stmatrix dispatcher must
+    not fire. The dispatch error log must list warp_stmatrix as rejected."""
+    @Tx.prim_func
+    def f():
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                D = Tx.alloc_buffer(
+                    (8, 32), "bfloat16", scope="shared",
+                    layout=TileLayout(S[(8, 32) : (32, 1)]),
+                )
+                regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
+                Tx.copy(D[0:8, 0:32], regs[0:8], trans=True)
+
+    try:
+        src = _compile_get_cuda(f)
+        assert "stmatrix" not in src
+    except Exception as e:
+        s = str(e)
+        assert "warp_stmatrix" in s

Reply via email to