This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 278a6af085 [Relax][TIR] Introduce new `cumsum` op for gpu (#16934)
278a6af085 is described below

commit 278a6af085d1a149bc9ae4ff4a7ac4b33fc6b6bb
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Apr 26 23:15:38 2024 +0800

    [Relax][TIR] Introduce new `cumsum` op for gpu (#16934)
---
 python/tvm/relax/backend/dispatch_sort_scan.py     |  41 +++++
 python/tvm/relax/backend_tir/__init__.py           |   1 +
 python/tvm/relax/backend_tir/cumsum.py             | 193 +++++++++++++++++++++
 .../relax/test_backend_dispatch_sort_scan.py       |  38 +++-
 4 files changed, 268 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py 
b/python/tvm/relax/backend/dispatch_sort_scan.py
index eb82e49d9a..870e6138d7 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -154,7 +154,48 @@ class SortScanDispatcher(PyExprMutator):
         if call.op.name in ("relax.cumprod", "relax.cumsum"):
             tgt = self._get_target(call.struct_info)
             axis = int(call.attrs.axis) if call.attrs.axis is not None else 
call.attrs.axis
+            shape = call.struct_info.shape
             kwargs = {}
+            if (
+                (axis == -1 or axis == len(shape) - 1)
+                and is_gpu_target(tgt)
+                and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
+                and call.op.name == "relax.cumsum"
+                and call.attrs.exclusive == 0
+            ):
+                from tvm.relax.backend_tir import (  # pylint: 
disable=import-outside-toplevel
+                    gpu_2d_continuous_cumsum,
+                )
+
+                dim = 1
+                for i in range(len(shape) - 1):
+                    dim *= shape[i]
+                in_dtype = call.args[0].struct_info.dtype
+                out_dtype = call.attrs.dtype
+                out_dtype = out_dtype or in_dtype
+                cumsum_2d_shape = relax.ShapeExpr([dim, shape[-1]])
+                reshape = relax.call_pure_packed(
+                    "vm.builtin.reshape",
+                    call.args[0],
+                    cumsum_2d_shape,
+                    sinfo_args=relax.TensorStructInfo(cumsum_2d_shape, 
out_dtype),
+                )
+                gv = self.builder_.add_func(
+                    gpu_2d_continuous_cumsum(in_dtype=in_dtype, 
out_dtype=out_dtype),
+                    "gpu_2d_continuous_cumsum",
+                )
+                cumsum = relax.call_tir(
+                    gv,
+                    reshape,
+                    out_sinfo=relax.TensorStructInfo(cumsum_2d_shape, 
out_dtype),
+                )
+                return relax.call_pure_packed(
+                    "vm.builtin.reshape",
+                    cumsum,
+                    shape,
+                    sinfo_args=call.struct_info,
+                )
+
             with tgt:
                 if call.op.name == "relax.cumsum":
                     te_func = topi.cuda.cumsum if is_gpu_target(tgt) else 
topi.cumsum
diff --git a/python/tvm/relax/backend_tir/__init__.py 
b/python/tvm/relax/backend_tir/__init__.py
index eeb8fe438f..10def47b8d 100644
--- a/python/tvm/relax/backend_tir/__init__.py
+++ b/python/tvm/relax/backend_tir/__init__.py
@@ -18,3 +18,4 @@
 
 from . import contrib
 from .pattern import get_tir_pattern
+from .cumsum import gpu_2d_continuous_cumsum
diff --git a/python/tvm/relax/backend_tir/cumsum.py 
b/python/tvm/relax/backend_tir/cumsum.py
new file mode 100644
index 0000000000..ade961ecf1
--- /dev/null
+++ b/python/tvm/relax/backend_tir/cumsum.py
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-nested-blocks
+"""Backend kernels for cumsum operator."""
+
+import math
+from typing import Optional
+
+from tvm.script import tir as T
+from tvm.tir import PrimFunc
+
+
+def _is_power_of_two(n: int):
+    """Check if n is a power of 2."""
+    return n > 0 and (n & (n - 1)) == 0
+
+
+def gpu_2d_continuous_cumsum(
+    ty_len: int = 4,
+    tx_len: int = 32,
+    thread_elem: int = 4,
+    in_dtype: str = "int32",
+    out_dtype: Optional[str] = None,
+) -> PrimFunc:
+    """Generate GPU kernel for 2D continuous cumsum, i.e. The cumsum axis is -1
+
+    Parameters
+    ----------
+    ty_len : int
+        The length of thread.y
+
+    tx_len : int
+        The length of thread.x
+
+    thread_elem : int
+        The number of elements processed by single thread
+
+    in_dtype : str
+        The input data type
+
+    out_dtype : Optional[str]
+        The output data type, if None, it will be the same as in_dtype
+
+    Returns
+    -------
+    cumsum : PrimFunc
+        The generated cumsum kernel
+    """
+
+    out_dtype = out_dtype or in_dtype
+
+    # Configuration for GPU kernel
+    TX = T.int64(tx_len)  # thread.x
+    TY = T.int64(ty_len)  # thread.y
+    N = T.int64(thread_elem)  # number of elements in single thread
+
+    if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not 
_is_power_of_two(N):
+        raise ValueError("Configuration of TX, TY, N must be power of 2")
+
+    # number of elements to be processed by single warp
+    warp_elem = T.int64(tx_len * thread_elem)
+    # number of elements to be processed by single block(SM)
+    block_elem = T.int64(tx_len * ty_len * thread_elem)
+
+    LOG_TX = T.int64(int(math.log2(tx_len)))
+    LOG_BLOCK_N = T.int64(int(math.log2(tx_len * ty_len * thread_elem)))
+
+    @T.macro
+    def block_inclusive_inside_block(
+        batch: T.int64,
+        cur_len: T.int64,
+        source: T.Buffer,
+        output: T.Buffer,
+        tmp_buf: T.Buffer,
+        src_offset: T.int64,
+        tmp_offset: T.int64,
+    ):
+        for by in T.thread_binding(batch, thread="blockIdx.y"):
+            for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), 
thread="blockIdx.x"):
+                with T.block():
+                    local_buf = T.alloc_buffer((thread_elem,), out_dtype, 
scope="local")
+                    shared_buf = T.alloc_buffer((block_elem,), out_dtype, 
scope="shared")
+                    for ty in T.thread_binding(TY, thread="threadIdx.y"):
+                        for tx in T.thread_binding(TX, thread="threadIdx.x"):
+                            tx_idx = bx * block_elem + ty * warp_elem + tx * 
thread_elem
+                            # Load data from global memory
+                            for i in T.vectorized(N):
+                                local_buf[i] = T.if_then_else(
+                                    tx_idx + i < cur_len,
+                                    T.Cast(out_dtype, source[by, src_offset + 
tx_idx + i]),
+                                    T.Cast(out_dtype, 0),
+                                )
+                            # Inclusive scan inside thread
+                            for i in T.unroll(1, N):
+                                local_buf[i] += local_buf[i - 1]
+                            # Store data to shared memory
+                            for i in T.vectorized(N):
+                                shared_buf[ty * warp_elem + tx * thread_elem + 
i] = local_buf[i]
+                            # Inclusive scan inside warp
+                            for i in T.unroll(LOG_TX):
+                                for j in T.vectorized(N):
+                                    idx: T.int64 = ty * warp_elem + tx * 
thread_elem
+                                    if tx >= (1 << i):
+                                        shared_buf[idx + j] += shared_buf[
+                                            idx - (1 << i) * thread_elem + N - 
1
+                                        ]
+                            # Inclusive scan inside block
+                            for i in T.unroll(1, TY):
+                                for j in T.vectorized(N):
+                                    if ty == 0:
+                                        idx: T.int64 = i * warp_elem + tx * 
thread_elem
+                                        shared_buf[idx + j] += shared_buf[i * 
warp_elem - 1]
+                            # Write sum of block to global memory
+                            for i in T.vectorized(N):
+                                idx: T.int64 = ty * warp_elem + tx * 
thread_elem + i
+                                if bx * block_elem + idx < cur_len:
+                                    output[by, src_offset + bx * block_elem + 
idx] = shared_buf[idx]
+                            if tx == 0 and ty == 0:
+                                for i in T.vectorized(N):
+                                    tmp_buf[by, tmp_offset + bx] = 
shared_buf[block_elem - 1]
+
+    @T.macro
+    def update_cross_block(
+        batch: T.int64,
+        cur_len: T.int64,
+        source: T.Buffer,
+        output: T.Buffer,
+        src_offset: T.int64,
+        out_offset: T.int64,
+    ):
+        for by in T.thread_binding(batch, thread="blockIdx.y"):
+            for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), 
thread="blockIdx.x"):
+                for ty in T.thread_binding(TY, thread="threadIdx.y"):
+                    for tx in T.thread_binding(TX, thread="threadIdx.x"):
+                        for i in T.serial(N):
+                            idx: T.int64 = bx * block_elem + ty * warp_elem + 
i * TX + tx
+                            if idx < cur_len:
+                                output[by, out_offset + idx] += T.if_then_else(
+                                    bx > 0, source[by, src_offset + bx - 1], 0
+                                )
+
+    @T.prim_func(private=True)
+    def cumsum(var_a: T.handle, var_out: T.handle):
+        T.func_attr({"tir.is_scheduled": 1})  # prevent further scheduling
+        m, n = T.int64(), T.int64()
+        A = T.match_buffer(var_a, [m, n], dtype=in_dtype)
+        Out = T.match_buffer(var_out, [m, n], dtype=out_dtype)
+        Tmp = T.alloc_buffer([m, n], dtype=out_dtype)
+        ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
+        total_rounds = ceil_log2 // LOG_BLOCK_N
+
+        block_inclusive_inside_block(
+            m, n, A, Out, Tmp, src_offset=T.int64(0), tmp_offset=T.int64(0)
+        )
+        for i in range(total_rounds):
+            cur_len = T.ceildiv(n, 1 << (LOG_BLOCK_N * (i + 1)))
+            block_inclusive_inside_block(
+                m,
+                cur_len,
+                Tmp,
+                Tmp,
+                Tmp,
+                src_offset=i * T.ceildiv(n, block_elem),
+                tmp_offset=(i + 1) * T.ceildiv(n, block_elem),
+            )
+        for i in range(total_rounds - 1):
+            real_idx = total_rounds - 1 - i - 1
+            cur_len = T.ceildiv(n, 1 << (LOG_BLOCK_N * (real_idx + 1)))
+            update_cross_block(
+                m,
+                cur_len,
+                Tmp,
+                Tmp,
+                src_offset=(real_idx + 1) * T.ceildiv(n, block_elem),
+                out_offset=real_idx * T.ceildiv(n, block_elem),
+            )
+        update_cross_block(m, n, Tmp, Out, src_offset=0, out_offset=0)
+
+    return cumsum
diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py 
b/tests/python/relax/test_backend_dispatch_sort_scan.py
index 0fb39dfc9c..a539621060 100644
--- a/tests/python/relax/test_backend_dispatch_sort_scan.py
+++ b/tests/python/relax/test_backend_dispatch_sort_scan.py
@@ -15,18 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import numpy as np
 import pytest
 
 import tvm
