gemini-code-assist[bot] commented on code in PR #19581:
URL: https://github.com/apache/tvm/pull/19581#discussion_r3254807304


##########
python/tvm/tirx/operator/intrinsics/cuda/cp_async.py:
##########
@@ -0,0 +1,910 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, invalid-name, too-many-arguments, 
too-many-locals, too-many-positional-arguments
+"""PTX cp.async / cp.async.bulk / cp.async.bulk.tensor intrinsics.
+
+Each PTX form table entry is registered as one ``device_intrinsic``.
+User-facing wrappers in ``tvm.tirx.op`` keep their v1 signatures;
+``register_codegen`` dispatchers below decode the (cp_size, fill_mode,
+predicate) / (dim, cta_mask, tile_mode) arguments to pick the right form.
+Bodies are hand-written ``asm volatile(...)`` strings.  The file is grouped
+as cp.async, cp.async.bulk.tensor, cp.async.bulk non-TMA, and CUDA
+compatibility helpers.
+"""
+
+import tvm
+from tvm.tirx.op import cuda_func_call
+
+from .._schema import device_intrinsic
+from .registry import CODEGEN_REGISTRY, register_codegen
+from .utils import parse_str
+
+_PREFETCH_CHOICES = ("", "64", "128", "256")
+_DIM_CHOICES = (1, 2, 3, 4, 5)
+_TILE_MODE_CHOICES = ("tile", "tile_gather4")
+
+
+def _safe(s):
+    return s.replace("::", "_").replace(".", "_")
+
+
+# =============================================================================
+# cp.async forms from the PTX Syntax block.
+#
+# Includes commit/wait plus the non-bulk shared/global copy forms.
+# =============================================================================
+device_intrinsic(
+    "ptx_cp_async_commit_group",
+    helper_name="tvm_builtin_ptx_cp_async_commit_group",
+    body='    asm volatile("cp.async.commit_group;");',
+)
+device_intrinsic(
+    "ptx_cp_async_wait_group",
+    n_attrs=1,
+    helper_name=lambda n: f"tvm_builtin_ptx_cp_async_wait_group_{int(n)}",
+    body=lambda n: f'    asm volatile("cp.async.wait_group {int(n)};");',
+)
+
+
+# cp.async non-bulk copy forms:
+#   Form 1: cp.async.ca.shared.global ... [dst], [src], cp-size{, src-size}{, 
cache-policy}
+#   Form 2: cp.async.cg.shared.global ... [dst], [src], 16{, src-size}{, 
cache-policy}
+#   Form 3: cp.async.ca.shared.global ... [dst], [src], cp-size{, 
ignore-src}{, cache-policy}
+#   Form 4: cp.async.cg.shared.global ... [dst], [src], 16{, ignore-src}{, 
cache-policy}
+
+
+def _cp_async_modifier_str(has_cache_hint, prefetch_size):
+    s = ""
+    if has_cache_hint:
+        s += ".L2::cache_hint"
+    if prefetch_size:
+        s += f".L2::{prefetch_size}B"
+    return s
+
+
+def _make_form_parts(ca_or_cg, fixed_cp_size, extra):
+    """Build a parts callable for one of the cp.async PTX forms.
+
+    Args layout: (dst, src [, extra_int], cache_policy, has_cache, 
prefetch_size [, cp_size_attr])
+    Forwarded operands: dst, src [, extra_int], cache_policy.
+    Trailing attrs: has_cache, prefetch_size [, cp_size if .ca].
+    """
+    n_op = 3 if extra is not None else 2
+    n_attrs = 2 if fixed_cp_size is not None else 3
+    extra_in_name = f"_with_{extra}" if extra is not None else ""
+
+    def _parts(*args):
+        # Operand args (forwarded) come first, then attr args.
+        attr_args = args[-n_attrs:]
+        has_cache = _bool_attr(attr_args[0])
+        prefetch_size = parse_str(attr_args[1])
+        cp_size = fixed_cp_size if fixed_cp_size is not None else 
int(attr_args[2])
+        modifier = _cp_async_modifier_str(has_cache, prefetch_size)
+        cache_operand = ', "l"(cache_policy)' if has_cache else ""
+        # name parts
+        name_cache = "_cache_hint" if has_cache else ""
+        name_prefetch = f"_prefetch_{prefetch_size}" if prefetch_size else ""
+        name = (
+            f"tvm_builtin_ptx_cp_async_{ca_or_cg}_{cp_size}"
+            f"{name_cache}{name_prefetch}{extra_in_name}"
+        )
+        sig = (
+            "(void* dst, void* src"
+            + (f", int {extra}" if extra else "")
+            + ", unsigned long long cache_policy)"
+        )
+        instr_base = f"cp.async.{ca_or_cg}.shared.global{modifier}"
+        if extra is None:
+            cache_arg = ", %2" if has_cache else ""
+            body = (
+                "    unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+                f'    asm volatile("{instr_base} [%0], [%1], 
{cp_size}{cache_arg};\\n"\n'
+                f'                 :: "r"(dst_addr), "l"(src){cache_operand} : 
"memory");'
+            )
+        else:
+            cache_arg = ", %3" if has_cache else ""
+            body = (
+                "    unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+                f'    asm volatile("{instr_base} [%0], [%1], {cp_size}, 
%2{cache_arg};\\n"\n'
+                f'                 :: "r"(dst_addr), "l"(src), "r"({extra})'
+                f'{cache_operand} : "memory");'
+            )
+        return name, sig, body
+
+    return _parts, n_op + n_attrs - n_op  # n_attrs
+
+
+def _register_nb_form(op_name, ca_or_cg, fixed_cp_size, extra):
+    parts_fn, n_attrs = _make_form_parts(ca_or_cg, fixed_cp_size, extra)
+    n_op = 3 if extra is not None else 2
+    sig_static = (
+        "(void* dst, void* src"
+        + (f", int {extra}" if extra else "")
+        + ", unsigned long long cache_policy)"
+    )
+    device_intrinsic(
+        f"ptx_cp_async_{op_name}",
+        n_attrs=n_attrs,
+        c_signature=sig_static,  # static — depends on `extra` not on attrs
+        helper_name=lambda *a, fn=parts_fn: fn(*a)[0],
+        body=lambda *a, fn=parts_fn: fn(*a)[2],
+    )
+    return n_op
+
+
+# Form 1: .ca + src-size (cp-size ∈ {4, 8}). src-size is required when present.
+_register_nb_form("ca_src_size", "ca", fixed_cp_size=None, extra="src_size")
+# Form 2: .cg + src-size (cp-size = 16).
+_register_nb_form("cg_src_size", "cg", fixed_cp_size=16, extra="src_size")
+# Form 3: .ca + ignore-src.
+_register_nb_form("ca_ignore_src", "ca", fixed_cp_size=None, 
extra="ignore_src")
+# Form 4: .cg + ignore-src.
+_register_nb_form("cg_ignore_src", "cg", fixed_cp_size=16, extra="ignore_src")
+# Plain degenerate of forms 1+2 with optional src-size omitted.
+_register_nb_form("ca", "ca", fixed_cp_size=None, extra=None)
+_register_nb_form("cg", "cg", fixed_cp_size=16, extra=None)
+
+
+def _make_setp_at_p_helper(ca_or_cg, cp_size, has_cache, prefetch):
+    """Wrapper convenience: ``setp+@p`` around a form 1/2 cp.async (predicate-
+    gated skip with dst untouched on false). Not a PTX form — emitted directly
+    here as a one-off helper rather than a separate device_intrinsic."""
+    modifier = _cp_async_modifier_str(has_cache, prefetch)
+    cache_arg = ", %4" if has_cache else ""
+    cache_operand = ', "l"(cache_policy)' if has_cache else ""
+    func_name = (
+        f"tvm_builtin_ptx_cp_async_{cp_size}"
+        + ("_cache_hint" if has_cache else "")
+        + (f"_prefetch_{prefetch}" if prefetch else "")
+        + "_predicate"
+    )
+    body = (
+        "  unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+        "  __asm__ __volatile__(\n"
+        '    "{\\n"\n'
+        '    " .reg .pred p;\\n"\n'
+        '    " setp.eq.u32 p, %3, 1;\\n"\n'
+        f'    " @p cp.async.{ca_or_cg}.shared.global{modifier}'
+        f' [%0], [%1], %2{cache_arg};\\n"\n'
+        '    "}\\n"\n'
+        f'    :: "r"(dst_addr), "l"(src), "n"({cp_size}), 
"r"(predicate){cache_operand}\n'
+        "  );"
+    )
+    source_code = (
+        f"\n__forceinline__ __device__ void {func_name}"
+        "(void* dst, void* src, int predicate, unsigned long long 
cache_policy) {\n"
+        f"{body}\n"
+        "}\n"
+    )
+    return func_name, source_code
+
+
+@register_codegen("ptx_cp_async")
+def codegen_ptx_cp_async(*args):
+    """Map the wrapper API to the 4 PTX form table entries.
+
+    Accepts three call shapes (sorted by arity):
+
+    * 5 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size)`` —
+      the legacy form emitted by ``s_tir/transform/InjectPTXAsyncCopy``.
+      Offsets are folded into the pointers via ``tvm_access_ptr`` (in
+      bytes; offsets are pre-scaled by the pass) and the call is
+      forwarded with default cache / predicate / fill_mode.
+    * 6 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size,
+      predicate)`` — same as 5-arg form with an explicit predicate.
+    * 8 args ``(dst_ptr, src_ptr, cp_size, cache_policy, has_cache_hint,
+      prefetch_size, predicate, fill_mode)`` — the fork-native wrapper
+      API.
+
+    The three resulting form_kinds:
+
+    * ``fill_mode == "zero"`` -> form 1/2 (src-size = predicate ? cp_size : 0)
+    * ``predicate != -1`` and no fill_mode -> form 1/2 wrapped in setp+@p
+      (wrapper convenience; not a PTX form)
+    * else -> form 1/2 with src-size omitted (the "plain" degenerate)
+    """
+    from tvm.tirx.op import if_then_else
+
+    if len(args) in (5, 6):
+        # Legacy InjectPTXAsyncCopy emission: (dst_ptr, dst_off, src_ptr,
+        # src_off, cp_size [, predicate]). Offsets are element indices into
+        # the typed buffers (the pass uses index_factor=1 except for the
+        # shared.dyn-merged byte-buffer path). Emit a C helper that scales
+        # the offset by the buffer element size, then runs cp.async.
+        #
+        # PTX plain form for both .ca and .cg is just
+        # ``cp.async.<v>.shared.global [dst], [src], cp_size;`` — three
+        # operands, no trailing src-size / cache-policy.
+        from tvm import DataType
+
+        dst_ptr_in, dst_offset, src_ptr_in, src_offset, cp_size = args[:5]
+        predicate = args[5] if len(args) == 6 else -1
+        cp_size_v = int(cp_size)
+        ca_or_cg = "cg" if cp_size_v == 16 else "ca"
+
+        # Recover the per-side element dtype from each pointer's type
+        # annotation (Var has type_annotation = PointerType(PrimType(dtype))).
+        # InjectPTXAsyncCopy emits offsets in element-units of each side's
+        # buffer dtype (dst gets dst_offset * src_elem_size only when dst is a
+        # merged shared.dyn byte buffer, in which case dst_elem_dtype is uint8
+        # and the resulting scale-by-1 is a no-op).
+        def _elem_bytes(ptr):
+            ta = getattr(ptr, "type_annotation", None)
+            if ta is None or getattr(ta, "element_type", None) is None:
+                return 1
+            et = ta.element_type
+            if not hasattr(et, "dtype"):
+                return 1
+            bits = DataType(str(et.dtype)).bits
+            assert bits % 8 == 0, f"non-byte element dtype: {et.dtype}"
+            return bits // 8
+
+        dst_elem_bytes = _elem_bytes(dst_ptr_in)
+        src_elem_bytes = _elem_bytes(src_ptr_in)
+        has_predicate = not (
+            (isinstance(predicate, int) and predicate == -1)
+            or (hasattr(predicate, "value") and int(predicate.value) == -1)
+        )
+
+        def _scale(n):
+            return "" if n == 1 else f" * {n}"
+
+        dst_scale = _scale(dst_elem_bytes)
+        src_scale = _scale(src_elem_bytes)
+        if has_predicate:
+            func_name = (
+                
f"ptx_cp_async_legacy_pred_{ca_or_cg}_{cp_size_v}_{dst_elem_bytes}_{src_elem_bytes}"
+            )
+            body = (
+                f"  uint8_t* dst_p = (uint8_t*)dst + dst_off{dst_scale};\n"
+                f"  uint8_t* src_p = (uint8_t*)src + src_off{src_scale};\n"
+                "  unsigned int dst_addr = __cvta_generic_to_shared(dst_p);\n"
+                "  __asm__ __volatile__(\n"
+                '    "{\\n"\n'
+                '    " .reg .pred p;\\n"\n'
+                '    " setp.eq.u32 p, %3, 1;\\n"\n'
+                f'    " @p cp.async.{ca_or_cg}.shared.global'
+                ' [%0], [%1], %2;\\n"\n'
+                '    "}\\n"\n'
+                f'    :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}), 
"r"(predicate)\n'
+                "  );"
+            )
+            source_code = (
+                f"\n__forceinline__ __device__ void {func_name}"
+                "(void* dst, int dst_off, void* src, int src_off, int 
predicate) {\n"
+                f"{body}\n"
+                "}\n"
+            )
+            return cuda_func_call(
+                func_name,
+                dst_ptr_in,
+                dst_offset,
+                src_ptr_in,
+                src_offset,
+                predicate,
+                source_code=source_code,
+            )
+        # No predicate — plain cp.async.
+        func_name = 
f"ptx_cp_async_legacy_{ca_or_cg}_{cp_size_v}_{dst_elem_bytes}_{src_elem_bytes}"
+        body = (
+            f"  uint8_t* dst_p = (uint8_t*)dst + dst_off{dst_scale};\n"
+            f"  uint8_t* src_p = (uint8_t*)src + src_off{src_scale};\n"
+            "  unsigned int dst_addr = __cvta_generic_to_shared(dst_p);\n"
+            f'  asm volatile("cp.async.{ca_or_cg}.shared.global'
+            ' [%0], [%1], %2;"\n'
+            f'    :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}));'
+        )
+        source_code = (
+            f"\n__forceinline__ __device__ void {func_name}"
+            "(void* dst, int dst_off, void* src, int src_off) {\n"
+            f"{body}\n"
+            "}\n"
+        )
+        return cuda_func_call(
+            func_name,
+            dst_ptr_in,
+            dst_offset,
+            src_ptr_in,
+            src_offset,
+            source_code=source_code,
+        )

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The legacy `codegen_ptx_cp_async` implementation uses `type_annotation` to 
recover element sizes. While this works for `Var` nodes with explicit pointer 
types, it may fail for generic handles or expressions where the type 
information is lost during previous lowering passes. Consider adding a fallback 
or ensuring that the `InjectPTXAsyncCopy` pass always preserves these 
annotations.



##########
python/tvm/tirx/lang/alloc_pool.py:
##########
@@ -0,0 +1,510 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SMEM and TMEM bump-allocator pools for TIRX kernels."""
+
+from __future__ import annotations
+
+import functools
+import operator
+
+from tvm import DataType
+from tvm.tirx.layout import S, TCol, TileLayout, TLane
+
+# ---------------------------------------------------------------------------
+# ir_builder helpers — imported lazily to avoid circular deps at module level
+# ---------------------------------------------------------------------------
+
+_ir = None
+
+
+def _get_ir():
+    global _ir
+    if _ir is None:
+        from tvm.tirx.script.builder import ir as _mod
+
+        _ir = _mod
+    return _ir
+
+
+def _get_frame():
+    from tvm.tirx.script.builder import frame
+
+    return frame
+
+
+# ---------------------------------------------------------------------------
+# Shared utilities
+# ---------------------------------------------------------------------------
+
+_POOL_UNSET = object()
+
+
+def _default_tmem_layout(rows, cols):
+    return TileLayout(S[(rows, cols) : (1 @ TLane, 1 @ TCol)])
+
+
+def _emit_stmt(expr):
+    ir = _get_ir()
+    ir.add_to_parent(ir.evaluate(expr))
+
+
+def _shape_product(shape):
+    return functools.reduce(operator.mul, shape, 1)
+
+
+def _auto_swizzle_mode(dtype):
+    """Select the default MMA swizzle mode for a shared-memory allocation."""
+    from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode
+
+    del dtype
+    return SwizzleMode.SWIZZLE_128B_ATOM
+
+
+def _swizzle_atom_bytes(swizzle_mode):
+    """Return the row width (in bytes) of one swizzle atom for 
*swizzle_mode*."""
+    from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode
+
+    return {
+        SwizzleMode.SWIZZLE_NONE: 0,
+        SwizzleMode.SWIZZLE_32B_ATOM: 32,
+        SwizzleMode.SWIZZLE_64B_ATOM: 64,
+        SwizzleMode.SWIZZLE_128B_ATOM: 128,
+    }[swizzle_mode]
+
+
+def _suggest_swizzle_for_row_bytes(row_bytes):
+    """Pick the largest valid swizzle mode whose atom row fits within 
*row_bytes*."""
+
+    for atom_bytes, mode in (
+        (128, "SWIZZLE_128B_ATOM"),
+        (64, "SWIZZLE_64B_ATOM"),
+        (32, "SWIZZLE_32B_ATOM"),
+    ):
+        if row_bytes >= atom_bytes and row_bytes % atom_bytes == 0:
+            return mode
+    return "SWIZZLE_NONE"
+
+
+def _validate_mma_alloc_shape(shape, dtype, swizzle_mode):
+    """Validate that *shape* / *dtype* / *swizzle_mode* are mutually 
compatible.
+
+    ``mma_shared_layout`` tiles a swizzle atom of shape ``[8, swizzle_bytes / 
dtype_bytes]``
+    over the last two logical dimensions of *shape*. If the row width or row 
count of
+    the request is smaller than (or not a multiple of) the atom, the underlying
+    ``Layout.tile_to`` lowers to a ``floordiv``/``floormod`` by zero and 
raises an
+    opaque internal "Divide by zero" diagnostic from ``tile_tile_ops.cc``. 
Catch the
+    misconfiguration here so callers see *what* is wrong and *how* to fix it.
+
+    Validation skipped when *swizzle_mode* is ``SWIZZLE_NONE`` (no atom).
+    """
+    from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode
+
+    if swizzle_mode == SwizzleMode.SWIZZLE_NONE:
+        return
+
+    if len(shape) < 2:
+        raise ValueError(
+            f"alloc_mma shape={tuple(shape)} has fewer than 2 dimensions; "
+            f"swizzled MMA layouts tile over the last two dims (rows, cols). "
+            f"Use swizzle_mode='none' for 1-D allocations."
+        )
+
+    # Only validate concrete int dims; symbolic dims fall through (the analyzer
+    # in C++ will still ICHECK on them, but at least we don't false-positive).
+    rows = shape[-2]
+    cols = shape[-1]
+    if not (isinstance(rows, int) and isinstance(cols, int)):
+        return
+
+    dtype_bytes = DataType(dtype).bits // 8
+    if dtype_bytes == 0:
+        # Sub-byte dtype (e.g. float4); ``cols`` is already in element units, 
so
+        # use a fractional check expressed via bits.
+        col_bits = cols * DataType(dtype).bits
+        atom_bits = _swizzle_atom_bytes(swizzle_mode) * 8
+        if col_bits < atom_bits or col_bits % atom_bits != 0:
+            row_bytes = col_bits // 8 if col_bits % 8 == 0 else col_bits / 8
+            atom_bytes = _swizzle_atom_bytes(swizzle_mode)
+            suggestion = _suggest_swizzle_for_row_bytes(col_bits // 8 if 
col_bits >= 8 else 0)
+            raise ValueError(
+                f"alloc_mma shape={tuple(shape)} with dtype={dtype!r} produces 
"
+                f"{row_bytes}B rows, which is incompatible with the 
{atom_bytes}B "
+                f"swizzle atom selected by {swizzle_mode.name}. "
+                f"Use swizzle_mode=SwizzleMode.{suggestion}, or widen 
shape[-1] "
+                f"to a multiple of "
+                f"{(atom_bits + DataType(dtype).bits - 1) // 
DataType(dtype).bits} elements."
+            )
+    else:
+        row_bytes = cols * dtype_bytes
+        atom_bytes = _swizzle_atom_bytes(swizzle_mode)
+        if row_bytes < atom_bytes or row_bytes % atom_bytes != 0:
+            suggestion = _suggest_swizzle_for_row_bytes(row_bytes)
+            min_cols = atom_bytes // dtype_bytes
+            raise ValueError(
+                f"alloc_mma shape={tuple(shape)} with dtype={dtype!r} produces 
"
+                f"{row_bytes}B rows, which is incompatible with the 
{atom_bytes}B "
+                f"swizzle atom selected by {swizzle_mode.name}. "
+                f"Use swizzle_mode=SwizzleMode.{suggestion}, or widen 
shape[-1] "
+                f"to a multiple of {min_cols} elements (>= {atom_bytes}B at 
{dtype})."
+            )
+
+    # Atom rows is always 8 (see ``mma_atom_shape`` in tma_utils.py).
+    atom_rows = 8
+    if rows < atom_rows or rows % atom_rows != 0:
+        raise ValueError(
+            f"alloc_mma shape={tuple(shape)} has shape[-2]={rows}, but the "
+            f"{swizzle_mode.name} atom requires shape[-2] to be a positive "
+            f"multiple of {atom_rows}. Use swizzle_mode='none', or widen 
shape[-2] "
+            f"to a multiple of {atom_rows}."
+        )

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `_validate_mma_alloc_shape` function provides excellent defensive checks 
for hardware alignment requirements. However, it currently skips validation for 
symbolic dimensions. Since TIR often uses symbolic variables for tiling, 
consider adding a check that uses the `arith.Analyzer` to verify if the 
alignment constraints can be proven, or at least emit a warning if they cannot 
be statically verified.



##########
python/tvm/tirx/operator/intrinsics/cuda/cp_async.py:
##########
@@ -0,0 +1,910 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, invalid-name, too-many-arguments, 
too-many-locals, too-many-positional-arguments
+"""PTX cp.async / cp.async.bulk / cp.async.bulk.tensor intrinsics.
+
+Each PTX form table entry is registered as one ``device_intrinsic``.
+User-facing wrappers in ``tvm.tirx.op`` keep their v1 signatures;
+``register_codegen`` dispatchers below decode the (cp_size, fill_mode,
+predicate) / (dim, cta_mask, tile_mode) arguments to pick the right form.
+Bodies are hand-written ``asm volatile(...)`` strings.  The file is grouped
+as cp.async, cp.async.bulk.tensor, cp.async.bulk non-TMA, and CUDA
+compatibility helpers.
+"""
+
+import tvm
+from tvm.tirx.op import cuda_func_call
+
+from .._schema import device_intrinsic
+from .registry import CODEGEN_REGISTRY, register_codegen
+from .utils import parse_str
+
+_PREFETCH_CHOICES = ("", "64", "128", "256")
+_DIM_CHOICES = (1, 2, 3, 4, 5)
+_TILE_MODE_CHOICES = ("tile", "tile_gather4")
+
+
+def _safe(s):
+    return s.replace("::", "_").replace(".", "_")
+
+
+# =============================================================================
+# cp.async forms from the PTX Syntax block.
+#
+# Includes commit/wait plus the non-bulk shared/global copy forms.
+# =============================================================================
+device_intrinsic(
+    "ptx_cp_async_commit_group",
+    helper_name="tvm_builtin_ptx_cp_async_commit_group",
+    body='    asm volatile("cp.async.commit_group;");',
+)
+device_intrinsic(
+    "ptx_cp_async_wait_group",
+    n_attrs=1,
+    helper_name=lambda n: f"tvm_builtin_ptx_cp_async_wait_group_{int(n)}",
+    body=lambda n: f'    asm volatile("cp.async.wait_group {int(n)};");',
+)
+
+
+# cp.async non-bulk copy forms:
+#   Form 1: cp.async.ca.shared.global ... [dst], [src], cp-size{, src-size}{, 
cache-policy}
+#   Form 2: cp.async.cg.shared.global ... [dst], [src], 16{, src-size}{, 
cache-policy}
+#   Form 3: cp.async.ca.shared.global ... [dst], [src], cp-size{, 
ignore-src}{, cache-policy}
+#   Form 4: cp.async.cg.shared.global ... [dst], [src], 16{, ignore-src}{, 
cache-policy}
+
+
+def _cp_async_modifier_str(has_cache_hint, prefetch_size):
+    s = ""
+    if has_cache_hint:
+        s += ".L2::cache_hint"
+    if prefetch_size:
+        s += f".L2::{prefetch_size}B"
+    return s
+
+
+def _make_form_parts(ca_or_cg, fixed_cp_size, extra):
+    """Build a parts callable for one of the cp.async PTX forms.
+
+    Args layout: (dst, src [, extra_int], cache_policy, has_cache, 
prefetch_size [, cp_size_attr])
+    Forwarded operands: dst, src [, extra_int], cache_policy.
+    Trailing attrs: has_cache, prefetch_size [, cp_size if .ca].
+    """
+    n_op = 3 if extra is not None else 2
+    n_attrs = 2 if fixed_cp_size is not None else 3
+    extra_in_name = f"_with_{extra}" if extra is not None else ""
+
+    def _parts(*args):
+        # Operand args (forwarded) come first, then attr args.
+        attr_args = args[-n_attrs:]
+        has_cache = _bool_attr(attr_args[0])
+        prefetch_size = parse_str(attr_args[1])
+        cp_size = fixed_cp_size if fixed_cp_size is not None else 
int(attr_args[2])
+        modifier = _cp_async_modifier_str(has_cache, prefetch_size)
+        cache_operand = ', "l"(cache_policy)' if has_cache else ""
+        # name parts
+        name_cache = "_cache_hint" if has_cache else ""
+        name_prefetch = f"_prefetch_{prefetch_size}" if prefetch_size else ""
+        name = (
+            f"tvm_builtin_ptx_cp_async_{ca_or_cg}_{cp_size}"
+            f"{name_cache}{name_prefetch}{extra_in_name}"
+        )
+        sig = (
+            "(void* dst, void* src"
+            + (f", int {extra}" if extra else "")
+            + ", unsigned long long cache_policy)"
+        )
+        instr_base = f"cp.async.{ca_or_cg}.shared.global{modifier}"
+        if extra is None:
+            cache_arg = ", %2" if has_cache else ""
+            body = (
+                "    unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+                f'    asm volatile("{instr_base} [%0], [%1], 
{cp_size}{cache_arg};\\n"\n'
+                f'                 :: "r"(dst_addr), "l"(src){cache_operand} : 
"memory");'
+            )
+        else:
+            cache_arg = ", %3" if has_cache else ""
+            body = (
+                "    unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+                f'    asm volatile("{instr_base} [%0], [%1], {cp_size}, 
%2{cache_arg};\\n"\n'
+                f'                 :: "r"(dst_addr), "l"(src), "r"({extra})'
+                f'{cache_operand} : "memory");'
+            )
+        return name, sig, body
+
+    return _parts, n_op + n_attrs - n_op  # n_attrs
+
+
+def _register_nb_form(op_name, ca_or_cg, fixed_cp_size, extra):
+    parts_fn, n_attrs = _make_form_parts(ca_or_cg, fixed_cp_size, extra)
+    n_op = 3 if extra is not None else 2
+    sig_static = (
+        "(void* dst, void* src"
+        + (f", int {extra}" if extra else "")
+        + ", unsigned long long cache_policy)"
+    )
+    device_intrinsic(
+        f"ptx_cp_async_{op_name}",
+        n_attrs=n_attrs,
+        c_signature=sig_static,  # static — depends on `extra` not on attrs
+        helper_name=lambda *a, fn=parts_fn: fn(*a)[0],
+        body=lambda *a, fn=parts_fn: fn(*a)[2],
+    )
+    return n_op
+
+
+# Form 1: .ca + src-size (cp-size ∈ {4, 8}). src-size is required when present.
+_register_nb_form("ca_src_size", "ca", fixed_cp_size=None, extra="src_size")
+# Form 2: .cg + src-size (cp-size = 16).
+_register_nb_form("cg_src_size", "cg", fixed_cp_size=16, extra="src_size")
+# Form 3: .ca + ignore-src.
+_register_nb_form("ca_ignore_src", "ca", fixed_cp_size=None, 
extra="ignore_src")
+# Form 4: .cg + ignore-src.
+_register_nb_form("cg_ignore_src", "cg", fixed_cp_size=16, extra="ignore_src")
+# Plain degenerate of forms 1+2 with optional src-size omitted.
+_register_nb_form("ca", "ca", fixed_cp_size=None, extra=None)
+_register_nb_form("cg", "cg", fixed_cp_size=16, extra=None)
+
+
+def _make_setp_at_p_helper(ca_or_cg, cp_size, has_cache, prefetch):
+    """Wrapper convenience: ``setp+@p`` around a form 1/2 cp.async (predicate-
+    gated skip with dst untouched on false). Not a PTX form — emitted directly
+    here as a one-off helper rather than a separate device_intrinsic."""
+    modifier = _cp_async_modifier_str(has_cache, prefetch)
+    cache_arg = ", %4" if has_cache else ""
+    cache_operand = ', "l"(cache_policy)' if has_cache else ""
+    func_name = (
+        f"tvm_builtin_ptx_cp_async_{cp_size}"
+        + ("_cache_hint" if has_cache else "")
+        + (f"_prefetch_{prefetch}" if prefetch else "")
+        + "_predicate"
+    )
+    body = (
+        "  unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+        "  __asm__ __volatile__(\n"
+        '    "{\\n"\n'
+        '    " .reg .pred p;\\n"\n'
+        '    " setp.eq.u32 p, %3, 1;\\n"\n'
+        f'    " @p cp.async.{ca_or_cg}.shared.global{modifier}'
+        f' [%0], [%1], %2{cache_arg};\\n"\n'
+        '    "}\\n"\n'
+        f'    :: "r"(dst_addr), "l"(src), "n"({cp_size}), 
"r"(predicate){cache_operand}\n'
+        "  );"
+    )
+    source_code = (
+        f"\n__forceinline__ __device__ void {func_name}"
+        "(void* dst, void* src, int predicate, unsigned long long 
cache_policy) {\n"
+        f"{body}\n"
+        "}\n"
+    )
+    return func_name, source_code
+
+
+@register_codegen("ptx_cp_async")
+def codegen_ptx_cp_async(*args):
+    """Map the wrapper API to the 4 PTX form table entries.
+
+    Accepts three call shapes (sorted by arity):
+
+    * 5 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size)`` —
+      the legacy form emitted by ``s_tir/transform/InjectPTXAsyncCopy``.
+      Offsets are folded into the pointers via ``tvm_access_ptr`` (in
+      bytes; offsets are pre-scaled by the pass) and the call is
+      forwarded with default cache / predicate / fill_mode.
+    * 6 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size,
+      predicate)`` — same as 5-arg form with an explicit predicate.
+    * 8 args ``(dst_ptr, src_ptr, cp_size, cache_policy, has_cache_hint,
+      prefetch_size, predicate, fill_mode)`` — the fork-native wrapper
+      API.
+
+    The three resulting form_kinds:
+
+    * ``fill_mode == "zero"`` -> form 1/2 (src-size = predicate ? cp_size : 0)
+    * ``predicate != -1`` and no fill_mode -> form 1/2 wrapped in setp+@p
+      (wrapper convenience; not a PTX form)
+    * else -> form 1/2 with src-size omitted (the "plain" degenerate)
+    """
+    from tvm.tirx.op import if_then_else
+
+    if len(args) in (5, 6):
+        # Legacy InjectPTXAsyncCopy emission: (dst_ptr, dst_off, src_ptr,
+        # src_off, cp_size [, predicate]). Offsets are element indices into
+        # the typed buffers (the pass uses index_factor=1 except for the
+        # shared.dyn-merged byte-buffer path). Emit a C helper that scales
+        # the offset by the buffer element size, then runs cp.async.
+        #
+        # PTX plain form for both .ca and .cg is just
+        # ``cp.async.<v>.shared.global [dst], [src], cp_size;`` — three
+        # operands, no trailing src-size / cache-policy.
+        from tvm import DataType
+
+        dst_ptr_in, dst_offset, src_ptr_in, src_offset, cp_size = args[:5]
+        predicate = args[5] if len(args) == 6 else -1
+        cp_size_v = int(cp_size)
+        ca_or_cg = "cg" if cp_size_v == 16 else "ca"
+
+        # Recover the per-side element dtype from each pointer's type
+        # annotation (Var has type_annotation = PointerType(PrimType(dtype))).
+        # InjectPTXAsyncCopy emits offsets in element-units of each side's
+        # buffer dtype (dst gets dst_offset * src_elem_size only when dst is a
+        # merged shared.dyn byte buffer, in which case dst_elem_dtype is uint8
+        # and the resulting scale-by-1 is a no-op).
+        def _elem_bytes(ptr):
+            ta = getattr(ptr, "type_annotation", None)
+            if ta is None or getattr(ta, "element_type", None) is None:
+                return 1
+            et = ta.element_type
+            if not hasattr(et, "dtype"):
+                return 1
+            bits = DataType(str(et.dtype)).bits
+            assert bits % 8 == 0, f"non-byte element dtype: {et.dtype}"
+            return bits // 8
+
+        dst_elem_bytes = _elem_bytes(dst_ptr_in)
+        src_elem_bytes = _elem_bytes(src_ptr_in)
+        has_predicate = not (
+            (isinstance(predicate, int) and predicate == -1)
+            or (hasattr(predicate, "value") and int(predicate.value) == -1)
+        )
+
+        def _scale(n):
+            return "" if n == 1 else f" * {n}"
+
+        dst_scale = _scale(dst_elem_bytes)
+        src_scale = _scale(src_elem_bytes)
+        if has_predicate:
+            func_name = (
+                
f"ptx_cp_async_legacy_pred_{ca_or_cg}_{cp_size_v}_{dst_elem_bytes}_{src_elem_bytes}"
+            )
+            body = (
+                f"  uint8_t* dst_p = (uint8_t*)dst + dst_off{dst_scale};\n"
+                f"  uint8_t* src_p = (uint8_t*)src + src_off{src_scale};\n"
+                "  unsigned int dst_addr = __cvta_generic_to_shared(dst_p);\n"
+                "  __asm__ __volatile__(\n"
+                '    "{\\n"\n'
+                '    " .reg .pred p;\\n"\n'
+                '    " setp.eq.u32 p, %3, 1;\\n"\n'
+                f'    " @p cp.async.{ca_or_cg}.shared.global'
+                ' [%0], [%1], %2;\\n"\n'
+                '    "}\\n"\n'
+                f'    :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}), 
"r"(predicate)\n'
+                "  );"
+            )
+            source_code = (
+                f"\n__forceinline__ __device__ void {func_name}"
+                "(void* dst, int dst_off, void* src, int src_off, int 
predicate) {\n"
+                f"{body}\n"
+                "}\n"
+            )
+            return cuda_func_call(
+                func_name,
+                dst_ptr_in,
+                dst_offset,
+                src_ptr_in,
+                src_offset,
+                predicate,
+                source_code=source_code,
+            )
+        # No predicate — plain cp.async.
+        func_name = 
f"ptx_cp_async_legacy_{ca_or_cg}_{cp_size_v}_{dst_elem_bytes}_{src_elem_bytes}"
+        body = (
+            f"  uint8_t* dst_p = (uint8_t*)dst + dst_off{dst_scale};\n"
+            f"  uint8_t* src_p = (uint8_t*)src + src_off{src_scale};\n"
+            "  unsigned int dst_addr = __cvta_generic_to_shared(dst_p);\n"
+            f'  asm volatile("cp.async.{ca_or_cg}.shared.global'
+            ' [%0], [%1], %2;"\n'
+            f'    :: "r"(dst_addr), "l"(src_p), "n"({cp_size_v}));'
+        )
+        source_code = (
+            f"\n__forceinline__ __device__ void {func_name}"
+            "(void* dst, int dst_off, void* src, int src_off) {\n"
+            f"{body}\n"
+            "}\n"
+        )
+        return cuda_func_call(
+            func_name,
+            dst_ptr_in,
+            dst_offset,
+            src_ptr_in,
+            src_offset,
+            source_code=source_code,
+        )
+    elif len(args) == 8:
+        (
+            dst_ptr,
+            src_ptr,
+            cp_size,
+            cache_policy,
+            has_cache_hint,
+            prefetch_size,
+            predicate,
+            fill_mode,
+        ) = args
+    else:
+        raise ValueError(f"ptx_cp_async codegen expects 5/6/8 args, got 
{len(args)}")
+
+    cp_size_v = int(cp_size)
+    ca_or_cg = "cg" if cp_size_v == 16 else "ca"
+    pref = "" if int(prefetch_size) == -1 else str(int(prefetch_size))
+    fill = parse_str(fill_mode)
+    has_cache = _bool_attr(has_cache_hint)
+    has_predicate = not (
+        (isinstance(predicate, int) and predicate == -1)
+        or (hasattr(predicate, "value") and int(predicate.value) == -1)
+    )
+
+    if fill == "zero":
+        src_size = if_then_else(predicate != 0, cp_size_v, 0)
+        op = f"tirx.ptx_cp_async_{ca_or_cg}_src_size"
+        if cp_size_v == 16:
+            args = [dst_ptr, src_ptr, src_size, cache_policy, has_cache, pref]
+        else:
+            args = [dst_ptr, src_ptr, src_size, cache_policy, has_cache, pref, 
cp_size_v]
+        result = CODEGEN_REGISTRY[op](args)
+        return result[0] if isinstance(result, tuple) else result
+
+    if has_predicate:
+        func_name, source_code = _make_setp_at_p_helper(ca_or_cg, cp_size_v, 
has_cache, pref)
+        return cuda_func_call(
+            func_name, dst_ptr, src_ptr, predicate, cache_policy, 
source_code=source_code
+        )
+
+    # Plain — form 1/2 with src-size omitted.
+    op = f"tirx.ptx_cp_async_{ca_or_cg}"
+    if cp_size_v == 16:
+        args = [dst_ptr, src_ptr, cache_policy, has_cache, pref]
+    else:
+        args = [dst_ptr, src_ptr, cache_policy, has_cache, pref, cp_size_v]
+    result = CODEGEN_REGISTRY[op](args)
+    return result[0] if isinstance(result, tuple) else result
+
+
+# =============================================================================
+# cp.async.bulk.tensor (TMA) — one device_intrinsic per arity variant of each
+# PTX form. Per-dim coord operands materialise via the ``c_signature`` 
callable.
+# =============================================================================
+
+
+def _is_sm100_or_higher():
+    target = tvm.target.Target.current()
+    if target is None:
+        return False
+    arch = target.arch[3:]
+    if not arch[-1].isdigit():
+        arch = arch[:-1]
+    return int(arch) >= 100
+
+
+def _resolve_cta_group_str(cta_group):
+    if cta_group == 2 or (cta_group != -1 and _is_sm100_or_higher()):
+        return f".cta_group::{cta_group}"
+    return ""
+
+
+def _coord_template(coord_count, start_slot):
+    inner = ", ".join(f"%{start_slot + i}" for i in range(coord_count))
+    return f"{{{inner}}}"
+
+
+def _coord_constraints(coord_count):
+    return ", ".join(f'"r"(coord{i})' for i in range(coord_count))
+
+
+def _coord_sig(n):
+    return ", ".join(f"int coord{i}" for i in range(n))
+
+
+# PTX cp.async.bulk.tensor global -> shared::cluster form:
+#   cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism
+#       {.multicast}{.cta_group}{.level::cache_hint}
+#       [dstMem], [tensorMap, tensorCoords], [mbar]{, im2colInfo}
+#       {, ctaMask} {, cache-policy}
+#   .dst = {.shared::cluster}; .src = {.global}
+#   .completion_mechanism = {.mbarrier::complete_tx::bytes}
+#   .multicast = {.multicast::cluster}
+#   .cta_group = {.cta_group::1, .cta_group::2}
+#   .load_mode = {.tile, .tile::gather4, .im2col, .im2col::w, .im2col::w::128}
+#   .level::cache_hint = {.L2::cache_hint}
+# This registration supports tile/tile::gather4 modes; ctaMask is only used
+# when the optional ``.multicast::cluster`` modifier is enabled.
+def _g2cluster_parts(*args):
+    attrs = args[-6:]
+    dim = int(attrs[0])
+    cta_group = int(attrs[1])
+    has_cache = _bool_attr(attrs[2])
+    tile_mode = parse_str(attrs[3])
+    bar_is_addr = _bool_attr(attrs[4])
+    multicast = _bool_attr(attrs[5])
+    coord_count = 5 if tile_mode == "tile_gather4" else dim
+    bar_type = "unsigned int bar_addr" if bar_is_addr else "void* bar"
+    sig = (
+        f"(void* dst, {bar_type}, unsigned long long tensormap_addr, "
+        "uint16_t cta_mask, unsigned long long cache_policy"
+        + (", " + _coord_sig(coord_count) if coord_count else "")
+        + ")"
+    )
+    name = (
+        f"ptx_cp_async_bulk_tensor_g2cluster_{tile_mode}_{dim}d"
+        f"{'_multicast' if multicast else ''}"
+        f"{'_cache_hint' if has_cache else ''}{'_bar_addr' if bar_is_addr else 
''}"
+    )
+    tile_modifier = ".tile::gather4" if tile_mode == "tile_gather4" else ""
+    cta_group_str = _resolve_cta_group_str(cta_group)
+    multicast_inst = ".multicast::cluster" if multicast else ""
+    cache_inst = ".L2::cache_hint" if has_cache else ""
+    mask_arg = ',\n          "h"(cta_mask)' if multicast else ""
+    cache_arg = ',\n          "l"(cache_policy)' if has_cache else ""
+    mask_slot = ", %3" if multicast else ""
+    cache_slot = ", %4" if multicast and has_cache else ", %3" if has_cache 
else ""
+    coord_start = 5 if multicast and has_cache else 4 if multicast or 
has_cache else 3
+    coord_tpl = _coord_template(coord_count, coord_start)
+    instr = (
+        f"cp.async.bulk.tensor.{dim}d.shared::cluster.global{tile_modifier}"
+        f".mbarrier::complete_tx::bytes{multicast_inst}"
+        f"{cta_group_str}{cache_inst}"
+    )
+    bar_addr_decl = (
+        "" if bar_is_addr else "    unsigned int bar_addr = 
__cvta_generic_to_shared(bar);\n"
+    )
+    body = (
+        "    unsigned int dst_addr = __cvta_generic_to_shared(dst);\n"
+        f"{bar_addr_decl}"
+        "    asm volatile(\n"
+        f'        "{instr} [%0], [%1, {coord_tpl}], 
[%2]{mask_slot}{cache_slot};"\n'
+        "        :\n"
+        f'        : "r"(dst_addr), "l"(tensormap_addr), 
"r"(bar_addr){mask_arg}{cache_arg},\n'
+        f"          {_coord_constraints(coord_count)}\n"
+        '        : "memory"\n'
+        "    );"
+    )
+    return name, sig, body

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   In `_g2cluster_parts`, the `coord_start` calculation logic is complex and 
depends on multiple boolean flags. If the PTX instruction format changes or 
more modifiers are added, this manual slot tracking will be error-prone. 
Consider using a more structured approach for mapping operands to `%` 
placeholders in the inline assembly template.



##########
python/tvm/tirx/operator/intrinsics/cuda/math.py:
##########
@@ -0,0 +1,501 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, invalid-name
+"""Math intrinsics.
+
+PTX side:
+* ``add{.rnd}{.ftz}.f32x2`` / ``sub`` / ``mul`` / ``fma`` — packed f32x2.
+* ``ex2.approx.ftz.f32`` / ``rcp.approx.ftz.f32`` — special functions.
+* ``max.f32`` / ``min.f32`` — 3-operand reduction form.
+
+CUDA side:
+* warp / CTA reductions (templated butterfly shuffle-XOR).
+"""
+
+from tvm.tirx.op import cuda_func_call
+
+from .._schema import device_intrinsic
+from .registry import register_codegen
+from .utils import parse_str, validate_power_of_two_range
+
+# =============================================================================
+# Packed f32x2 arithmetic — `add{.rnd}{.ftz}.f32x2 d, a, b ;` and friends.
+# Inputs are packed into a `.b64` register (low half = elem 0, high half =
+# elem 1); the body packs/unpacks via ``make_float2`` + ``reinterpret_cast``.
+# =============================================================================
+
+# PTX add/sub/mul/fma over (f32 | f32x2 | f64), DPS form.
+#   add{.rnd}{.ftz}{.sat}.f32     [d], a, b
+#   add{.rnd}{.ftz}.f32x2          [d], a, b      (a,b are packed-as-u64)
+#   add{.rnd}.f64                  [d], a, b
+#   (sub / mul same shape; fma adds a `c` operand)
+# Inputs a/b/c are register operands (scalar fp32 / packed u64 / scalar fp64).
+# Result is written through `d` (a pointer).
+_PACKED_ROUNDING = ("rz", "rn", "rm", "rp")
+
+
+# Per-dtype operand types and asm constraints.
+#  - c_in: C type of input register operand (matches PTX register type)
+#  - out_cast: pointer cast applied at d_addr (callers may pass 
float*/double*/...)
+#  - in_cstr / out_cstr: GCC asm constraint letter
+_DTYPE_INFO = {
+    "f32": {"c_in": "float", "out_cast": "float*", "in_cstr": "f", "out_cstr": 
"f"},
+    "f32x2": {
+        "c_in": "unsigned long long",
+        "out_cast": "uint64_t*",
+        "in_cstr": "l",
+        "out_cstr": "l",
+    },
+    "f64": {"c_in": "double", "out_cast": "double*", "in_cstr": "d", 
"out_cstr": "d"},
+}
+
+
+def _ptx_arith_modifier_string(dtype, rounding, ftz, sat):
+    """Build the `.rnd.ftz.sat` modifier substring + name suffix."""
+    rnd = parse_str(rounding)
+    assert rnd in _PACKED_ROUNDING, f"invalid rounding {rnd!r}, expected one 
of {_PACKED_ROUNDING}"
+    ftz_b = bool(int(ftz)) if hasattr(ftz, "value") else bool(ftz)
+    sat_b = bool(int(sat)) if hasattr(sat, "value") else bool(sat)
+    if dtype == "f64" and (ftz_b or sat_b):
+        raise ValueError("PTX <op>.f64 does not accept .ftz or .sat")
+    if dtype == "f32x2" and sat_b:
+        raise ValueError("PTX <op>.f32x2 does not accept .sat")
+    mod = f".{rnd}"
+    if ftz_b:
+        mod += ".ftz"
+    if sat_b:
+        mod += ".sat"
+    name_suffix = f"_{rnd}"
+    if ftz_b:
+        name_suffix += "_ftz"
+    if sat_b:
+        name_suffix += "_sat"
+    return mod, name_suffix
+
+
+def _ptx_binary_arith_parts(op, dtype):
+    """Return (name_fn, sig, body_fn) for ptx_{op}_{dtype} binary form."""
+    info = _DTYPE_INFO[dtype]
+    # Destination is ``void*`` so callers can pass any element-type pointer
+    # (float* / double* / uint64_t*); body reinterpret-casts to the right type.
+    sig = f"(void* d, {info['c_in']} a, {info['c_in']} b)"
+
+    def _name(d, a, b, rounding, ftz, sat):
+        _, suf = _ptx_arith_modifier_string(dtype, rounding, ftz, sat)
+        return f"tvm_builtin_ptx_{op}_{dtype}{suf}"
+
+    out_c = info["out_cstr"]
+    in_c = info["in_cstr"]
+    out_cast = info["out_cast"]
+
+    def _body(d, a, b, rounding, ftz, sat):
+        mod, _ = _ptx_arith_modifier_string(dtype, rounding, ftz, sat)
+        return (
+            f'    asm volatile("{op}{mod}.{dtype} %0, %1, %2;"\n'
+            f'        : "={out_c}"(*reinterpret_cast<{out_cast}>(d))\n'
+            f'        : "{in_c}"(a), "{in_c}"(b));'
+        )
+
+    return _name, sig, _body
+
+
+def _ptx_fma_parts(dtype):
+    """Return (name_fn, sig, body_fn) for ptx_fma_{dtype}."""
+    info = _DTYPE_INFO[dtype]
+    sig = f"(void* d, {info['c_in']} a, {info['c_in']} b, {info['c_in']} c)"
+
+    def _name(d, a, b, c, rounding, ftz, sat):
+        _, suf = _ptx_arith_modifier_string(dtype, rounding, ftz, sat)
+        return f"tvm_builtin_ptx_fma_{dtype}{suf}"
+
+    out_c = info["out_cstr"]
+    in_c = info["in_cstr"]
+    out_cast = info["out_cast"]
+
+    def _body(d, a, b, c, rounding, ftz, sat):
+        mod, _ = _ptx_arith_modifier_string(dtype, rounding, ftz, sat)
+        return (
+            f'    asm volatile("fma{mod}.{dtype} %0, %1, %2, %3;"\n'
+            f'        : "={out_c}"(*reinterpret_cast<{out_cast}>(d))\n'
+            f'        : "{in_c}"(a), "{in_c}"(b), "{in_c}"(c));'
+        )
+
+    return _name, sig, _body
+
+
+# Register 12 ops: {add, sub, mul, fma} x {f32, f32x2, f64}.
+for _dtype in ("f32", "f32x2", "f64"):
+    for _op in ("add", "sub", "mul"):
+        _name_fn, _sig, _body_fn = _ptx_binary_arith_parts(_op, _dtype)
+        device_intrinsic(
+            f"ptx_{_op}_{_dtype}",
+            n_attrs=3,  # rounding, ftz, sat
+            helper_name=_name_fn,
+            c_signature=_sig,
+            body=_body_fn,
+        )
+    _name_fn, _sig, _body_fn = _ptx_fma_parts(_dtype)
+    device_intrinsic(
+        f"ptx_fma_{_dtype}",
+        n_attrs=3,
+        helper_name=_name_fn,
+        c_signature=_sig,
+        body=_body_fn,
+    )
+del _dtype, _op, _name_fn, _sig, _body_fn
+
+
+# =============================================================================
+# ex2.approx.ftz.f32 / rcp.approx.ftz.f32 — 1 form each.
+# =============================================================================
+device_intrinsic(
+    "ptx_exp2",
+    c_signature="(float x)",
+    return_type="float",
+    body=(
+        "    float result;\n"
+        '    asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : 
"f"(x));\n'
+        "    return result;"
+    ),
+)
+device_intrinsic(
+    "ptx_rcp",
+    c_signature="(float x)",
+    return_type="float",
+    body=(
+        "    float result;\n"
+        '    asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : 
"f"(x));\n'
+        "    return result;"
+    ),
+)
+
+
+# =============================================================================
+# 3-operand max.f32 / min.f32 — the f32, 3-operand form-table entry of the
+# redux/reduction-style fp32 max/min ops.
+# =============================================================================
+_ABC_SIG = "(float a, float b, float c)"
+device_intrinsic(
+    "ptx_reduce3_max_f32",
+    c_signature=_ABC_SIG,
+    return_type="float",
+    body=(
+        "    float result;\n"
+        '    asm volatile("max.f32 %0, %1, %2, %3;"\n'
+        '                 : "=f"(result) : "f"(a), "f"(b), "f"(c));\n'
+        "    return result;"
+    ),
+)
+device_intrinsic(
+    "ptx_reduce3_min_f32",
+    c_signature=_ABC_SIG,
+    return_type="float",
+    body=(
+        "    float result;\n"
+        '    asm volatile("min.f32 %0, %1, %2, %3;"\n'
+        '                 : "=f"(result) : "f"(a), "f"(b), "f"(c));\n'
+        "    return result;"
+    ),
+)
+
+
+_BINARY_F32_SIG = "(float a, float b)"
+
+
+def _ptx_max_f32_body(a, b, ftz, nan):
+    ftz_b = bool(int(ftz)) if hasattr(ftz, "value") else bool(ftz)
+    nan_b = bool(int(nan)) if hasattr(nan, "value") else bool(nan)
+    ftz_suffix = ".ftz" if ftz_b else ""
+    nan_suffix = ".NaN" if nan_b else ""
+    return (
+        "    float result;\n"
+        f'    asm volatile("max{ftz_suffix}{nan_suffix}.f32 %0, %1, %2;"\n'
+        '                 : "=f"(result) : "f"(a), "f"(b));\n'
+        "    return result;"
+    )
+
+
+def _ptx_max_f32_name(a, b, ftz, nan):
+    ftz_b = bool(int(ftz)) if hasattr(ftz, "value") else bool(ftz)
+    nan_b = bool(int(nan)) if hasattr(nan, "value") else bool(nan)
+    suffix = ""
+    if ftz_b:
+        suffix += "_ftz"
+    if nan_b:
+        suffix += "_nan"
+    return f"tvm_builtin_ptx_max_f32{suffix}"
+
+
+device_intrinsic(
+    "ptx_max_f32",
+    n_attrs=2,
+    helper_name=_ptx_max_f32_name,
+    c_signature=_BINARY_F32_SIG,
+    return_type="float",
+    body=_ptx_max_f32_body,
+)
+
+
+# =============================================================================
+# CUDA-side warp / CTA reductions (templated butterfly shuffle-XOR).
+# Emitted directly via ``cuda_func_call`` — the helper signature uses a
+# single template parameter ``T`` for both arg and return, which doesn't
+# match the operand-driven C signature pattern.
+# =============================================================================
+
+# (accumulation expression, identity value for cross-warp padding)
+_OP_TABLE = {
+    "sum": ("val += shuffled;", "T(0)"),
+    "max": ("val = max(val, shuffled);", "-INFINITY"),
+    "min": ("val = min(val, shuffled);", "INFINITY"),
+}
+
+
+def _validate_op(op_str, context):
+    if op_str not in _OP_TABLE:
+        raise ValueError(f"Unsupported {context} op '{op_str}', expected one 
of {list(_OP_TABLE)}")
+    return _OP_TABLE[op_str]
+
+
+def _warp_reduce_source(func_name, width_int, step_expr):
+    return (
+        f"\ntemplate <typename T>\n"
+        f"__forceinline__ __device__ T {func_name}(T val) {{\n"
+        f"    #pragma unroll\n"
+        f"    for (int mask = {width_int} >> 1; mask > 0; mask >>= 1) {{\n"
+        "        T shuffled = __shfl_xor_sync(0xFFFFFFFF, val, mask);\n"
+        f"        {step_expr}\n"
+        "    }\n"
+        "    return val;\n"
+        "}\n"
+    )
+
+
+@register_codegen("cuda_warp_reduce")
+def codegen_cuda_warp_reduce(value, op, width):
+    op_str = parse_str(op)
+    width_int = validate_power_of_two_range(width, 2, 32, "warp_reduce width")
+    step_expr, _ = _validate_op(op_str, "warp_reduce")
+
+    func_name = f"tvm_builtin_cuda_warp_reduce_{op_str}_{width_int}"
+    source_code = _warp_reduce_source(func_name, width_int, step_expr)
+    return cuda_func_call(func_name, value, source_code=source_code, 
return_type=value.dtype)
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `codegen_cuda_warp_reduce` function generates a templated helper but 
uses `value.dtype` as the return type. If `value` is a packed type (like 
`f32x2` represented as `uint64`), the template instantiation in CUDA might need 
explicit handling for the underlying scalar types to ensure correct butterfly 
shuffle behavior across lanes.



##########
python/tvm/tirx/lang/tile_scheduler.py:
##########
@@ -0,0 +1,818 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Reusable tile scheduler helpers for TIR tests/kernels.
+
+These classes emit TIR via @Tx.inline. Decorate with @Tx.meta_class so that
+instances are automatically treated as meta values inside @Tx.prim_func.
+"""
+
+from tvm.script import tirx as Tx
+
+
[email protected]_class
+class BaseTileScheduler:
+    """Base class for tile schedulers with common state and macros."""
+
+    def __init__(self, prefix: str):
+        self.m_idx = Tx.local_scalar("int32")
+        self.n_idx = Tx.local_scalar("int32")
+        self.linear_idx = Tx.local_scalar("int32")
+
+    @Tx.inline
+    def update_current_m_n_idx(self, linear_idx):
+        # To be implemented by subclasses
+        pass
+
+    @Tx.inline
+    def init(self, linear_init):
+        self.linear_idx = linear_init
+        self.update_current_m_n_idx(linear_init)
+
+    @Tx.inline
+    def next_tile(self, step):
+        self.linear_idx = self.linear_idx + step
+        self.update_current_m_n_idx(self.linear_idx)
+
+    def valid(self, total_tiles):
+        return self.linear_idx < total_tiles
+
+
+class ClusterPersistentScheduler2D(BaseTileScheduler):
+    """
+    Tile scheduler for cluster-based persistent kernels.
+
+    Distributes a 2D tile grid across persistent clusters using group-major 
ordering
+    for L2 cache locality. Each cluster starts at its cluster_id and strides by
+    num_clusters to process tiles.
+
+    Tile Ordering (group-major for L2 locality):
+    - Tiles are grouped into "L2 groups" of `l2_group_size` rows
+    - Within a group, tiles are visited in column-major order within the group
+    - Groups are processed in row-major order
+
+    Example with 4x4 tiles, l2_group_size=2:
+        Group 0 (rows 0-1):  0  2  4  6
+                             1  3  5  7
+        Group 1 (rows 2-3):  8 10 12 14
+                             9 11 13 15
+
+    Serpentine Mode (serpentine=True):
+    - Uses CUTLASS-style 2D block swizzle with serpentine traversal
+    - Grid is divided into swizzle_size x swizzle_size blocks
+    - Within each block, tiles are visited in row-major order
+    - Blocks are traversed in serpentine order (even block-rows forward, odd 
backward)
+    - This provides better L2 locality by reusing both A and B tiles
+
+    Example with 4x4 tiles, swizzle_size=2, serpentine=True:
+        Block layout:
+          Block(0,0)  Block(0,1)
+          Block(1,0)  Block(1,1)
+
+        Tile numbering with serpentine:
+               n=0  n=1  n=2  n=3
+          m=0   0    1   14   15
+          m=1   2    3   12   13
+          m=2   4    5   10   11
+          m=3   6    7    8    9
+
+        Traversal: Block(0,0) -> Block(1,0) -> Block(1,1) -> Block(0,1)
+                   (serpentine: down in col 0, then up in col 1)
+
+    Parameters
+    ----------
+    prefix : str
+        Prefix for TIR variable names
+    num_m_tiles : int | Tx.ExprLike
+        Total number of tiles in M dimension (can be runtime expression)
+    num_n_tiles : int
+        Total number of tiles in N dimension
+    num_clusters : int
+        Number of persistent clusters (determines stride)
+    l2_group_size : int
+        Number of M-tile rows per L2 locality group (default: 8)
+        When serpentine=True, this is used as swizzle_size for 2D blocks
+    cluster_m : int
+        Cluster dimension in M for hierarchical scheduling (default: 1)
+    cluster_n : int
+        Cluster dimension in N for hierarchical scheduling (default: 1)
+    serpentine : bool
+        If True, use CUTLASS-style 2D block swizzle with serpentine traversal 
(default: False)
+
+    Attributes
+    ----------
+    m_idx : Tx.local_scalar
+        Current M tile index (output)
+    n_idx : Tx.local_scalar
+        Current N tile index (output)
+    work_idx : Tx.local_scalar
+        Global work item index for this cluster
+    tile_count : Tx.local_scalar
+        Number of tiles processed by this cluster so far
+
+    Usage
+    -----
+    ```python
+    scheduler = ClusterPersistentScheduler2D(
+        "sched", num_m_tiles=M_TILES, num_n_tiles=N_TILES,
+        num_clusters=NUM_CLUSTERS, l2_group_size=8
+    )
+    scheduler.init(cluster_id)  # cluster_id = cta_idx // CLUSTER_SIZE
+
+    while scheduler.valid():
+        m = Tx.meta_var(scheduler.m_idx)  # current M tile
+        n = Tx.meta_var(scheduler.n_idx)  # current N tile
+        # ... process tile (m, n) ...
+        scheduler.next_tile()
+    ```
+
+    Examples
+    --------
+    Example 1: Basic persistent kernel
+    ```
+    num_m_tiles=4, num_n_tiles=4, num_clusters=3, l2_group_size=2
+    cluster_m=1, cluster_n=1 (default, no tile subdivision)
+
+    Group-major tile numbering (l2_group_size=2):
+           n=0  n=1  n=2  n=3
+      m=0   0    2    4    6   ┐ L2 group 0
+      m=1   1    3    5    7   ┘
+      m=2   8   10   12   14   ┐ L2 group 1
+      m=3   9   11   13   15   ┘
+
+    Work distribution (cluster starts at cluster_id, strides by 
num_clusters=3):
+      cluster 0: work_idx 0,3,6,9,12,15  -> tiles 0,3,6,9,12,15
+      cluster 1: work_idx 1,4,7,10,13    -> tiles 1,4,7,10,13
+      cluster 2: work_idx 2,5,8,11,14    -> tiles 2,5,8,11,14
+
+    Tile grid (which cluster handles each tile):
+           n=0  n=1  n=2  n=3
+      m=0   C0   C2   C1   C0   ┐ L2 group 0
+      m=1   C1   C0   C2   C1   ┘
+      m=2   C2   C1   C0   C2   ┐ L2 group 1
+      m=3   C0   C2   C1   C0   ┘
+
+    Tile sequence per cluster (in execution order):
+      cluster 0: (0,0)->(1,1)->(0,3)->(2,0)->(2,3)->(3,3)
+      cluster 1: (1,0)->(0,2)->(1,3)->(2,1)->(3,2)
+      cluster 2: (0,1)->(1,2)->(2,0)->(3,1)->(2,3)
+    ```
+
+    Example 2: 2SM GEMM (typical B200 config)
+    ```
+    M=1024, N=512, CTA_M=128, MMA_N=128, CLUSTER_M=2, CLUSTER_N=1
+    => M_TILES=8, N_TILES=4
+    => CLUSTER_M_TILES=4, CLUSTER_N_TILES=4 (scheduler at cluster granularity)
+
+    Scheduler params:
+      num_m_tiles=4, num_n_tiles=4, num_clusters=74, l2_group_size=8
+      cluster_m=1, cluster_n=1
+
+    Key: Scheduler outputs CLUSTER-level tiles.
+         All CTAs in same cluster get SAME (m_idx, n_idx) from scheduler.
+         CTAs differentiate via cluster_rank (computed OUTSIDE scheduler):
+           cluster_rank = cta_idx % CLUSTER_SIZE
+           cb_m = cluster_rank % CLUSTER_M   # 0 or 1 for 2SM
+           cb_n = cluster_rank // CLUSTER_M  # 0 for 2SM
+
+    Final CTA tile:
+      cta_m = m_idx * CLUSTER_M + cb_m
+      cta_n = n_idx * CLUSTER_N + cb_n
+
+    Example: cluster 5 gets scheduler tile (1,2)
+      CTA rank=0 (cb_m=0): actual tile (2,2)
+      CTA rank=1 (cb_m=1): actual tile (3,2)
+    ```
+    """
+
+    def __init__(
+        self,
+        prefix: str,
+        num_m_tiles,
+        num_n_tiles: int,
+        num_clusters: int,
+        l2_group_size: int = 8,
+        cluster_m: int = 1,
+        cluster_n: int = 1,
+        serpentine: bool = False,
+    ):
+        super().__init__(prefix)
+        self._num_m_tiles = num_m_tiles
+        self._num_n_tiles = num_n_tiles
+        self._num_clusters = num_clusters
+        self._l2_group_size = l2_group_size
+        self._cluster_m = cluster_m
+        self._cluster_n = cluster_n
+        self._serpentine = serpentine
+
+        # Rename internal state for clarity
+        self.work_idx = self.linear_idx  # alias: global work item index
+        self.tile_count = Tx.local_scalar("int32")
+        self.tile_idx = self.tile_count  # alias for backward compatibility
+
+        is_static_m = isinstance(num_m_tiles, int)
+
+        # Number of tile columns after accounting for cluster_n
+        n_tile_cols = (num_n_tiles + cluster_n - 1) // cluster_n
+        self._N_TILE_COLS = n_tile_cols
+
+        if is_static_m:
+            self._M_TILE_ROWS = (num_m_tiles + cluster_m - 1) // cluster_m
+            self._FULL_GROUPS = self._M_TILE_ROWS // l2_group_size
+        else:
+            # Dynamic expressions for runtime M
+            self._M_TILE_ROWS = Tx.truncdiv(
+                self._num_m_tiles + self._cluster_m - 1, self._cluster_m
+            )
+            self._FULL_GROUPS = Tx.truncdiv(self._M_TILE_ROWS, 
self._l2_group_size)
+
+        self._TAIL_ROWS = self._M_TILE_ROWS - self._FULL_GROUPS * l2_group_size
+        self._TOTAL_TILES = self._M_TILE_ROWS * n_tile_cols * cluster_m * 
cluster_n
+
+        # For serpentine mode: precompute block counts
+        if serpentine:
+            self._N_BLOCKS = n_tile_cols // l2_group_size  # full blocks in N
+            self._M_BLOCKS = (
+                self._M_TILE_ROWS // l2_group_size
+                if is_static_m
+                else Tx.truncdiv(self._M_TILE_ROWS, l2_group_size)
+            )
+            self._BLOCK_SIZE = l2_group_size * l2_group_size  # tiles per block
+            self._FULL_BLOCK_TILES = self._M_BLOCKS * self._N_BLOCKS * 
self._BLOCK_SIZE
+            # Residual tiles (not covered by full blocks)
+            self._RESIDUAL_N = n_tile_cols - self._N_BLOCKS * l2_group_size
+            self._RESIDUAL_M = self._M_TILE_ROWS - self._M_BLOCKS * 
l2_group_size
+
+    # fmt: off
+    @Tx.inline
+    def update_current_m_n_idx(self, work_idx):
+        """Convert global work index to (m_idx, n_idx) tile coordinates."""
+        CLUSTER_M = Tx.meta_var(self._cluster_m)
+        CLUSTER_N = Tx.meta_var(self._cluster_n)
+
+        # Extract hierarchical cluster-local offsets
+        cluster_m_offset = Tx.meta_var(work_idx % CLUSTER_M)
+        t = Tx.meta_var(work_idx // CLUSTER_M)
+        cluster_n_offset = Tx.meta_var(t % CLUSTER_N)
+        tile_linear = Tx.meta_var(t // CLUSTER_N)
+
+        @Tx.inline
+        def set_tile_coords(tile_row, tile_col):
+            self.m_idx = tile_row * CLUSTER_M + cluster_m_offset
+            self.n_idx = tile_col * CLUSTER_N + cluster_n_offset
+
+        if self._serpentine:
+            self._update_serpentine(tile_linear, set_tile_coords)
+        else:
+            self._update_group_major(tile_linear, set_tile_coords)
+
+    def _update_group_major(self, tile_linear, set_tile_coords):
+        """Group-major ordering with parse-time pruning of statically-dead 
branches.
+
+        The TIR script parser does not constant-fold ``if False: ...``, so a
+        Python-literal ``FULL_GROUPS == 0`` would otherwise produce
+        ``T.bitwise_and(T.bool(False), tile_linear < 0)`` IR plus the dead
+        then-leg.  Branch in plain Python here and only invoke the inline
+        emitter that can actually fire.
+        """
+        full_zero = isinstance(self._FULL_GROUPS, int) and self._FULL_GROUPS 
== 0
+        tail_zero = isinstance(self._TAIL_ROWS, int) and self._TAIL_ROWS == 0
+        if full_zero and tail_zero:
+            self._gm_emit_zero(set_tile_coords)
+        elif full_zero:
+            self._gm_emit_tail_only(tile_linear, set_tile_coords)
+        elif tail_zero:
+            self._gm_emit_full_only(tile_linear, set_tile_coords)
+        else:
+            self._gm_emit_full_and_tail(tile_linear, set_tile_coords)
+
+    @Tx.inline
+    def _gm_emit_zero(self, set_tile_coords):
+        set_tile_coords(0, 0)
+
+    @Tx.inline
+    def _gm_emit_full_only(self, tile_linear, set_tile_coords):
+        FULL_GROUPS = Tx.meta_var(self._FULL_GROUPS)
+        GROUP_SIZE = Tx.meta_var(self._l2_group_size)
+        GROUP_SPAN = Tx.meta_var(self._l2_group_size * self._N_TILE_COLS)
+        if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN):
+            group_id: Tx.let = tile_linear // GROUP_SPAN
+            within_group: Tx.let = tile_linear % GROUP_SPAN
+            tile_row: Tx.let = group_id * GROUP_SIZE + (within_group % 
GROUP_SIZE)
+            tile_col: Tx.let = within_group // GROUP_SIZE
+            set_tile_coords(tile_row, tile_col)
+        else:
+            set_tile_coords(0, 0)
+
+    @Tx.inline
+    def _gm_emit_tail_only(self, tile_linear, set_tile_coords):
+        FULL_GROUPS = Tx.meta_var(self._FULL_GROUPS)
+        TAIL_ROWS = Tx.meta_var(self._TAIL_ROWS)
+        GROUP_SIZE = Tx.meta_var(self._l2_group_size)
+        GROUP_SPAN = Tx.meta_var(self._l2_group_size * self._N_TILE_COLS)
+        if TAIL_ROWS > 0:
+            rem: Tx.let = tile_linear - FULL_GROUPS * GROUP_SPAN
+            tile_row: Tx.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS)
+            tile_col: Tx.let = rem // TAIL_ROWS
+            set_tile_coords(tile_row, tile_col)
+        else:
+            set_tile_coords(0, 0)
+
+    @Tx.inline
+    def _gm_emit_full_and_tail(self, tile_linear, set_tile_coords):
+        FULL_GROUPS = Tx.meta_var(self._FULL_GROUPS)
+        TAIL_ROWS = Tx.meta_var(self._TAIL_ROWS)
+        GROUP_SIZE = Tx.meta_var(self._l2_group_size)
+        GROUP_SPAN = Tx.meta_var(self._l2_group_size * self._N_TILE_COLS)
+        if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN):
+            group_id: Tx.let = tile_linear // GROUP_SPAN
+            within_group: Tx.let = tile_linear % GROUP_SPAN
+            tile_row: Tx.let = group_id * GROUP_SIZE + (within_group % 
GROUP_SIZE)
+            tile_col: Tx.let = within_group // GROUP_SIZE
+            set_tile_coords(tile_row, tile_col)
+        elif TAIL_ROWS > 0:
+            rem: Tx.let = tile_linear - FULL_GROUPS * GROUP_SPAN
+            tile_row: Tx.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS)
+            tile_col: Tx.let = rem // TAIL_ROWS
+            set_tile_coords(tile_row, tile_col)
+        else:
+            set_tile_coords(0, 0)
+
+    @Tx.inline
+    def _update_serpentine(self, tile_linear, set_tile_coords):
+        """CUTLASS-style 2D block swizzle with serpentine traversal.
+
+        Algorithm:
+        1. Divide grid into swizzle_size x swizzle_size blocks
+        2. Within each block, visit tiles in row-major order
+        3. Blocks are traversed column by column (along N)
+        4. Within each column of blocks, use serpentine:
+           - Even columns: top to bottom
+           - Odd columns: bottom to top
+
+        This maximizes L2 reuse for both A and B matrices.
+        """
+        S = Tx.meta_var(self._l2_group_size)  # swizzle_size
+        M_BLOCKS = Tx.meta_var(self._M_BLOCKS)
+        N_BLOCKS = Tx.meta_var(self._N_BLOCKS)
+        BLOCK_SIZE = Tx.meta_var(self._BLOCK_SIZE)  # S * S
+        FULL_BLOCK_TILES = Tx.meta_var(self._FULL_BLOCK_TILES)
+        M_TILE_ROWS = Tx.meta_var(self._M_TILE_ROWS)
+        Tx.meta_var(self._N_TILE_COLS)
+        RESIDUAL_N = Tx.meta_var(self._RESIDUAL_N)
+        RESIDUAL_M = Tx.meta_var(self._RESIDUAL_M)
+
+        # Check if we're in the full block region
+        if (M_BLOCKS > 0) & (N_BLOCKS > 0) & (tile_linear < FULL_BLOCK_TILES):
+            # Which block (in linear order along columns of blocks)
+            block_linear: Tx.let = tile_linear // BLOCK_SIZE
+            within_block: Tx.let = tile_linear % BLOCK_SIZE
+
+            # Block column and row
+            block_col: Tx.let = block_linear // M_BLOCKS
+            block_row_raw: Tx.let = block_linear % M_BLOCKS
+
+            # Serpentine: odd columns go bottom-to-top
+            block_row: Tx.let = Tx.Select(
+                block_col % 2 == 0,
+                block_row_raw,
+                M_BLOCKS - 1 - block_row_raw
+            )
+
+            # Position within block (row-major within block)
+            local_row: Tx.let = within_block // S
+            local_col: Tx.let = within_block % S
+
+            tile_row: Tx.let = block_row * S + local_row
+            tile_col: Tx.let = block_col * S + local_col
+            set_tile_coords(tile_row, tile_col)
+
+        elif RESIDUAL_N > 0:
+            # Residual tiles in the rightmost partial column of blocks
+            # These are tiles where n >= N_BLOCKS * S
+            rem: Tx.let = tile_linear - FULL_BLOCK_TILES
+
+            # First handle the right residual strip (full M height, partial N 
width)
+            right_strip_tiles: Tx.let = M_TILE_ROWS * RESIDUAL_N
+            if rem < right_strip_tiles:
+                # Row-major within the right strip
+                tile_row: Tx.let = rem // RESIDUAL_N
+                tile_col: Tx.let = N_BLOCKS * S + (rem % RESIDUAL_N)
+                set_tile_coords(tile_row, tile_col)
+            elif RESIDUAL_M > 0:
+                # Bottom residual strip (already covered in right strip 
overlap)
+                # This handles corner case - shouldn't normally reach here
+                # as right strip already covers full M height
+                set_tile_coords(0, 0)
+            else:
+                set_tile_coords(0, 0)
+
+        elif RESIDUAL_M > 0:
+            # Bottom residual strip only (no right residual)
+            rem: Tx.let = tile_linear - FULL_BLOCK_TILES
+            bottom_strip_tiles: Tx.let = RESIDUAL_M * (N_BLOCKS * S)
+            if rem < bottom_strip_tiles:
+                tile_row: Tx.let = M_BLOCKS * S + (rem % RESIDUAL_M)
+                tile_col: Tx.let = rem // RESIDUAL_M
+                set_tile_coords(tile_row, tile_col)
+            else:
+                set_tile_coords(0, 0)
+        else:
+            # Fallback
+            set_tile_coords(0, 0)
+
+    @Tx.inline
+    def init(self, cluster_id):
+        """Initialize scheduler for a given cluster.
+
+        Parameters
+        ----------
+        cluster_id : int
+            The cluster's index (typically cta_idx // CLUSTER_SIZE)
+        """
+        self.linear_idx = cluster_id
+        self.tile_count = 0
+        self.update_current_m_n_idx(cluster_id)
+
+    @Tx.inline
+    def next_tile(self):
+        """Advance to the next tile for this cluster."""
+        self.linear_idx = self.linear_idx + self._num_clusters
+        self.tile_count = self.tile_count + 1
+        self.update_current_m_n_idx(self.linear_idx)
+
+    @Tx.inline
+    def next_tile_stride(self, stride: int):
+        """Advance by a custom stride (for non-standard scheduling)."""
+        self.linear_idx = self.linear_idx + stride
+        self.tile_count = self.tile_count + 1
+        self.update_current_m_n_idx(self.linear_idx)
+    # fmt: on
+
+    def valid(self):
+        """Check if this cluster has more tiles to process."""
+        return self.linear_idx < self._TOTAL_TILES
+
+
+class GroupMajor3D(BaseTileScheduler):
+    """
+    3D grouped-row scheduler (M,N,K) with tail handling on M.
+
+    Args
+    ----
+    prefix: str
+    m_tiles: int | T PrimExpr   # tiles along M (static or runtime)
+    n_tiles: int                # tiles along N (static)
+    k_tiles: int                # tiles along K (static)
+    group_rows: int             # rows per group along M
+    step: int = 1               # default stride for next_tile()
+    """
+
+    def __init__(
+        self, prefix: str, m_tiles, n_tiles: int, k_tiles: int, group_rows: 
int, step: int = 1
+    ):
+        super().__init__(prefix)
+        self._step = step
+        self.tile_idx = Tx.local_scalar("int32")
+        self.k_idx = Tx.local_scalar("int32")
+
+        # ---- constants / primexprs baked once ----
+        self._G = group_rows
+        self._N = n_tiles
+        self._K = k_tiles
+
+        if isinstance(m_tiles, int):
+            self._GROUPS = m_tiles // group_rows
+            self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows
+            self._SAFE_FINAL_ROWS = max(self._FINAL_ROWS, 1)
+            self._GROUP_SIZE = group_rows * n_tiles * k_tiles
+            self._TOTAL = m_tiles * n_tiles * k_tiles
+        else:
+            self._GROUPS = Tx.truncdiv(m_tiles, group_rows)
+            self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows
+            self._SAFE_FINAL_ROWS = Tx.max(self._FINAL_ROWS, 1)
+            self._GROUP_SIZE = self._G * self._N * self._K
+            self._TOTAL = m_tiles * n_tiles * k_tiles
+
+        # handy composites used in macro
+        self._FULL_BOUND = self._GROUPS * self._GROUP_SIZE
+        self._HAS_FULL = self._GROUPS > 0
+        self._HAS_TAIL = self._FINAL_ROWS > 0
+
+    # fmt: off
+    @Tx.inline
+    def update_current_m_n_idx(self, linear_idx):
+        # full-group formulas
+        full_m: Tx.let = Tx.floordiv(linear_idx, self._GROUP_SIZE) * self._G + 
Tx.floormod(
+            linear_idx, self._G
+        )
+        full_n: Tx.let = Tx.floormod(Tx.floordiv(linear_idx, self._G), self._N)
+        full_k: Tx.let = Tx.floordiv(Tx.floormod(linear_idx, 
self._GROUP_SIZE), self._G * self._N)
+
+        # tail formulas (relative to FULL_BOUND)
+        # Use _SAFE_FINAL_ROWS (max(FINAL_ROWS, 1)) to avoid divide-by-zero 
when there is no tail
+        rem: Tx.let = linear_idx - self._FULL_BOUND
+        tail_m: Tx.let = self._GROUPS * self._G + Tx.floormod(rem, 
self._SAFE_FINAL_ROWS)
+        tail_n: Tx.let = Tx.floordiv(rem, self._SAFE_FINAL_ROWS) % self._N
+        tail_k: Tx.let = Tx.floordiv(rem, self._SAFE_FINAL_ROWS * self._N)
+
+        # choose phase
+        if self._HAS_FULL & (linear_idx < self._FULL_BOUND):
+            self.m_idx = full_m
+            self.n_idx = full_n
+            self.k_idx = full_k
+        elif self._HAS_TAIL:
+            self.m_idx = tail_m
+            self.n_idx = tail_n
+            self.k_idx = tail_k
+        else:
+            self.m_idx = 0
+            self.n_idx = 0
+            self.k_idx = 0
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   In `GroupMajor3D.update_current_m_n_idx`, the tail formulas use 
`_SAFE_FINAL_ROWS` (which is `max(FINAL_ROWS, 1)`) to avoid division by zero. 
While this prevents the crash, it might lead to incorrect `m_idx` calculations 
if `linear_idx` somehow exceeds the expected bounds when no tail exists. Ensure 
that the `valid()` check is always strictly enforced before calling this update.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to