This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit aef7bd7e97c0e2c1c3b11363ae1c36b8ae5ed92d Author: Hongyi Jin <[email protected]> AuthorDate: Thu May 28 02:06:34 2026 -0400 feat(gemm_async): accept Layout F C operand for M=64 MMAs (#648) * feat(gemm_async): accept Layout F C operand for M=64 MMAs The tcgen05 gemm_async dispatch hardcoded the C-operand expected layout to ``S[(M, N) : (1@TLane, 1@TCol)]`` (Layout D, M=128 identity), and rejected any other layout via the ``assert_structural_equal`` on the sliced layout. That blocked the canonical pairing introduced in #646: C buffer allocated as ``tmem_pool.alloc((64, N), datapath="F")`` → M=64 MMA write → ``.16x*b`` M=64 readback which is the only way to get the M=64 fragment's logical rows 0..63 dense (no half-slab garbage) with a single PTX issue per readback. This change detects when the C buffer was tagged Layout F via the new ``tmem_pool.alloc(..., datapath="F")`` kwarg and asserts the slice against ``tmem_datapath_layout("F", 64, N)`` (the (4, 16, N) split-warp scatter form) instead of the Layout D identity. The is_2x2 path is untouched. Tests: 16 gemm_async tests + 141 tmem_16xnb tests pass with the change. M=128 Layout D and 2x2 paths take the same code path as before. A smoke test (BF16 16x16x16 GEMM into a Layout F (64, 64) buffer, read back via ``.16x256b.x8`` M=64) verifies the new pairing end-to-end on B200. Used by: kernel-evolution/gdn/prefill/v0 M=64 readback alignment. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * test(gemm_async): add Layout F C operand round-trip for M=64 MMA Companion test to the previous commit's gemm_async dispatch change. Issues an M=64 bf16 GEMM into a Layout F TMEM buffer (built from ``tmem_datapath_layout("F", 64, N)``), reads back via ``.16x256b`` M=64, and asserts against the numpy reference. Without the dispatch change the test fails to compile because the C-operand layout check rejects Layout F. --- .../tile_primitive/cuda/gemm_async/tcgen05.py | 44 +++++++- .../cuda/gemm_async/test_gemm_async.py | 122 ++++++++++++++++++++- 2 files changed, 163 insertions(+), 3 deletions(-) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py index 4e89155973..a439355a97 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py @@ -30,7 +30,16 @@ from tvm.arith.analyzer import Analyzer from tvm.runtime import DataType from tvm.script import tirx as Tx from tvm.tirx import PrimFunc -from tvm.tirx.layout import ComposeLayout, Iter, R, S, TCol, TileLayout, TLane +from tvm.tirx.layout import ( + ComposeLayout, + Iter, + R, + S, + TCol, + TileLayout, + TLane, + tmem_datapath_layout, +) from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch from tvm.tirx.operator.tile_primitive.ops import KernelReplacePoint from tvm.tirx.stmt import AllocBuffer, Evaluate, SeqStmt, TilePrimitiveCall @@ -292,6 +301,25 @@ def _choose_mma_tile(M, N, cta_group, MMA_N_MIN): return M_mma, N_mma +def _layout_matches_datapath_f(tmem_buf) -> bool: + """Return True if ``tmem_buf.layout`` structurally equals Layout F (M=64 + scattered) over the buffer's full (64, X) shape — i.e. the buffer was + allocated via ``tmem_pool.alloc((64, X), datapath="F")``. + + Used by the C-operand layout check to accept M=64 MMA writes into Layout + F C buffers (the canonical pairing for M=64 outputs that are read back + via ``.16x*b`` M=64; see PTX ISA §9.7.16.10.5). + """ + if tmem_buf.layout is None or int(tmem_buf.shape[0]) != 64: + return False + try: + expected = tmem_datapath_layout("F", 64, tmem_buf.shape[1]).canonicalize() + tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), expected) + return True + except (AssertionError, ValueError): + return False + + def gemm_async_tcgen05_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc: """Schedule an asynchronous GEMM operation using tcgen05.mma (Blackwell Tensor Core). @@ -628,11 +656,23 @@ def gemm_async_tcgen05_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) - ) # Check C's sliced layout, allow offset. - # 4x1 layout: (M, N):(1@TLane, 1@TCol) + # 4x1 layout (Layout D, M=128 identity): (M, N):(1@TLane, 1@TCol) # 2x2 layout: (M, 2, N//2):(1@TLane, 64@TLane, 1@TCol) + # Layout F (M=64 scatter): the full TMEM buffer is shape (64, X) with the + # scattered row→lane mapping from tmem_datapath_layout("F", 64, X). When + # the user allocates with ``tmem_pool.alloc(..., datapath="F")`` and slices + # the full row range, the slice layout structurally matches Layout F over + # (M=64, N) — assert against that base instead of the Layout D identity. if is_2x2: N_half = N // 2 base = TileLayout(S[(M, 2, N_half) : (1 @ TLane, 64 @ TLane, 1 @ TCol)]) + elif ( + M == 64 + and int(C_buffer.shape[0]) == 64 + and C_buffer.layout is not None + and _layout_matches_datapath_f(C_buffer) + ): + base = tmem_datapath_layout("F", 64, N) else: base = TileLayout(S[(M, N) : (1 @ TLane, 1 @ TCol)]) expected_c_layout = TileLayout.from_iters( diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py index 729aee5862..c498520045 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py @@ -31,7 +31,7 @@ import tvm import tvm.testing from tvm.ir.type import PointerType, PrimType from tvm.script import tirx as Tx -from tvm.tirx.layout import S, TCol, TileLayout, TLane +from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg from tvm.tirx.operator.tile_primitive.cuda.gemm_async import sf_tmem_layout from tvm.tirx.operator.tile_primitive.cuda.tma_utils import ( @@ -299,6 +299,126 @@ def test_gemm_tcgen05_cta_group_1(task): np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) + +def test_gemm_tcgen05_cta_group_1_layout_f_m64(): + """M=64 MMA with C operand allocated as Layout F (datapath="F"). + + Exercises the new ``gemm_async`` path that accepts C buffers tagged + Layout F — written by an M=64 MMA in their canonical scattered + row->lane mapping (PTX ISA §9.7.16.10.5), read back via the + ``.16x256b`` M=64 atom (one PTX issue covering all 64 logical rows + densely). Without the dispatch change this kernel fails to compile + because the C-operand layout check asserts Layout D identity. + """ + M, N, K = 64, 64, 64 + A_dtype, B_dtype, C_dtype = "float16", "float16", "float32" + A_shape, B_shape, C_shape = (M, K), (N, K), (M, N) + A_layout = mma_shared_layout(A_dtype, 3, A_shape) + B_layout = mma_shared_layout(B_dtype, 3, B_shape) + + # The C TMEM buffer carries Layout F over its full (64, N) shape; that's + # what gemm_async structurally matches against to accept the M=64 write. + from tvm.tirx.layout import tmem_datapath_layout + c_layout = tmem_datapath_layout("F", 64, N) + + # fmt: off + @Tx.prim_func + def gemm_layout_f(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: + A = Tx.match_buffer(A_ptr, A_shape, A_dtype) + B = Tx.match_buffer(B_ptr, B_shape, B_dtype) + C = Tx.match_buffer(C_ptr, C_shape, C_dtype) + + Tx.device_entry() + warp_id = Tx.warp_id([4]) + cta_id = Tx.cta_id([1]) + wg_id = Tx.warpgroup_id([1]) + tid_in_wg = Tx.thread_id_in_wg([128]) + lane_id = Tx.lane_id([32]) + + A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) + B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) + tmem_addr = Tx.alloc_shared([1], "uint32") + tma_mbar = Tx.alloc_shared([1], "uint64") + mma_mbar = Tx.alloc_shared([1], "uint64") + + if tid_in_wg == 0: + with Tx.thread(): + Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + Tx.ptx.fence.proxy_async("shared::cta") + Tx.cuda.cta_sync() + + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=64, cta_group=1) + Tx.cuda.cta_sync() + # Layout F C operand — the path under test. + tmem = Tx.decl_buffer((64, N), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=c_layout) # noqa: E501 + + if tid_in_wg == 0: + with Tx.thread(): + tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[:, :], A[:, :], **tma_args) + Tx.copy_async(B_smem[:, :], B[:, :], **tma_args) + Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), (M * K + N * K) * 2) + Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + Tx.cuda.cta_sync() + + if tid_in_wg == 0: + with Tx.thread(): + Tx.gemm_async(tmem[0:64, 0:N], A_smem[:, :], B_smem[:, :], dispatch="tcgen05") # noqa: E501 + Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + Tx.cuda.cta_sync() + Tx.ptx.tcgen05.fence.after_thread_sync() + + # Read back via .16x256b M=64 (the canonical pairing). + reg = Tx.alloc_local(32, dtype="float32") + reg_view = reg.view(64, N, layout=tcgen05_atom_layout("16x256b", (64, N), "float32")) + if wg_id == 0: + with Tx.warpgroup(): + Tx.copy_async(reg_view[:, :], tmem[0:64, 0:N]) + Tx.ptx.tcgen05.wait.ld() + Tx.cuda.cta_sync() + + # Per-(reg -> row, col) decomposition for .16x256b M=64 fp32 (BT=64 -> rep=8): + # r = v0p + 2*va + 4*vb, v0p in {0,1}, va in {0,1}, vb in [0, 8) + # row = (lane_id >> 2) + 8*va + 16*warp_id + # col = v0p + ((lane_id & 3) << 1) + 8*vb + for vb in Tx.unroll(8): + for va in Tx.unroll(2): + for v0p in Tx.unroll(2): + r: Tx.let = v0p + 2 * va + 4 * vb + row: Tx.let = (lane_id >> 2) + 8 * va + 16 * warp_id + col: Tx.let = v0p + ((lane_id & 3) << 1) + 8 * vb + C[row, col] = reg[r] + + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=64, cta_group=1) + # fmt: on + + dev = tvm.cuda(0) + np.random.seed(0) + target = tvm.target.Target("cuda") + with target: + mod = tvm.compile(tvm.IRModule({"main": gemm_layout_f}), target=target, tir_pipeline="tirx") + + A_np = np.random.randn(*A_shape).astype(A_dtype) + B_np = np.random.randn(*B_shape).astype(B_dtype) + C_np = np.zeros(C_shape, dtype=C_dtype) + A_tvm = tvm.runtime.tensor(A_np, dev) + B_tvm = tvm.runtime.tensor(B_np, dev) + C_tvm = tvm.runtime.tensor(C_np, dev) + mod["main"](A_tvm, B_tvm, C_tvm) + + C_ref = A_np.astype(np.float32) @ B_np.astype(np.float32).T + np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-2, rtol=1e-2) + + + + @pytest.mark.parametrize( "task", [
