This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit aef7bd7e97c0e2c1c3b11363ae1c36b8ae5ed92d
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu May 28 02:06:34 2026 -0400

    feat(gemm_async): accept Layout F C operand for M=64 MMAs (#648)
    
    * feat(gemm_async): accept Layout F C operand for M=64 MMAs
    
    The tcgen05 gemm_async dispatch hardcoded the C-operand expected layout to
    ``S[(M, N) : (1@TLane, 1@TCol)]`` (Layout D, M=128 identity), and rejected
    any other layout via the ``assert_structural_equal`` on the sliced layout.
    That blocked the canonical pairing introduced in #646:
    
      C buffer allocated as ``tmem_pool.alloc((64, N), datapath="F")``  →
      M=64 MMA write  →  ``.16x*b`` M=64 readback
    
    which is the only way to get the M=64 fragment's logical rows 0..63 dense
    (no half-slab garbage) with a single PTX issue per readback.
    
    This change detects when the C buffer was tagged Layout F via the new
    ``tmem_pool.alloc(..., datapath="F")`` kwarg and asserts the slice against
    ``tmem_datapath_layout("F", 64, N)`` (the (4, 16, N) split-warp scatter
    form) instead of the Layout D identity. The is_2x2 path is untouched.
    
    Tests: 16 gemm_async tests + 141 tmem_16xnb tests pass with the change.
    M=128 Layout D and 2x2 paths take the same code path as before. A smoke
    test (BF16 16x16x16 GEMM into a Layout F (64, 64) buffer, read back via
    ``.16x256b.x8`` M=64) verifies the new pairing end-to-end on B200.
    
    Used by: kernel-evolution/gdn/prefill/v0 M=64 readback alignment.
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
    
    * test(gemm_async): add Layout F C operand round-trip for M=64 MMA
    
    Companion test to the previous commit's gemm_async dispatch change.
    Issues an M=64 bf16 GEMM into a Layout F TMEM buffer (built from
    ``tmem_datapath_layout("F", 64, N)``), reads back via ``.16x256b`` M=64,
    and asserts against the numpy reference. Without the dispatch change the
    test fails to compile because the C-operand layout check rejects Layout F.
---
 .../tile_primitive/cuda/gemm_async/tcgen05.py      |  44 +++++++-
 .../cuda/gemm_async/test_gemm_async.py             | 122 ++++++++++++++++++++-
 2 files changed, 163 insertions(+), 3 deletions(-)

diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py
index 4e89155973..a439355a97 100644
--- a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py
@@ -30,7 +30,16 @@ from tvm.arith.analyzer import Analyzer
 from tvm.runtime import DataType
 from tvm.script import tirx as Tx
 from tvm.tirx import PrimFunc
-from tvm.tirx.layout import ComposeLayout, Iter, R, S, TCol, TileLayout, TLane
+from tvm.tirx.layout import (
+    ComposeLayout,
+    Iter,
+    R,
+    S,
+    TCol,
+    TileLayout,
+    TLane,
+    tmem_datapath_layout,
+)
 from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, 
register_dispatch
 from tvm.tirx.operator.tile_primitive.ops import KernelReplacePoint
 from tvm.tirx.stmt import AllocBuffer, Evaluate, SeqStmt, TilePrimitiveCall
@@ -292,6 +301,25 @@ def _choose_mma_tile(M, N, cta_group, MMA_N_MIN):
     return M_mma, N_mma
 
 
+def _layout_matches_datapath_f(tmem_buf) -> bool:
+    """Return True if ``tmem_buf.layout`` structurally equals Layout F (M=64
+    scattered) over the buffer's full (64, X) shape — i.e. the buffer was
+    allocated via ``tmem_pool.alloc((64, X), datapath="F")``.
+
+    Used by the C-operand layout check to accept M=64 MMA writes into Layout
+    F C buffers (the canonical pairing for M=64 outputs that are read back
+    via ``.16x*b`` M=64; see PTX ISA §9.7.16.10.5).
+    """
+    if tmem_buf.layout is None or int(tmem_buf.shape[0]) != 64:
+        return False
+    try:
+        expected = tmem_datapath_layout("F", 64, 
tmem_buf.shape[1]).canonicalize()
+        tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), 
expected)
+        return True
+    except (AssertionError, ValueError):
+        return False
+
+
 def gemm_async_tcgen05_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) 
