This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f7f287f88f0487838e043a50cdc43999fd9b849d
Author: Bohan Hou <[email protected]>
AuthorDate: Sun May 17 21:50:11 2026 -0400

    feat(op): add permute_layout primitive; remove permute_dims (#629)
---
 include/tvm/tirx/tirx_op.h                         |   7 -
 python/tvm/tirx/layout.py                          |  24 ++
 .../tvm/tirx/operator/tile_primitive/__init__.py   |   2 +-
 .../cuda/permute_dims/vectorized_last_2d.py        | 151 --------
 .../{permute_dims => permute_layout}/__init__.py   |   2 +-
 .../cuda/permute_layout/warp_xor_swizzle.py        | 389 +++++++++++++++++++
 python/tvm/tirx/operator/tile_primitive/ops.py     |  53 ++-
 python/tvm/tirx/script/builder/tirx.py             |  47 ++-
 src/tirx/op/tirx.cc                                |   1 -
 .../tile_primitive/cuda/test_gemm_async.py         |  35 +-
 .../tile_primitive/cuda/test_permute_dims.py       | 152 --------
 .../tile_primitive/cuda/test_permute_layout.py     | 425 +++++++++++++++++++++
 tests/python/tirx/test_op.py                       |  12 -
 13 files changed, 940 insertions(+), 360 deletions(-)

diff --git a/include/tvm/tirx/tirx_op.h b/include/tvm/tirx/tirx_op.h
index 7da9e9af0e..18a40f1f5c 100644
--- a/include/tvm/tirx/tirx_op.h
+++ b/include/tvm/tirx/tirx_op.h
@@ -220,13 +220,6 @@ class DispatchContext : public ffi::ObjectRef {
  */
 TVM_DLL const Op& cast();
 
-/*!
- * \brief See pesudo code below:
- *
- * Tx.permute_dims(BufferRegion buffer, List order)
- */
-TVM_DLL const Op& permute_dims();
-
 /*!
  * \brief See pesudo code below:
  *
diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py
index d5c29faee8..e05d9a37f7 100644
--- a/python/tvm/tirx/layout.py
+++ b/python/tvm/tirx/layout.py
@@ -403,6 +403,30 @@ class Layout(Object):
         else:
             raise ValueError(f"Unsupported layout type: {type(self)}")
 
+    def broadcast(self, num: int, position: int = -1, axis: '"Axis" | str' = 
"m") -> "Layout":
+        """Insert a stride-0 broadcast dim of extent ``num`` at ``position``.
+
+        ``position`` follows Python list-insert semantics (negative indices
+        count from the end; ``-1`` appends after the last shard dim).  The
+        new dim has stride 0 — accessing along it doesn't move the byte
+        offset, so the same physical element is "seen" ``num`` times.
+
+        Useful for layouts where a consumer reads the same SMEM datum
+        multiple times (e.g. ``sf_reuse`` over MMA-K steps).
+        """
+        if isinstance(self, TileLayout):
+            if isinstance(axis, str):
+                axis = Axis.get(axis)
+            new_iter = Iter(num, 0, axis)
+            shard = list(self.shard)
+            insert_at = position if position >= 0 else len(shard) + 1 + 
position
+            shard.insert(insert_at, new_iter)
+            return TileLayout.from_iters(shard, self.replica, self.offset)
+        elif isinstance(self, ComposeLayout):
+            return ComposeLayout(self.swizzle, self.tile_layout.broadcast(num, 
position, axis))
+        else:
+            raise ValueError(f"broadcast not supported for {type(self)}")
+
     def pack(self, num: int) -> "Layout":
         """Pack the layout, where num contiguous elements in the layout are 
packed into a single element.
 
diff --git a/python/tvm/tirx/operator/tile_primitive/__init__.py 
b/python/tvm/tirx/operator/tile_primitive/__init__.py
index 345059bd68..23c76ff3fb 100644
--- a/python/tvm/tirx/operator/tile_primitive/__init__.py
+++ b/python/tvm/tirx/operator/tile_primitive/__init__.py
@@ -28,7 +28,7 @@ from .registry import DispatchContext
 from .cuda.copy import *
 from .cuda.reduction import *
 from .cuda.copy_async import *
-from .cuda.permute_dims import *
+from .cuda.permute_layout import *
 from .cuda.gemm_async import *
 from .cuda.elementwise import *
 from .trn import *
diff --git 
a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/vectorized_last_2d.py
 
b/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/vectorized_last_2d.py
deleted file mode 100644
index c468ed1d92..0000000000
--- 
a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/vectorized_last_2d.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""CUDA permute_dims dispatch: vectorized_permute_dims_last_2d variant."""
-
-import math
-
-from tvm.script import tirx as Tx
-from tvm.tirx import Buffer, BufferRegion, PrimFunc
-from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, 
register_dispatch
-from tvm.tirx.stmt import TilePrimitiveCall
-
-from ..common import get_indices, get_st_extent
-
-
-def validate_deepgemm_permute_dims(op_call: TilePrimitiveCall, sctx: 
DispatchContext) -> bool:
-    op_call = TilePrimitiveCall.downcast(op_call)
-    if isinstance(op_call.buffer, Buffer):
-        buffer: Buffer = op_call.buffer
-        extent = buffer.shape
-    elif isinstance(op_call.buffer, BufferRegion):
-        buffer: Buffer = op_call.buffer.buffer
-        st, extent = get_st_extent(op_call.buffer)
-
-    order = op_call.order
-    if sctx.is_warp:
-        assert "threadIdx.y" not in sctx.launch_params and "threadIdx.z" not 
in sctx.launch_params
-        ndim = len(order)
-        expected_order = [*list(range(ndim - 2)), ndim - 1, ndim - 2]
-        if list(order) != expected_order:
-            return False
-        if not math.prod(extent[:-2]) == 1:
-            return False
-        strides = list(buffer.strides)
-        if not (strides == [] or (strides[-1] == 1 and strides[-2] == 
extent[-1])):
-            return False
-        return True
-    return False
-
-
-def vectorized_permute_dims_last_2d_impl(
-    op_call: TilePrimitiveCall, sctx: DispatchContext
-) -> PrimFunc | None:
-    op_call = TilePrimitiveCall.downcast(op_call)
-    if isinstance(op_call.buffer, Buffer):
-        buffer: Buffer = op_call.buffer
-        extent = shape = buffer.shape
-        st = [0] * len(extent)
-    elif isinstance(op_call.buffer, BufferRegion):
-        buffer: Buffer = op_call.buffer.buffer
-        shape = buffer.shape
-        st, extent = get_st_extent(op_call.buffer)
-
-    M, N = extent[-2:]
-    vec_len = op_call.config.get("vec_len")
-
-    if vec_len is None:
-        for vec_len in range(4, 0, -1):
-            if M % vec_len == 0:
-                break
-
-    if not shape[-1] % vec_len == 0:
-        vec_len = 1
-    if not (st[-2] * shape[-1] + st[-1]) % vec_len == 0:
-        vec_len = 1
-
-    # Thread and vectorization setup
-    if sctx.is_warp:
-        tid_x = sctx.launch_params["threadIdx.x"]
-        assert "threadIdx.y" not in sctx.launch_params and "threadIdx.z" not 
in sctx.launch_params
-
-        # fmt: off
-        @Tx.prim_func
-        def impl():
-            warp_size = Tx.meta_var(32)
-            lane_id = Tx.meta_var(tid_x % warp_size)
-            reg_trans = Tx.alloc_buffer((N // warp_size, M // vec_len, 
vec_len), buffer.dtype, scope="local")  # noqa: E501
-            for wi in Tx.unroll(0, N // warp_size):
-                for vi in Tx.unroll(0, M // vec_len):
-                    for vec in Tx.unroll(vec_len):
-                        old_index = Tx.meta_var(get_indices((vi * vec_len + 
vec) * N + wi * warp_size + lane_id, st, extent))  # noqa: E501
-                        reg_trans[wi, vi, vec] = buffer[tuple(old_index)]
-            Tx.cuda.warp_sync()
-            for wi in Tx.unroll(0, N // warp_size):
-                for vi in Tx.unroll(0, M // vec_len):
-                    for vec in Tx.vectorized(vec_len):
-                        new_index = Tx.meta_var(get_indices((wi * warp_size + 
lane_id) * M + vi * vec_len + vec, st, extent))  # noqa: E501
-                        buffer[tuple(new_index)] = reg_trans[wi, vi, vec]
-            Tx.cuda.warp_sync()
-        # fmt: on
-    else:
-        raise NotImplementedError
-    return impl
-
-
-# === Variant: permute_dims/vectorized_permute_dims_last_2d (priority=20) ===
-#
-# When: shared-memory buffer with TileLayout, permutation swaps only the last
-# 2 dimensions (e.g. [0,1,3,2] for 4D), at warp scope. In-place transpose.
-#
-# Before (TilePrimitiveCall):
-#     with Tx.warp():
-#         Tx.permute_dims(A_smem[0:64, 0:64], order=[1, 0])
-#         # A_smem: shared float16 (64, 64), in-place transpose
-#
-# After (warp-level register-buffered transpose, vec_len=4):
-#     lane_id = threadIdx.x % 32
-#     reg_trans = Tx.alloc_buffer((2, 16, 4), "float16", scope="local")
-#     # Phase 1: read rows into registers (each lane reads a column stripe)
-#     for wi in Tx.unroll(2):                          # N // warp_size
-#         for vi in Tx.unroll(16):                     # M // vec_len
-#             for vec in Tx.unroll(4):
-#                 reg_trans[wi, vi, vec] = A_smem[(vi*4+vec)*64 + 
wi*32+lane_id]
-#     Tx.cuda.warp_sync()
-#     # Phase 2: write back transposed (column index becomes row)
-#     for wi in Tx.unroll(2):
-#         for vi in Tx.unroll(16):
-#             for vec in Tx.vectorized(4):
-#                 A_smem[(wi*32+lane_id)*64 + vi*4+vec] = reg_trans[wi, vi, 
vec]
-#     Tx.cuda.warp_sync()
-@register_dispatch(
-    "permute_dims",
-    "cuda",
-    variant="vectorized_permute_dims_last_2d",
-    priority=20,
-    when=[
-        predicate(
-            "validate_deepgemm_permute_dims",
-            lambda op, sctx: (
-                validate_deepgemm_permute_dims(op, sctx),
-                "validate_deepgemm_permute_dims failed",
-            ),
-        )
-    ],
-)
-def permute_dims_dispatch(op: TilePrimitiveCall, sctx: DispatchContext) -> 
PrimFunc | None:
-    return vectorized_permute_dims_last_2d_impl(op, sctx)
diff --git 
a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/__init__.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py
similarity index 95%
rename from 
python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/__init__.py
rename to 
python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py
index 172da2d78b..e406e9c3fd 100644
--- a/python/tvm/tirx/operator/tile_primitive/cuda/permute_dims/__init__.py
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py
@@ -15,4 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from .vectorized_last_2d import *
+from .warp_xor_swizzle import *
diff --git 
a/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py
 
b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py
new file mode 100644
index 0000000000..3868d21e72
--- /dev/null
+++ 
b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CUDA permute_layout dispatch: warp register-staged in-place transpose with
+optional per-lane XOR-swizzle to avoid SMEM bank conflicts on the write phase.
+
+The dispatcher reasons about the **layout's shard**, not the buffer's
+declared shape (the two can differ — a buffer with ``shape=(PIPE, M, K)``
+may carry a layout whose shard has more dims internally, with grouping
+mapping shard segments onto buffer dims).  Concretely:
+
+    src_sliced = src.layout.slice(src.shape, region).canonicalize()
+    dst_sliced = dst.layout.slice(dst.shape, region).canonicalize()
+    # If the two sliced shards have different structures (which is common —
+    # a linear layout collapses to 1D under canon while a transposed one
+    # keeps its multi-dim structure), regroup src to dst's shape.
+    if src_sliced.shard != dst_sliced.shard:
+        src_sliced, _ = src_sliced.group(dst.shard.extents)
+    extent  = [int(it.extent) for it in dst_sliced.shard]   # iteration shape
+    src_str = [int(it.stride) for it in src_sliced.shard]
+    dst_str = [int(it.stride) for it in dst_sliced.shard]
+
+The algorithm:
+
+    regs[P]
+    for r in 0..P:
+        j  = r XOR ((lane >> SHIFT) & MASK)
+        i  = lane + j * 32                             # flat logical index
+        idx = decompose(i, extent)                     # iter multi-dim index
+        regs[r] = src[project(idx, src.shape, slice_starts)]
+    warp_sync()
+    for r in 0..P:
+        j  = r XOR ((lane >> SHIFT) & MASK)
+        i  = lane + j * 32
+        idx = decompose(i, extent)
+        dst[project(idx, dst.shape, slice_starts)] = regs[r]
+    warp_sync()
+
+where ``project`` mixed-radix-folds the iter shard dims back onto the
+buffer's iterated slice dims (so the emit's index matches buf.shape rank,
+which TIR's BufferLoad/Store requires).
+
+SHIFT and MASK are chosen by simulating the bank pattern at the **shard
+granularity** (where strides are affine), trying k = 0, 1, …, log2(P)
+and picking the smallest k that makes both phases bank-conflict-free.
+
+Correctness rests on:
+
+* For each lane, ``r ↦ r XOR const`` is a bijection on ``[0, P)``.
+* Therefore (lane, r) ↔ flat over [0, V).
+* Both layouts are verified bijections on the slice (every logical
+  position has a unique byte offset under that layout).
+* The mixed-radix projection from iter shard idx to buf coord is exactly
+  what TIR's BufferLoad does internally when buf.shape rank < shard rank
+  — so iter shard's strides and the buffer-indexed byte offset agree.
+"""
+
+from __future__ import annotations
+
+import math
+
+from tvm.runtime import DataType
+from tvm.script import tirx as Tx
+from tvm.tirx import Buffer, BufferRegion, IntImm, PrimFunc
+from tvm.tirx.layout import TileLayout, _flatten_coord
+from tvm.tirx.operator.tile_primitive import DispatchContext, fail, 
register_dispatch
+from tvm.tirx.stmt import TilePrimitiveCall
+
+from ..common import get_indices, get_st_extent
+
+# ---------- helpers ----------------------------------------------------------
+
+
+def _as_buffer_and_region(arg):
+    """Normalize a Buffer or BufferRegion to (buffer, start_list, 
extent_list)."""
+    if isinstance(arg, Buffer):
+        buf = arg
+        extent = list(buf.shape)
+        st = [0] * len(extent)
+    elif isinstance(arg, BufferRegion):
+        buf = arg.buffer
+        st, extent = get_st_extent(arg)
+    else:
+        raise TypeError(f"unexpected permute_layout arg type: {type(arg)}")
+    return buf, list(st), list(extent)
+
+
+def _as_int(x):
+    """Return int(x) if x is int-like, else None."""
+    if isinstance(x, int):
+        return x
+    if isinstance(x, IntImm):
+        return int(x.value)
+    if hasattr(x, "value") and isinstance(x.value, int):
+        return int(x.value)
+    try:
+        return int(x)
+    except (TypeError, ValueError):
+        return None
+
+
+def _layout_shard_int(layout):
+    """Return (extents, strides) as int lists from a TileLayout's shard, or 
(None, None)."""
+    if not isinstance(layout, TileLayout):
+        return None, None
+    extents, strides = [], []
+    for it in layout.shard:
+        e = _as_int(it.extent)
+        s = _as_int(it.stride)
+        if e is None or s is None:
+            return None, None
+        extents.append(e)
+        strides.append(s)
+    return extents, strides
+
+
+def _decompose_row_major(i, extent):
+    out, rem = [], i
+    for e in reversed(extent):
+        out.append(rem % e)
+        rem //= e
+    return list(reversed(out))
+
+
+def _eval_offset(idx, strides):
+    return sum(i * s for i, s in zip(idx, strides))
+
+
+def _check_bijection(extent, strides):
+    """Iteration extents + strides define a bijection on [0, V)?"""
+    V = math.prod(extent)
+    seen = set()
+    for i in range(V):
+        off = _eval_offset(_decompose_row_major(i, extent), strides)
+        if off in seen:
+            return False
+        seen.add(off)
+    return len(seen) == V
+
+
+def _bank_free(extent, strides, dtype_bytes, P, k):
+    """For every register slot r ∈ [0, P), do the 32 lanes hit 32 distinct 
banks?"""
+    T, BANKS, BANK_W = 32, 32, 4
+    shift = 5 - k
+    mask = (1 << k) - 1
+    for r in range(P):
+        seen = set()
+        for lane in range(T):
+            j = r ^ ((lane >> shift) & mask)
+            flat = lane + j * T
+            idx = _decompose_row_major(flat, extent)
+            off_bytes = _eval_offset(idx, strides) * dtype_bytes
+            bank = (off_bytes // BANK_W) % BANKS
+            if bank in seen:
+                return False
+            seen.add(bank)
+    return True
+
+
+def _choose_xor_k(extent, src_strides, dst_strides, dtype_bytes, P):
+    max_k = int(math.log2(P)) if P > 0 else 0
+    for k in range(max_k + 1):
+        if _bank_free(extent, src_strides, dtype_bytes, P, k) and _bank_free(
+            extent, dst_strides, dtype_bytes, P, k
+        ):
+            return k
+    return None
+
+
+# ---------- validator + dispatch impl ---------------------------------------
+
+
+def _gather(op_call):
+    op_call = TilePrimitiveCall.downcast(op_call)
+    dst_arg, src_arg = op_call.args[0], op_call.args[1]
+    src_buf, src_st, src_ext = _as_buffer_and_region(src_arg)
+    dst_buf, dst_st, dst_ext = _as_buffer_and_region(dst_arg)
+    return src_buf, src_st, src_ext, dst_buf, dst_st, dst_ext
+
+
+def _why_reject(op_call, sctx):
+    if not sctx.is_warp:
+        return f"scope {sctx.scope_kind!r} is not 'warp'"
+    if "threadIdx.y" in sctx.launch_params or "threadIdx.z" in 
sctx.launch_params:
+        return "multi-dim threadIdx is not supported"
+
+    src_buf, src_st, src_ext, dst_buf, dst_st, dst_ext = _gather(op_call)
+
+    if src_buf.dtype != dst_buf.dtype:
+        return f"dtype mismatch: dst={dst_buf.dtype} vs src={src_buf.dtype}"
+
+    src_ext_i = [_as_int(e) for e in src_ext]
+    dst_ext_i = [_as_int(e) for e in dst_ext]
+    if None in src_ext_i or None in dst_ext_i:
+        return "extents must be compile-time integers"
+    if src_ext_i != dst_ext_i:
+        return f"slice shape mismatch: src={src_ext_i} vs dst={dst_ext_i}"
+
+    dtype_bytes = DataType(src_buf.dtype).bits // 8
+    if dtype_bytes not in (1, 2, 4, 8, 16):
+        return f"unsupported dtype byte width: {dtype_bytes}"
+
+    if not isinstance(src_buf.layout, TileLayout):
+        return "src buffer's layout is not a plain TileLayout"
+    if not isinstance(dst_buf.layout, TileLayout):
+        return "dst buffer's layout is not a plain TileLayout"
+
+    # Slice + canonicalize both layouts.  The result's shard describes the
+    # iteration domain; runtime starts (like ``ks``) are folded into the
+    # layout's offset, separate from the shard's affine part.
+    src_region = [(s, s + e) for s, e in zip(src_st, src_ext)]
+    dst_region = [(s, s + e) for s, e in zip(dst_st, dst_ext)]
+    src_sliced = src_buf.layout.slice(list(src_buf.shape), src_region)
+    dst_sliced = dst_buf.layout.slice(list(dst_buf.shape), dst_region)
+    if src_sliced is None or dst_sliced is None:
+        return "layout.slice failed"
+    src_sliced = src_sliced.canonicalize()
+    dst_sliced = dst_sliced.canonicalize()
+
+    # Iteration shape: regroup dst onto the iterated buf dims; the result's
+    # shard may stay finer than iter_buf_extents (one buf dim ↔ several shard
+    # dims via seps), which is fine.  Then regroup src to match dst's shard
+    # extents exactly so both phases share the same iteration index space.
+    iter_buf_extents = [e for e in src_ext_i if e != 1]
+    try:
+        dst_grouped, dst_seps = dst_sliced.group(iter_buf_extents)
+        src_grouped, _ = src_sliced.group([int(it.extent) for it in 
dst_grouped.shard])
+    except Exception as e:
+        return f"layout.group failed: {e}"
+
+    dst_ext_, dst_str_ = _layout_shard_int(dst_grouped)
+    src_ext_, src_str_ = _layout_shard_int(src_grouped)
+    if dst_ext_ is None or src_ext_ is None:
+        return "regrouped layout shard contains non-integer extent/stride"
+    if src_ext_ != dst_ext_:
+        return f"src shard {src_ext_} doesn't match dst shard {dst_ext_} after 
regrouping"
+
+    extent = dst_ext_
+    V = math.prod(extent)
+    T = 32
+    if V == 0 or V % T != 0:
+        return f"volume {V} not divisible by warp size {T}"
+    P = V // T
+    if P == 0 or (P & (P - 1)) != 0 or P > T:
+        return f"per-thread count {P} must be power of 2 in [1, {T}]"
+
+    if not _check_bijection(extent, src_str_):
+        return "src layout (regrouped) is not a bijection on the slice"
+    if not _check_bijection(extent, dst_str_):
+        return "dst layout is not a bijection on the slice"
+    return None
+
+
+def _impl(op_call, sctx):
+    src_buf, src_st, src_ext, dst_buf, dst_st, dst_ext = _gather(op_call)
+    src_ext_i = [_as_int(e) for e in src_ext]
+    dst_ext_i = [_as_int(e) for e in dst_ext]
+
+    src_region = [(s, s + e) for s, e in zip(src_st, src_ext)]
+    dst_region = [(s, s + e) for s, e in zip(dst_st, dst_ext)]
+    src_sliced = src_buf.layout.slice(list(src_buf.shape), 
src_region).canonicalize()
+    dst_sliced = dst_buf.layout.slice(list(dst_buf.shape), 
dst_region).canonicalize()
+
+    iter_buf_extents = [e for e in src_ext_i if e != 1]
+    dst_grouped, dst_seps = dst_sliced.group(iter_buf_extents)
+    src_grouped, _ = src_sliced.group([int(it.extent) for it in 
dst_grouped.shard])
+
+    extent, dst_str_ = _layout_shard_int(dst_grouped)
+    _, src_str_ = _layout_shard_int(src_grouped)
+    V = math.prod(extent)
+    P = V // 32
+    dtype_bytes = DataType(src_buf.dtype).bits // 8
+
+    k_opt = _choose_xor_k(extent, src_str_, dst_str_, dtype_bytes, P)
+    if k_opt is None:
+        fail(f"no XOR-bits k ∈ [0, log2(P)={int(math.log2(P))}] makes both 
phases bank-free")
+
+    shift = 5 - k_opt
+    mask = (1 << k_opt) - 1
+
+    iter_buf_dims = [i for i, e in enumerate(src_ext_i) if e != 1]
+    seps = list(dst_seps)
+
+    def _project(iter_idx, st_list):
+        buf_idx = list(st_list)
+        for bi in range(len(seps) - 1):
+            lo, hi = seps[bi], seps[bi + 1]
+            flat = _flatten_coord(iter_idx[lo:hi], extent[lo:hi])
+            buf_idx[iter_buf_dims[bi]] = st_list[iter_buf_dims[bi]] + flat
+        return tuple(buf_idx)
+
+    tid_x = sctx.launch_params["threadIdx.x"]
+    dtype = src_buf.dtype
+
+    # fmt: off
+    @Tx.prim_func
+    def impl():
+        warp_size = Tx.meta_var(32)
+        lane_id = Tx.meta_var(tid_x % warp_size)
+        regs = Tx.alloc_buffer((P,), dtype, scope="local")
+        # Phase 1: read via L_src
+        for r in Tx.unroll(0, P):
+            j = Tx.meta_var(r ^ ((lane_id >> shift) & mask))
+            flat = Tx.meta_var(lane_id + j * warp_size)
+            iter_idx = Tx.meta_var(get_indices(flat, [0] * len(extent), 
extent))
+            src_idx = Tx.meta_var(_project(iter_idx, src_st))
+            regs[r] = src_buf[tuple(src_idx)]
+        Tx.cuda.warp_sync()
+        # Phase 2: write via L_dst
+        for r in Tx.unroll(0, P):
+            j = Tx.meta_var(r ^ ((lane_id >> shift) & mask))
+            flat = Tx.meta_var(lane_id + j * warp_size)
+            iter_idx = Tx.meta_var(get_indices(flat, [0] * len(extent), 
extent))
+            dst_idx = Tx.meta_var(_project(iter_idx, dst_st))
+            dst_buf[tuple(dst_idx)] = regs[r]
+        Tx.cuda.warp_sync()
+    # fmt: on
+    return impl
+
+
+# === Variant: permute_layout/warp_xor_swizzle (priority=20) ============
+#
+# When: warp scope; matching dst/src dtype + slice shape; both buffers carry
+# a plain TileLayout; after slice + canonicalize (and regrouping src to dst's
+# structure if needed), the iteration extents form a power-of-2 ≤32 elements
+# per lane; both layouts are bijections on the slice; and there exists an
+# XOR-bits ``k`` that makes both phases bank-conflict-free.
+#
+# Buffer ``shape`` rank does NOT need to equal layout ``shard`` rank — the
+# dispatcher uses the layout shard for iteration (after slice+canon) and
+# projects back onto ``buf.shape`` via mixed-radix grouping for the emit.
+#
+# Before (TilePrimitiveCall):
+#     with Tx.warp():
+#         # SFA_smem: u32 (PIPE, BLK_SFA//32, 32), layout shard 4D
+#         #   (PIPE, BLK_SFA//128, 4, 32) strides (BLK_SFA, 128, 32, 1)
+#         # SFA_post: same shape; layout shard 4D, strides (BLK_SFA, 128, 1, 4)
+#         Tx.permute_layout(SFA_post[ks, :, :], SFA_smem[ks, :, :])
+#
+# After (BLK_SFA=128, P=4, k=2, shift=3):
+#     lane_id = threadIdx.x % 32
+#     regs = Tx.alloc_buffer((4,), "uint32", scope="local")
+#     for r in Tx.unroll(4):
+#         j = r ^ ((lane_id >> 3) & 0x3)
+#         flat = lane_id + j * 32
+#         (g, l) = decompose(flat, extent=[4, 32])
+#         regs[r] = src[ks, g, l]
+#     Tx.cuda.warp_sync()
+#     for r in Tx.unroll(4):
+#         j = r ^ ((lane_id >> 3) & 0x3)
+#         flat = lane_id + j * 32
+#         (g, l) = decompose(flat, extent=[4, 32])
+#         dst[ks, g, l] = regs[r]
+#     Tx.cuda.warp_sync()
+@register_dispatch(
+    "permute_layout",
+    "cuda",
+    variant="warp_xor_swizzle",
+    priority=20,
+)
+def permute_layout_dispatch(op: TilePrimitiveCall, sctx: DispatchContext) -> 
PrimFunc:
+    reason = _why_reject(op, sctx)
+    if reason is not None:
+        fail(reason)
+    return _impl(op, sctx)
+
+
+__all__ = [
+    "_bank_free",
+    "_check_bijection",
+    "_choose_xor_k",
+    "_decompose_row_major",
+    "_eval_offset",
+    "permute_layout_dispatch",
+]
diff --git a/python/tvm/tirx/operator/tile_primitive/ops.py 
b/python/tvm/tirx/operator/tile_primitive/ops.py
index 7795e76dbf..97f16def6e 100644
--- a/python/tvm/tirx/operator/tile_primitive/ops.py
+++ b/python/tvm/tirx/operator/tile_primitive/ops.py
@@ -551,27 +551,60 @@ class ComposeOp(TilePrimitiveCall):
         )
 
 
-class PermuteDims(TilePrimitiveCall):
-    """Permute the tensor dimensions with given order."""
+def _register_permute_layout_op():
+    """Register tirx.permute_layout dynamically (Python-only, no C++ rebuild).
 
-    op = get_tirx_op("permute_dims")
+    Mirrors the TIRX_DEFINE_DISPATCH_OP macro: marks the op as a TIRx op
+    and a dispatch op so the well-formed verifier and printer accept it.
+    """
+
+    tirx_name = "tirx.permute_layout"
+    try:
+        return Op.get(tirx_name)
+    except Exception:
+        from tvm.ir import _ffi_api as ir_ffi
+        from tvm.ir.op import register_op_attr
+
+        ir_ffi.RegisterOp(tirx_name, "Permute the physical layout of a buffer 
in-place.")
+        register_op_attr(tirx_name, "TIsTIRxOp", True)
+        register_op_attr(tirx_name, "TIsDispatchOp", True)
+        register_op_attr(tirx_name, "TScriptPrinterName", "permute_layout")
+        return Op.get(tirx_name)
+
+
+_register_permute_layout_op()
 
-    order = ArgProperty(1)
+
+class PermuteLayout(TilePrimitiveCall):
+    """Move data so the buffer's bytes are arranged under a different layout.
+
+    Logical shape is preserved; only the byte placement changes. ``dst`` and
+    ``src`` carry their own TileLayouts; on lowering, the dispatcher reads
+    those layouts and emits a register-staged warp transpose, optionally
+    inserting a bank-conflict-avoiding XOR-swizzle on the per-lane register
+    slots.
+
+    Args: ``permute_layout(dst_region, src_region)``.
+    ``dst`` and ``src`` may alias the same underlying SMEM (in-place).
+    """
+
+    op = get_tirx_op("permute_layout")
 
     @property
-    def buffer(self) -> PrimExpr:
-        """Get the source expressions (inputs) of the operator."""
+    def dst(self) -> PrimExpr:
         return self.args[0]
 
+    @property
+    def src(self) -> PrimExpr:
+        return self.args[1]
+
     @property
     def srcs(self) -> list[PrimExpr]:
-        """Get the source expressions (inputs) of the operator."""
-        return [self.buffer]
+        return [self.src]
 
     @property
     def dsts(self) -> list[PrimExpr]:
-        """Get the destination expressions (outputs) of the operator."""
-        return [self.buffer]
+        return [self.dst]
 
 
 class GenericOp(TilePrimitiveCall):
diff --git a/python/tvm/tirx/script/builder/tirx.py 
b/python/tvm/tirx/script/builder/tirx.py
index efe79e1aa5..c73a8e615a 100644
--- a/python/tvm/tirx/script/builder/tirx.py
+++ b/python/tvm/tirx/script/builder/tirx.py
@@ -1325,33 +1325,50 @@ def reshape(buffer: Buffer, shape: list[PrimExpr]):
     )
 
 
-def permute_dims(
-    buffer: BufferRegion | Buffer,
-    order: list[PrimExpr | int],
+def permute_layout(
+    dst: BufferRegion | Buffer,
+    src: BufferRegion | Buffer,
     workspace: dict[str, Buffer] | None = None,
     dispatch: str | None = None,
     **kwargs,
 ):
-    """Permute the tensor dimensions with given order.
+    """Move data so the buffer's bytes are arranged under a different layout.
 
+    Logical shape is preserved (``dst.shape == src.shape``); only the
+    byte placement changes (``dst.layout != src.layout``). ``dst`` and
+    ``src`` may alias the same SMEM (in-place) or be two distinct buffers.
 
     Parameters
     ----------
-    buffer : Union[BufferRegion, Buffer]
-        The tensor to be permuted.
+    dst : Union[BufferRegion, Buffer]
+        Destination view (carries the target layout).
+    src : Union[BufferRegion, Buffer]
+        Source view (carries the current layout).
+    workspace : Dict[str, Buffer]
+        Optional workspace for the operator.
+    dispatch : Optional[str]
+        Force a specific dispatch variant by name.
+    """
 
-    order : List[Union[PrimExpr, int]]
-        The permuting order.
+    # Promote Buffer to BufferRegion covering the full extent, matching the
+    # convention used by ``Tx.<dynamic>`` fallback registration.
+    from tvm.tirx import Buffer as _TBuffer
 
-    workspace : Dict[str, Buffer]
-        The workspace of the operator.
+    def _to_region(b):
+        if isinstance(b, _TBuffer):
+            slices = [slice(None) for _ in range(len(b.shape))]
+            return b[slices]
+        return b
 
-    config : Dict[str, Any]
-        The scheduler configuration.
-    """
     config = kwargs or {}
     return f_insert(
-        tirx_op.PermuteDims(buffer, order, workspace=workspace, config=config, 
dispatch=dispatch)
+        tirx_op.PermuteLayout(
+            _to_region(dst),
+            _to_region(src),
+            workspace=workspace,
+            config=config,
+            dispatch=dispatch,
+        )
     )
 
 
@@ -1379,7 +1396,7 @@ __all__ = [
     "min",
     "minimum",
     "mul",
-    "permute_dims",
+    "permute_layout",
     "reciprocal",
     "reduce_negate",
     "select",
diff --git a/src/tirx/op/tirx.cc b/src/tirx/op/tirx.cc
index 2f205c7c3e..2e244e3f90 100644
--- a/src/tirx/op/tirx.cc
+++ b/src/tirx/op/tirx.cc
@@ -214,7 +214,6 @@ TIRX_DEFINE_DISPATCH_OP(select);
 TIRX_DEFINE_DISPATCH_OP(cast);
 TIRX_DEFINE_DISPATCH_OP(fma);
 TIRX_DEFINE_DISPATCH_OP(silu);
-TIRX_DEFINE_DISPATCH_OP(permute_dims);
 
 /********************* Compose Ops **********************/
 #define TIRX_DEFINE_COMPOSE_OP(OpName) \
diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py 
b/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py
index 164a903b96..df7389e2ed 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/test_gemm_async.py
@@ -653,6 +653,7 @@ def test_gemm_block_scaled_fp8_cta_group_1(task):
     F32_BYTES = 4
     F128_BYTES = 16
     SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)])
