This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc2c5a28c9 [TVMScript][TIR] Add source kernel intetration via
call_kernel (#17434)
dc2c5a28c9 is described below
commit dc2c5a28c9132aa314cca237ffbe32e1bad8dd2a
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Oct 3 06:50:45 2024 -0700
[TVMScript][TIR] Add source kernel intetration via call_kernel (#17434)
* [TVMScript][TIR] Add source kernel intetration via call_kernel
* lint
* lint
---
.../tvm/script/ir_builder/tir/external_kernel.py | 62 ++++++++++++-
tests/python/relax/test_tir_call_source_kernel.py | 100 +++++++++++++++++++++
2 files changed, 160 insertions(+), 2 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py
b/python/tvm/script/ir_builder/tir/external_kernel.py
index 8c2467fad3..405e1e6cbf 100644
--- a/python/tvm/script/ir_builder/tir/external_kernel.py
+++ b/python/tvm/script/ir_builder/tir/external_kernel.py
@@ -18,14 +18,16 @@
import json
import logging
import tempfile
+from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from tvm import __version__ as tvm_version
from tvm import tir
-from tvm.runtime import Module, load_module
+from tvm.runtime import Module, load_module, const
+from tvm.contrib import nvcc
-class BaseKernel:
+class BaseKernel: # pylint: disable=too-few-public-methods
"""Base class for external kernels."""
def compile_to_device_module(
@@ -91,6 +93,60 @@ class BaseKernel:
return kernel_module
+class SourceKernel(BaseKernel): # pylint: disable=too-few-public-methods
+ """A kernel from source code."""
+
+ def __init__(self, source_code: str):
+ self.source_code = source_code
+
+ def compile_to_device_module( # pylint: disable=arguments-differ
+ self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any],
**kwargs: Dict[str, Any]
+ ) -> Tuple[str, Module, List[Any]]:
+ """Compile the kernel to a device module."""
+ from tvm.relax.frontend.nn import SourceModule # pylint:
disable=import-outside-toplevel
+
+ kernel_name = kwargs["kernel_name"]
+ assert len(grid) == 2, (
+ "grid should be two list of integers, representing the dimension
of "
+ "['blockIdx.x', 'blockIdx.y', 'blockIdx.z'] and "
+ "['threadIdx.x', 'threadIdx.y', 'threadIdx.z']"
+ )
+ assert isinstance(grid[0], (list, tuple)) and isinstance(grid[1],
(list, tuple))
+ launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][:
len(grid[0])] + [
+ "threadIdx.x",
+ "threadIdx.y",
+ "threadIdx.z",
+ ][: len(grid[1])]
+ runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg
in args]
+ kernel_arg_types = [arg.dtype for arg in runtime_args]
+ runtime_args = runtime_args + list(grid[0]) + list(grid[1])
+
+ # Reuse compilation path from SourceModule
+ compile_options = SourceModule.get_compile_options("cu")
+ source_code = self.source_code
+ try:
+ source_path = Path(source_code)
+ if source_path.is_file():
+ with open(source_path, "r") as f:
+ source_code = f.read()
+ except: # pylint: disable=bare-except
+ pass
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ ptx_path = f"{temp_dir}/{kernel_name}.ptx"
+ nvcc.compile_cuda(
+ source_code, target_format="ptx", options=compile_options,
path_target=ptx_path
+ )
+ with open(ptx_path, "r") as f:
+ ptx = f.read()
+
+ kernel_module = self._create_cuda_module(
+ ptx, kernel_arg_types, launch_param_tags, kernel_name
+ )
+
+ return kernel_name, kernel_module, runtime_args
+
+
def call_kernel(
kernel,
launch_args: List[Union[int, tir.PrimExpr, List[Union[int,
tir.PrimExpr]]]],
@@ -123,6 +179,8 @@ def call_kernel(
from .triton import TritonKernel # pylint:
disable=import-outside-toplevel
kernel = TritonKernel(kernel)
+ elif kernel_type == "builtins.str":
+ kernel = SourceKernel(kernel)
else:
raise ValueError("Unsupported kernel type {}".format(kernel_type))
diff --git a/tests/python/relax/test_tir_call_source_kernel.py
b/tests/python/relax/test_tir_call_source_kernel.py
new file mode 100644
index 0000000000..9a877ad35f
--- /dev/null
+++ b/tests/python/relax/test_tir_call_source_kernel.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import tir as T, ir as I, relax as R
+
+add_cuda_source = """
+extern "C" __global__ void add_kernel(float* x, float* y, float* output, int
n_elements) {
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+ if (i < n_elements) {
+ output[i] = x[i] + y[i];
+ }
+}
+"""
+
+
[email protected]_cuda
+def test_tir_call_source_kernel():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def add(x_handle: T.handle, y_handle: T.handle, output_handle:
T.handle) -> None:
+ T.func_attr({"global_symbol": "add"})
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,), "float32")
+ y = T.match_buffer(y_handle, (m,), "float32")
+ output = T.match_buffer(output_handle, (m,), "float32")
+ with T.block("root"):
+ T.reads(x[0:m], y[0:m])
+ T.writes(output[0:m])
+ BLOCK_SIZE = T.meta_var(64)
+ T.call_kernel(
+ add_cuda_source,
+ ((T.ceildiv(m, BLOCK_SIZE),), (BLOCK_SIZE,)),
+ x.data,
+ y.data,
+ output.data,
+ m,
+ kernel_name="add_kernel",
+ )
+
+ @R.function
+ def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",),
"float32")):
+ m = T.int64()
+ with R.dataflow():
+ output = R.call_tir(Module.add, [x, y],
relax.TensorStructInfo((m,), "float32"))
+ R.output(output)
+ return output
+
+ @I.ir_module
+ class Parsed:
+ @T.prim_func
+ def add(x_handle: T.handle, y_handle: T.handle, output_handle:
T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,))
+ y = T.match_buffer(y_handle, (m,))
+ output = T.match_buffer(output_handle, (m,))
+ with T.block("root"):
+ T.reads(x[0:m], y[0:m])
+ T.writes(output[0:m])
+ T.call_packed(
+ "add_kernel",
+ x.data,
+ y.data,
+ output.data,
+ m,
+ (m + T.int64(64) - T.int64(1)) // T.int64(64),
+ 64,
+ )
+
+ tvm.ir.assert_structural_equal(Module["add"], Parsed["add"])
+ assert len(Module.get_attr("external_mods")) == 1
+
+ device = tvm.cuda(0)
+ x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device)
+ y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device)
+ output_np = x_nd.numpy() + y_nd.numpy()
+
+ with tvm.target.Target("cuda"):
+ lib = relax.build(Module)
+ output_nd = tvm.runtime.relax_vm.VirtualMachine(lib,
device)["main"](x_nd, y_nd)
+ tvm.testing.assert_allclose(output_nd.numpy(), output_np, rtol=1e-5)