-from tvm import topi, relax, tir, dlight
 import tvm.script
 import tvm.testing
-from tvm.script import relax as R, tir as T, ir as I
+from tvm import dlight, relax, tir, topi
 from tvm.contrib.thrust import can_use_thrust
-
-
-from tvm.relax.backend import DispatchSortScan
 from tvm.ir.base import assert_structural_equal
+from tvm.relax.backend import DispatchSortScan
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 def test_dispatch_scanop():
@@ -399,5 +400,32 @@ def test_dispatch_topk_gpu():
     assert_structural_equal(mod, expected_mod)
 
 
[email protected]_cuda
+def test_dispatch_cumsum_gpu():
+    """Test cumsum kernel dispatch and numerical correctness"""
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "int32")):
+            with R.dataflow():
+                gv = R.cumsum(x, axis=-1, exclusive=False)
+                R.output(gv)
+            return gv
+
+    size = (8, 2000)
+    np_data = np.random.randint(0, 10, size).astype("int32")
+    np_cumsum = np.cumsum(np_data, axis=-1)
+    for target in ["cuda", "vulkan -supports_int64=1"]:
+        with tvm.target.Target(target):
+            mod = DispatchSortScan()(Module)
+            ex = tvm.relax.build(mod, target)
+            device = tvm.device(target, 0)
+            vm = tvm.relax.VirtualMachine(ex, device)
+            tvm_data = tvm.nd.array(np_data, device)
+            cumsum = vm["main"](tvm_data)
+            tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to