+    SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)])
 
     # fmt: off
     @Tx.prim_func
@@ -673,6 +674,8 @@ def test_gemm_block_scaled_fp8_cta_group_1(task):
             B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", 
layout=B_layout)
             SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
             SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
+            SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout)
+            SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout)
             tmem_addr = Tx.alloc_shared([1], "uint32")
             tma_mbar = Tx.alloc_shared([1], "uint64")
             mma_mbar = Tx.alloc_shared([1], "uint64")
@@ -715,8 +718,8 @@ def test_gemm_block_scaled_fp8_cta_group_1(task):
             # Transpose scale factors in shared memory
             if Tx.filter(warp_id, 0, 1):
                 with Tx.warp():
-                    Tx.permute_dims(SFA_smem[:, :], [1, 0])
-                    Tx.permute_dims(SFB_smem[:, :], [1, 0])
+                    Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :])
+                    Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :])
             Tx.cuda.cta_sync()
 
             # Copy SFA/SFB from shared to TMEM via tcgen05.cp, then issue MMA
@@ -856,6 +859,7 @@ def test_gemm_block_scaled_fp8_cta_group_2(task):
     F32_BYTES = 4
     F128_BYTES = 16
     SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)])
+    SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)])
 
     # fmt: off
     @Tx.prim_func
