This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 36b274d7387ddcd7c8cfafe1f667e5bfb99cc785 Author: Bohan Hou <[email protected]> AuthorDate: Tue May 19 01:11:13 2026 -0400 feat(tvmscript): add @Tx.jit decorator, Tx.constexpr params, Tx.wg_reg_tile (#635) @Tx.jit captures the kernel and defers parsing until .specialize(**constexpr_kwargs); each constexpr param can have a default value. Use to drop the def tir_kernel(M, N, K): ... @Tx.prim_func def kernel(...) closure-wrapper pattern (saves one indent level). Also add Tx.wg_reg_tile(elem_per_thread, dtype) sugar for the recurring warpgroup register tile + wg_local_layout + scope='local' alloc pattern. Parser changes (TIRx-side only, core parser untouched): - entry.py: _ConstexprProxy sentinel, TIRJit class with specialize() and per-kwargs PrimFunc cache, jit() decorator factory. - parser.py: visit_function_def skips args whose annotation is the constexpr sentinel (their value flows via extra_vars from .specialize()). find_decorator_annotation also recognizes @Tx.jit alongside @Tx.prim_func for keyword passthrough (private, persistent, is_stir). Builder changes: - builder/ir.py: wg_reg_tile() wraps alloc_buffer with the (128, X) shape + wg_local_layout(X) + scope='local' pattern. Tests: tests/python/tirx/test_jit.py covers loop-bound constexpr, multi-constexpr in Buffer shape, body-expression constexpr, cache identity, default values, missing/extra arg errors, @Tx.jit + nested @Tx.inline interop, type-checking the returned PrimFunc. --- python/tvm/tirx/script/builder/ir.py | 34 ++++- python/tvm/tirx/script/parser/__init__.py | 16 ++- python/tvm/tirx/script/parser/entry.py | 175 +++++++++++++++++++++++ python/tvm/tirx/script/parser/parser.py | 12 +- tests/python/tirx/test_jit.py | 225 ++++++++++++++++++++++++++++++ 5 files changed, 455 insertions(+), 7 deletions(-) diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index c906cde4d3..85a4d5e93f 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -85,7 +85,16 @@ from tvm.tirx.expr import ( Sub, ) from tvm.tirx.generic import cast -from tvm.tirx.layout import ComposeLayout, Iter, Layout, R, S, SwizzleLayout, TileLayout +from tvm.tirx.layout import ( + ComposeLayout, + Iter, + Layout, + R, + S, + SwizzleLayout, + TileLayout, + wg_local_layout, +) from . import _ffi_api, frame, utils from .external_kernel import call_kernel @@ -890,6 +899,28 @@ def alloc_buffer( return buf +def wg_reg_tile(elem_per_thread: int, dtype: str = "float32") -> Buffer: + """Warpgroup-wide ``(128, elem_per_thread)`` register tile in local scope. + + Sugar for the recurring pattern:: + + Tx.alloc_buffer( + (128, elem_per_thread), dtype, + layout=wg_local_layout(elem_per_thread), + scope="local", + ) + + Used to stage a tcgen05 load: each of the 128 threads in a warpgroup + owns one row of ``elem_per_thread`` contiguous elements. + """ + return alloc_buffer( + (128, elem_per_thread), + dtype, + layout=wg_local_layout(elem_per_thread), + scope="local", + ) + + def sblock_alloc_buffer( shape: list[PrimExpr] | tuple[PrimExpr] | PrimExpr | Integral, dtype: str = "float32", @@ -3726,6 +3757,7 @@ __all__ = [ "sblock_attr", "alloc_buffer", "sblock_alloc_buffer", + "wg_reg_tile", "axis", "serial", "parallel", diff --git a/python/tvm/tirx/script/parser/__init__.py b/python/tvm/tirx/script/parser/__init__.py index 2ca0179a83..5f6b8d38f1 100644 --- a/python/tvm/tirx/script/parser/__init__.py +++ b/python/tvm/tirx/script/parser/__init__.py @@ -24,14 +24,24 @@ from tvm.tirx.script.builder import ir as _tir from . import operation as _operation from . import parser as _parser -from .entry import Buffer, Ptr +from .entry import Buffer, Ptr, constexpr if TYPE_CHECKING: # pylint: disable=invalid-name # Define prim_func and make it type check as static method # so most tvmscript won't trigger pylint error here. prim_func = staticmethod + jit = staticmethod else: - from .entry import inline, macro, prim_func + from .entry import inline, jit, macro, prim_func -__all__ = _tir.__all__ + ["Buffer", "Ptr", "bool", "prim_func", "inline", "macro"] +__all__ = _tir.__all__ + [ + "Buffer", + "Ptr", + "bool", + "constexpr", + "inline", + "jit", + "macro", + "prim_func", +] diff --git a/python/tvm/tirx/script/parser/entry.py b/python/tvm/tirx/script/parser/entry.py index 94b1e2afcb..8d9e0f0cf3 100644 --- a/python/tvm/tirx/script/parser/entry.py +++ b/python/tvm/tirx/script/parser/entry.py @@ -211,6 +211,164 @@ def inline(*args, definition_depth: int | None = None, defining_var_table=None) setattr(inline, "dispatch_token", "tir.inline") +class TIRJit: + """Top-level kernel decorator with constexpr params + ``.specialize()``. + + Parses the function body lazily: parsing is deferred until ``.specialize()`` + supplies concrete values for the params annotated as ``Tx.constexpr``. The + return type of ``.specialize()`` is a ``tvm.tirx.PrimFunc``, identical in + type to what ``@Tx.prim_func`` produces today. + + Constexpr params are removed from the resulting PrimFunc's parameter list; + their values are baked into the IR (e.g. into ``Tx.Buffer((M, K), ...)`` + shape annotations and into the body). + """ + + def __init__( + self, + func: Callable, + check_well_formed: bool = True, + is_stir: bool = False, + persistent: bool = False, + private: bool = False, + ) -> None: + self.func = func + self.check_well_formed = check_well_formed + self.is_stir = is_stir + self.persistent = persistent # pylint: disable=unused-private-member + self.private = private # pylint: disable=unused-private-member + # Resolved closure vars (computed once; the function itself is the + # capture point, so this never changes between specializations). + self._closure_vars: dict[str, Any] = utils.inspect_function_capture(func) + # Detect which params are marked Tx.constexpr. With PEP 563 + # (``from __future__ import annotations``), each annotation is a + # string; we eval them one-by-one so a constexpr probe is not + # blocked by sibling annotations that reference yet-undefined names + # (e.g. ``A: Tx.Buffer((N,), ...)`` referencing constexpr ``N``). + raw_anns = getattr(func, "__annotations__", {}) or {} + eval_globals = {**func.__globals__, **self._closure_vars} + sig = inspect.signature(func) + constexpr_names: set[str] = set() + constexpr_defaults: dict[str, Any] = {} + for name, param in sig.parameters.items(): + ann = raw_anns.get(name) + if isinstance(ann, str): + try: + ann = eval(ann, eval_globals) # pylint: disable=eval-used + except Exception: # pylint: disable=broad-except + ann = None + if ann is constexpr: + constexpr_names.add(name) + if param.default is not inspect.Parameter.empty: + constexpr_defaults[name] = param.default + self.constexpr_names: frozenset[str] = frozenset(constexpr_names) + self.constexpr_defaults: dict[str, Any] = constexpr_defaults + self._cache: dict[tuple, PrimFunc] = {} + + def specialize(self, **constexpr_kwargs) -> PrimFunc: + """Build a concrete PrimFunc by binding the constexpr params. + + Parameters + ---------- + **constexpr_kwargs + One value per ``Tx.constexpr``-annotated parameter. All such + parameters must be supplied; passing names that are not + constexpr-annotated is an error. + + Returns + ------- + PrimFunc + A concrete TIRx PrimFunc, identical in type to the output of + ``@Tx.prim_func``. + """ + extra = constexpr_kwargs.keys() - self.constexpr_names + if extra: + raise TypeError( + f"{self.func.__name__}.specialize() got unexpected arg(s): " + f"{sorted(extra)} (constexpr params are: {sorted(self.constexpr_names)})" + ) + effective = {**self.constexpr_defaults, **constexpr_kwargs} + missing = self.constexpr_names - effective.keys() + if missing: + raise TypeError( + f"{self.func.__name__}.specialize() missing constexpr arg(s) " + f"(no default provided): {sorted(missing)}" + ) + + try: + cache_key = tuple(sorted(effective.items())) + cached = self._cache.get(cache_key) + except TypeError as err: + raise TypeError( + f"{self.func.__name__}.specialize(): all constexpr values must " + f"be hashable (got: {effective!r})" + ) from err + if cached is not None: + return cached + + extra_vars = {**self._closure_vars, **effective} + prim_func = parse( + self.func, + extra_vars, + check_well_formed=self.check_well_formed, + is_stir=self.is_stir, + ) + setattr(prim_func, "__name__", self.func.__name__) + self._cache[cache_key] = prim_func + return prim_func + + +def jit( + func: Callable | None = None, + private: bool = False, + check_well_formed: bool = True, + is_stir: bool = False, + persistent: bool = False, +) -> "TIRJit | Callable": + """Decorator: capture the kernel and defer parsing until ``.specialize()``. + + Use ``@Tx.jit`` (instead of ``@Tx.prim_func``) when the kernel takes + compile-time parameters annotated with ``Tx.constexpr``. The resulting + object exposes ``.specialize(**constexpr_kwargs)``, which returns a + ``tvm.tirx.PrimFunc``. + + Example:: + + from tvm.script import tirx as Tx + + @Tx.jit + def add( + A: Tx.Buffer((N,), "float32"), + B: Tx.Buffer((N,), "float32"), + *, + N: Tx.constexpr, + ): + with Tx.kernel(): + ... + + kernel = add.specialize(N=1024) # returns a PrimFunc + """ + + def decorator_wrapper(func: Callable) -> TIRJit: + if not inspect.isfunction(func): + raise TypeError(f"Expect a function, but got: {func}") + return TIRJit( + func, + check_well_formed=check_well_formed, + is_stir=is_stir, + persistent=persistent, + private=private, + ) + + if func is not None: + return decorator_wrapper(func) + setattr(decorator_wrapper, "dispatch_token", "tirx") + return decorator_wrapper + + +setattr(jit, "dispatch_token", "tirx") + + class TIRMacro(ScriptMacro): """Specialization of the ScriptMacro class for TIR. @@ -342,5 +500,22 @@ class PtrProxy: return self(*keys) +class _ConstexprProxy: + """Sentinel marker for compile-time (specialization-time) parameters. + + Used as a parameter annotation in ``@Tx.jit`` decorated functions to mark + a parameter as constexpr — its value is supplied to ``.specialize(**kwargs)`` + rather than at call time, and it is removed from the generated PrimFunc's + runtime parameter list. + """ + + def __or__(self, other): + return self + + def __ror__(self, other): + return self + + Buffer = BufferProxy() # pylint: disable=invalid-name Ptr = PtrProxy() # pylint: disable=invalid-name +constexpr = _ConstexprProxy() # pylint: disable=invalid-name diff --git a/python/tvm/tirx/script/parser/parser.py b/python/tvm/tirx/script/parser/parser.py index da1f59e96c..e7556db7f3 100644 --- a/python/tvm/tirx/script/parser/parser.py +++ b/python/tvm/tirx/script/parser/parser.py @@ -34,6 +34,7 @@ from tvm.tirx.script import builder as T from tvm.tirx.script.builder.ir import name_meta_class_value from tvm.tirx.stmt import BufferRegion +from .entry import constexpr as _constexpr_sentinel from .entry import inline @@ -241,9 +242,9 @@ def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: b Check the value of given annotation (argument name) in the prim_func decorator. Returns the value of the annotation if present, otherwise giving the default value. """ - # look for the named argument in the prim_func decorator + # look for the named argument in the prim_func / jit decorator for dec in node.decorator_list: - if not isinstance(dec, doc.Call) or dec.func.attr != "prim_func": + if not isinstance(dec, doc.Call) or dec.func.attr not in ("prim_func", "jit"): continue for keyword in dec.keywords: if keyword.arg == annotation: @@ -633,12 +634,17 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.report_error(arg, "Type annotation required for function parameters.") try: ann = self.eval_expr(arg.annotation) - if callable(ann): + if callable(ann) and ann is not _constexpr_sentinel: ann = ann() except Exception: # pylint: disable=broad-except ann = func_annotation.get(arg.arg, None) if ann is None: raise + if ann is _constexpr_sentinel: + # Tx.constexpr param: value was bound in extra_vars by + # TIRJit.specialize() and lives in an outer var_table + # frame; do not register a runtime PrimFunc param. + continue param = T.arg(arg.arg, ann) self.var_table.add(arg.arg, param) self.visit_body(node.body) diff --git a/tests/python/tirx/test_jit.py b/tests/python/tirx/test_jit.py new file mode 100644 index 0000000000..637563867c --- /dev/null +++ b/tests/python/tirx/test_jit.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ruff: noqa: F821 +"""Tests for ``@Tx.jit`` + ``Tx.constexpr``.""" + +from __future__ import annotations + +import pytest + +import tvm +from tvm.ir import assert_structural_equal +from tvm.script import tirx as Tx + + +def test_int_constexpr_specializes_loop_bound(): + @Tx.jit(private=True) + def add( + A: Tx.Buffer((N,), "int32"), + B: Tx.Buffer((N,), "int32"), + C: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + ): + for i in range(N): + C[i] = A[i] + B[i] + + @Tx.prim_func(private=True) + def expected( + A: Tx.Buffer((128,), "int32"), + B: Tx.Buffer((128,), "int32"), + C: Tx.Buffer((128,), "int32"), + ): + for i in range(128): + C[i] = A[i] + B[i] + + assert_structural_equal(add.specialize(N=128), expected, map_free_vars=True) + + +def test_constexpr_in_2d_buffer_shape(): + @Tx.jit(private=True) + def matadd( + A: Tx.Buffer((M, K), "int32"), + B: Tx.Buffer((M, K), "int32"), + C: Tx.Buffer((M, K), "int32"), + *, + M: Tx.constexpr, + K: Tx.constexpr, + ): + for m in range(M): + for k in range(K): + C[m, k] = A[m, k] + B[m, k] + + @Tx.prim_func(private=True) + def expected( + A: Tx.Buffer((4, 8), "int32"), + B: Tx.Buffer((4, 8), "int32"), + C: Tx.Buffer((4, 8), "int32"), + ): + for m in range(4): + for k in range(8): + C[m, k] = A[m, k] + B[m, k] + + assert_structural_equal(matadd.specialize(M=4, K=8), expected, map_free_vars=True) + + +def test_constexpr_in_body_expression(): + @Tx.jit(private=True) + def scaled_copy( + A: Tx.Buffer((N,), "int32"), + B: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + SCALE: Tx.constexpr, + ): + for i in range(N): + B[i] = A[i] * SCALE + + @Tx.prim_func(private=True) + def expected( + A: Tx.Buffer((16,), "int32"), + B: Tx.Buffer((16,), "int32"), + ): + for i in range(16): + B[i] = A[i] * 3 + + assert_structural_equal(scaled_copy.specialize(N=16, SCALE=3), expected, map_free_vars=True) + + +def test_specialize_cache_returns_same_instance(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + ): + for i in range(N): + A[i] = 0 + + a = k.specialize(N=8) + b = k.specialize(N=8) + assert a is b + + +def test_specialize_different_args_produce_different_funcs(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + ): + for i in range(N): + A[i] = 0 + + assert k.specialize(N=8) is not k.specialize(N=16) + + +def test_specialize_missing_constexpr_raises(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + SCALE: Tx.constexpr, + ): + for i in range(N): + A[i] = SCALE + + with pytest.raises(TypeError, match="missing"): + k.specialize(N=8) + + +def test_specialize_extra_kwarg_raises(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + ): + for i in range(N): + A[i] = 0 + + with pytest.raises(TypeError, match="unexpected"): + k.specialize(N=8, BOGUS=42) + + +def test_jit_kernel_with_nested_inline_helper(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + ): + @Tx.inline + def double(x): + A[x] = A[x] * 2 + + for i in range(N): + double(i) + + @Tx.prim_func(private=True) + def expected( + A: Tx.Buffer((4,), "int32"), + ): + for i in range(4): + A[i] = A[i] * 2 + + assert_structural_equal(k.specialize(N=4), expected, map_free_vars=True) + + +def test_constexpr_default_value(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + SCALE: Tx.constexpr = 7, + ): + for i in range(N): + A[i] = SCALE + + @Tx.prim_func(private=True) + def expected( + A: Tx.Buffer((8,), "int32"), + ): + for i in range(8): + A[i] = 7 + + assert_structural_equal(k.specialize(N=8), expected, map_free_vars=True) + # Override the default + overridden = k.specialize(N=8, SCALE=99) + assert k.specialize(N=8) is not overridden + + +def test_specialize_returns_primfunc(): + @Tx.jit(private=True) + def k( + A: Tx.Buffer((N,), "int32"), + *, + N: Tx.constexpr, + ): + for i in range(N): + A[i] = 0 + + spec = k.specialize(N=8) + assert isinstance(spec, tvm.tirx.PrimFunc) + # Specialized PrimFunc has only the runtime params (constexpr stripped). + assert len(spec.params) == 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])