-> PrimFunc:
     """Schedule an asynchronous GEMM operation using tcgen05.mma (Blackwell 
Tensor Core).
 
@@ -628,11 +656,23 @@ def gemm_async_tcgen05_impl(op_call: TilePrimitiveCall, 
sctx: DispatchContext) -
         )
 
     # Check C's sliced layout, allow offset.
-    # 4x1 layout: (M, N):(1@TLane, 1@TCol)
+    # 4x1 layout (Layout D, M=128 identity): (M, N):(1@TLane, 1@TCol)
     # 2x2 layout: (M, 2, N//2):(1@TLane, 64@TLane, 1@TCol)
+    # Layout F (M=64 scatter): the full TMEM buffer is shape (64, X) with the
+    # scattered row→lane mapping from tmem_datapath_layout("F", 64, X). When
+    # the user allocates with ``tmem_pool.alloc(..., datapath="F")`` and slices
+    # the full row range, the slice layout structurally matches Layout F over
+    # (M=64, N) — assert against that base instead of the Layout D identity.
     if is_2x2:
         N_half = N // 2
         base = TileLayout(S[(M, 2, N_half) : (1 @ TLane, 64 @ TLane, 1 @ 
TCol)])
+    elif (
+        M == 64
+        and int(C_buffer.shape[0]) == 64
+        and C_buffer.layout is not None
+        and _layout_matches_datapath_f(C_buffer)
+    ):
+        base = tmem_datapath_layout("F", 64, N)
     else:
         base = TileLayout(S[(M, N) : (1 @ TLane, 1 @ TCol)])
     expected_c_layout = TileLayout.from_iters(
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py 
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
index 729aee5862..c498520045 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
@@ -31,7 +31,7 @@ import tvm
 import tvm.testing
 from tvm.ir.type import PointerType, PrimType
 from tvm.script import tirx as Tx
-from tvm.tirx.layout import S, TCol, TileLayout, TLane
+from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout
 from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg
 from tvm.tirx.operator.tile_primitive.cuda.gemm_async import sf_tmem_layout
 from tvm.tirx.operator.tile_primitive.cuda.tma_utils import (
@@ -299,6 +299,126 @@ def test_gemm_tcgen05_cta_group_1(task):
         np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
 
 
+
+def test_gemm_tcgen05_cta_group_1_layout_f_m64():
+    """M=64 MMA with C operand allocated as Layout F (datapath="F").
+
+    Exercises the new ``gemm_async`` path that accepts C buffers tagged
+    Layout F — written by an M=64 MMA in their canonical scattered
+    row->lane mapping (PTX ISA §9.7.16.10.5), read back via the
+    ``.16x256b`` M=64 atom (one PTX issue covering all 64 logical rows
+    densely). Without the dispatch change this kernel fails to compile
+    because the C-operand layout check asserts Layout D identity.
+    """
+    M, N, K = 64, 64, 64
+    A_dtype, B_dtype, C_dtype = "float16", "float16", "float32"
+    A_shape, B_shape, C_shape = (M, K), (N, K), (M, N)
+    A_layout = mma_shared_layout(A_dtype, 3, A_shape)
+    B_layout = mma_shared_layout(B_dtype, 3, B_shape)
+
+    # The C TMEM buffer carries Layout F over its full (64, N) shape; that's
+    # what gemm_async structurally matches against to accept the M=64 write.
+    from tvm.tirx.layout import tmem_datapath_layout
+    c_layout = tmem_datapath_layout("F", 64, N)
+
+    # fmt: off
+    @Tx.prim_func
+    def gemm_layout_f(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> 
None:
+        A = Tx.match_buffer(A_ptr, A_shape, A_dtype)
+        B = Tx.match_buffer(B_ptr, B_shape, B_dtype)
+        C = Tx.match_buffer(C_ptr, C_shape, C_dtype)
+
+        Tx.device_entry()
+        warp_id = Tx.warp_id([4])
+        cta_id  = Tx.cta_id([1])
+        wg_id   = Tx.warpgroup_id([1])
+        tid_in_wg = Tx.thread_id_in_wg([128])
+        lane_id = Tx.lane_id([32])
+
+        A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", 
layout=A_layout)
+        B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", 
layout=B_layout)
+        tmem_addr = Tx.alloc_shared([1], "uint32")
+        tma_mbar = Tx.alloc_shared([1], "uint64")
+        mma_mbar = Tx.alloc_shared([1], "uint64")
+
+        if tid_in_wg == 0:
+            with Tx.thread():
+                Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1)
+                Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1)
+        Tx.ptx.fence.proxy_async("shared::cta")
+        Tx.cuda.cta_sync()
+
+        if warp_id == 0:
+            with Tx.warp():
+                Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=64, 
cta_group=1)
+        Tx.cuda.cta_sync()
+        # Layout F C operand — the path under test.
+        tmem = Tx.decl_buffer((64, N), C_dtype, scope="tmem", 
allocated_addr=tmem_addr[0], layout=c_layout)  # noqa: E501
+
+        if tid_in_wg == 0:
+            with Tx.thread():
+                tma_args = Tx.meta_var({"dispatch": "tma", "mbar": 
tma_mbar.ptr_to([0])})
+                Tx.copy_async(A_smem[:, :], A[:, :], **tma_args)
+                Tx.copy_async(B_smem[:, :], B[:, :], **tma_args)
+                Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), (M * K 
+ N * K) * 2)
+        Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0)
+        Tx.cuda.cta_sync()
+
+        if tid_in_wg == 0:
+            with Tx.thread():
+                Tx.gemm_async(tmem[0:64, 0:N], A_smem[:, :], B_smem[:, :], 
dispatch="tcgen05")  # noqa: E501
+                Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1)
+        Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0)
+        Tx.cuda.cta_sync()
+        Tx.ptx.tcgen05.fence.after_thread_sync()
+
+        # Read back via .16x256b M=64 (the canonical pairing).
+        reg = Tx.alloc_local(32, dtype="float32")
+        reg_view = reg.view(64, N, layout=tcgen05_atom_layout("16x256b", (64, 
N), "float32"))
+        if wg_id == 0:
+            with Tx.warpgroup():
+                Tx.copy_async(reg_view[:, :], tmem[0:64, 0:N])
+                Tx.ptx.tcgen05.wait.ld()
+        Tx.cuda.cta_sync()
+
+        # Per-(reg -> row, col) decomposition for .16x256b M=64 fp32 (BT=64 -> 
rep=8):
+        #   r = v0p + 2*va + 4*vb,   v0p in {0,1}, va in {0,1}, vb in [0, 8)
+        #   row = (lane_id >> 2) + 8*va + 16*warp_id
+        #   col = v0p + ((lane_id & 3) << 1) + 8*vb
+        for vb in Tx.unroll(8):
+            for va in Tx.unroll(2):
+                for v0p in Tx.unroll(2):
+                    r: Tx.let = v0p + 2 * va + 4 * vb
+                    row: Tx.let = (lane_id >> 2) + 8 * va + 16 * warp_id
+                    col: Tx.let = v0p + ((lane_id & 3) << 1) + 8 * vb
+                    C[row, col] = reg[r]
+
+        if warp_id == 0:
+            with Tx.warp():
+                Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+                Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=64, cta_group=1)
+    # fmt: on
+
+    dev = tvm.cuda(0)
+    np.random.seed(0)
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.compile(tvm.IRModule({"main": gemm_layout_f}), 
target=target, tir_pipeline="tirx")
+
+    A_np = np.random.randn(*A_shape).astype(A_dtype)
+    B_np = np.random.randn(*B_shape).astype(B_dtype)
+    C_np = np.zeros(C_shape, dtype=C_dtype)
+    A_tvm = tvm.runtime.tensor(A_np, dev)
+    B_tvm = tvm.runtime.tensor(B_np, dev)
+    C_tvm = tvm.runtime.tensor(C_np, dev)
+    mod["main"](A_tvm, B_tvm, C_tvm)
+
+    C_ref = A_np.astype(np.float32) @ B_np.astype(np.float32).T
+    np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-2, rtol=1e-2)
+
+
+
+
 @pytest.mark.parametrize(
     "task",
     [

Reply via email to