@@ -877,6 +881,8 @@ def test_gemm_block_scaled_fp8_cta_group_2(task):
             B_smem = Tx.alloc_buffer(B_shape_per_cta, B_dtype, scope="shared", 
layout=B_layout)
             SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
             SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
+            SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout)
+            SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout)
             tmem_addr = Tx.alloc_shared([1], "uint32")
             tma_mbar = Tx.alloc_shared([1], "uint64")
             mma_mbar = Tx.alloc_shared([1], "uint64")
@@ -923,8 +929,8 @@ def test_gemm_block_scaled_fp8_cta_group_2(task):
             # Transpose scale factors (both CTAs)
             if Tx.filter(warp_id, 0, 1):
                 with Tx.warp():
-                    Tx.permute_dims(SFA_smem[:, :], [1, 0])
-                    Tx.permute_dims(SFB_smem[:, :], [1, 0])
+                    Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :])
+                    Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :])
             Tx.cuda.cta_sync()
 
             # Copy SFA/SFB from shared to TMEM via tcgen05.cp (both CTAs, 
cta_group=2)
@@ -1057,6 +1063,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_1():
     F32_BYTES = 4
     F128_BYTES = 16
     SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)])
+    SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)])
 
     # fmt: off
     @Tx.prim_func
@@ -1080,6 +1087,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_1():
 
             SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
             SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
