This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 36b274d7387ddcd7c8cfafe1f667e5bfb99cc785
Author: Bohan Hou <[email protected]>
AuthorDate: Tue May 19 01:11:13 2026 -0400

    feat(tvmscript): add @Tx.jit decorator, Tx.constexpr params, Tx.wg_reg_tile 
(#635)
    
    @Tx.jit captures the kernel and defers parsing until
    .specialize(**constexpr_kwargs); each constexpr param can have a default
    value. Use to drop the def tir_kernel(M, N, K): ... @Tx.prim_func def
    kernel(...) closure-wrapper pattern (saves one indent level).
    
    Also add Tx.wg_reg_tile(elem_per_thread, dtype) sugar for the recurring
    warpgroup register tile + wg_local_layout + scope='local' alloc pattern.
    
    Parser changes (TIRx-side only, core parser untouched):
    - entry.py: _ConstexprProxy sentinel, TIRJit class with specialize() and
      per-kwargs PrimFunc cache, jit() decorator factory.
    - parser.py: visit_function_def skips args whose annotation is the
      constexpr sentinel (their value flows via extra_vars from .specialize()).
      find_decorator_annotation also recognizes @Tx.jit alongside @Tx.prim_func
      for keyword passthrough (private, persistent, is_stir).
    
    Builder changes:
    - builder/ir.py: wg_reg_tile() wraps alloc_buffer with the (128, X)
      shape + wg_local_layout(X) + scope='local' pattern.
    
    Tests: tests/python/tirx/test_jit.py covers loop-bound constexpr,
    multi-constexpr in Buffer shape, body-expression constexpr, cache
    identity, default values, missing/extra arg errors, @Tx.jit + nested
    @Tx.inline interop, type-checking the returned PrimFunc.
---
 python/tvm/tirx/script/builder/ir.py      |  34 ++++-
 python/tvm/tirx/script/parser/__init__.py |  16 ++-
 python/tvm/tirx/script/parser/entry.py    | 175 +++++++++++++++++++++++
 python/tvm/tirx/script/parser/parser.py   |  12 +-
 tests/python/tirx/test_jit.py             | 225 ++++++++++++++++++++++++++++++
 5 files changed, 455 insertions(+), 7 deletions(-)

diff --git a/python/tvm/tirx/script/builder/ir.py 
b/python/tvm/tirx/script/builder/ir.py
index c906cde4d3..85a4d5e93f 100644
--- a/python/tvm/tirx/script/builder/ir.py
+++ b/python/tvm/tirx/script/builder/ir.py
@@ -85,7 +85,16 @@ from tvm.tirx.expr import (
     Sub,
 )
 from tvm.tirx.generic import cast
-from tvm.tirx.layout import ComposeLayout, Iter, Layout, R, S, SwizzleLayout, 
TileLayout
+from tvm.tirx.layout import (
+    ComposeLayout,
+    Iter,
+    Layout,
+    R,
+    S,
+    SwizzleLayout,
+    TileLayout,
+    wg_local_layout,
+)
 
 from . import _ffi_api, frame, utils
 from .external_kernel import call_kernel
@@ -890,6 +899,28 @@ def alloc_buffer(
     return buf
 
 
+def wg_reg_tile(elem_per_thread: int, dtype: str = "float32") -> Buffer:
+    """Warpgroup-wide ``(128, elem_per_thread)`` register tile in local scope.
+
+    Sugar for the recurring pattern::
+
+        Tx.alloc_buffer(
+            (128, elem_per_thread), dtype,
+            layout=wg_local_layout(elem_per_thread),
+            scope="local",
+        )
+
+    Used to stage a tcgen05 load: each of the 128 threads in a warpgroup
+    owns one row of ``elem_per_thread`` contiguous elements.
+    """
+    return alloc_buffer(
+        (128, elem_per_thread),
+        dtype,
+        layout=wg_local_layout(elem_per_thread),
+        scope="local",
+    )
+
+
 def sblock_alloc_buffer(
     shape: list[PrimExpr] | tuple[PrimExpr] | PrimExpr | Integral,
     dtype: str = "float32",
@@ -3726,6 +3757,7 @@ __all__ = [
     "sblock_attr",
     "alloc_buffer",
     "sblock_alloc_buffer",
+    "wg_reg_tile",
     "axis",
     "serial",
     "parallel",
diff --git a/python/tvm/tirx/script/parser/__init__.py 
b/python/tvm/tirx/script/parser/__init__.py
index 2ca0179a83..5f6b8d38f1 100644
--- a/python/tvm/tirx/script/parser/__init__.py
+++ b/python/tvm/tirx/script/parser/__init__.py
@@ -24,14 +24,24 @@ from tvm.tirx.script.builder import ir as _tir
 
 from . import operation as _operation
 from . import parser as _parser
-from .entry import Buffer, Ptr
+from .entry import Buffer, Ptr, constexpr
 
 if TYPE_CHECKING:
     # pylint: disable=invalid-name
     # Define prim_func and make it type check as static method
     # so most tvmscript won't trigger pylint error here.
     prim_func = staticmethod
+    jit = staticmethod
 else:
-    from .entry import inline, macro, prim_func
+    from .entry import inline, jit, macro, prim_func
 
-__all__ = _tir.__all__ + ["Buffer", "Ptr", "bool", "prim_func", "inline", 
"macro"]
+__all__ = _tir.__all__ + [
+    "Buffer",
+    "Ptr",
+    "bool",
+    "constexpr",
+    "inline",
+    "jit",
+    "macro",
+    "prim_func",
+]
diff --git a/python/tvm/tirx/script/parser/entry.py 
b/python/tvm/tirx/script/parser/entry.py
index 94b1e2afcb..8d9e0f0cf3 100644
--- a/python/tvm/tirx/script/parser/entry.py
+++ b/python/tvm/tirx/script/parser/entry.py
@@ -211,6 +211,164 @@ def inline(*args, definition_depth: int | None = None, 
defining_var_table=None)
 setattr(inline, "dispatch_token", "tir.inline")
 
 
+class TIRJit:
+    """Top-level kernel decorator with constexpr params + ``.specialize()``.
+
+    Parses the function body lazily: parsing is deferred until 
``.specialize()``
+    supplies concrete values for the params annotated as ``Tx.constexpr``. The
+    return type of ``.specialize()`` is a ``tvm.tirx.PrimFunc``, identical in
+    type to what ``@Tx.prim_func`` produces today.
+
+    Constexpr params are removed from the resulting PrimFunc's parameter list;
+    their values are baked into the IR (e.g. into ``Tx.Buffer((M, K), ...)``
+    shape annotations and into the body).
+    """
+
+    def __init__(
+        self,
+        func: Callable,
+        check_well_formed: bool = True,
+        is_stir: bool = False,
+        persistent: bool = False,
+        private: bool = False,
+    ) -> None:
+        self.func = func
+        self.check_well_formed = check_well_formed
+        self.is_stir = is_stir
+        self.persistent = persistent  # pylint: disable=unused-private-member
+        self.private = private  # pylint: disable=unused-private-member
+        # Resolved closure vars (computed once; the function itself is the
+        # capture point, so this never changes between specializations).
+        self._closure_vars: dict[str, Any] = 
utils.inspect_function_capture(func)
+        # Detect which params are marked Tx.constexpr. With PEP 563
+        # (``from __future__ import annotations``), each annotation is a
+        # string; we eval them one-by-one so a constexpr probe is not
+        # blocked by sibling annotations that reference yet-undefined names
+        # (e.g. ``A: Tx.Buffer((N,), ...)`` referencing constexpr ``N``).
+        raw_anns = getattr(func, "__annotations__", {}) or {}
+        eval_globals = {**func.__globals__, **self._closure_vars}
+        sig = inspect.signature(func)
+        constexpr_names: set[str] = set()
+        constexpr_defaults: dict[str, Any] = {}
+        for name, param in sig.parameters.items():
+            ann = raw_anns.get(name)
+            if isinstance(ann, str):
+                try:
+                    ann = eval(ann, eval_globals)  # pylint: disable=eval-used
+                except Exception:  # pylint: disable=broad-except
+                    ann = None
+            if ann is constexpr:
+                constexpr_names.add(name)
+                if param.default is not inspect.Parameter.empty:
+                    constexpr_defaults[name] = param.default
+        self.constexpr_names: frozenset[str] = frozenset(constexpr_names)
+        self.constexpr_defaults: dict[str, Any] = constexpr_defaults
+        self._cache: dict[tuple, PrimFunc] = {}
+
+    def specialize(self, **constexpr_kwargs) -> PrimFunc:
+        """Build a concrete PrimFunc by binding the constexpr params.
+
+        Parameters
+        ----------
+        **constexpr_kwargs
+            One value per ``Tx.constexpr``-annotated parameter. All such
+            parameters must be supplied; passing names that are not
+            constexpr-annotated is an error.
+
+        Returns
+        -------
+        PrimFunc
+            A concrete TIRx PrimFunc, identical in type to the output of
+            ``@Tx.prim_func``.
+        """
+        extra = constexpr_kwargs.keys() - self.constexpr_names
+        if extra:
+            raise TypeError(
+                f"{self.func.__name__}.specialize() got unexpected arg(s): "
+                f"{sorted(extra)} (constexpr params are: 
{sorted(self.constexpr_names)})"
+            )
+        effective = {**self.constexpr_defaults, **constexpr_kwargs}
+        missing = self.constexpr_names - effective.keys()
+        if missing:
+            raise TypeError(
+                f"{self.func.__name__}.specialize() missing constexpr arg(s) "
+                f"(no default provided): {sorted(missing)}"
+            )
+
+        try:
+            cache_key = tuple(sorted(effective.items()))
+            cached = self._cache.get(cache_key)
+        except TypeError as err:
+            raise TypeError(
+                f"{self.func.__name__}.specialize(): all constexpr values must 
"
+                f"be hashable (got: {effective!r})"
+            ) from err
+        if cached is not None:
+            return cached
+
+        extra_vars = {**self._closure_vars, **effective}
+        prim_func = parse(
+            self.func,
+            extra_vars,
+            check_well_formed=self.check_well_formed,
+            is_stir=self.is_stir,
+        )
+        setattr(prim_func, "__name__", self.func.__name__)
+        self._cache[cache_key] = prim_func
+        return prim_func
+
+
+def jit(
+    func: Callable | None = None,
+    private: bool = False,
+    check_well_formed: bool = True,
+    is_stir: bool = False,
+    persistent: bool = False,
+) -> "TIRJit | Callable":
+    """Decorator: capture the kernel and defer parsing until ``.specialize()``.
+
+    Use ``@Tx.jit`` (instead of ``@Tx.prim_func``) when the kernel takes
+    compile-time parameters annotated with ``Tx.constexpr``. The resulting
+    object exposes ``.specialize(**constexpr_kwargs)``, which returns a
+    ``tvm.tirx.PrimFunc``.
+
+    Example::
+
+        from tvm.script import tirx as Tx
+
+        @Tx.jit
+        def add(
+            A: Tx.Buffer((N,), "float32"),
+            B: Tx.Buffer((N,), "float32"),
+            *,
+            N: Tx.constexpr,
+        ):
+            with Tx.kernel():
+                ...
+
+        kernel = add.specialize(N=1024)  # returns a PrimFunc
+    """
+
+    def decorator_wrapper(func: Callable) -> TIRJit:
+        if not inspect.isfunction(func):
+            raise TypeError(f"Expect a function, but got: {func}")
+        return TIRJit(
+            func,
+            check_well_formed=check_well_formed,
+            is_stir=is_stir,
+            persistent=persistent,
+            private=private,
+        )
+
+    if func is not None:
+        return decorator_wrapper(func)
+    setattr(decorator_wrapper, "dispatch_token", "tirx")
+    return decorator_wrapper
+
+
+setattr(jit, "dispatch_token", "tirx")
+
+
 class TIRMacro(ScriptMacro):
     """Specialization of the ScriptMacro class for TIR.
 
@@ -342,5 +500,22 @@ class PtrProxy:
         return self(*keys)
 
 
+class _ConstexprProxy:
+    """Sentinel marker for compile-time (specialization-time) parameters.
+
+    Used as a parameter annotation in ``@Tx.jit`` decorated functions to mark
+    a parameter as constexpr — its value is supplied to 
``.specialize(**kwargs)``
+    rather than at call time, and it is removed from the generated PrimFunc's
+    runtime parameter list.
+    """
+
+    def __or__(self, other):
+        return self
+
+    def __ror__(self, other):
+        return self
+
+
 Buffer = BufferProxy()  # pylint: disable=invalid-name
 Ptr = PtrProxy()  # pylint: disable=invalid-name
+constexpr = _ConstexprProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/tirx/script/parser/parser.py 
b/python/tvm/tirx/script/parser/parser.py
index da1f59e96c..e7556db7f3 100644
--- a/python/tvm/tirx/script/parser/parser.py
+++ b/python/tvm/tirx/script/parser/parser.py
@@ -34,6 +34,7 @@ from tvm.tirx.script import builder as T
 from tvm.tirx.script.builder.ir import name_meta_class_value
 from tvm.tirx.stmt import BufferRegion
 
+from .entry import constexpr as _constexpr_sentinel
 from .entry import inline
 
 
@@ -241,9 +242,9 @@ def find_decorator_annotation(node: doc.FunctionDef, 
annotation: str, default: b
     Check the value of given annotation (argument name) in the prim_func 
decorator.
     Returns the value of the annotation if present, otherwise giving the 
default value.
     """
-    # look for the named argument in the prim_func decorator
+    # look for the named argument in the prim_func / jit decorator
     for dec in node.decorator_list:
-        if not isinstance(dec, doc.Call) or dec.func.attr != "prim_func":
+        if not isinstance(dec, doc.Call) or dec.func.attr not in ("prim_func", 
"jit"):
             continue
         for keyword in dec.keywords:
             if keyword.arg == annotation:
@@ -633,12 +634,17 @@ def visit_function_def(self: Parser, node: 
doc.FunctionDef) -> None:
                         self.report_error(arg, "Type annotation required for 
function parameters.")
                     try:
                         ann = self.eval_expr(arg.annotation)
-                        if callable(ann):
+                        if callable(ann) and ann is not _constexpr_sentinel:
                             ann = ann()
                     except Exception:  # pylint: disable=broad-except
                         ann = func_annotation.get(arg.arg, None)
                         if ann is None:
                             raise
+                    if ann is _constexpr_sentinel:
+                        # Tx.constexpr param: value was bound in extra_vars by
+                        # TIRJit.specialize() and lives in an outer var_table
+                        # frame; do not register a runtime PrimFunc param.
+                        continue
                     param = T.arg(arg.arg, ann)
                     self.var_table.add(arg.arg, param)
                 self.visit_body(node.body)
diff --git a/tests/python/tirx/test_jit.py b/tests/python/tirx/test_jit.py
new file mode 100644
index 0000000000..637563867c
--- /dev/null
+++ b/tests/python/tirx/test_jit.py
@@ -0,0 +1,225 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: F821
+"""Tests for ``@Tx.jit`` + ``Tx.constexpr``."""
+
+from __future__ import annotations
+
+import pytest
+
+import tvm
+from tvm.ir import assert_structural_equal
+from tvm.script import tirx as Tx
+
+
+def test_int_constexpr_specializes_loop_bound():
+    @Tx.jit(private=True)
+    def add(
+        A: Tx.Buffer((N,), "int32"),
+        B: Tx.Buffer((N,), "int32"),
+        C: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+    ):
+        for i in range(N):
+            C[i] = A[i] + B[i]
+
+    @Tx.prim_func(private=True)
+    def expected(
+        A: Tx.Buffer((128,), "int32"),
+        B: Tx.Buffer((128,), "int32"),
+        C: Tx.Buffer((128,), "int32"),
+    ):
+        for i in range(128):
+            C[i] = A[i] + B[i]
+
+    assert_structural_equal(add.specialize(N=128), expected, 
map_free_vars=True)
+
+
+def test_constexpr_in_2d_buffer_shape():
+    @Tx.jit(private=True)
+    def matadd(
+        A: Tx.Buffer((M, K), "int32"),
+        B: Tx.Buffer((M, K), "int32"),
+        C: Tx.Buffer((M, K), "int32"),
+        *,
+        M: Tx.constexpr,
+        K: Tx.constexpr,
+    ):
+        for m in range(M):
+            for k in range(K):
+                C[m, k] = A[m, k] + B[m, k]
+
+    @Tx.prim_func(private=True)
+    def expected(
+        A: Tx.Buffer((4, 8), "int32"),
+        B: Tx.Buffer((4, 8), "int32"),
+        C: Tx.Buffer((4, 8), "int32"),
+    ):
+        for m in range(4):
+            for k in range(8):
+                C[m, k] = A[m, k] + B[m, k]
+
+    assert_structural_equal(matadd.specialize(M=4, K=8), expected, 
map_free_vars=True)
+
+
+def test_constexpr_in_body_expression():
+    @Tx.jit(private=True)
+    def scaled_copy(
+        A: Tx.Buffer((N,), "int32"),
+        B: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+        SCALE: Tx.constexpr,
+    ):
+        for i in range(N):
+            B[i] = A[i] * SCALE
+
+    @Tx.prim_func(private=True)
+    def expected(
+        A: Tx.Buffer((16,), "int32"),
+        B: Tx.Buffer((16,), "int32"),
+    ):
+        for i in range(16):
+            B[i] = A[i] * 3
+
+    assert_structural_equal(scaled_copy.specialize(N=16, SCALE=3), expected, 
map_free_vars=True)
+
+
+def test_specialize_cache_returns_same_instance():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+    ):
+        for i in range(N):
+            A[i] = 0
+
+    a = k.specialize(N=8)
+    b = k.specialize(N=8)
+    assert a is b
+
+
+def test_specialize_different_args_produce_different_funcs():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+    ):
+        for i in range(N):
+            A[i] = 0
+
+    assert k.specialize(N=8) is not k.specialize(N=16)
+
+
+def test_specialize_missing_constexpr_raises():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+        SCALE: Tx.constexpr,
+    ):
+        for i in range(N):
+            A[i] = SCALE
+
+    with pytest.raises(TypeError, match="missing"):
+        k.specialize(N=8)
+
+
+def test_specialize_extra_kwarg_raises():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+    ):
+        for i in range(N):
+            A[i] = 0
+
+    with pytest.raises(TypeError, match="unexpected"):
+        k.specialize(N=8, BOGUS=42)
+
+
+def test_jit_kernel_with_nested_inline_helper():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+    ):
+        @Tx.inline
+        def double(x):
+            A[x] = A[x] * 2
+
+        for i in range(N):
+            double(i)
+
+    @Tx.prim_func(private=True)
+    def expected(
+        A: Tx.Buffer((4,), "int32"),
+    ):
+        for i in range(4):
+            A[i] = A[i] * 2
+
+    assert_structural_equal(k.specialize(N=4), expected, map_free_vars=True)
+
+
+def test_constexpr_default_value():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+        SCALE: Tx.constexpr = 7,
+    ):
+        for i in range(N):
+            A[i] = SCALE
+
+    @Tx.prim_func(private=True)
+    def expected(
+        A: Tx.Buffer((8,), "int32"),
+    ):
+        for i in range(8):
+            A[i] = 7
+
+    assert_structural_equal(k.specialize(N=8), expected, map_free_vars=True)
+    # Override the default
+    overridden = k.specialize(N=8, SCALE=99)
+    assert k.specialize(N=8) is not overridden
+
+
+def test_specialize_returns_primfunc():
+    @Tx.jit(private=True)
+    def k(
+        A: Tx.Buffer((N,), "int32"),
+        *,
+        N: Tx.constexpr,
+    ):
+        for i in range(N):
+            A[i] = 0
+
+    spec = k.specialize(N=8)
+    assert isinstance(spec, tvm.tirx.PrimFunc)
+    # Specialized PrimFunc has only the runtime params (constexpr stripped).
+    assert len(spec.params) == 1
+
+
+if __name__ == "__main__":
+    pytest.main([__file__, "-v"])

Reply via email to