This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit f7f287f88f0487838e043a50cdc43999fd9b849d Author: Bohan Hou <[email protected]> AuthorDate: Sun May 17 21:50:11 2026 -0400 feat(op): add permute_layout primitive; remove permute_dims (#629) --- include/tvm/tirx/tirx_op.h | 7 - python/tvm/tirx/layout.py | 24 ++ .../tvm/tirx/operator/tile_primitive/__init__.py | 2 +- .../cuda/permute_dims/vectorized_last_2d.py | 151 -------- .../{permute_dims => permute_layout}/__init__.py | 2 +- .../cuda/permute_layout/warp_xor_swizzle.py | 389 +++++++++++++++++++ python/tvm/tirx/operator/tile_primitive/ops.py | 53 ++- python/tvm/tirx/script/builder/tirx.py | 47 ++- src/tirx/op/tirx.cc | 1 - .../tile_primitive/cuda/test_gemm_async.py | 35 +- .../tile_primitive/cuda/test_permute_dims.py | 152 -------- .../tile_primitive/cuda/test_permute_layout.py | 425 +++++++++++++++++++++ tests/python/tirx/test_op.py | 12 - 13 files changed, 940 insertions(+), 360 deletions(-) diff --git a/include/tvm/tirx/tirx_op.h b/include/tvm/tirx/tirx_op.h index 7da9e9af0e..18a40f1f5c 100644 --- a/include/tvm/tirx/tirx_op.h +++ b/include/tvm/tirx/tirx_op.h @@ -220,13 +220,6 @@ class DispatchContext : public ffi::ObjectRef { */ TVM_DLL const Op& cast(); -/*! - * \brief See pesudo code below: - * - * Tx.permute_dims(BufferRegion buffer, List order) - */ -TVM_DLL const Op& permute_dims(); - /*! * \brief See pesudo code below: * diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py index d5c29faee8..e05d9a37f7 100644 --- a/python/tvm/tirx/layout.py +++ b/python/tvm/tirx/layout.py @@ -403,6 +403,30 @@ class Layout(Object): else: raise ValueError(f"Unsupported layout type: {type(self)}") + def broadcast(self, num: int, position: int = -1, axis: '"Axis" | str' = "m") -> "Layout": + """Insert a stride-0 broadcast dim of extent ``num`` at ``position``. + + ``position`` follows Python list-insert semantics (negative indices + count from the end; ``-1`` appends after the last shard dim). The + new dim has stride 0 — accessing along it doesn't move the byte + offset, so the same physical element is "seen" ``num`` times. + + Useful for layouts where a consumer reads the same SMEM datum + multiple times (e.g. ``sf_reuse`` over MMA-K steps). + """ + if isinstance(self, TileLayout): + if isinstance(axis, str): + axis = Axis.get(axis) + new_iter = Iter(num, 0, axis) + shard = list(self.shard) + insert_at = position if position >= 0 else len(shard) + 1 + position + shard.insert(insert_at, new_iter) + return TileLayout.from_iters(shard, self.replica, self.offset) + elif isinstance(self, ComposeLayout): + return ComposeLayout(self.swizzle, self.tile_layout.broadcast(num, position, axis)) + else: + raise ValueError(f"broadcast not supported for {type(self)}") + def pack(self, num: int) -> "Layout": """Pack the layout, where num contiguous elements in the layout are packed into a single element. diff --git a/python/tvm/tirx/operator/tile_primitive/__init__.py b/python/tvm/tirx/operator/tile_primitive/__init__.py index 345059bd68..23c76ff3fb 100644 --- a/python/tvm/tirx/operator/tile_primitive/__init__.py +++ b/python/tvm/tirx/operator/tile_primitive/__init__.py @@ -28,7 +28,7 @@ from .registry import DispatchContext from .cuda.copy import * from .cuda.reduction import * from .cuda.copy_async import * -from .cuda.permute_dims import * +from .cuda.permute_layout import * from .cuda.gemm_async import * from .cuda.elementwise import * from .trn import * diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/vectorized_last_2d.py b/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/vectorized_last_2d.py deleted file mode 100644 index c468ed1d92..0000000000 --- a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/vectorized_last_2d.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""CUDA permute_dims dispatch: vectorized_permute_dims_last_2d variant.""" - -import math - -from tvm.script import tirx as Tx -from tvm.tirx import Buffer, BufferRegion, PrimFunc -from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch -from tvm.tirx.stmt import TilePrimitiveCall - -from ..common import get_indices, get_st_extent - - -def validate_deepgemm_permute_dims(op_call: TilePrimitiveCall, sctx: DispatchContext) -> bool: - op_call = TilePrimitiveCall.downcast(op_call) - if isinstance(op_call.buffer, Buffer): - buffer: Buffer = op_call.buffer - extent = buffer.shape - elif isinstance(op_call.buffer, BufferRegion): - buffer: Buffer = op_call.buffer.buffer - st, extent = get_st_extent(op_call.buffer) - - order = op_call.order - if sctx.is_warp: - assert "threadIdx.y" not in sctx.launch_params and "threadIdx.z" not in sctx.launch_params - ndim = len(order) - expected_order = [*list(range(ndim - 2)), ndim - 1, ndim - 2] - if list(order) != expected_order: - return False - if not math.prod(extent[:-2]) == 1: - return False - strides = list(buffer.strides) - if not (strides == [] or (strides[-1] == 1 and strides[-2] == extent[-1])): - return False - return True - return False - - -def vectorized_permute_dims_last_2d_impl( - op_call: TilePrimitiveCall, sctx: DispatchContext -) -> PrimFunc | None: - op_call = TilePrimitiveCall.downcast(op_call) - if isinstance(op_call.buffer, Buffer): - buffer: Buffer = op_call.buffer - extent = shape = buffer.shape - st = [0] * len(extent) - elif isinstance(op_call.buffer, BufferRegion): - buffer: Buffer = op_call.buffer.buffer - shape = buffer.shape - st, extent = get_st_extent(op_call.buffer) - - M, N = extent[-2:] - vec_len = op_call.config.get("vec_len") - - if vec_len is None: - for vec_len in range(4, 0, -1): - if M % vec_len == 0: - break - - if not shape[-1] % vec_len == 0: - vec_len = 1 - if not (st[-2] * shape[-1] + st[-1]) % vec_len == 0: - vec_len = 1 - - # Thread and vectorization setup - if sctx.is_warp: - tid_x = sctx.launch_params["threadIdx.x"] - assert "threadIdx.y" not in sctx.launch_params and "threadIdx.z" not in sctx.launch_params - - # fmt: off - @Tx.prim_func - def impl(): - warp_size = Tx.meta_var(32) - lane_id = Tx.meta_var(tid_x % warp_size) - reg_trans = Tx.alloc_buffer((N // warp_size, M // vec_len, vec_len), buffer.dtype, scope="local") # noqa: E501 - for wi in Tx.unroll(0, N // warp_size): - for vi in Tx.unroll(0, M // vec_len): - for vec in Tx.unroll(vec_len): - old_index = Tx.meta_var(get_indices((vi * vec_len + vec) * N + wi * warp_size + lane_id, st, extent)) # noqa: E501 - reg_trans[wi, vi, vec] = buffer[tuple(old_index)] - Tx.cuda.warp_sync() - for wi in Tx.unroll(0, N // warp_size): - for vi in Tx.unroll(0, M // vec_len): - for vec in Tx.vectorized(vec_len): - new_index = Tx.meta_var(get_indices((wi * warp_size + lane_id) * M + vi * vec_len + vec, st, extent)) # noqa: E501 - buffer[tuple(new_index)] = reg_trans[wi, vi, vec] - Tx.cuda.warp_sync() - # fmt: on - else: - raise NotImplementedError - return impl - - -# === Variant: permute_dims/vectorized_permute_dims_last_2d (priority=20) === -# -# When: shared-memory buffer with TileLayout, permutation swaps only the last -# 2 dimensions (e.g. [0,1,3,2] for 4D), at warp scope. In-place transpose. -# -# Before (TilePrimitiveCall): -# with Tx.warp(): -# Tx.permute_dims(A_smem[0:64, 0:64], order=[1, 0]) -# # A_smem: shared float16 (64, 64), in-place transpose -# -# After (warp-level register-buffered transpose, vec_len=4): -# lane_id = threadIdx.x % 32 -# reg_trans = Tx.alloc_buffer((2, 16, 4), "float16", scope="local") -# # Phase 1: read rows into registers (each lane reads a column stripe) -# for wi in Tx.unroll(2): # N // warp_size -# for vi in Tx.unroll(16): # M // vec_len -# for vec in Tx.unroll(4): -# reg_trans[wi, vi, vec] = A_smem[(vi*4+vec)*64 + wi*32+lane_id] -# Tx.cuda.warp_sync() -# # Phase 2: write back transposed (column index becomes row) -# for wi in Tx.unroll(2): -# for vi in Tx.unroll(16): -# for vec in Tx.vectorized(4): -# A_smem[(wi*32+lane_id)*64 + vi*4+vec] = reg_trans[wi, vi, vec] -# Tx.cuda.warp_sync() -@register_dispatch( - "permute_dims", - "cuda", - variant="vectorized_permute_dims_last_2d", - priority=20, - when=[ - predicate( - "validate_deepgemm_permute_dims", - lambda op, sctx: ( - validate_deepgemm_permute_dims(op, sctx), - "validate_deepgemm_permute_dims failed", - ), - ) - ], -) -def permute_dims_dispatch(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: - return vectorized_permute_dims_last_2d_impl(op, sctx) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/__init__.py b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py similarity index 95% rename from python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/__init__.py rename to python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py index 172da2d78b..e406e9c3fd 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/__init__.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. -from .vectorized_last_2d import * +from .warp_xor_swizzle import * diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py new file mode 100644 index 0000000000..3868d21e72 --- /dev/null +++ b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py @@ -0,0 +1,389 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CUDA permute_layout dispatch: warp register-staged in-place transpose with +optional per-lane XOR-swizzle to avoid SMEM bank conflicts on the write phase. + +The dispatcher reasons about the **layout's shard**, not the buffer's +declared shape (the two can differ — a buffer with ``shape=(PIPE, M, K)`` +may carry a layout whose shard has more dims internally, with grouping +mapping shard segments onto buffer dims). Concretely: + + src_sliced = src.layout.slice(src.shape, region).canonicalize() + dst_sliced = dst.layout.slice(dst.shape, region).canonicalize() + # If the two sliced shards have different structures (which is common — + # a linear layout collapses to 1D under canon while a transposed one + # keeps its multi-dim structure), regroup src to dst's shape. + if src_sliced.shard != dst_sliced.shard: + src_sliced, _ = src_sliced.group(dst.shard.extents) + extent = [int(it.extent) for it in dst_sliced.shard] # iteration shape + src_str = [int(it.stride) for it in src_sliced.shard] + dst_str = [int(it.stride) for it in dst_sliced.shard] + +The algorithm: + + regs[P] + for r in 0..P: + j = r XOR ((lane >> SHIFT) & MASK) + i = lane + j * 32 # flat logical index + idx = decompose(i, extent) # iter multi-dim index + regs[r] = src[project(idx, src.shape, slice_starts)] + warp_sync() + for r in 0..P: + j = r XOR ((lane >> SHIFT) & MASK) + i = lane + j * 32 + idx = decompose(i, extent) + dst[project(idx, dst.shape, slice_starts)] = regs[r] + warp_sync() + +where ``project`` mixed-radix-folds the iter shard dims back onto the +buffer's iterated slice dims (so the emit's index matches buf.shape rank, +which TIR's BufferLoad/Store requires). + +SHIFT and MASK are chosen by simulating the bank pattern at the **shard +granularity** (where strides are affine), trying k = 0, 1, …, log2(P) +and picking the smallest k that makes both phases bank-conflict-free. + +Correctness rests on: + +* For each lane, ``r ↦ r XOR const`` is a bijection on ``[0, P)``. +* Therefore (lane, r) ↔ flat over [0, V). +* Both layouts are verified bijections on the slice (every logical + position has a unique byte offset under that layout). +* The mixed-radix projection from iter shard idx to buf coord is exactly + what TIR's BufferLoad does internally when buf.shape rank < shard rank + — so iter shard's strides and the buffer-indexed byte offset agree. +""" + +from __future__ import annotations + +import math + +from tvm.runtime import DataType +from tvm.script import tirx as Tx +from tvm.tirx import Buffer, BufferRegion, IntImm, PrimFunc +from tvm.tirx.layout import TileLayout, _flatten_coord +from tvm.tirx.operator.tile_primitive import DispatchContext, fail, register_dispatch +from tvm.tirx.stmt import TilePrimitiveCall + +from ..common import get_indices, get_st_extent + +# ---------- helpers ---------------------------------------------------------- + + +def _as_buffer_and_region(arg): + """Normalize a Buffer or BufferRegion to (buffer, start_list, extent_list).""" + if isinstance(arg, Buffer): + buf = arg + extent = list(buf.shape) + st = [0] * len(extent) + elif isinstance(arg, BufferRegion): + buf = arg.buffer + st, extent = get_st_extent(arg) + else: + raise TypeError(f"unexpected permute_layout arg type: {type(arg)}") + return buf, list(st), list(extent) + + +def _as_int(x): + """Return int(x) if x is int-like, else None.""" + if isinstance(x, int): + return x + if isinstance(x, IntImm): + return int(x.value) + if hasattr(x, "value") and isinstance(x.value, int): + return int(x.value) + try: + return int(x) + except (TypeError, ValueError): + return None + + +def _layout_shard_int(layout): + """Return (extents, strides) as int lists from a TileLayout's shard, or (None, None).""" + if not isinstance(layout, TileLayout): + return None, None + extents, strides = [], [] + for it in layout.shard: + e = _as_int(it.extent) + s = _as_int(it.stride) + if e is None or s is None: + return None, None + extents.append(e) + strides.append(s) + return extents, strides + + +def _decompose_row_major(i, extent): + out, rem = [], i + for e in reversed(extent): + out.append(rem % e) + rem //= e + return list(reversed(out)) + + +def _eval_offset(idx, strides): + return sum(i * s for i, s in zip(idx, strides)) + + +def _check_bijection(extent, strides): + """Iteration extents + strides define a bijection on [0, V)?""" + V = math.prod(extent) + seen = set() + for i in range(V): + off = _eval_offset(_decompose_row_major(i, extent), strides) + if off in seen: + return False + seen.add(off) + return len(seen) == V + + +def _bank_free(extent, strides, dtype_bytes, P, k): + """For every register slot r ∈ [0, P), do the 32 lanes hit 32 distinct banks?""" + T, BANKS, BANK_W = 32, 32, 4 + shift = 5 - k + mask = (1 << k) - 1 + for r in range(P): + seen = set() + for lane in range(T): + j = r ^ ((lane >> shift) & mask) + flat = lane + j * T + idx = _decompose_row_major(flat, extent) + off_bytes = _eval_offset(idx, strides) * dtype_bytes + bank = (off_bytes // BANK_W) % BANKS + if bank in seen: + return False + seen.add(bank) + return True + + +def _choose_xor_k(extent, src_strides, dst_strides, dtype_bytes, P): + max_k = int(math.log2(P)) if P > 0 else 0 + for k in range(max_k + 1): + if _bank_free(extent, src_strides, dtype_bytes, P, k) and _bank_free( + extent, dst_strides, dtype_bytes, P, k + ): + return k + return None + + +# ---------- validator + dispatch impl --------------------------------------- + + +def _gather(op_call): + op_call = TilePrimitiveCall.downcast(op_call) + dst_arg, src_arg = op_call.args[0], op_call.args[1] + src_buf, src_st, src_ext = _as_buffer_and_region(src_arg) + dst_buf, dst_st, dst_ext = _as_buffer_and_region(dst_arg) + return src_buf, src_st, src_ext, dst_buf, dst_st, dst_ext + + +def _why_reject(op_call, sctx): + if not sctx.is_warp: + return f"scope {sctx.scope_kind!r} is not 'warp'" + if "threadIdx.y" in sctx.launch_params or "threadIdx.z" in sctx.launch_params: + return "multi-dim threadIdx is not supported" + + src_buf, src_st, src_ext, dst_buf, dst_st, dst_ext = _gather(op_call) + + if src_buf.dtype != dst_buf.dtype: + return f"dtype mismatch: dst={dst_buf.dtype} vs src={src_buf.dtype}" + + src_ext_i = [_as_int(e) for e in src_ext] + dst_ext_i = [_as_int(e) for e in dst_ext] + if None in src_ext_i or None in dst_ext_i: + return "extents must be compile-time integers" + if src_ext_i != dst_ext_i: + return f"slice shape mismatch: src={src_ext_i} vs dst={dst_ext_i}" + + dtype_bytes = DataType(src_buf.dtype).bits // 8 + if dtype_bytes not in (1, 2, 4, 8, 16): + return f"unsupported dtype byte width: {dtype_bytes}" + + if not isinstance(src_buf.layout, TileLayout): + return "src buffer's layout is not a plain TileLayout" + if not isinstance(dst_buf.layout, TileLayout): + return "dst buffer's layout is not a plain TileLayout" + + # Slice + canonicalize both layouts. The result's shard describes the + # iteration domain; runtime starts (like ``ks``) are folded into the + # layout's offset, separate from the shard's affine part. + src_region = [(s, s + e) for s, e in zip(src_st, src_ext)] + dst_region = [(s, s + e) for s, e in zip(dst_st, dst_ext)] + src_sliced = src_buf.layout.slice(list(src_buf.shape), src_region) + dst_sliced = dst_buf.layout.slice(list(dst_buf.shape), dst_region) + if src_sliced is None or dst_sliced is None: + return "layout.slice failed" + src_sliced = src_sliced.canonicalize() + dst_sliced = dst_sliced.canonicalize() + + # Iteration shape: regroup dst onto the iterated buf dims; the result's + # shard may stay finer than iter_buf_extents (one buf dim ↔ several shard + # dims via seps), which is fine. Then regroup src to match dst's shard + # extents exactly so both phases share the same iteration index space. + iter_buf_extents = [e for e in src_ext_i if e != 1] + try: + dst_grouped, dst_seps = dst_sliced.group(iter_buf_extents) + src_grouped, _ = src_sliced.group([int(it.extent) for it in dst_grouped.shard]) + except Exception as e: + return f"layout.group failed: {e}" + + dst_ext_, dst_str_ = _layout_shard_int(dst_grouped) + src_ext_, src_str_ = _layout_shard_int(src_grouped) + if dst_ext_ is None or src_ext_ is None: + return "regrouped layout shard contains non-integer extent/stride" + if src_ext_ != dst_ext_: + return f"src shard {src_ext_} doesn't match dst shard {dst_ext_} after regrouping" + + extent = dst_ext_ + V = math.prod(extent) + T = 32 + if V == 0 or V % T != 0: + return f"volume {V} not divisible by warp size {T}" + P = V // T + if P == 0 or (P & (P - 1)) != 0 or P > T: + return f"per-thread count {P} must be power of 2 in [1, {T}]" + + if not _check_bijection(extent, src_str_): + return "src layout (regrouped) is not a bijection on the slice" + if not _check_bijection(extent, dst_str_): + return "dst layout is not a bijection on the slice" + return None + + +def _impl(op_call, sctx): + src_buf, src_st, src_ext, dst_buf, dst_st, dst_ext = _gather(op_call) + src_ext_i = [_as_int(e) for e in src_ext] + dst_ext_i = [_as_int(e) for e in dst_ext] + + src_region = [(s, s + e) for s, e in zip(src_st, src_ext)] + dst_region = [(s, s + e) for s, e in zip(dst_st, dst_ext)] + src_sliced = src_buf.layout.slice(list(src_buf.shape), src_region).canonicalize() + dst_sliced = dst_buf.layout.slice(list(dst_buf.shape), dst_region).canonicalize() + + iter_buf_extents = [e for e in src_ext_i if e != 1] + dst_grouped, dst_seps = dst_sliced.group(iter_buf_extents) + src_grouped, _ = src_sliced.group([int(it.extent) for it in dst_grouped.shard]) + + extent, dst_str_ = _layout_shard_int(dst_grouped) + _, src_str_ = _layout_shard_int(src_grouped) + V = math.prod(extent) + P = V // 32 + dtype_bytes = DataType(src_buf.dtype).bits // 8 + + k_opt = _choose_xor_k(extent, src_str_, dst_str_, dtype_bytes, P) + if k_opt is None: + fail(f"no XOR-bits k ∈ [0, log2(P)={int(math.log2(P))}] makes both phases bank-free") + + shift = 5 - k_opt + mask = (1 << k_opt) - 1 + + iter_buf_dims = [i for i, e in enumerate(src_ext_i) if e != 1] + seps = list(dst_seps) + + def _project(iter_idx, st_list): + buf_idx = list(st_list) + for bi in range(len(seps) - 1): + lo, hi = seps[bi], seps[bi + 1] + flat = _flatten_coord(iter_idx[lo:hi], extent[lo:hi]) + buf_idx[iter_buf_dims[bi]] = st_list[iter_buf_dims[bi]] + flat + return tuple(buf_idx) + + tid_x = sctx.launch_params["threadIdx.x"] + dtype = src_buf.dtype + + # fmt: off + @Tx.prim_func + def impl(): + warp_size = Tx.meta_var(32) + lane_id = Tx.meta_var(tid_x % warp_size) + regs = Tx.alloc_buffer((P,), dtype, scope="local") + # Phase 1: read via L_src + for r in Tx.unroll(0, P): + j = Tx.meta_var(r ^ ((lane_id >> shift) & mask)) + flat = Tx.meta_var(lane_id + j * warp_size) + iter_idx = Tx.meta_var(get_indices(flat, [0] * len(extent), extent)) + src_idx = Tx.meta_var(_project(iter_idx, src_st)) + regs[r] = src_buf[tuple(src_idx)] + Tx.cuda.warp_sync() + # Phase 2: write via L_dst + for r in Tx.unroll(0, P): + j = Tx.meta_var(r ^ ((lane_id >> shift) & mask)) + flat = Tx.meta_var(lane_id + j * warp_size) + iter_idx = Tx.meta_var(get_indices(flat, [0] * len(extent), extent)) + dst_idx = Tx.meta_var(_project(iter_idx, dst_st)) + dst_buf[tuple(dst_idx)] = regs[r] + Tx.cuda.warp_sync() + # fmt: on + return impl + + +# === Variant: permute_layout/warp_xor_swizzle (priority=20) ============ +# +# When: warp scope; matching dst/src dtype + slice shape; both buffers carry +# a plain TileLayout; after slice + canonicalize (and regrouping src to dst's +# structure if needed), the iteration extents form a power-of-2 ≤32 elements +# per lane; both layouts are bijections on the slice; and there exists an +# XOR-bits ``k`` that makes both phases bank-conflict-free. +# +# Buffer ``shape`` rank does NOT need to equal layout ``shard`` rank — the +# dispatcher uses the layout shard for iteration (after slice+canon) and +# projects back onto ``buf.shape`` via mixed-radix grouping for the emit. +# +# Before (TilePrimitiveCall): +# with Tx.warp(): +# # SFA_smem: u32 (PIPE, BLK_SFA//32, 32), layout shard 4D +# # (PIPE, BLK_SFA//128, 4, 32) strides (BLK_SFA, 128, 32, 1) +# # SFA_post: same shape; layout shard 4D, strides (BLK_SFA, 128, 1, 4) +# Tx.permute_layout(SFA_post[ks, :, :], SFA_smem[ks, :, :]) +# +# After (BLK_SFA=128, P=4, k=2, shift=3): +# lane_id = threadIdx.x % 32 +# regs = Tx.alloc_buffer((4,), "uint32", scope="local") +# for r in Tx.unroll(4): +# j = r ^ ((lane_id >> 3) & 0x3) +# flat = lane_id + j * 32 +# (g, l) = decompose(flat, extent=[4, 32]) +# regs[r] = src[ks, g, l] +# Tx.cuda.warp_sync() +# for r in Tx.unroll(4): +# j = r ^ ((lane_id >> 3) & 0x3) +# flat = lane_id + j * 32 +# (g, l) = decompose(flat, extent=[4, 32]) +# dst[ks, g, l] = regs[r] +# Tx.cuda.warp_sync() +@register_dispatch( + "permute_layout", + "cuda", + variant="warp_xor_swizzle", + priority=20, +) +def permute_layout_dispatch(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc: + reason = _why_reject(op, sctx) + if reason is not None: + fail(reason) + return _impl(op, sctx) + + +__all__ = [ + "_bank_free", + "_check_bijection", + "_choose_xor_k", + "_decompose_row_major", + "_eval_offset", + "permute_layout_dispatch", +] diff --git a/python/tvm/tirx/operator/tile_primitive/ops.py b/python/tvm/tirx/operator/tile_primitive/ops.py index 7795e76dbf..97f16def6e 100644 --- a/python/tvm/tirx/operator/tile_primitive/ops.py +++ b/python/tvm/tirx/operator/tile_primitive/ops.py @@ -551,27 +551,60 @@ class ComposeOp(TilePrimitiveCall): ) -class PermuteDims(TilePrimitiveCall): - """Permute the tensor dimensions with given order.""" +def _register_permute_layout_op(): + """Register tirx.permute_layout dynamically (Python-only, no C++ rebuild). - op = get_tirx_op("permute_dims") + Mirrors the TIRX_DEFINE_DISPATCH_OP macro: marks the op as a TIRx op + and a dispatch op so the well-formed verifier and printer accept it. + """ + + tirx_name = "tirx.permute_layout" + try: + return Op.get(tirx_name) + except Exception: + from tvm.ir import _ffi_api as ir_ffi + from tvm.ir.op import register_op_attr + + ir_ffi.RegisterOp(tirx_name, "Permute the physical layout of a buffer in-place.") + register_op_attr(tirx_name, "TIsTIRxOp", True) + register_op_attr(tirx_name, "TIsDispatchOp", True) + register_op_attr(tirx_name, "TScriptPrinterName", "permute_layout") + return Op.get(tirx_name) + + +_register_permute_layout_op() - order = ArgProperty(1) + +class PermuteLayout(TilePrimitiveCall): + """Move data so the buffer's bytes are arranged under a different layout. + + Logical shape is preserved; only the byte placement changes. ``dst`` and + ``src`` carry their own TileLayouts; on lowering, the dispatcher reads + those layouts and emits a register-staged warp transpose, optionally + inserting a bank-conflict-avoiding XOR-swizzle on the per-lane register + slots. + + Args: ``permute_layout(dst_region, src_region)``. + ``dst`` and ``src`` may alias the same underlying SMEM (in-place). + """ + + op = get_tirx_op("permute_layout") @property - def buffer(self) -> PrimExpr: - """Get the source expressions (inputs) of the operator.""" + def dst(self) -> PrimExpr: return self.args[0] + @property + def src(self) -> PrimExpr: + return self.args[1] + @property def srcs(self) -> list[PrimExpr]: - """Get the source expressions (inputs) of the operator.""" - return [self.buffer] + return [self.src] @property def dsts(self) -> list[PrimExpr]: - """Get the destination expressions (outputs) of the operator.""" - return [self.buffer] + return [self.dst] class GenericOp(TilePrimitiveCall): diff --git a/python/tvm/tirx/script/builder/tirx.py b/python/tvm/tirx/script/builder/tirx.py index efe79e1aa5..c73a8e615a 100644 --- a/python/tvm/tirx/script/builder/tirx.py +++ b/python/tvm/tirx/script/builder/tirx.py @@ -1325,33 +1325,50 @@ def reshape(buffer: Buffer, shape: list[PrimExpr]): ) -def permute_dims( - buffer: BufferRegion | Buffer, - order: list[PrimExpr | int], +def permute_layout( + dst: BufferRegion | Buffer, + src: BufferRegion | Buffer, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, **kwargs, ): - """Permute the tensor dimensions with given order. + """Move data so the buffer's bytes are arranged under a different layout. + Logical shape is preserved (``dst.shape == src.shape``); only the + byte placement changes (``dst.layout != src.layout``). ``dst`` and + ``src`` may alias the same SMEM (in-place) or be two distinct buffers. Parameters ---------- - buffer : Union[BufferRegion, Buffer] - The tensor to be permuted. + dst : Union[BufferRegion, Buffer] + Destination view (carries the target layout). + src : Union[BufferRegion, Buffer] + Source view (carries the current layout). + workspace : Dict[str, Buffer] + Optional workspace for the operator. + dispatch : Optional[str] + Force a specific dispatch variant by name. + """ - order : List[Union[PrimExpr, int]] - The permuting order. + # Promote Buffer to BufferRegion covering the full extent, matching the + # convention used by ``Tx.<dynamic>`` fallback registration. + from tvm.tirx import Buffer as _TBuffer - workspace : Dict[str, Buffer] - The workspace of the operator. + def _to_region(b): + if isinstance(b, _TBuffer): + slices = [slice(None) for _ in range(len(b.shape))] + return b[slices] + return b - config : Dict[str, Any] - The scheduler configuration. - """ config = kwargs or {} return f_insert( - tirx_op.PermuteDims(buffer, order, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.PermuteLayout( + _to_region(dst), + _to_region(src), + workspace=workspace, + config=config, + dispatch=dispatch, + ) ) @@ -1379,7 +1396,7 @@ __all__ = [ "min", "minimum", "mul", - "permute_dims", + "permute_layout", "reciprocal", "reduce_negate", "select", diff --git a/src/tirx/op/tirx.cc b/src/tirx/op/tirx.cc index 2f205c7c3e..2e244e3f90 100644 --- a/src/tirx/op/tirx.cc +++ b/src/tirx/op/tirx.cc @@ -214,7 +214,6 @@ TIRX_DEFINE_DISPATCH_OP(select); TIRX_DEFINE_DISPATCH_OP(cast); TIRX_DEFINE_DISPATCH_OP(fma); TIRX_DEFINE_DISPATCH_OP(silu); -TIRX_DEFINE_DISPATCH_OP(permute_dims); /********************* Compose Ops **********************/ #define TIRX_DEFINE_COMPOSE_OP(OpName) \ diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py index 164a903b96..df7389e2ed 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py @@ -653,6 +653,7 @@ def test_gemm_block_scaled_fp8_cta_group_1(task): F32_BYTES = 4 F128_BYTES = 16 SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)]) + SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off @Tx.prim_func @@ -673,6 +674,8 @@ def test_gemm_block_scaled_fp8_cta_group_1(task): B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) + SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) tmem_addr = Tx.alloc_shared([1], "uint32") tma_mbar = Tx.alloc_shared([1], "uint64") mma_mbar = Tx.alloc_shared([1], "uint64") @@ -715,8 +718,8 @@ def test_gemm_block_scaled_fp8_cta_group_1(task): # Transpose scale factors in shared memory if Tx.filter(warp_id, 0, 1): with Tx.warp(): - Tx.permute_dims(SFA_smem[:, :], [1, 0]) - Tx.permute_dims(SFB_smem[:, :], [1, 0]) + Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) Tx.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp, then issue MMA @@ -856,6 +859,7 @@ def test_gemm_block_scaled_fp8_cta_group_2(task): F32_BYTES = 4 F128_BYTES = 16 SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)]) + SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off @Tx.prim_func @@ -877,6 +881,8 @@ def test_gemm_block_scaled_fp8_cta_group_2(task): B_smem = Tx.alloc_buffer(B_shape_per_cta, B_dtype, scope="shared", layout=B_layout) SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) + SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) tmem_addr = Tx.alloc_shared([1], "uint32") tma_mbar = Tx.alloc_shared([1], "uint64") mma_mbar = Tx.alloc_shared([1], "uint64") @@ -923,8 +929,8 @@ def test_gemm_block_scaled_fp8_cta_group_2(task): # Transpose scale factors (both CTAs) if Tx.filter(warp_id, 0, 1): with Tx.warp(): - Tx.permute_dims(SFA_smem[:, :], [1, 0]) - Tx.permute_dims(SFB_smem[:, :], [1, 0]) + Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) Tx.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp (both CTAs, cta_group=2) @@ -1057,6 +1063,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_1(): F32_BYTES = 4 F128_BYTES = 16 SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)]) + SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off @Tx.prim_func @@ -1080,6 +1087,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_1(): SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) + SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) tmem_addr = Tx.alloc_shared([1], "uint32") tma_mbar = Tx.alloc_shared([1], "uint64") mma_mbar = Tx.alloc_shared([1], "uint64") @@ -1122,8 +1131,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_1(): # Transpose scale factors in shared memory if Tx.filter(warp_id, 0, 1): with Tx.warp(): - Tx.permute_dims(SFA_smem[:, :], [1, 0]) - Tx.permute_dims(SFB_smem[:, :], [1, 0]) + Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) Tx.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp, then issue MMA @@ -1244,6 +1253,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_2(): F32_BYTES = 4 F128_BYTES = 16 SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)]) + SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off @Tx.prim_func @@ -1268,6 +1278,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_2(): SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) + SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) tmem_addr = Tx.alloc_shared([1], "uint32") tma_mbar = Tx.alloc_shared([1], "uint64") mma_mbar = Tx.alloc_shared([1], "uint64") @@ -1314,8 +1326,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_2(): # Transpose scale factors if Tx.filter(warp_id, 0, 1): with Tx.warp(): - Tx.permute_dims(SFA_smem[:, :], [1, 0]) - Tx.permute_dims(SFB_smem[:, :], [1, 0]) + Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) Tx.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp @@ -1456,6 +1468,7 @@ def test_gemm_block_scaled_fp8_sf_id(): F32_BYTES = 4 F128_BYTES = 16 SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)]) + SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off @Tx.prim_func @@ -1476,6 +1489,8 @@ def test_gemm_block_scaled_fp8_sf_id(): B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) + SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) tmem_addr = Tx.alloc_shared([1], "uint32") tma_mbar = Tx.alloc_shared([1], "uint64") mma_mbar = Tx.alloc_shared([1], "uint64") @@ -1518,8 +1533,8 @@ def test_gemm_block_scaled_fp8_sf_id(): # Transpose scale factors in shared memory if Tx.filter(warp_id, 0, 1): with Tx.warp(): - Tx.permute_dims(SFA_smem[:, :], [1, 0]) - Tx.permute_dims(SFB_smem[:, :], [1, 0]) + Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) Tx.cuda.cta_sync() # Copy SF to TMEM, then single MMA call (schedule auto-derives sf_id per ki) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_permute_dims.py b/tests/python/tirx/operator/tile_primitive/cuda/test_permute_dims.py deleted file mode 100644 index 3cea1eb9d6..0000000000 --- a/tests/python/tirx/operator/tile_primitive/cuda/test_permute_dims.py +++ /dev/null @@ -1,152 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-function-docstring -import ml_dtypes -import numpy as np -import pytest - -import tvm -import tvm.testing -from tvm.script import tirx as Tx -from tvm.tirx.layout import S, TileLayout - -ml_dtypes_dict = { - "float8_e4m3fn": ml_dtypes.float8_e4m3fn, - "float8_e5m2": ml_dtypes.float8_e5m2, - "bfloat16": ml_dtypes.bfloat16, - "int4": ml_dtypes.int4, -} - - [email protected]( - "task", - [ - ( - (4, 32), # a_shape - TileLayout(S[4, 32]), # layoutA - tvm.cuda(0), - ), - ( - (4, 64), # a_shape - TileLayout(S[4, 64]), # layoutA - tvm.cuda(0), - ), - ( - (3, 64), # a_shape - TileLayout(S[3, 64]), # layoutA - tvm.cuda(0), - ), - ( - (9, 64), # a_shape - TileLayout(S[9, 64]), # layoutA - tvm.cuda(0), - ), - ], -) [email protected]("dtype", ["uint8", "float16", "int32"]) -def test_vectorized_permute_dims_2d(task, dtype): - a_shape, layoutA, dev = task - list(slice(None) for _ in range(len(a_shape))) - - # fmt: off - @Tx.prim_func - def permute_dims(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=layoutA) - - with Tx.kernel(): - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_dims(A, [1, 0]) - # fmt: on - - target = tvm.target.Target("cuda") - with target: - mod = tvm.IRModule({"main": permute_dims}) - - mod = tvm.compile(mod, target=target, tir_pipeline="tirx") - print(mod.mod.imports[0].inspect_source()) - - np.random.seed(0) - A_np = tvm.testing.generate_random_array(dtype, a_shape) - - A = tvm.runtime.tensor(A_np, dev) - mod(A) - A_ref = np.transpose(A_np, (1, 0)).reshape(a_shape) - np.testing.assert_allclose(A_ref.flatten(), A.numpy().flatten()) - - [email protected]( - "task", - [ - ( - (1, 4, 32), # a_shape - TileLayout(S[1, 4, 32]), # layoutA - [0, 0, 0], - [1, 4, 32], - tvm.cuda(0), - ), - ( - (2, 2, 8, 64), # a_shape - TileLayout(S[2, 2, 8, 64]), # layoutA - [1, 1, 0, 0], - [1, 1, 8, 64], - tvm.cuda(0), - ), - ((1, 10, 40), TileLayout(S[1, 10, 40]), [0, 5, 3], [1, 4, 32], tvm.cuda(0)), - ], -) [email protected]("dtype", ["uint8", "float16", "int32"]) -def test_vectorized_permute_dims_nd(task, dtype): - a_shape, layoutA, st, extent, dev = task - ndim = len(a_shape) - region = list(slice(st[i], st[i] + extent[i]) for i in range(ndim)) - order = [*list(range(ndim - 2)), ndim - 1, ndim - 2] - - # fmt: off - @Tx.prim_func - def permute_dims(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=layoutA) - - with Tx.kernel(): - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_dims(A[tuple(region)], order) - # fmt: on - - target = tvm.target.Target("cuda") - with target: - mod = tvm.IRModule({"main": permute_dims}) - - mod = tvm.compile(mod, target=target, tir_pipeline="tirx") - print(mod.mod.imports[0].inspect_source()) - - np.random.seed(0) - A_np = tvm.testing.generate_random_array(dtype, a_shape) - - A = tvm.runtime.tensor(A_np, dev) - mod(A) - A_ref = A_np.copy() - A_ref[tuple(region)] = np.transpose(A_np[tuple(region)], order).reshape(extent) - np.testing.assert_allclose(A_ref.flatten(), A.numpy().flatten()) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_permute_layout.py b/tests/python/tirx/operator/tile_primitive/cuda/test_permute_layout.py new file mode 100644 index 0000000000..bd986ea3e5 --- /dev/null +++ b/tests/python/tirx/operator/tile_primitive/cuda/test_permute_layout.py @@ -0,0 +1,425 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring + +"""Tests for ``Tx.permute_layout``. + +Coverage: + +- The algorithm helpers (`_bank_free`, `_check_bijection`, `_choose_xor_k`) + directly, with a NumPy oracle. +- End-to-end compiled-kernel byte-for-byte equivalence on CUDA for the SF + fp8-blockwise-gemm transpose shapes (BLK_SFA = 128, 256) plus a few + generic linear↔stride-permuted layouts and additional dtypes (u8, fp16, + i32, u64). +- Reject cases: non-warp scope, dtype mismatch, shape mismatch, swizzle/ + compose layouts, layouts whose strides don't form a bijection on the + slice. Each must surface as a ``RuntimeError`` from the dispatcher and + NOT silently emit a wrong kernel. +""" + +from __future__ import annotations + +import math + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm.script import tirx as Tx +from tvm.tirx.layout import S, SwizzleLayout, TileLayout + +# Helpers exposed by the dispatcher module for direct algorithm tests. +from tvm.tirx.operator.tile_primitive.cuda.permute_layout.warp_xor_swizzle import ( + _bank_free, + _check_bijection, + _choose_xor_k, +) + +# --------------------------------------------------------------------------- +# Algorithm-only tests (no CUDA needed). +# --------------------------------------------------------------------------- + + +def _np_layout_offset(extent, strides, multi_idx): + return int(sum(s * i for s, i in zip(strides, multi_idx))) + + +def _expected_permute(src_np, src_strides, dst_strides, extent): + """Compute the expected output: dst at byte offset ``L_dst(i)`` holds the + value at ``src`` byte offset ``L_src(i)``, for every logical index i. + """ + V = math.prod(extent) + dst_np = np.zeros_like(src_np) + for flat in range(V): + idx = [] + rem = flat + for e in reversed(extent): + idx.append(rem % e) + rem //= e + idx = list(reversed(idx)) + src_off = _np_layout_offset(extent, src_strides, idx) + dst_off = _np_layout_offset(extent, dst_strides, idx) + dst_np.reshape(-1)[dst_off] = src_np.reshape(-1)[src_off] + return dst_np + + +def test_bank_free_sf_128_u32(): + """SF BLK_SFA=128: write phase has 4-way conflict at k=0, free at k=2.""" + extent = [4, 32] + src = [32, 1] + dst = [1, 4] + bytes_per = 4 + P = 4 + assert _bank_free(extent, src, bytes_per, P, 0) + assert not _bank_free(extent, dst, bytes_per, P, 0) + assert _bank_free(extent, dst, bytes_per, P, 2) + assert _choose_xor_k(extent, src, dst, bytes_per, P) == 2 + + +def test_bank_free_sf_256_u32(): + """SF BLK_SFA=256: same shift=3 (k=2) handles the high block too.""" + extent = [2, 4, 32] + src = [128, 32, 1] + dst = [128, 1, 4] + bytes_per = 4 + P = 8 + assert _bank_free(extent, src, bytes_per, P, 0) + assert not _bank_free(extent, dst, bytes_per, P, 0) + assert _bank_free(extent, dst, bytes_per, P, 2) + assert _choose_xor_k(extent, src, dst, bytes_per, P) == 2 + + +def test_identity_no_xor(): + """L_src == L_dst => k=0 (no XOR needed and the op is essentially a copy).""" + assert _choose_xor_k([4, 32], [32, 1], [32, 1], 4, 4) == 0 + # A 2D buffer with row-major to row-major is a true no-op. + assert _bank_free([4, 32], [32, 1], 4, 4, 0) + + +def test_bijection_check_rejects_aliased(): + """If two logical indices map to the same physical byte, reject.""" + # Stride 0 on a non-singleton extent => alias. + assert not _check_bijection([4, 32], [0, 1]) + # Negative or non-contiguous-but-bijective is still fine. + assert _check_bijection([4, 32], [1, 4]) + + +def test_dtype_widths_choose_xor_k(): + """Each dtype's outcome: + + The unvectorized algorithm is provably correct only when every per-lane + access maps to a single 4-byte bank. For 4-byte dtypes that always holds + (one element per bank), so we expect a valid k. For sub-4-byte dtypes + with stride-1 reads, multiple lanes share a bank no matter how we permute + register slots — the dispatcher correctly rejects those (k is None). + """ + extent = [4, 32] + src = [32, 1] # linear + dst = [1, 4] # transposed + # u32: this is the SF case; the algorithm must find k=2. + assert _choose_xor_k(extent, src, dst, 4, 4) == 2 + # u16/fp16, u8: stride-1 in bytes < 4 packs >1 lane into the same bank; + # register-slot XOR cannot fix that, so the dispatcher rejects. + assert _choose_xor_k(extent, src, dst, 2, 4) is None + assert _choose_xor_k(extent, src, dst, 1, 4) is None + + +# --------------------------------------------------------------------------- +# End-to-end compiled-kernel tests on CUDA. +# --------------------------------------------------------------------------- + + +def _has_cuda(): + try: + return tvm.cuda(0).exist + except Exception: + return False + + +needs_cuda = pytest.mark.skipif(not _has_cuda(), reason="needs CUDA") + + +def _compile_and_run(prim_func, np_inputs): + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": prim_func}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + dev = tvm.cuda(0) + tensors = [tvm.runtime.tensor(a, dev) for a in np_inputs] + mod(*tensors) + return [t.numpy() for t in tensors], mod.mod.imports[0].inspect_source() + + +@needs_cuda [email protected]( + "name, pipe, blk, dtype", + [ + ("sf_128_u32", 2, 128, "uint32"), + ("sf_256_u32", 2, 256, "uint32"), + ("sf_128_i32", 2, 128, "int32"), + ("sf_128_fp32", 2, 128, "float32"), + ], +) +def test_sf_blockwise_transpose(name, pipe, blk, dtype): + """SF blockwise-GEMM scale-factor transpose, the canonical use case.""" + high = blk // 128 if blk >= 128 else 1 + # Use 4D logical shape (PIPE, high, 4, 32) to keep the high-block factored. + shape = (pipe, high, 4, 32) + + # Element strides for src (linear) and dst (transposed within each + # 128-block). Stage stride = blk; each 128-block contributes 128 to the + # high stride. + src_strides = (blk, 128, 32, 1) + dst_strides = (blk, 128, 1, 4) + pre = TileLayout(S[shape:src_strides]) + post = TileLayout(S[shape:dst_strides]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, shape, dtype, layout=pre) + B_buf = Tx.match_buffer(B, shape, dtype, layout=post) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + for s in Tx.serial(0, pipe): + Tx.permute_layout( + B_buf[s, 0:high, 0:4, 0:32], A_buf[s, 0:high, 0:4, 0:32] + ) + # fmt: on + + np.random.seed(0) + A_np = tvm.testing.generate_random_array(dtype, shape) + B_np = np.zeros_like(A_np) + + [_, B_out], src = _compile_and_run(f, [A_np, B_np]) + + # The dispatcher must have picked the XOR-swizzled variant; check that + # the generated CUDA contains the per-lane XOR pattern. This is the + # "no perf regression" smoke test: any future variant that omits the + # XOR would re-introduce 4-way bank conflicts. + assert ">> 3" in src, f"expected XOR-swizzle (lane>>3) in CUDA for {name}" + assert "warp_sync" in src or "syncwarp" in src + + # Byte-for-byte equality via numpy reference. + for s in range(pipe): + A_flat = A_np[s].reshape(-1) + B_flat = B_out[s].reshape(-1) + ref = _expected_permute( + A_flat, + list(src_strides[1:]), + list(dst_strides[1:]), + list(shape[1:]), + ) + np.testing.assert_array_equal(B_flat, ref, err_msg=f"{name} stage {s}") + + +@needs_cuda +def test_identity_passes_through_as_copy(): + """L_src == L_dst should still compile and produce a correct (identity) copy.""" + shape = (4, 32) + layout = TileLayout(S[shape : (32, 1)]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, shape, "uint32", layout=layout) + B_buf = Tx.match_buffer(B, shape, "uint32", layout=layout) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + Tx.permute_layout(B_buf, A_buf) + # fmt: on + + np.random.seed(0) + A_np = tvm.testing.generate_random_array("uint32", shape) + B_np = np.zeros_like(A_np) + [_, B_out], _ = _compile_and_run(f, [A_np, B_np]) + np.testing.assert_array_equal(B_out, A_np) + + +@needs_cuda [email protected]("dtype", ["uint32", "int32", "float32"]) [email protected]( + "shape, src_strides, dst_strides", + [ + # (8, 32) → (8, 32) transposed: src linear, dst column-major. + ((8, 32), (32, 1), (1, 8)), + # (16, 32): per_thread = 16 — tests P=16 path. + ((16, 32), (32, 1), (1, 16)), + ], +) +def test_generic_transpose(shape, src_strides, dst_strides, dtype): + """Generic linear↔transposed pairs at various P values.""" + pre = TileLayout(S[shape:src_strides]) + post = TileLayout(S[shape:dst_strides]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, shape, dtype, layout=pre) + B_buf = Tx.match_buffer(B, shape, dtype, layout=post) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + Tx.permute_layout(B_buf, A_buf) + # fmt: on + + np.random.seed(0) + A_np = tvm.testing.generate_random_array(dtype, shape) + B_np = np.zeros_like(A_np) + [_, B_out], _ = _compile_and_run(f, [A_np, B_np]) + + ref = _expected_permute(A_np.reshape(-1), list(src_strides), list(dst_strides), list(shape)) + np.testing.assert_array_equal(B_out.reshape(-1), ref) + + +# --------------------------------------------------------------------------- +# Reject cases: the dispatcher must surface a clear error, never silently +# emit a wrong kernel. +# --------------------------------------------------------------------------- + + +def _build_and_assert_rejected(shape, src_layout, dst_layout, dtype, msg_substr): + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, shape, dtype, layout=src_layout) + B_buf = Tx.match_buffer(B, shape, dtype, layout=dst_layout) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + Tx.permute_layout(B_buf, A_buf) + # fmt: on + + target = tvm.target.Target("cuda") + with target, pytest.raises(RuntimeError) as exc_info: + mod = tvm.IRModule({"main": f}) + tvm.compile(mod, target=target, tir_pipeline="tirx") + assert msg_substr in str(exc_info.value), ( + f"expected reject reason to mention {msg_substr!r}, got: {exc_info.value}" + ) + + +def test_reject_dtype_mismatch(): + shape = (4, 32) + layout = TileLayout(S[shape : (32, 1)]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, shape, "uint32", layout=layout) + B_buf = Tx.match_buffer(B, shape, "uint16", layout=layout) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + Tx.permute_layout(B_buf, A_buf) + # fmt: on + + target = tvm.target.Target("cuda") + with target, pytest.raises(RuntimeError) as exc_info: + tvm.compile(tvm.IRModule({"main": f}), target=target, tir_pipeline="tirx") + assert "dtype mismatch" in str(exc_info.value) + + +def test_reject_shape_mismatch(): + src_layout = TileLayout(S[(4, 32) : (32, 1)]) + dst_layout = TileLayout(S[(8, 16) : (16, 1)]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=src_layout) + B_buf = Tx.match_buffer(B, (8, 16), "uint32", layout=dst_layout) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + Tx.permute_layout(B_buf, A_buf) + # fmt: on + + target = tvm.target.Target("cuda") + with target, pytest.raises(RuntimeError) as exc_info: + tvm.compile(tvm.IRModule({"main": f}), target=target, tir_pipeline="tirx") + assert "shape mismatch" in str(exc_info.value) + + +def test_reject_swizzle_layout(): + """ComposeLayout(SwizzleLayout, TileLayout) is not supported by the warp variant.""" + from tvm.tirx.layout import ComposeLayout + + inner = TileLayout(S[(4, 32) : (32, 1)]) + sw = SwizzleLayout(per_element=2, swizzle_len=2, atom_len=4) + swizzled = ComposeLayout(sw, inner) + plain = TileLayout(S[(4, 32) : (1, 4)]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=swizzled) + B_buf = Tx.match_buffer(B, (4, 32), "uint32", layout=plain) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + with Tx.warp(): + Tx.permute_layout(B_buf, A_buf) + # fmt: on + + target = tvm.target.Target("cuda") + with target, pytest.raises(RuntimeError) as exc_info: + tvm.compile(tvm.IRModule({"main": f}), target=target, tir_pipeline="tirx") + assert "TileLayout" in str(exc_info.value) + + +def test_reject_non_warp_scope(): + layout_pre = TileLayout(S[(4, 32) : (32, 1)]) + layout_post = TileLayout(S[(4, 32) : (1, 4)]) + + # fmt: off + @Tx.prim_func + def f(A: Tx.handle, B: Tx.handle): + A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=layout_pre) + B_buf = Tx.match_buffer(B, (4, 32), "uint32", layout=layout_post) + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + Tx.permute_layout(B_buf, A_buf) # cta scope, not warp + # fmt: on + + target = tvm.target.Target("cuda") + with target, pytest.raises(RuntimeError) as exc_info: + tvm.compile(tvm.IRModule({"main": f}), target=target, tir_pipeline="tirx") + assert "warp" in str(exc_info.value) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tirx/test_op.py b/tests/python/tirx/test_op.py index 8de3462c7c..82295cdb3c 100644 --- a/tests/python/tirx/test_op.py +++ b/tests/python/tirx/test_op.py @@ -179,17 +179,6 @@ def test_buffer_replacer_no_shared_default(): assert len(r2.buffer_map) == 0 -def test_permute_dims_buffer_property(): - """Regression test for F2: PermuteDims.buffer should return args[0], not recurse.""" - from tvm.tirx.operator.tile_primitive.ops import PermuteDims - - A = decl_buffer((64, 64), "float32", scope="global") - pd = PermuteDims(A[0:64, 0:64], [1, 0]) - # This would stack overflow before the fix - buf = pd.buffer - assert buf is not None - - def test_gemm_async_partial_scale_factor(): """Regression test for F7: gemm_async must reject partial scale factors.""" from tvm.tirx.script.builder.tirx import gemm_async @@ -219,5 +208,4 @@ if __name__ == "__main__": test_tx_existing_op_not_overridden() test_opcall_downcast_tolerant() test_buffer_replacer_no_shared_default() - test_permute_dims_buffer_property() test_gemm_async_partial_scale_factor()