+            SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout)
+            SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout)
             tmem_addr = Tx.alloc_shared([1], "uint32")
             tma_mbar = Tx.alloc_shared([1], "uint64")
             mma_mbar = Tx.alloc_shared([1], "uint64")
@@ -1122,8 +1131,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_1():
             # Transpose scale factors in shared memory
             if Tx.filter(warp_id, 0, 1):
                 with Tx.warp():
-                    Tx.permute_dims(SFA_smem[:, :], [1, 0])
-                    Tx.permute_dims(SFB_smem[:, :], [1, 0])
+                    Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :])
+                    Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :])
             Tx.cuda.cta_sync()
 
             # Copy SFA/SFB from shared to TMEM via tcgen05.cp, then issue MMA
@@ -1244,6 +1253,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_2():
     F32_BYTES = 4
     F128_BYTES = 16
     SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)])
+    SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)])
 
     # fmt: off
     @Tx.prim_func
@@ -1268,6 +1278,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_2():
 
             SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
             SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
+            SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout)
+            SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout)
             tmem_addr = Tx.alloc_shared([1], "uint32")
             tma_mbar = Tx.alloc_shared([1], "uint64")
             mma_mbar = Tx.alloc_shared([1], "uint64")
@@ -1314,8 +1326,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_2():
             # Transpose scale factors
             if Tx.filter(warp_id, 0, 1):
                 with Tx.warp():
