This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e201a6aa5 [TIRx] Preserve Triton call_kernel compile options (#19728)
3e201a6aa5 is described below

commit 3e201a6aa50b50942bf44297a669795f9a7c126d
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 11 07:21:14 2026 -0400

    [TIRx] Preserve Triton call_kernel compile options (#19728)
    
    Previously `_generate_triton_kernel` overwrote the user-provided kwargs
    with the constexpr dict before calling triton.compiler.compile, so
    options such as num_warps passed to T.call_kernel were silently dropped.
    Pass constexprs to ASTSource and forward the user kwargs as compile
    options.
    
    The pre-3.3 compatibility branches are removed in favor of an explicit
    minimum-version check: they were never exercised in CI (which does not
    install Triton), and Triton >= 3.3 has shipped with PyTorch since 2.7.
    
    The integration test now matches the actual lowering, where constexpr
    parameters (BLOCK_SIZE) appear as runtime kernel arguments in
    call_packed, and passes num_warps=8 expecting a thread extent of 256 to
    cover the option forwarding.
---
 python/tvm/tirx/script/builder/triton.py           | 25 ++++++++++------------
 .../python/contrib/test_tir_triton_integration.py  | 10 ++++++++-
 2 files changed, 20 insertions(+), 15 deletions(-)

diff --git a/python/tvm/tirx/script/builder/triton.py 
b/python/tvm/tirx/script/builder/triton.py
index 5c5d2b5567..14f2d92bab 100644
--- a/python/tvm/tirx/script/builder/triton.py
+++ b/python/tvm/tirx/script/builder/triton.py
@@ -29,6 +29,11 @@ from tvm.topi.utils import get_const_int
 
 from .external_kernel import BaseKernel
 
+if version.parse(triton.__version__) < version.parse("3.3.0"):
+    raise ImportError(
+        f"TIR Triton integration requires Triton >= 3.3.0, but found Triton 
{triton.__version__}"
+    )
+
 
 class TritonKernel(BaseKernel):
     """A kernel from Triton JIT function.
@@ -74,12 +79,9 @@ class TritonKernel(BaseKernel):
             : len(grid)
         ]
         launch_args = [num_warps * 32] + list(grid)
-        if version.parse(triton.__version__) >= version.parse("3.3.0"):
-            kernel_arg_types = [
-                arg.dtype if not isinstance(arg, int) else "int64" for arg in 
kernel_args
-            ]
-        else:
-            kernel_arg_types = [arg.dtype for arg in kernel_args]
+        kernel_arg_types = [
+            arg.dtype if not isinstance(arg, int) else "int64" for arg in 
kernel_args
+        ]
         if triton_kernel.metadata.shared > 0:
             # Add shared memory size to the launch arguments
             launch_param_tags.append("tirx.use_dyn_shared_memory")
@@ -107,9 +109,8 @@ class TritonKernel(BaseKernel):
         for i, arg in enumerate(args):
             if kernel_params[i].is_constexpr:
                 constants[kernel_params[i].name] = get_const_int(arg)
-                if version.parse(triton.__version__) >= version.parse("3.3.0"):
-                    signature[kernel_params[i].name] = "constexpr"
-                    kernel_args.append(arg)
+                signature[kernel_params[i].name] = "constexpr"
+                kernel_args.append(arg)
                 continue
             if arg.dtype == "handle":
                 assert isinstance(arg, tirx.Var)
@@ -122,10 +123,6 @@ class TritonKernel(BaseKernel):
 
         # TODO: Support default argument in the kernel
         # TODO: Add specialization for aligned buffer pointers
-        if version.parse(triton.__version__) >= version.parse("3.3.0"):
-            kwargs = {"constexprs": constants}
-        else:
-            kwargs = {"constants": constants}
-        source = triton.compiler.ASTSource(fn=func, signature=signature, 
**kwargs)
+        source = triton.compiler.ASTSource(fn=func, signature=signature, 
constexprs=constants)
         compiled = triton.compiler.compile(source, options=kwargs)
         return compiled, kernel_args
diff --git a/tests/python/contrib/test_tir_triton_integration.py 
b/tests/python/contrib/test_tir_triton_integration.py
index 29fb44adda..33d7962e8f 100644
--- a/tests/python/contrib/test_tir_triton_integration.py
+++ b/tests/python/contrib/test_tir_triton_integration.py
@@ -31,8 +31,12 @@ from tvm.script import tirx as T
 try:
     import triton
     import triton.language as tl
+    from packaging import version
 except ImportError:
     pytestmark = pytest.skip("Triton is not available", 
allow_module_level=True)
+else:
+    if version.parse(triton.__version__) < version.parse("3.3.0"):
+        pytestmark = pytest.skip("Triton >= 3.3.0 is required", 
allow_module_level=True)
 
 
 @tvm.testing.requires_cuda
@@ -76,6 +80,7 @@ def test_tir_triton_integration():
                     output.data,
                     m,
                     BLOCK_SIZE,
+                    num_warps=8,
                 )
 
         @R.function
@@ -86,6 +91,8 @@ def test_tir_triton_integration():
                 R.output(output)
             return output
 
+    # Constexpr parameters (BLOCK_SIZE) stay in the kernel arguments, and the
+    # thread extent is 256 because the kernel is compiled with num_warps=8.
     @I.ir_module(s_tir=True)
     class Parsed:
         @T.prim_func(s_tir=True)
@@ -103,7 +110,8 @@ def test_tir_triton_integration():
                     y.data,
                     output.data,
                     m,
-                    128,
+                    64,
+                    256,
                     (m + T.int64(64) - T.int64(1)) // T.int64(64),
                 )
 

Reply via email to