This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c211e482be [Triton] Support latest `triton.compile` interface (#17913)
c211e482be is described below
commit c211e482bee8ebd139e8c58624e18c988b4ec4b2
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu May 1 00:35:31 2025 -0400
[Triton] Support latest `triton.compile` interface (#17913)
This PR updates the Triton JIT compilation so that it supports
the latest `triton.compile` interface introduced in Triton 3.3.0.
---
python/tvm/script/ir_builder/tir/triton.py | 23 +++++++++++++++++++----
1 file changed, 19 insertions(+), 4 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/triton.py
b/python/tvm/script/ir_builder/tir/triton.py
index 2d37d93a6d..4f0f2cd520 100644
--- a/python/tvm/script/ir_builder/tir/triton.py
+++ b/python/tvm/script/ir_builder/tir/triton.py
@@ -16,13 +16,16 @@
# under the License.
"""Triton kernel integration with TIR"""
-from typing import Tuple, List, Union, Any, Dict
+from typing import Any, Dict, List, Tuple, Union
import triton
+from packaging import version
from triton.runtime.jit import type_canonicalisation_dict
+
from tvm import tir
-from tvm.topi.utils import get_const_int
from tvm.runtime import Module
+from tvm.topi.utils import get_const_int
+
from .external_kernel import BaseKernel
@@ -70,7 +73,12 @@ class TritonKernel(BaseKernel):
: len(grid)
]
launch_args = [num_warps * 32] + list(grid)
- kernel_arg_types = [arg.dtype for arg in kernel_args]
+ if version.parse(triton.__version__) >= version.parse("3.3.0"):
+ kernel_arg_types = [
+ arg.dtype if not isinstance(arg, int) else "int64" for arg in
kernel_args
+ ]
+ else:
+ kernel_arg_types = [arg.dtype for arg in kernel_args]
if triton_kernel.metadata.shared > 0:
# Add shared memory size to the launch arguments
launch_param_tags.append("tir.use_dyn_shared_memory")
@@ -98,6 +106,9 @@ class TritonKernel(BaseKernel):
for i, arg in enumerate(args):
if kernel_params[i].is_constexpr:
constants[kernel_params[i].name] = get_const_int(arg)
+ if version.parse(triton.__version__) >= version.parse("3.3.0"):
+ signature[kernel_params[i].name] = "constexpr"
+ kernel_args.append(arg)
continue
if arg.dtype == "handle":
assert isinstance(arg, tir.Var)
@@ -110,6 +121,10 @@ class TritonKernel(BaseKernel):
# TODO: Support default argument in the kernel
# TODO: Add specialization for aligned buffer pointers
- source = triton.compiler.ASTSource(fn=func, constants=constants,
signature=signature)
+ if version.parse(triton.__version__) >= version.parse("3.3.0"):
+ kwargs = {"constexprs": constants}
+ else:
+ kwargs = {"constants": constants}
+ source = triton.compiler.ASTSource(fn=func, signature=signature,
**kwargs)
compiled = triton.compiler.compile(source, options=kwargs)
return compiled, kernel_args