This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit e3271628f450f4a534818524cfa7559a1dc099f0 Author: Hongyi Jin <[email protected]> AuthorDate: Wed May 27 16:23:47 2026 -0400 feat(tirx): add .16x{64,128,256}b tcgen05.ld/st dispatch + factory (#644) Adds a unified ``tcgen05_atom_layout(instr_shape, tensor_shape, dtype)`` factory and matching ``Tx.alloc_tcgen05_frag(...)`` wrapper so users can allocate per-thread register fragments for any of the supported PTX shape-1 atoms (.32x32b, .16x64b, .16x128b, .16x256b) with one call. ``Tx.copy_async`` inspects the local-side layout and dispatches to the right tcgen05 PTX atom. Layout derivation: per-shape (row, col) decomposition is derived from the CUTLASS DstLayout traits (3rdparty/cutlass/include/cute/atom/copy_traits_sm100.hpp). For .16x*b shapes (M=64 fragments) each warp's atom covers one 16-row slab of the warpgroup's 64-row fragment, driven by the PTX 9.7.16.8.1 access restriction that puts warp i on TMEM lanes i*32..i*32+31. For .32x32b (M=128) the factory returns the canonical (128, K):(1@tid_in_wg, 1) layout already accepted by the existing dispatch path. TMEM is kept dense for 16-bit dtypes (2 elements per 32-bit cell, matching the existing .32x32b convention) rather than going through .pack::16b / .unpack::16b — those would waste half the TMEM cell width. Bit-exact micro-tests (test_tmem_16xnb.py, 92 cases): - fp32 load through tcgen05_atom_layout: stage A via .32x32b.st (chunked for K>128 fp32 cols), load via .<instr_shape>.x<rep>, per-thread dump, host reconstructs (row, col) from the layout formula. - fp32 store round-trip (.<instr_shape>.st → .32x32b.ld). - 16-bit (fp16/bf16) self-consistent round-trip (.<instr_shape>.st → .<instr_shape>.ld preserves per-thread bits). - Explicit Tx.alloc_tcgen05_frag wrapper compile + PTX-emission check. Regression: full tests/python/tirx/operator/tile_primitive/cuda/ suite (955 tests) passes unchanged. Files - python/tvm/tirx/layout.py: factory ``tcgen05_atom_layout`` + supported rep tables + per-shape iter decompositions. - python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py: split copy_tmem_local_impl into ``_emit_32x32b_path`` (unchanged) and ``_emit_16xnb_path`` (new); structural-match local layout to dispatch. - python/tvm/tirx/script/builder/ir.py: ``alloc_tcgen05_frag`` wrapper. - tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py: bit-exact micro-test suite. --- python/tvm/tirx/layout.py | 200 ++++++ .../tile_primitive/cuda/copy_async/tcgen05_ldst.py | 178 +++++- python/tvm/tirx/script/builder/ir.py | 57 ++ .../cuda/copy_async/test_tmem_16xnb.py | 709 +++++++++++++++++++++ 4 files changed, 1140 insertions(+), 4 deletions(-) diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py index ed55b4f802..e17a0d61f8 100644 --- a/python/tvm/tirx/layout.py +++ b/python/tvm/tirx/layout.py @@ -566,6 +566,7 @@ except NameError: # pragma: no cover __all__ = [] # type: ignore[var-annotated] __all__ += list(_AXIS_NAMES) __all__ += ["R", "S"] +__all__ += ["wg_local_layout", "tcgen05_atom_layout"] def wg_local_layout(cols, rows=128): @@ -577,6 +578,205 @@ def wg_local_layout(cols, rows=128): return TileLayout(S[(rows, cols) : (1 @ Axis.tid_in_wg, 1)]) +# Allowed (.shape, .num) combinations for tcgen05.ld/st atoms. +# Source: PTX ISA Table 49 (tcgen05-num-shapes-ld). +_TCGEN05_ATOM_REPS = { + "32x32b": (1, 2, 4, 8, 16, 32, 64, 128), + "16x64b": (1, 2, 4, 8, 16, 32, 64, 128), + "16x128b": (1, 2, 4, 8, 16, 32, 64), + "16x256b": (1, 2, 4, 8, 16, 32), +} + + +# Per-warp fp32-column factor for each instr_shape. For .16x*b atoms the +# warpgroup fragment is 64 rows × (factor * rep) fp32 cols; for .32x32b the +# fragment is 128 rows × (factor * rep) fp32 cols with factor=1. +_TCGEN05_COL_FACTOR_FP32 = {"32x32b": 1, "16x64b": 2, "16x128b": 4, "16x256b": 8} + +# Number of fragment rows per warpgroup for each instr_shape. +_TCGEN05_FRAG_ROWS = {"32x32b": 128, "16x64b": 64, "16x128b": 64, "16x256b": 64} + + +def tcgen05_atom_layout( + instr_shape: str, tensor_shape: tuple[int, int], dtype +) -> "TileLayout": + """Register-side ``TileLayout`` for ``tcgen05.ld``/``tcgen05.st`` ``.16x*`` atoms. + + Describes the per-warpgroup register tile that ``Tx.copy_async`` produces + when reading a TMEM fragment via ``tcgen05.{ld,st}.<instr_shape>.xN``. + ``rep`` (the ``.xN`` qualifier) is inferred from ``tensor_shape``. + + Fragment row count is determined by ``instr_shape``: ``.32x32b`` covers an + M=128 fragment (128 rows per warpgroup), and ``.16x{64,128,256}b`` covers + an M=64 fragment (64 rows per warpgroup). + + TMEM is kept **dense** for 16-bit dtypes: two 16-bit elements per 32-bit + TMEM cell (matching the existing ``.32x32b`` convention). The PTX op is + issued with the plain ``.b32`` form (no ``.pack::16b`` qualifier), and + the returned layout describes the per-thread register file with two + packed 16-bit elements per 32-bit register. + + Parameters + ---------- + instr_shape : str + The PTX atom's ``.shape`` qualifier. One of ``"32x32b"``, ``"16x64b"``, + ``"16x128b"``, ``"16x256b"``. + tensor_shape : tuple[int, int] + The logical fragment shape in **element units**. Must be + ``(frag_rows, K)`` where ``frag_rows`` is ``128`` for ``.32x32b`` and + ``64`` for the other shapes, and ``K`` is divisible by the per-warp + column factor for the chosen instr_shape and dtype:: + + K must be a power-of-two multiple of (factor_fp32 * elem_per_32b) + + where ``factor_fp32`` is ``1`` / ``2`` / ``4`` / ``8`` for ``.32x32b`` / + ``.16x64b`` / ``.16x128b`` / ``.16x256b``, and ``elem_per_32b`` is + ``1`` for fp32 and ``2`` for fp16/bf16. The inferred rep must be in PTX + Table 49's supported set for the chosen instr_shape. + dtype : str | tvm.DataType + Element dtype. ``"float32"``, ``"float16"``, or ``"bfloat16"``. + + Returns + ------- + TileLayout + A ``(64, K)``-shaped tile layout. The factory builds it as a sequence + of fine-grained iters describing the per-(lane, register) destination + position; ``.group([(64, K)])[0]`` flattens to two iters. + + Examples + -------- + ``tcgen05_atom_layout("16x64b", (64, 64), "float32")`` → ``.16x64b.x32`` (rep=32, fp32). + + ``tcgen05_atom_layout("16x128b", (64, 256), "float16")`` → ``.16x128b.x32`` (rep=32, + fp16; two fp16 elements packed per 32-bit register and per 32-bit TMEM cell). + """ + if instr_shape not in _TCGEN05_ATOM_REPS: + raise ValueError( + f"tcgen05_atom_layout instr_shape must be one of " + f"{list(_TCGEN05_ATOM_REPS)}, got {instr_shape!r}" + ) + bits = tvm.runtime.DataType(dtype).bits + if bits not in (16, 32): + raise ValueError( + f"tcgen05_atom_layout dtype must be a 32-bit or 16-bit type, " + f"got {dtype} ({bits} bits)" + ) + if len(tensor_shape) != 2: + raise ValueError( + f"tcgen05_atom_layout tensor_shape must be 2-D (rows, cols), got {tensor_shape!r}" + ) + rows, cols = tensor_shape + expected_rows = _TCGEN05_FRAG_ROWS[instr_shape] + if rows != expected_rows: + raise ValueError( + f"tcgen05_atom_layout {instr_shape!r} expects rows={expected_rows}, got {rows}" + ) + + elem_per_32b = 32 // bits + col_factor_elem = _TCGEN05_COL_FACTOR_FP32[instr_shape] * elem_per_32b + if cols % col_factor_elem != 0: + raise ValueError( + f"tcgen05_atom_layout cols={cols} not divisible by the per-rep column " + f"factor {col_factor_elem} for instr_shape={instr_shape!r} dtype={dtype}; " + f"valid cols are k * {col_factor_elem} for k in " + f"{_TCGEN05_ATOM_REPS[instr_shape]}" + ) + rep = cols // col_factor_elem + if rep not in _TCGEN05_ATOM_REPS[instr_shape]: + raise ValueError( + f"tcgen05_atom_layout inferred rep={rep} (from cols={cols}) is not in " + f"the PTX Table 49 supported set for {instr_shape}: " + f"{_TCGEN05_ATOM_REPS[instr_shape]}" + ) + + laneid = Axis.laneid + wid = Axis.wid_in_wg + N = rep + shape = instr_shape + # All m-strides below are written in fp32-reg units; we multiply by + # elem_per_32b at the end and prepend a C_pack iter for the 16-bit case + # (each fp32 reg packs ``elem_per_32b`` elements at adjacent col positions). + + if shape == "32x32b": + # M=128 fragment, simple thread-rows layout: + # (rows=128, cols=K) : (1@tid_in_wg, 1) + # Each of 128 warpgroup threads owns one row; cols are contiguous in + # the per-thread storage. For 16-bit dtypes the K cols are packed two + # per 32-bit register (handled by the per-thread storage element count + # naturally — m-stride 1 in element units). + iters = [ + Iter(rows, 1, Axis.tid_in_wg), + Iter(cols, 1, "m"), + ] + return TileLayout.from_iters(iters, [], {}) + + if shape == "16x64b": + # Per-warp tile (fp32 view): (16 rows, 2N cols). Per-lane regs = N. + # Lane (t0, t1, t2): t0 = laneid & 1, t1 = (laneid >> 1) & 1, t2 = laneid >> 2. + # Row = t2 + 8*t0 + 16*wid_in_wg + # Col (fp32) = t1 + 2*r, r ∈ [0, N) + row_iters_fp32 = [ + (8, 4, laneid), # R_t2: laneid bits 2..4 → R bits 0..2 + (2, 1, laneid), # R_t0: laneid bit 0 → R bit 3 + (4, 1, wid), # R_w: wid_in_wg → R bits 4..5 + ] + col_iters_fp32 = [ + (2, 2, laneid), # C_t1: laneid bit 1 → C bit 0 + (N, 1, "m"), # C_r: register slot → C bits 1.. + ] + elif shape == "16x128b": + # Per-warp tile (fp32 view): (16 rows, 4N cols). Per-lane regs = 2N. + # Lane (t0, t1): t0 = laneid & 3, t1 = laneid >> 2. + # Reg r = ra + 2*rb, ra ∈ {0,1}, rb ∈ [0, N). + # Row = t1 + 8*ra + 16*wid_in_wg + # Col (fp32) = t0 + 4*rb + row_iters_fp32 = [ + (8, 4, laneid), # R_t1: laneid bits 2..4 → R bits 0..2 + (2, 1, "m"), # R_ra: reg bit 0 → R bit 3 + (4, 1, wid), # R_w + ] + col_iters_fp32 = [ + (4, 1, laneid), # C_t0: laneid bits 0..1 → C bits 0..1 + (N, 2, "m"), # C_rb: reg bits 1.. → C bits 2.. + ] + else: # 16x256b + # Per-warp tile (fp32 view): (16 rows, 8N cols). Per-lane regs = 4N. + # Lane (t0, t1) as for 16x128b. Reg r = v0p + 2*va + 4*vb. + # Row = t1 + 8*va + 16*wid_in_wg + # Col (fp32) = v0p + 2*t0 + 8*vb + row_iters_fp32 = [ + (8, 4, laneid), # R_t1 + (2, 2, "m"), # R_va: reg bit 1 → R bit 3 + (4, 1, wid), # R_w + ] + col_iters_fp32 = [ + (2, 1, "m"), # C_v0p: reg bit 0 → C bit 0 + (4, 1, laneid), # C_t0 + (N, 4, "m"), # C_vb: reg bits 2.. → C bits 3.. + ] + + def _scale(iters): + out = [] + for ext, stride, axis in iters: + if axis == "m": + out.append((ext, stride * elem_per_32b, axis)) + else: + out.append((ext, stride, axis)) + return out + + row_iters = _scale(row_iters_fp32) + col_iters = _scale(col_iters_fp32) + + # For the 16-bit packed variant each fp32 register holds two adjacent + # column elements (low / high halves), so we prepend a C_pack iter of + # extent ``elem_per_32b`` and m-stride 1 to the column group. + if elem_per_32b > 1: + col_iters = [(elem_per_32b, 1, "m"), *col_iters] + + iters = [Iter(ext, stride, axis) for ext, stride, axis in row_iters + col_iters] + return TileLayout.from_iters(iters, [], {}) + + # ------------------------------------------------------------------ # Helper types to support `PrimExpr @ Axis` and `sum` for offsets # ------------------------------------------------------------------ diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py index 4700d4e0da..345eef4519 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py @@ -27,7 +27,7 @@ from tvm.arith import Analyzer from tvm.runtime import DataType from tvm.script import tirx as Tx from tvm.tirx import Buffer, PrimFunc -from tvm.tirx.layout import S, TCol, TileLayout, TLane, tid_in_wg +from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout, tid_in_wg from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch from tvm.tirx.stmt import TilePrimitiveCall @@ -35,6 +35,42 @@ from ..common import get_st_extent from ..copy import _is_valid_copy, _scope_allowed from ..exec_scope_utils import exec_scope_ok +# Per-warp fp32-column factor for each instr_shape (mirrors +# ``_TCGEN05_COL_FACTOR_FP32`` in ``tvm.tirx.layout``; .16x64b → 2, +# .16x128b → 4, .16x256b → 8). Source: PTX ISA Table 49. +_TCGEN05_COL_FACTOR_FP32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8} + + +def _match_tcgen05_atom_layout(buf): + """Return ``(instr_shape, rep)`` if ``buf.layout`` matches a tcgen05 + ``.16x*b`` atom layout for some supported ``instr_shape``. + + The local buffer shape ``(64, K)`` together with the dtype determines the + candidate ``rep`` for each ``instr_shape``; we just probe the three shapes + and structurally compare. ``None`` if no atom layout matches. + """ + if len(buf.shape) != 2: + return None + rows, cols = int(buf.shape[0]), int(buf.shape[1]) + if rows != 64: + return None + dtype = buf.dtype + layout_c = buf.layout.canonicalize() + for shape in _TCGEN05_COL_FACTOR_FP32: + try: + cand = tcgen05_atom_layout(shape, (rows, cols), dtype).canonicalize() + except ValueError: + continue + try: + tvm.ir.assert_structural_equal(layout_c, cand) + except (AssertionError, ValueError): + continue + # Recover rep from cols (same arithmetic the factory uses). + elem_per_32b = 32 // DataType(dtype).bits + rep = cols // (_TCGEN05_COL_FACTOR_FP32[shape] * elem_per_32b) + return shape, rep + return None + def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: op_call = TilePrimitiveCall.downcast(op_call) @@ -56,11 +92,51 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> P assert tmem_buf.layout is not None assert local_buf.layout is not None assert tmem_buf.dtype == local_buf.dtype + assert tmem_buf.allocated_addr is not None analyzer = Analyzer() elem_size = DataType(local_buf.dtype).bits elem_per_32b = 32 // elem_size assert len(local_buf.shape) == len(tmem_buf.shape) == 2 + + # Try the .16x* (M=64) path first by structural-matching the register-side + # layout against ``tcgen05_atom_layout(instr_shape, (64, K), dtype)``. The + # TMEM-side layout is the standard (128, W):(1@TLane, 1@TCol); the M=64 + # fragment lives at lanes 0..15 of each warp's accessible slab (per PTX + # 9.7.16.8.1), so each warp issues with row_offset=0 and collectively the + # 4 warps cover all 64 rows. + atom_match = _match_tcgen05_atom_layout(local_buf) + + if atom_match is not None: + shape, num = atom_match + return _emit_16xnb_path( + shape=shape, + num=num, + direction=direction, + tmem_buf=tmem_buf, + local_buf=local_buf, + tmem_region=tmem_region, + local_region=local_region, + elem_per_32b=elem_per_32b, + analyzer=analyzer, + ) + + # Fall through to the existing .32x32b (M=128) path. + return _emit_32x32b_path( + direction=direction, + tmem_buf=tmem_buf, + local_buf=local_buf, + tmem_region=tmem_region, + local_region=local_region, + elem_per_32b=elem_per_32b, + analyzer=analyzer, + ) + + +def _emit_32x32b_path( + *, direction, tmem_buf, local_buf, tmem_region, local_region, elem_per_32b, analyzer +) -> PrimFunc: + """Original M=128 fragment path using ``tcgen05.{ld,st}.32x32b.xN``.""" # local: 128xWIDTH <-> tmem: 128xSHAPE[1] assert analyzer.can_prove_equal(local_buf.shape[0], 128) assert analyzer.can_prove_equal(tmem_buf.shape[0], 128) @@ -87,10 +163,7 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> P # local layout TileLayout(S[(128, width) : (1 @ tid_in_wg, 1)]).canonicalize() - # tmem allocated addr is not None - assert tmem_buf.allocated_addr is not None tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout) - # tvm.ir.assert_structural_equal(local_buf.layout.canonicalize(), local_layout) # local: [0:128, 0:WIDTH] <-> tmem: [0:128, st:st+WIDTH] assert analyzer.can_prove_equal(tmem_st[0], 0) assert analyzer.can_prove_equal(tmem_extent[0], 128) @@ -121,6 +194,103 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> P return impl +def _emit_16xnb_path( + *, + shape, + num, + direction, + tmem_buf, + local_buf, + tmem_region, + local_region, + elem_per_32b, + analyzer, +) -> PrimFunc: + """M=64 fragment path using ``tcgen05.{ld,st}.<shape>.x<num>`` (one of + ``.16x64b``, ``.16x128b``, ``.16x256b``). + + Each of the warpgroup's 4 warps issues the atom once with + ``row_offset=0``; the PTX TMEM access restriction places warp ``i`` on + TMEM lanes ``i*32..i*32+31``, of which the atom uses the first 16 to + cover one 16-row slab of the 64-row fragment. Collectively, the four + warps cover all 64 rows. + """ + # Per-atom column footprint in fp32 columns: + # .16x64b → 2N .16x128b → 4N .16x256b → 8N + col_factor_fp32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8}[shape] + # Per-thread register count (in 32-bit units): + # .16x64b.xN → N .16x128b.xN → 2N .16x256b.xN → 4N + regs_per_thread = {"16x64b": num, "16x128b": 2 * num, "16x256b": 4 * num}[shape] + # Logical column width that the local buffer view exposes (in element units). + width_elems = col_factor_fp32 * num * elem_per_32b + # Per-thread storage in element units (same total bits as the register vector). + per_thread_elems = regs_per_thread * elem_per_32b + + # Local-side: shape (64, K_cols) + assert analyzer.can_prove_equal(local_buf.shape[0], 64), ( + f".16x*b path expects local_buf rows=64, got {local_buf.shape[0]}" + ) + assert analyzer.can_prove_equal(local_buf.shape[1], width_elems), ( + f".16x*b path expects local_buf cols={width_elems}, got {local_buf.shape[1]}" + ) + + # TMEM-side: shape (128, W); the M=64 fragment occupies the first 16 lanes of + # each warp's 32-lane slab. + assert analyzer.can_prove_equal(tmem_buf.shape[0], 128), ( + f".16x*b path expects tmem_buf rows=128, got {tmem_buf.shape[0]}" + ) + tmem_layout = TileLayout(S[(128, tmem_buf.shape[1]) : (1 @ TLane, 1 @ TCol)]).canonicalize() + tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout) + + tmem_st, tmem_extent = get_st_extent(tmem_region) + local_st, local_extent = get_st_extent(local_region) + + # Local slice must be the full (64, K_cols) view. + assert analyzer.can_prove_equal(local_st[0], 0) + assert analyzer.can_prove_equal(local_extent[0], 64) + assert analyzer.can_prove_equal(local_extent[1], width_elems) + + # TMEM slice must start at row 0 (warp 0 of the WG is at lane 0) and span + # 64 rows (collectively the 4 warps' first 16-lane chunks). + assert analyzer.can_prove_equal(tmem_st[0], 0) + assert analyzer.can_prove_equal(tmem_extent[0], 64) + assert analyzer.can_prove_equal(tmem_extent[1], width_elems) + + col_off = tmem_st[1] + assert analyzer.can_prove_equal(tvm.tirx.floormod(col_off, elem_per_32b), 0) + col_off_32b = tvm.tirx.floordiv(col_off, elem_per_32b) + local_col_off = local_st[1] + assert analyzer.can_prove_equal(tvm.tirx.floormod(local_col_off, elem_per_32b), 0) + local_col_off_elems = local_col_off + + is_load = direction == "tmem2local" + op = Tx.ptx.tcgen05.ld if is_load else Tx.ptx.tcgen05.st + # We intentionally do *not* emit ``.pack::16b`` / ``.unpack::16b`` for + # 16-bit dtypes. That qualifier would store one 16-bit element per 32-bit + # TMEM cell (LOW half only, HIGH half wasted) — fine for some CUTLASS + # epilogues but a 2x TMEM waste vs. the existing ``.32x32b`` convention, + # which packs two 16-bit elements per cell. By using the plain ``.b32`` + # form we keep TMEM dense (2 elements per 32-bit cell); the per-thread + # register file holds two packed 16-bit values per 32-bit register, and + # the layout factory's iters describe that packing. + + # fmt: off + @Tx.prim_func(check_well_formed=False) + def impl(): + with Tx.warp(): + # Per-thread 1-D flat view of the local storage, then a uint32 view + # for the register-pointer arguments of the PTX builtin. + local_storage = local_buf.view(per_thread_elems, layout=TileLayout(S[per_thread_elems])) # noqa: E501 + local_32b = local_storage.view("uint32") + op( + tmem_buf.allocated_addr[0], + *[local_32b[local_col_off_elems // elem_per_32b + i] for i in range(regs_per_thread)], # noqa: E501 + shape=shape, num=num, row=0, col=col_off_32b, + ) + # fmt: on + return impl + + # === Variant: copy_async/tmem<->local (priority=10) === # # When: one buffer is in tmem (tensor memory, Blackwell SM100+) and the other diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index da24e71a7d..0c6b8607b8 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -1868,6 +1868,62 @@ smem = alloc_shared tmem = functools.partial(alloc_buffer, scope="tmem") +def alloc_tcgen05_frag(instr_shape, tensor_shape, dtype): + """Allocate a register fragment for ``tcgen05.{ld,st}`` atoms. + + Sizes the per-thread storage, allocates ``local`` scope memory, and returns + a 2-D view of shape ``tensor_shape`` with a matching ``tcgen05_atom_layout``. + Pass the result to ``Tx.copy_async`` (with a ``(128, W)``-shaped TMEM + buffer) to trigger the corresponding dispatch path. + + Parameters + ---------- + instr_shape : str + ``"32x32b"`` (M=128 fragment, 128 row warpgroup tile, layout + ``(128, K):(1@tid_in_wg, 1)``); or ``"16x64b"`` / ``"16x128b"`` / + ``"16x256b"`` (M=64 fragments, 64 row warpgroup tile with the + per-shape per-lane register decomposition). + tensor_shape : tuple[int, int] + Logical fragment shape ``(frag_rows, K)`` in element units. ``frag_rows`` + is ``128`` for ``.32x32b`` and ``64`` for the ``.16x*b`` shapes. + dtype : str + ``"float32"``, ``"float16"``, or ``"bfloat16"``. + + Returns + ------- + Buffer + 2-D view of shape ``tensor_shape`` whose layout matches + ``tcgen05_atom_layout(instr_shape, tensor_shape, dtype)``. + + Examples + -------- + M=128 readback (existing dispatch): + ``frag = Tx.alloc_tcgen05_frag("32x32b", (128, 64), "float32")`` + ``Tx.copy_async(frag[:, :], tmem[:, 0:64])`` + + M=64 readback (.16x64b dispatch): + ``frag = Tx.alloc_tcgen05_frag("16x64b", (64, 64), "float32")`` + ``Tx.copy_async(frag[:, :], tmem[0:64, 0:64])`` + """ + from tvm.tirx.layout import tcgen05_atom_layout # local import to avoid cycle + + rows, cols = tensor_shape + bits = DataType(dtype).bits + # Per-warpgroup total bits = 64 rows × K cols × bits. Divided across 128 + # threads gives per-thread bits; convert to element count. + per_thread_bits = (rows * cols * bits) // 128 + if per_thread_bits % bits != 0: + raise ValueError( + f"alloc_tcgen05_frag tensor_shape={tensor_shape} dtype={dtype!r} " + f"does not evenly divide across 128 threads" + ) + per_thread_elems = per_thread_bits // bits + + layout = tcgen05_atom_layout(instr_shape, tensor_shape, dtype) + flat = alloc_local((per_thread_elems,), dtype) + return flat.view(rows, cols, layout=layout) + + if TYPE_CHECKING: ScalarT = TypeVar("ScalarT") @@ -4021,6 +4077,7 @@ __all__ += [ "alloc_local", "alloc_scalar", "alloc_shared", + "alloc_tcgen05_frag", "cluster", "cluster_id", "cta", diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py new file mode 100644 index 0000000000..09a15f3bd1 --- /dev/null +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py @@ -0,0 +1,709 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, missing-function-docstring +"""Bit-exact tests for the ``.16x{64,128,256}b`` ``tcgen05.{ld,st}`` dispatch. + +For each ``(shape, rep, dtype, direction)`` we: + +1. Fill a (128, FULL_W) host buffer ``A`` with random values. +2. Stage ``A`` into TMEM via the existing ``.32x32b`` ld/st round-trip. +3. Issue the new ``.16x*b`` atom via ``Tx.copy_async`` to read a (64, K_cols) + fragment from TMEM into a register tile shaped by ``tcgen05_atom_layout``. +4. Dump the register tile to a ``(128, regs_per_thread)`` global buffer indexed + ``B[tid_in_wg, r]``. +5. Reconstruct the expected ``B[t, r]`` on the host from the per-(lane, reg) → + (frag_row, frag_col) formula. The M=64 fragment occupies TMEM lanes + ``warp_id * 32 + (0..15)``, so ``frag_row R`` maps to TMEM lane + ``(R // 16) * 32 + (R % 16)``. + +For the store direction we run the inverse: prefill the register tile via host → +``B`` → ``.32x32b.ld``-staged read, write to TMEM via the new ``.16x*b.st``, +then read TMEM back via ``.32x32b.ld`` into a (128, FULL_W) buffer and check +that the M=64 fragment's row positions hold the expected register data. +""" + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm.script import tirx as Tx +from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout +from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg + + +# -------------------------------------------------------------------------- +# Shape metadata + host-side layout reconstruction +# -------------------------------------------------------------------------- + +# (.shape, .num) ranges supported by PTX Table 49. +_SHAPE_REPS = { + "32x32b": (1, 2, 4, 8, 16, 32, 64, 128), + "16x64b": (1, 2, 4, 8, 16, 32, 64, 128), + "16x128b": (1, 2, 4, 8, 16, 32, 64), + "16x256b": (1, 2, 4, 8, 16, 32), +} + +# Per-warp fp32 column span = factor * rep. +_COL_FACTOR_FP32 = {"32x32b": 1, "16x64b": 2, "16x128b": 4, "16x256b": 8} + +# Per-thread 32-bit register count = factor * rep. +_REGS_FACTOR = {"32x32b": 1, "16x64b": 1, "16x128b": 2, "16x256b": 4} + +# Per-warpgroup fragment row count. +_FRAG_ROWS = {"32x32b": 128, "16x64b": 64, "16x128b": 64, "16x256b": 64} + + +def _decompose_fp32(shape: str, t: int, r: int) -> tuple[int, int]: + """Return ``(frag_row, frag_col)`` in fp32 element units for the fp32 atom.""" + laneid = t & 31 + wid_in_wg = t >> 5 + if shape == "32x32b": + # M=128 fragment: each thread t owns full row t with N consecutive cols. + row = t + col = r + elif shape == "16x64b": + t0 = laneid & 1 + t1 = (laneid >> 1) & 1 + t2 = laneid >> 2 + row = t2 + 8 * t0 + 16 * wid_in_wg + col = t1 + 2 * r + elif shape == "16x128b": + t0 = laneid & 3 + t1 = laneid >> 2 + ra = r & 1 + rb = r >> 1 + row = t1 + 8 * ra + 16 * wid_in_wg + col = t0 + 4 * rb + elif shape == "16x256b": + t0 = laneid & 3 + t1 = laneid >> 2 + v0p = r & 1 + va = (r >> 1) & 1 + vb = r >> 2 + row = t1 + 8 * va + 16 * wid_in_wg + col = v0p + 2 * t0 + 8 * vb + else: + raise ValueError(shape) + return row, col + + +def _frag_row_to_tmem_lane(shape: str, R: int) -> int: + """Map fragment row R to its physical TMEM lane. + + For ``.32x32b`` (M=128) the mapping is identity: row R lives at TMEM lane R. + For ``.16x*b`` (M=64) the fragment occupies the first 16 lanes of each + warp's 32-lane slab, so ``R`` ∈ [0, 64) lives at lane ``(R // 16) * 32 + (R % 16)``. + """ + if shape == "32x32b": + return R + return (R // 16) * 32 + (R % 16) + + +def _expected_reg_value_fp32( + A: np.ndarray, shape: str, rep: int, tmem_col_off: int, t: int, r: int +) -> np.uint32: + """fp32 path: return the bit-pattern (as uint32) that thread ``t`` register + ``r`` should hold after ``.<shape>.x<rep>`` reads ``A`` (staged into TMEM) at + column offset ``tmem_col_off``.""" + row, col = _decompose_fp32(shape, t, r) + tmem_lane = _frag_row_to_tmem_lane(shape, row) + val = np.float32(A[tmem_lane, tmem_col_off + col]) + return val.view(np.uint32) + + +def _expected_reg_value_16b( + A: np.ndarray, shape: str, rep: int, tmem_col_off: int, t: int, r: int, dtype_np +) -> np.uint32: + """16-bit path (fp16 / bf16 with .pack::16b): each fp32 register packs two + 16-bit elements at adjacent columns ``(2*col_fp32, 2*col_fp32 + 1)``.""" + row, col_fp32 = _decompose_fp32(shape, t, r) + tmem_lane = _frag_row_to_tmem_lane(shape, row) + lo = dtype_np(A[tmem_lane, tmem_col_off + 2 * col_fp32]) + hi = dtype_np(A[tmem_lane, tmem_col_off + 2 * col_fp32 + 1]) + lo_u16 = lo.view(np.uint16) + hi_u16 = hi.view(np.uint16) + return np.uint32(int(lo_u16) | (int(hi_u16) << 16)) + + +# -------------------------------------------------------------------------- +# Test 1: load direction +# -------------------------------------------------------------------------- + + [email protected]("shape", list(_SHAPE_REPS)) [email protected]("rep", [1, 2, 4, 8, 16, 32]) # subset; full reps below [email protected]("dtype", ["float32"]) +def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype): + """Bit-exact verification of ``tcgen05.<shape>.x<rep>.b32`` load.""" + if rep not in _SHAPE_REPS[shape]: + pytest.skip(f"rep {rep} not valid for {shape}") + _run_load_test(shape, rep, dtype) + + [email protected]( + "shape, rep", + [ + ("16x64b", 64), + ("16x64b", 128), + ("16x128b", 64), + ], +) +def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep): + """High-rep entries that aren't in the parametrize-cross above.""" + _run_load_test(shape, rep, "float32") + + [email protected]("shape", list(_SHAPE_REPS)) [email protected]("rep", [1, 2, 4, 8, 16, 32]) [email protected]("dtype", ["float16", "bfloat16"]) +def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype): + """Self-consistent round-trip for 16-bit pack::16b path. + + The fp32 ``test_tcgen05_ld_16xnb_load_fp32`` already validates the + ``(lane, reg) → (frag_row, frag_col)`` mapping bit-exactly against the + standard ``.32x32b`` staging. For the 16-bit case the staging convention + differs (``.32x32b.st`` packs two fp16 per 32-bit TMEM cell, whereas + ``.16x*b.ld.pack::16b`` reads two fp16 from the LOW halves of adjacent + 32-bit cells), so we instead verify the new dispatch round-trips + per-thread data via ``.16x*b.st.unpack::16b`` → ``.16x*b.ld.pack::16b``. + A bit-exact round-trip is sufficient evidence that the per-thread + register-layout matches between the load and store atom families. + """ + if rep not in _SHAPE_REPS[shape]: + pytest.skip(f"rep {rep} not valid for {shape}") + _run_roundtrip_16b(shape, rep, dtype) + + +def _run_roundtrip_16b(shape: str, rep: int, dtype: str): + bits = tvm.runtime.DataType(dtype).bits + assert bits == 16 + elem_per_32b = 2 + K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep + K_cols_elem = K_cols_fp32 * elem_per_32b + regs_per_thread = _REGS_FACTOR[shape] * rep + per_thread_elems = regs_per_thread * elem_per_32b + frag_rows = _FRAG_ROWS[shape] + + # The 16-bit round-trip writes and reads exclusively through .16x*b atoms, + # so the TMEM column footprint is whatever ``K_cols_fp32`` says — no + # .32x32b staging constraint applies here. + tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32)) + stage_width_elem = tmem_col_width_32b * elem_per_32b + atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) + + @Tx.prim_func + def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: + # Per-thread input/output: A[tid_in_wg, i] feeds register slot i of the + # warpgroup-collective fragment; B[tid_in_wg, i] is what comes back + # after a .16x*b.st → .16x*b.ld round-trip. + A = Tx.match_buffer(A_ptr, (128, per_thread_elems), dtype) + B = Tx.match_buffer(B_ptr, (128, per_thread_elems), dtype) + + Tx.device_entry() + warp_id = Tx.warp_id([128 // 32]) + Tx.cta_id([2]) + wg_id = Tx.warpgroup_id([1]) + Tx.warp_id_in_wg([4]) + Tx.lane_id([32]) + tid_in_wg = Tx.thread_id([128]) + + tmem_addr = Tx.alloc_shared([1], "uint32") + + if wg_id == 0: + with Tx.warpgroup(): + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.alloc( + Tx.address_of(tmem_addr), + n_cols=tmem_col_width_32b, + cta_group=1, + ) + + Tx.tvm_storage_sync("shared") + + tmem = Tx.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + + # Load per-thread A → reg_in + reg_in = Tx.alloc_local((per_thread_elems,), dtype) + with Tx.thread(): + for i in range(per_thread_elems): + reg_in[i] = A[tid_in_wg, i] + Tx.cuda.cta_sync() + + # reg_in -> TMEM via .<shape>.x<rep>.st.unpack::16b + frag_in = reg_in.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_in[:, :]) + Tx.ptx.tcgen05.wait.st() + Tx.cuda.cta_sync() + + # TMEM -> reg_out via .<shape>.x<rep>.ld.pack::16b + reg_out = Tx.alloc_local((per_thread_elems,), dtype) + frag_out = reg_out.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.copy_async(frag_out[:, :], tmem[0:frag_rows, 0:K_cols_elem]) + Tx.ptx.tcgen05.wait.ld() + Tx.cuda.cta_sync() + + # reg_out -> B + with Tx.thread(): + for i in range(per_thread_elems): + B[tid_in_wg, i] = reg_out[i] + + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + Tx.ptx.tcgen05.dealloc( + tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1 + ) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + A_np = tvm.testing.generate_random_array(dtype, (128, per_thread_elems)) + B_np = np.zeros((128, per_thread_elems), dtype=dtype) + DEV = tvm.cuda(0) + A = tvm.runtime.tensor(A_np, DEV) + B = tvm.runtime.tensor(B_np, DEV) + mod(A, B) + # Round-trip should preserve every per-thread bit pattern. + A_view = A.numpy().view(np.uint16) + B_view = B.numpy().view(np.uint16) + np.testing.assert_array_equal(B_view, A_view) + + +def _next_pow2(x: int) -> int: + if x <= 1: + return 1 + return 1 << (x - 1).bit_length() + + +def _run_load_test(shape: str, rep: int, dtype: str): + """Stage A into TMEM via .32x32b, then read it back as the fragment via + .<shape>.x<rep> (through ``Tx.alloc_tcgen05_frag``), and compare each + thread's registers against the expected layout-derived value.""" + bits = tvm.runtime.DataType(dtype).bits + elem_per_32b = 32 // bits + # Per-warp fp32 col span × number of warps in one warpgroup covers the + # fragment column footprint. The TMEM allocation is sized for the same + # element-column count. + K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep + K_cols_elem = K_cols_fp32 * elem_per_32b + regs_per_thread = _REGS_FACTOR[shape] * rep # 32-bit register count + per_thread_elems = regs_per_thread * elem_per_32b + frag_rows = _FRAG_ROWS[shape] + + tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32)) + + # Staging via .32x32b caps at num=128 (= 128 fp32 cols) per atom call. For + # configs whose K_cols_fp32 exceeds 128 we split the stage into multiple + # chunks of CHUNK_FP32 fp32 cols each. + CHUNK_FP32 = 128 + chunk_elem = CHUNK_FP32 * elem_per_32b + num_chunks = tmem_col_width_32b // CHUNK_FP32 if tmem_col_width_32b > CHUNK_FP32 else 1 + chunk_width_32b = tmem_col_width_32b if num_chunks == 1 else CHUNK_FP32 + chunk_width_elem = chunk_width_32b * elem_per_32b + stage_width_elem = tmem_col_width_32b * elem_per_32b + + # Vector length for global<->local copies (in elements). + VEC_LEN = 128 // bits + if stage_width_elem % VEC_LEN != 0: + pytest.skip(f"stage_width_elem {stage_width_elem} % VEC_LEN {VEC_LEN} != 0") + + g_layout = TileLayout( + S[(128, stage_width_elem // VEC_LEN, VEC_LEN) : (stage_width_elem, VEC_LEN, 1)] + ) + chunk_view = TileLayout(S[(128, chunk_width_elem) : (1 @ axis_tid_in_wg, 1)]) + # The factory + wrapper both go through ``tcgen05_atom_layout``; we use it + # explicitly here so that ``frag_local`` has the canonical layout that + # ``Tx.copy_async`` matches when dispatching to the right atom path. + atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) + + @Tx.prim_func + def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: + # A is the host data we stage into TMEM via the standard .32x32b path. + A = Tx.match_buffer(A_ptr, (128, stage_width_elem), dtype) + # B is a per-thread register dump: B[tid_in_wg, reg_idx_in_elements]. + B = Tx.match_buffer(B_ptr, (128, per_thread_elems), dtype) + + A_flat = A.view(-1) + + Tx.device_entry() + warp_id = Tx.warp_id([128 // 32]) + Tx.cta_id([2]) + wg_id = Tx.warpgroup_id([1]) + Tx.warp_id_in_wg([4]) + Tx.lane_id([32]) + tid_in_wg = Tx.thread_id([128]) + + tmem_addr = Tx.alloc_shared([1], "uint32") + + if wg_id == 0: + with Tx.warpgroup(): + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.alloc( + Tx.address_of(tmem_addr), + n_cols=tmem_col_width_32b, + cta_group=1, + ) + + Tx.tvm_storage_sync("shared") + + tmem = Tx.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + + # Per-thread chunk staging buffer (CHUNK_FP32 fp32 worth). + stage_reg = Tx.alloc_local((chunk_width_elem,), dtype) + stage_local = stage_reg.view(128, chunk_width_elem, layout=chunk_view) + + # Walk chunks: A[:, ck:ck+chunk] -> stage_reg -> TMEM[:, ck:ck+chunk] + for chunk_idx in range(num_chunks): + col_off_elem = chunk_idx * chunk_width_elem + with Tx.thread(): + for i in range(chunk_width_elem // VEC_LEN): + # Each thread's row offset in A_flat: stage_width_elem; within + # the row, this chunk starts at col_off_elem and each vector + # picks up VEC_LEN elements at slot i. + g_offset = Tx.meta_var( + tid_in_wg * stage_width_elem + + col_off_elem + + i * VEC_LEN + ) + Tx.copy( + stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], + A_flat[g_offset : g_offset + VEC_LEN], + ) + Tx.cuda.cta_sync() + Tx.copy_async( + tmem[:, col_off_elem : col_off_elem + chunk_width_elem], + stage_local[:, :], + ) + Tx.ptx.tcgen05.wait.st() + Tx.cuda.cta_sync() + + # TMEM[0:frag_rows, 0:K_cols] -> frag_local via .<shape>.x<rep>.ld. + # Use ``tcgen05_atom_layout`` so dispatch matches the new path + # (or stays on .32x32b for instr_shape="32x32b"). Keep the flat + # ``frag_reg`` for the per-thread dump below. + frag_reg = Tx.alloc_local((per_thread_elems,), dtype) + frag_local = frag_reg.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.copy_async(frag_local[:, :], tmem[0:frag_rows, 0:K_cols_elem]) + Tx.ptx.tcgen05.wait.ld() + Tx.cuda.cta_sync() + + # Dump per-thread regs to B[tid_in_wg, :] + with Tx.thread(): + for i in range(per_thread_elems): + B[tid_in_wg, i] = frag_reg[i] + + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + Tx.ptx.tcgen05.dealloc( + tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1 + ) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + A_np = tvm.testing.generate_random_array(dtype, (128, stage_width_elem)) + B_np = np.zeros((128, per_thread_elems), dtype=dtype) + DEV = tvm.cuda(0) + A = tvm.runtime.tensor(A_np, DEV) + B = tvm.runtime.tensor(B_np, DEV) + mod(A, B) + B_out = B.numpy() + + # Build expected B_out from the layout. + if bits == 32: + # Each register slot in B[t, r] holds a single fp32; compare bit-exactly. + B_expected = np.zeros((128, per_thread_elems), dtype=np.uint32) + for t in range(128): + for r in range(regs_per_thread): + B_expected[t, r] = _expected_reg_value_fp32(A_np, shape, rep, 0, t, r) + B_view = B_out.view(np.uint32) + np.testing.assert_array_equal(B_view, B_expected) + else: + # B[t, :] holds per_thread_elems 16-bit values; each fp32 register packs + # two of them in (low, high) order. Compare bit-exactly via uint32 view. + dtype_np = np.float16 if dtype == "float16" else np.dtype("bfloat16") + if dtype == "bfloat16": + # numpy doesn't have a stable bfloat16 dtype across versions; use ml_dtypes. + try: + from ml_dtypes import bfloat16 as _bf16 # noqa: PLC0415 + + dtype_np = _bf16 + except ImportError: + pytest.skip("bfloat16 verification needs ml_dtypes") + B_view = B_out.view(np.uint32).reshape(128, regs_per_thread) + B_expected = np.zeros((128, regs_per_thread), dtype=np.uint32) + for t in range(128): + for r in range(regs_per_thread): + B_expected[t, r] = _expected_reg_value_16b( + A_np, shape, rep, 0, t, r, dtype_np + ) + np.testing.assert_array_equal(B_view, B_expected) + + +# -------------------------------------------------------------------------- +# Test 2: store direction (mirror of test 1, with .st instead of .ld) +# -------------------------------------------------------------------------- + + [email protected]("shape", list(_SHAPE_REPS)) [email protected]("rep", [1, 4, 16]) [email protected]("dtype", ["float32"]) +def test_tcgen05_st_16xnb_store(shape, rep, dtype): + """Round-trip test: write the M=64 fragment via .<shape>.x<rep>.st then read + via the standard .32x32b path; verify the host-known fragment data ends up + at the expected TMEM lane positions. + + Only fp32 here — the 16-bit case has a different staging convention + (pack::16b reads/writes the LOW halves of adjacent cells, not low/high of + one cell) and is covered by ``test_tcgen05_16xnb_roundtrip_16b`` via a + self-consistent .16x*b.st → .16x*b.ld loop. + """ + if rep not in _SHAPE_REPS[shape]: + pytest.skip(f"rep {rep} not valid for {shape}") + bits = tvm.runtime.DataType(dtype).bits + elem_per_32b = 32 // bits + K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep + K_cols_elem = K_cols_fp32 * elem_per_32b + regs_per_thread = _REGS_FACTOR[shape] * rep + per_thread_elems = regs_per_thread * elem_per_32b + frag_rows = _FRAG_ROWS[shape] + + tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32)) + if tmem_col_width_32b > 128: + pytest.skip(f"tmem_col_width_32b {tmem_col_width_32b} > 128 not supported by .32x32b staging") # noqa: E501 + stage_width_elem = tmem_col_width_32b * elem_per_32b + VEC_LEN = 128 // bits + if stage_width_elem % VEC_LEN != 0: + pytest.skip(f"stage_width_elem {stage_width_elem} % VEC_LEN {VEC_LEN} != 0") + + g_layout = TileLayout( + S[(128, stage_width_elem // VEC_LEN, VEC_LEN) : (stage_width_elem, VEC_LEN, 1)] + ) + stage_view = TileLayout(S[(128, stage_width_elem) : (1 @ axis_tid_in_wg, 1)]) + atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) + + @Tx.prim_func + def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: + # A[tid_in_wg, i] is the i-th per-thread element to feed into the atom store. + A = Tx.match_buffer(A_ptr, (128, per_thread_elems), dtype) + # B[lane, col] is the TMEM-staged readout after the round-trip. + B = Tx.match_buffer(B_ptr, (128, stage_width_elem), dtype) + B_flat = B.view(-1) + + Tx.device_entry() + warp_id = Tx.warp_id([128 // 32]) + Tx.cta_id([2]) + wg_id = Tx.warpgroup_id([1]) + Tx.warp_id_in_wg([4]) + Tx.lane_id([32]) + tid_in_wg = Tx.thread_id([128]) + + tmem_addr = Tx.alloc_shared([1], "uint32") + + if wg_id == 0: + with Tx.warpgroup(): + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.alloc( + Tx.address_of(tmem_addr), + n_cols=tmem_col_width_32b, + cta_group=1, + ) + + Tx.tvm_storage_sync("shared") + + tmem = Tx.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + + # Load per-thread A → frag_reg + frag_reg = Tx.alloc_local((per_thread_elems,), dtype) + with Tx.thread(): + for i in range(per_thread_elems): + frag_reg[i] = A[tid_in_wg, i] + Tx.cuda.cta_sync() + + # frag_local -> TMEM via .<shape>.x<rep>.st + frag_local = frag_reg.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_local[:, :]) + Tx.ptx.tcgen05.wait.st() + Tx.cuda.cta_sync() + + # TMEM -> readout via .32x32b.ld + stage_reg = Tx.alloc_local((stage_width_elem,), dtype) + stage_local = stage_reg.view(128, stage_width_elem, layout=stage_view) + Tx.copy_async(stage_local[:, :], tmem[:, :]) + Tx.ptx.tcgen05.wait.ld() + Tx.cuda.cta_sync() + + # readout -> B (full 128×stage_width_elem dump) + with Tx.thread(): + for i in range(stage_width_elem // VEC_LEN): + g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy( + B_flat[g_offset : g_offset + VEC_LEN], + stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], + ) + + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + Tx.ptx.tcgen05.dealloc( + tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1 + ) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + A_np = tvm.testing.generate_random_array(dtype, (128, per_thread_elems)) + B_np = np.zeros((128, stage_width_elem), dtype=dtype) + DEV = tvm.cuda(0) + A = tvm.runtime.tensor(A_np, DEV) + B = tvm.runtime.tensor(B_np, DEV) + mod(A, B) + B_out = B.numpy() + + # Build expected TMEM staging: only rows that the M=64 fragment writes to + # should match A's per-thread data; other rows are untouched (we set B_np to + # zero and the .32x32b.ld reads whatever the TMEM allocator left, which may + # be arbitrary, so only check the fragment positions). + if bits == 32: + view = B_out.view(np.uint32) + for t in range(128): + for r in range(regs_per_thread): + row, col = _decompose_fp32(shape, t, r) + tmem_lane = _frag_row_to_tmem_lane(shape, row) + expected = np.float32(A_np[t, r]).view(np.uint32) + assert view[tmem_lane, col] == expected, ( + f"{shape}.x{rep} {dtype}: thread {t} reg {r} → " + f"(row={row}, col={col}) tmem_lane={tmem_lane} got " + f"{view[tmem_lane, col]:#x} want {expected:#x}" + ) + else: + # 16-bit: each fp32 reg packs two 16-bit elements at adjacent TMEM cols. + view = B_out.view(np.uint16) + for t in range(128): + for r in range(regs_per_thread): + row, col_fp32 = _decompose_fp32(shape, t, r) + tmem_lane = _frag_row_to_tmem_lane(shape, row) + lo = np.float16(A_np[t, 2 * r]).view(np.uint16) if dtype == "float16" else None + # bfloat16 (numpy) lacks a clean .view(uint16); skip in store mode + # for now to keep this test path bit-exact only for float16. + if dtype != "float16": + pytest.skip("16b store check restricted to float16") + hi = np.float16(A_np[t, 2 * r + 1]).view(np.uint16) + assert view[tmem_lane, 2 * col_fp32] == lo, ( + f"{shape}.x{rep} {dtype}: t={t} r={r} lo " + f"({tmem_lane=}, {col_fp32=}) got {view[tmem_lane, 2 * col_fp32]:#x} " + f"want {lo:#x}" + ) + assert view[tmem_lane, 2 * col_fp32 + 1] == hi + + +# -------------------------------------------------------------------------- +# Wrapper test: exercise Tx.alloc_tcgen05_frag directly (compile-only smoke). +# -------------------------------------------------------------------------- + + [email protected]( + "shape, frag_rows, K_cols", + [ + ("32x32b", 128, 32), # .32x32b.x32 fp32: simple thread-rows layout + ("32x32b", 128, 64), # .32x32b.x64 fp32 + ("16x64b", 64, 64), # .16x64b.x32 fp32 + ("16x128b", 64, 64), # .16x128b.x16 fp32 + ("16x256b", 64, 64), # .16x256b.x8 fp32 + ], +) +def test_alloc_tcgen05_frag_wrapper_compiles(shape, frag_rows, K_cols): + """Ensure Tx.alloc_tcgen05_frag yields a buffer that ``Tx.copy_async`` accepts + and lowers to the correct tcgen05 atom for each supported instr_shape.""" + + @Tx.prim_func + def kernel(A_ptr: Tx.handle) -> None: + Tx.match_buffer(A_ptr, (128, K_cols), "float32") + Tx.device_entry() + warp_id = Tx.warp_id([4]) + Tx.cta_id([2]) + wg_id = Tx.warpgroup_id([1]) + Tx.warp_id_in_wg([4]) + Tx.lane_id([32]) + Tx.thread_id([128]) + + tmem_addr = Tx.alloc_shared([1], "uint32") + if wg_id == 0: + with Tx.warpgroup(): + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.alloc( + Tx.address_of(tmem_addr), n_cols=max(32, K_cols), cta_group=1 + ) + Tx.tvm_storage_sync("shared") + tmem = Tx.decl_buffer( + (128, K_cols), + "float32", + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, K_cols) : (1 @ TLane, 1 @ TCol)]), + ) + # One-liner: wrapper handles per-thread storage + layout. + frag = Tx.alloc_tcgen05_frag(shape, (frag_rows, K_cols), "float32") + Tx.copy_async(frag[:, :], tmem[0:frag_rows, 0:K_cols]) + Tx.ptx.tcgen05.wait.ld() + if warp_id == 0: + with Tx.warp(): + Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + Tx.ptx.tcgen05.dealloc( + tmem_addr[0], n_cols=max(32, K_cols), cta_group=1 + ) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + # Compiles cleanly + the generated CUDA contains the expected PTX shape. + src = mod.mod.imports[0].inspect_source() + assert shape in src, ( + f"expected .{shape}.x? in generated PTX, but `{shape}` not found in CUDA source" + ) + + +if __name__ == "__main__": + tvm.testing.main()