-                    Tx.permute_dims(SFA_smem[:, :], [1, 0])
-                    Tx.permute_dims(SFB_smem[:, :], [1, 0])
+                    Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :])
+                    Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :])
             Tx.cuda.cta_sync()
 
             # Copy SFA/SFB from shared to TMEM via tcgen05.cp
@@ -1456,6 +1468,7 @@ def test_gemm_block_scaled_fp8_sf_id():
     F32_BYTES = 4
     F128_BYTES = 16
     SF_smem_layout = TileLayout(S[(4, 32) : (32, 1)])
+    SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)])
 
     # fmt: off
     @Tx.prim_func
@@ -1476,6 +1489,8 @@ def test_gemm_block_scaled_fp8_sf_id():
             B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", 
layout=B_layout)
             SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
             SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", 
layout=SF_smem_layout)
+            SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout)
+            SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout)
             tmem_addr = Tx.alloc_shared([1], "uint32")
             tma_mbar = Tx.alloc_shared([1], "uint64")
             mma_mbar = Tx.alloc_shared([1], "uint64")
@@ -1518,8 +1533,8 @@ def test_gemm_block_scaled_fp8_sf_id():
             # Transpose scale factors in shared memory
             if Tx.filter(warp_id, 0, 1):
                 with Tx.warp():
