This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c211e482be [Triton] Support latest `triton.compile` interface (#17913)
c211e482be is described below

commit c211e482bee8ebd139e8c58624e18c988b4ec4b2
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu May 1 00:35:31 2025 -0400

    [Triton] Support latest `triton.compile` interface (#17913)
    
    This PR updates the Triton JIT compilation so that it supports
    the latest `triton.compile` interface introduced in Triton 3.3.0.
---
 python/tvm/script/ir_builder/tir/triton.py | 23 +++++++++++++++++++----
 1 file changed, 19 insertions(+), 4 deletions(-)

diff --git a/python/tvm/script/ir_builder/tir/triton.py 
b/python/tvm/script/ir_builder/tir/triton.py
index 2d37d93a6d..4f0f2cd520 100644
--- a/python/tvm/script/ir_builder/tir/triton.py
+++ b/python/tvm/script/ir_builder/tir/triton.py
@@ -16,13 +16,16 @@
 # under the License.
 """Triton kernel integration with TIR"""
 
-from typing import Tuple, List, Union, Any, Dict
+from typing import Any, Dict, List, Tuple, Union
 
 import triton
+from packaging import version
 from triton.runtime.jit import type_canonicalisation_dict
+
 from tvm import tir
-from tvm.topi.utils import get_const_int
 from tvm.runtime import Module
+from tvm.topi.utils import get_const_int
+
 from .external_kernel import BaseKernel
 
 
@@ -70,7 +73,12 @@ class TritonKernel(BaseKernel):
             : len(grid)
         ]
         launch_args = [num_warps * 32] + list(grid)
-        kernel_arg_types = [arg.dtype for arg in kernel_args]
+        if version.parse(triton.__version__) >= version.parse("3.3.0"):
+            kernel_arg_types = [
+                arg.dtype if not isinstance(arg, int) else "int64" for arg in 
kernel_args
+            ]
+        else:
+            kernel_arg_types = [arg.dtype for arg in kernel_args]
         if triton_kernel.metadata.shared > 0:
             # Add shared memory size to the launch arguments
             launch_param_tags.append("tir.use_dyn_shared_memory")
@@ -98,6 +106,9 @@ class TritonKernel(BaseKernel):
         for i, arg in enumerate(args):
             if kernel_params[i].is_constexpr:
                 constants[kernel_params[i].name] = get_const_int(arg)
+                if version.parse(triton.__version__) >= version.parse("3.3.0"):
+                    signature[kernel_params[i].name] = "constexpr"
+                    kernel_args.append(arg)
                 continue
             if arg.dtype == "handle":
                 assert isinstance(arg, tir.Var)
@@ -110,6 +121,10 @@ class TritonKernel(BaseKernel):
 
         # TODO: Support default argument in the kernel
         # TODO: Add specialization for aligned buffer pointers
-        source = triton.compiler.ASTSource(fn=func, constants=constants, 
signature=signature)
+        if version.parse(triton.__version__) >= version.parse("3.3.0"):
+            kwargs = {"constexprs": constants}
+        else:
+            kwargs = {"constants": constants}
+        source = triton.compiler.ASTSource(fn=func, signature=signature, 
**kwargs)
         compiled = triton.compiler.compile(source, options=kwargs)
         return compiled, kernel_args

Reply via email to