This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc2c5a28c9 [TVMScript][TIR] Add source kernel intetration via 
call_kernel (#17434)
dc2c5a28c9 is described below

commit dc2c5a28c9132aa314cca237ffbe32e1bad8dd2a
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Oct 3 06:50:45 2024 -0700

    [TVMScript][TIR] Add source kernel intetration via call_kernel (#17434)
    
    * [TVMScript][TIR] Add source kernel intetration via call_kernel
    
    * lint
    
    * lint
---
 .../tvm/script/ir_builder/tir/external_kernel.py   |  62 ++++++++++++-
 tests/python/relax/test_tir_call_source_kernel.py  | 100 +++++++++++++++++++++
 2 files changed, 160 insertions(+), 2 deletions(-)

diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py 
b/python/tvm/script/ir_builder/tir/external_kernel.py
index 8c2467fad3..405e1e6cbf 100644
--- a/python/tvm/script/ir_builder/tir/external_kernel.py
+++ b/python/tvm/script/ir_builder/tir/external_kernel.py
@@ -18,14 +18,16 @@
 import json
 import logging
 import tempfile
+from pathlib import Path
 from typing import Any, Dict, List, Tuple, Union
 
 from tvm import __version__ as tvm_version
 from tvm import tir
-from tvm.runtime import Module, load_module
+from tvm.runtime import Module, load_module, const
+from tvm.contrib import nvcc
 
 
-class BaseKernel:
+class BaseKernel:  # pylint: disable=too-few-public-methods
     """Base class for external kernels."""
 
     def compile_to_device_module(
@@ -91,6 +93,60 @@ class BaseKernel:
         return kernel_module
 
 
+class SourceKernel(BaseKernel):  # pylint: disable=too-few-public-methods
+    """A kernel from source code."""
+
+    def __init__(self, source_code: str):
+        self.source_code = source_code
+
+    def compile_to_device_module(  # pylint: disable=arguments-differ
+        self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any], 
**kwargs: Dict[str, Any]
+    ) -> Tuple[str, Module, List[Any]]:
+        """Compile the kernel to a device module."""
+        from tvm.relax.frontend.nn import SourceModule  # pylint: 
disable=import-outside-toplevel
+
+        kernel_name = kwargs["kernel_name"]
+        assert len(grid) == 2, (
+            "grid should be two list of integers, representing the dimension 
of "
+            "['blockIdx.x', 'blockIdx.y', 'blockIdx.z'] and "
+            "['threadIdx.x', 'threadIdx.y', 'threadIdx.z']"
+        )
+        assert isinstance(grid[0], (list, tuple)) and isinstance(grid[1], 
(list, tuple))
+        launch_param_tags = ["blockIdx.x", "blockIdx.y", "blockIdx.z"][: 
len(grid[0])] + [
+            "threadIdx.x",
+            "threadIdx.y",
+            "threadIdx.z",
+        ][: len(grid[1])]
+        runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg 
in args]
+        kernel_arg_types = [arg.dtype for arg in runtime_args]
+        runtime_args = runtime_args + list(grid[0]) + list(grid[1])
+
+        # Reuse compilation path from SourceModule
+        compile_options = SourceModule.get_compile_options("cu")
+        source_code = self.source_code
+        try:
+            source_path = Path(source_code)
+            if source_path.is_file():
+                with open(source_path, "r") as f:
+                    source_code = f.read()
+        except:  # pylint: disable=bare-except
+            pass
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            ptx_path = f"{temp_dir}/{kernel_name}.ptx"
+            nvcc.compile_cuda(
+                source_code, target_format="ptx", options=compile_options, 
path_target=ptx_path
+            )
+            with open(ptx_path, "r") as f:
+                ptx = f.read()
+
+            kernel_module = self._create_cuda_module(
+                ptx, kernel_arg_types, launch_param_tags, kernel_name
+            )
+
+        return kernel_name, kernel_module, runtime_args
+
+
 def call_kernel(
     kernel,
     launch_args: List[Union[int, tir.PrimExpr, List[Union[int, 
tir.PrimExpr]]]],
@@ -123,6 +179,8 @@ def call_kernel(
         from .triton import TritonKernel  # pylint: 
disable=import-outside-toplevel
 
         kernel = TritonKernel(kernel)
+    elif kernel_type == "builtins.str":
+        kernel = SourceKernel(kernel)
     else:
         raise ValueError("Unsupported kernel type {}".format(kernel_type))
 
diff --git a/tests/python/relax/test_tir_call_source_kernel.py 
b/tests/python/relax/test_tir_call_source_kernel.py
new file mode 100644
index 0000000000..9a877ad35f
--- /dev/null
+++ b/tests/python/relax/test_tir_call_source_kernel.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import tir as T, ir as I, relax as R
+
+add_cuda_source = """
+extern "C" __global__ void add_kernel(float* x, float* y, float* output, int 
n_elements) {
+    int i = blockIdx.x * blockDim.x + threadIdx.x;
+    if (i < n_elements) {
+        output[i] = x[i] + y[i];
+    }
+}
+"""
+
+
[email protected]_cuda
+def test_tir_call_source_kernel():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def add(x_handle: T.handle, y_handle: T.handle, output_handle: 
T.handle) -> None:
+            T.func_attr({"global_symbol": "add"})
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,), "float32")
+            y = T.match_buffer(y_handle, (m,), "float32")
+            output = T.match_buffer(output_handle, (m,), "float32")
+            with T.block("root"):
+                T.reads(x[0:m], y[0:m])
+                T.writes(output[0:m])
+                BLOCK_SIZE = T.meta_var(64)
+                T.call_kernel(
+                    add_cuda_source,
+                    ((T.ceildiv(m, BLOCK_SIZE),), (BLOCK_SIZE,)),
+                    x.data,
+                    y.data,
+                    output.data,
+                    m,
+                    kernel_name="add_kernel",
+                )
+
+        @R.function
+        def main(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), 
"float32")):
+            m = T.int64()
+            with R.dataflow():
+                output = R.call_tir(Module.add, [x, y], 
relax.TensorStructInfo((m,), "float32"))
+                R.output(output)
+            return output
+
+    @I.ir_module
+    class Parsed:
+        @T.prim_func
+        def add(x_handle: T.handle, y_handle: T.handle, output_handle: 
T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,))
+            y = T.match_buffer(y_handle, (m,))
+            output = T.match_buffer(output_handle, (m,))
+            with T.block("root"):
+                T.reads(x[0:m], y[0:m])
+                T.writes(output[0:m])
+                T.call_packed(
+                    "add_kernel",
+                    x.data,
+                    y.data,
+                    output.data,
+                    m,
+                    (m + T.int64(64) - T.int64(1)) // T.int64(64),
+                    64,
+                )
+
+    tvm.ir.assert_structural_equal(Module["add"], Parsed["add"])
+    assert len(Module.get_attr("external_mods")) == 1
+
+    device = tvm.cuda(0)
+    x_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device)
+    y_nd = tvm.nd.array(np.random.rand(256).astype(np.float32), device)
+    output_np = x_nd.numpy() + y_nd.numpy()
+
+    with tvm.target.Target("cuda"):
+        lib = relax.build(Module)
+        output_nd = tvm.runtime.relax_vm.VirtualMachine(lib, 
device)["main"](x_nd, y_nd)
+        tvm.testing.assert_allclose(output_nd.numpy(), output_np, rtol=1e-5)

Reply via email to