-                    Tx.permute_dims(SFA_smem[:, :], [1, 0])
-                    Tx.permute_dims(SFB_smem[:, :], [1, 0])
+                    Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :])
+                    Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :])
             Tx.cuda.cta_sync()
 
             # Copy SF to TMEM, then single MMA call (schedule auto-derives 
sf_id per ki)
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/test_permute_dims.py 
b/tests/python/tirx/operator/tile_primitive/cuda/test_permute_dims.py
deleted file mode 100644
index 3cea1eb9d6..0000000000
--- a/tests/python/tirx/operator/tile_primitive/cuda/test_permute_dims.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=missing-function-docstring
-import ml_dtypes
-import numpy as np
-import pytest
-
-import tvm
-import tvm.testing
-from tvm.script import tirx as Tx
-from tvm.tirx.layout import S, TileLayout
-
-ml_dtypes_dict = {
-    "float8_e4m3fn": ml_dtypes.float8_e4m3fn,
-    "float8_e5m2": ml_dtypes.float8_e5m2,
-    "bfloat16": ml_dtypes.bfloat16,
-    "int4": ml_dtypes.int4,
-}
-
-
[email protected](
-    "task",
-    [
-        (
-            (4, 32),  # a_shape
-            TileLayout(S[4, 32]),  # layoutA
-            tvm.cuda(0),
-        ),
-        (
-            (4, 64),  # a_shape
-            TileLayout(S[4, 64]),  # layoutA
-            tvm.cuda(0),
-        ),
-        (
-            (3, 64),  # a_shape
-            TileLayout(S[3, 64]),  # layoutA
-            tvm.cuda(0),
-        ),
-        (
-            (9, 64),  # a_shape
-            TileLayout(S[9, 64]),  # layoutA
-            tvm.cuda(0),
-        ),
-    ],
-)
[email protected]("dtype", ["uint8", "float16", "int32"])
-def test_vectorized_permute_dims_2d(task, dtype):
-    a_shape, layoutA, dev = task
-    list(slice(None) for _ in range(len(a_shape)))
-
-    # fmt: off
-    @Tx.prim_func
-    def permute_dims(A_ptr: Tx.handle) -> None:
-        A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=layoutA)
-
-        with Tx.kernel():
-            cta_id = Tx.cta_id([1])
-            tid = Tx.thread_id([32])
-            with Tx.cta():
-                with Tx.warp():
-                    Tx.permute_dims(A, [1, 0])
-    # fmt: on
-
-    target = tvm.target.Target("cuda")
-    with target:
-        mod = tvm.IRModule({"main": permute_dims})
-
-        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
-        print(mod.mod.imports[0].inspect_source())
-
-        np.random.seed(0)
-        A_np = tvm.testing.generate_random_array(dtype, a_shape)
-
-        A = tvm.runtime.tensor(A_np, dev)
-        mod(A)
-        A_ref = np.transpose(A_np, (1, 0)).reshape(a_shape)
-        np.testing.assert_allclose(A_ref.flatten(), A.numpy().flatten())
-
-
[email protected](
-    "task",
-    [
-        (
-            (1, 4, 32),  # a_shape
-            TileLayout(S[1, 4, 32]),  # layoutA
-            [0, 0, 0],
-            [1, 4, 32],
-            tvm.cuda(0),
-        ),
-        (
-            (2, 2, 8, 64),  # a_shape
-            TileLayout(S[2, 2, 8, 64]),  # layoutA
-            [1, 1, 0, 0],
-            [1, 1, 8, 64],
-            tvm.cuda(0),
-        ),
-        ((1, 10, 40), TileLayout(S[1, 10, 40]), [0, 5, 3], [1, 4, 32], 
tvm.cuda(0)),
-    ],
-)
[email protected]("dtype", ["uint8", "float16", "int32"])
-def test_vectorized_permute_dims_nd(task, dtype):
-    a_shape, layoutA, st, extent, dev = task
-    ndim = len(a_shape)
-    region = list(slice(st[i], st[i] + extent[i]) for i in range(ndim))
-    order = [*list(range(ndim - 2)), ndim - 1, ndim - 2]
-
-    # fmt: off
-    @Tx.prim_func
-    def permute_dims(A_ptr: Tx.handle) -> None:
-        A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=layoutA)
-
-        with Tx.kernel():
-            cta_id = Tx.cta_id([1])
-            tid = Tx.thread_id([32])
-            with Tx.cta():
-                with Tx.warp():
-                    Tx.permute_dims(A[tuple(region)], order)
-    # fmt: on
-
-    target = tvm.target.Target("cuda")
-    with target:
-        mod = tvm.IRModule({"main": permute_dims})
-
-        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
-        print(mod.mod.imports[0].inspect_source())
-
-        np.random.seed(0)
-        A_np = tvm.testing.generate_random_array(dtype, a_shape)
-
-        A = tvm.runtime.tensor(A_np, dev)
-        mod(A)
-        A_ref = A_np.copy()
-        A_ref[tuple(region)] = np.transpose(A_np[tuple(region)], 
order).reshape(extent)
-        np.testing.assert_allclose(A_ref.flatten(), A.numpy().flatten())
-
-
-if __name__ == "__main__":
-    tvm.testing.main()
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/test_permute_layout.py 
b/tests/python/tirx/operator/tile_primitive/cuda/test_permute_layout.py
new file mode 100644
index 0000000000..bd986ea3e5
--- /dev/null
+++ b/tests/python/tirx/operator/tile_primitive/cuda/test_permute_layout.py
@@ -0,0 +1,425 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring
+
+"""Tests for ``Tx.permute_layout``.
+
+Coverage:
+
+- The algorithm helpers (`_bank_free`, `_check_bijection`, `_choose_xor_k`)
+  directly, with a NumPy oracle.
+- End-to-end compiled-kernel byte-for-byte equivalence on CUDA for the SF
+  fp8-blockwise-gemm transpose shapes (BLK_SFA = 128, 256) plus a few
+  generic linear↔stride-permuted layouts and additional dtypes (u8, fp16,
+  i32, u64).
+- Reject cases: non-warp scope, dtype mismatch, shape mismatch, swizzle/
+  compose layouts, layouts whose strides don't form a bijection on the
+  slice.  Each must surface as a ``RuntimeError`` from the dispatcher and
+  NOT silently emit a wrong kernel.
+"""
+
+from __future__ import annotations
+
+import math
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.script import tirx as Tx
+from tvm.tirx.layout import S, SwizzleLayout, TileLayout
+
+# Helpers exposed by the dispatcher module for direct algorithm tests.
+from tvm.tirx.operator.tile_primitive.cuda.permute_layout.warp_xor_swizzle 
import (
+    _bank_free,
+    _check_bijection,
+    _choose_xor_k,
+)
+
+# ---------------------------------------------------------------------------
+# Algorithm-only tests (no CUDA needed).
+# ---------------------------------------------------------------------------
+
+
+def _np_layout_offset(extent, strides, multi_idx):
+    return int(sum(s * i for s, i in zip(strides, multi_idx)))
+
+
+def _expected_permute(src_np, src_strides, dst_strides, extent):
+    """Compute the expected output: dst at byte offset ``L_dst(i)`` holds the
+    value at ``src`` byte offset ``L_src(i)``, for every logical index i.
+    """
+    V = math.prod(extent)
+    dst_np = np.zeros_like(src_np)
+    for flat in range(V):
+        idx = []
+        rem = flat
+        for e in reversed(extent):
+            idx.append(rem % e)
+            rem //= e
+        idx = list(reversed(idx))
+        src_off = _np_layout_offset(extent, src_strides, idx)
+        dst_off = _np_layout_offset(extent, dst_strides, idx)
+        dst_np.reshape(-1)[dst_off] = src_np.reshape(-1)[src_off]
+    return dst_np
+
+
+def test_bank_free_sf_128_u32():
+    """SF BLK_SFA=128: write phase has 4-way conflict at k=0, free at k=2."""
+    extent = [4, 32]
+    src = [32, 1]
+    dst = [1, 4]
+    bytes_per = 4
+    P = 4
+    assert _bank_free(extent, src, bytes_per, P, 0)
+    assert not _bank_free(extent, dst, bytes_per, P, 0)
+    assert _bank_free(extent, dst, bytes_per, P, 2)
+    assert _choose_xor_k(extent, src, dst, bytes_per, P) == 2
+
+
+def test_bank_free_sf_256_u32():
+    """SF BLK_SFA=256: same shift=3 (k=2) handles the high block too."""
+    extent = [2, 4, 32]
+    src = [128, 32, 1]
+    dst = [128, 1, 4]
+    bytes_per = 4
+    P = 8
+    assert _bank_free(extent, src, bytes_per, P, 0)
+    assert not _bank_free(extent, dst, bytes_per, P, 0)
+    assert _bank_free(extent, dst, bytes_per, P, 2)
+    assert _choose_xor_k(extent, src, dst, bytes_per, P) == 2
+
+
+def test_identity_no_xor():
+    """L_src == L_dst => k=0 (no XOR needed and the op is essentially a 
copy)."""
+    assert _choose_xor_k([4, 32], [32, 1], [32, 1], 4, 4) == 0
+    # A 2D buffer with row-major to row-major is a true no-op.
+    assert _bank_free([4, 32], [32, 1], 4, 4, 0)
+
+
+def test_bijection_check_rejects_aliased():
+    """If two logical indices map to the same physical byte, reject."""
+    # Stride 0 on a non-singleton extent => alias.
+    assert not _check_bijection([4, 32], [0, 1])
+    # Negative or non-contiguous-but-bijective is still fine.
+    assert _check_bijection([4, 32], [1, 4])
+
+
+def test_dtype_widths_choose_xor_k():
+    """Each dtype's outcome:
+
+    The unvectorized algorithm is provably correct only when every per-lane
+    access maps to a single 4-byte bank.  For 4-byte dtypes that always holds
+    (one element per bank), so we expect a valid k.  For sub-4-byte dtypes
+    with stride-1 reads, multiple lanes share a bank no matter how we permute
+    register slots — the dispatcher correctly rejects those (k is None).
+    """
+    extent = [4, 32]
+    src = [32, 1]  # linear
+    dst = [1, 4]  # transposed
+    # u32: this is the SF case; the algorithm must find k=2.
+    assert _choose_xor_k(extent, src, dst, 4, 4) == 2
+    # u16/fp16, u8: stride-1 in bytes < 4 packs >1 lane into the same bank;
+    # register-slot XOR cannot fix that, so the dispatcher rejects.
+    assert _choose_xor_k(extent, src, dst, 2, 4) is None
+    assert _choose_xor_k(extent, src, dst, 1, 4) is None
+
+
+# ---------------------------------------------------------------------------
+# End-to-end compiled-kernel tests on CUDA.
+# ---------------------------------------------------------------------------
+
+
+def _has_cuda():
+    try:
+        return tvm.cuda(0).exist
+    except Exception:
+        return False
+
+
+needs_cuda = pytest.mark.skipif(not _has_cuda(), reason="needs CUDA")
+
+
+def _compile_and_run(prim_func, np_inputs):
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": prim_func})
+        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+    dev = tvm.cuda(0)
+    tensors = [tvm.runtime.tensor(a, dev) for a in np_inputs]
+    mod(*tensors)
+    return [t.numpy() for t in tensors], mod.mod.imports[0].inspect_source()
+
+
+@needs_cuda
[email protected](
+    "name, pipe, blk, dtype",
+    [
+        ("sf_128_u32", 2, 128, "uint32"),
+        ("sf_256_u32", 2, 256, "uint32"),
+        ("sf_128_i32", 2, 128, "int32"),
+        ("sf_128_fp32", 2, 128, "float32"),
+    ],
+)
+def test_sf_blockwise_transpose(name, pipe, blk, dtype):
+    """SF blockwise-GEMM scale-factor transpose, the canonical use case."""
+    high = blk // 128 if blk >= 128 else 1
+    # Use 4D logical shape (PIPE, high, 4, 32) to keep the high-block factored.
+    shape = (pipe, high, 4, 32)
+
+    # Element strides for src (linear) and dst (transposed within each
+    # 128-block).  Stage stride = blk; each 128-block contributes 128 to the
+    # high stride.
+    src_strides = (blk, 128, 32, 1)
+    dst_strides = (blk, 128, 1, 4)
+    pre = TileLayout(S[shape:src_strides])
+    post = TileLayout(S[shape:dst_strides])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, shape, dtype, layout=pre)
+        B_buf = Tx.match_buffer(B, shape, dtype, layout=post)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    for s in Tx.serial(0, pipe):
+                        Tx.permute_layout(
+                            B_buf[s, 0:high, 0:4, 0:32], A_buf[s, 0:high, 0:4, 
0:32]
+                        )
+    # fmt: on
+
+    np.random.seed(0)
+    A_np = tvm.testing.generate_random_array(dtype, shape)
+    B_np = np.zeros_like(A_np)
+
+    [_, B_out], src = _compile_and_run(f, [A_np, B_np])
+
+    # The dispatcher must have picked the XOR-swizzled variant; check that
+    # the generated CUDA contains the per-lane XOR pattern.  This is the
+    # "no perf regression" smoke test: any future variant that omits the
+    # XOR would re-introduce 4-way bank conflicts.
+    assert ">> 3" in src, f"expected XOR-swizzle (lane>>3) in CUDA for {name}"
+    assert "warp_sync" in src or "syncwarp" in src
+
+    # Byte-for-byte equality via numpy reference.
+    for s in range(pipe):
+        A_flat = A_np[s].reshape(-1)
+        B_flat = B_out[s].reshape(-1)
+        ref = _expected_permute(
+            A_flat,
+            list(src_strides[1:]),
+            list(dst_strides[1:]),
+            list(shape[1:]),
+        )
+        np.testing.assert_array_equal(B_flat, ref, err_msg=f"{name} stage {s}")
+
+
+@needs_cuda
+def test_identity_passes_through_as_copy():
+    """L_src == L_dst should still compile and produce a correct (identity) 
copy."""
+    shape = (4, 32)
+    layout = TileLayout(S[shape : (32, 1)])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, shape, "uint32", layout=layout)
+        B_buf = Tx.match_buffer(B, shape, "uint32", layout=layout)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    Tx.permute_layout(B_buf, A_buf)
+    # fmt: on
+
+    np.random.seed(0)
+    A_np = tvm.testing.generate_random_array("uint32", shape)
+    B_np = np.zeros_like(A_np)
+    [_, B_out], _ = _compile_and_run(f, [A_np, B_np])
+    np.testing.assert_array_equal(B_out, A_np)
+
+
+@needs_cuda
[email protected]("dtype", ["uint32", "int32", "float32"])
[email protected](
+    "shape, src_strides, dst_strides",
+    [
+        # (8, 32) → (8, 32) transposed: src linear, dst column-major.
+        ((8, 32), (32, 1), (1, 8)),
+        # (16, 32): per_thread = 16 — tests P=16 path.
+        ((16, 32), (32, 1), (1, 16)),
+    ],
+)
+def test_generic_transpose(shape, src_strides, dst_strides, dtype):
+    """Generic linear↔transposed pairs at various P values."""
+    pre = TileLayout(S[shape:src_strides])
+    post = TileLayout(S[shape:dst_strides])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, shape, dtype, layout=pre)
+        B_buf = Tx.match_buffer(B, shape, dtype, layout=post)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    Tx.permute_layout(B_buf, A_buf)
+    # fmt: on
+
+    np.random.seed(0)
+    A_np = tvm.testing.generate_random_array(dtype, shape)
+    B_np = np.zeros_like(A_np)
+    [_, B_out], _ = _compile_and_run(f, [A_np, B_np])
+
+    ref = _expected_permute(A_np.reshape(-1), list(src_strides), 
list(dst_strides), list(shape))
+    np.testing.assert_array_equal(B_out.reshape(-1), ref)
+
+
+# ---------------------------------------------------------------------------
+# Reject cases: the dispatcher must surface a clear error, never silently
+# emit a wrong kernel.
+# ---------------------------------------------------------------------------
+
+
+def _build_and_assert_rejected(shape, src_layout, dst_layout, dtype, 
msg_substr):
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, shape, dtype, layout=src_layout)
+        B_buf = Tx.match_buffer(B, shape, dtype, layout=dst_layout)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    Tx.permute_layout(B_buf, A_buf)
+    # fmt: on
+
+    target = tvm.target.Target("cuda")
+    with target, pytest.raises(RuntimeError) as exc_info:
+        mod = tvm.IRModule({"main": f})
+        tvm.compile(mod, target=target, tir_pipeline="tirx")
+    assert msg_substr in str(exc_info.value), (
+        f"expected reject reason to mention {msg_substr!r}, got: 
{exc_info.value}"
+    )
+
+
+def test_reject_dtype_mismatch():
+    shape = (4, 32)
+    layout = TileLayout(S[shape : (32, 1)])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, shape, "uint32", layout=layout)
+        B_buf = Tx.match_buffer(B, shape, "uint16", layout=layout)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    Tx.permute_layout(B_buf, A_buf)
+    # fmt: on
+
+    target = tvm.target.Target("cuda")
+    with target, pytest.raises(RuntimeError) as exc_info:
+        tvm.compile(tvm.IRModule({"main": f}), target=target, 
tir_pipeline="tirx")
+    assert "dtype mismatch" in str(exc_info.value)
+
+
+def test_reject_shape_mismatch():
+    src_layout = TileLayout(S[(4, 32) : (32, 1)])
+    dst_layout = TileLayout(S[(8, 16) : (16, 1)])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=src_layout)
+        B_buf = Tx.match_buffer(B, (8, 16), "uint32", layout=dst_layout)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    Tx.permute_layout(B_buf, A_buf)
+    # fmt: on
+
+    target = tvm.target.Target("cuda")
+    with target, pytest.raises(RuntimeError) as exc_info:
+        tvm.compile(tvm.IRModule({"main": f}), target=target, 
tir_pipeline="tirx")
+    assert "shape mismatch" in str(exc_info.value)
+
+
+def test_reject_swizzle_layout():
+    """ComposeLayout(SwizzleLayout, TileLayout) is not supported by the warp 
variant."""
+    from tvm.tirx.layout import ComposeLayout
+
+    inner = TileLayout(S[(4, 32) : (32, 1)])
+    sw = SwizzleLayout(per_element=2, swizzle_len=2, atom_len=4)
+    swizzled = ComposeLayout(sw, inner)
+    plain = TileLayout(S[(4, 32) : (1, 4)])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=swizzled)
+        B_buf = Tx.match_buffer(B, (4, 32), "uint32", layout=plain)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                with Tx.warp():
+                    Tx.permute_layout(B_buf, A_buf)
+    # fmt: on
+
+    target = tvm.target.Target("cuda")
+    with target, pytest.raises(RuntimeError) as exc_info:
+        tvm.compile(tvm.IRModule({"main": f}), target=target, 
tir_pipeline="tirx")
+    assert "TileLayout" in str(exc_info.value)
+
+
+def test_reject_non_warp_scope():
+    layout_pre = TileLayout(S[(4, 32) : (32, 1)])
+    layout_post = TileLayout(S[(4, 32) : (1, 4)])
+
+    # fmt: off
+    @Tx.prim_func
+    def f(A: Tx.handle, B: Tx.handle):
+        A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=layout_pre)
+        B_buf = Tx.match_buffer(B, (4, 32), "uint32", layout=layout_post)
+        with Tx.kernel():
+            Tx.cta_id([1])
+            Tx.thread_id([32])
+            with Tx.cta():
+                Tx.permute_layout(B_buf, A_buf)  # cta scope, not warp
+    # fmt: on
+
+    target = tvm.target.Target("cuda")
+    with target, pytest.raises(RuntimeError) as exc_info:
+        tvm.compile(tvm.IRModule({"main": f}), target=target, 
tir_pipeline="tirx")
+    assert "warp" in str(exc_info.value)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/tirx/test_op.py b/tests/python/tirx/test_op.py
index 8de3462c7c..82295cdb3c 100644
--- a/tests/python/tirx/test_op.py
+++ b/tests/python/tirx/test_op.py
@@ -179,17 +179,6 @@ def test_buffer_replacer_no_shared_default():
     assert len(r2.buffer_map) == 0
 
 
-def test_permute_dims_buffer_property():
-    """Regression test for F2: PermuteDims.buffer should return args[0], not 
recurse."""
-    from tvm.tirx.operator.tile_primitive.ops import PermuteDims
-
-    A = decl_buffer((64, 64), "float32", scope="global")
-    pd = PermuteDims(A[0:64, 0:64], [1, 0])
-    # This would stack overflow before the fix
-    buf = pd.buffer
-    assert buf is not None
-
-
 def test_gemm_async_partial_scale_factor():
     """Regression test for F7: gemm_async must reject partial scale factors."""
     from tvm.tirx.script.builder.tirx import gemm_async
@@ -219,5 +208,4 @@ if __name__ == "__main__":
     test_tx_existing_op_not_overridden()
     test_opcall_downcast_tolerant()
     test_buffer_replacer_no_shared_default()
-    test_permute_dims_buffer_property()
     test_gemm_async_partial_scale_factor()

Reply via email to