This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 51cfb70f86 [Fix][Dlight] Fix GeneralReduction for log-sum-exp (#16923)
51cfb70f86 is described below

commit 51cfb70f868c057d0d73aa60bc96b99ce722ecd2
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Apr 25 20:31:46 2024 -0400

    [Fix][Dlight] Fix GeneralReduction for log-sum-exp (#16923)
    
    This PR fixes the GeneralReduction dlight rule so that it can support
    scheduling log-sum-exp function.
    
    Prior to this issue, the rule makes a strong assumption on the pattern
    of the given function, which allows scheduling softmax, but fails to
    schedule log-sum-exp due to pattern mismatch. This PR enhances the rule
    and makes it able to match the pattern of log-sum-exp and apply
    subsequent scheduling.
    
    A regression test is added.
---
 python/tvm/dlight/gpu/general_reduction.py        |  35 +++--
 tests/python/dlight/test_gpu_general_reduction.py | 149 ++++++++++++++++++++++
 2 files changed, 176 insertions(+), 8 deletions(-)

diff --git a/python/tvm/dlight/gpu/general_reduction.py 
b/python/tvm/dlight/gpu/general_reduction.py
index 28b68a8b62..ef6bb1db91 100644
--- a/python/tvm/dlight/gpu/general_reduction.py
+++ b/python/tvm/dlight/gpu/general_reduction.py
@@ -18,7 +18,7 @@
 """Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
 from typing import List, Union
 
-from tvm import tir
+from tvm import arith, tir
 from tvm.target import Target
 
 from ..base import normalize_prim_func, try_inline_contiguous_spatial
@@ -57,13 +57,32 @@ class GeneralReduction(GPUScheduleRule):
         # Align the number of block iters of the last block.
         num_last_block_iter = len(block_infos[-1].dom_kind())
         if num_last_block_iter < len(dom_kind):
-            index_map = tir.IndexMap.from_func(
-                lambda *iters: (
-                    [tir.const(0, iters[0].dtype)] * (len(dom_kind) - 
num_last_block_iter)
-                    + list(iters)
-                ),
-                ndim=num_last_block_iter,
-            )
+
+            def f_layout_mapping(*iters):
+                analyzer = arith.Analyzer()
+                # Try to match the iters of last block to the iters of the 
first block.
+                # For matched positions, use the iter from the input `iters`.
+                # For unmatched positions, use a new iter which is constant 0.
+                num_matched = 0
+                target_layout_iters = []
+                for block_iter in block_infos[0].iters:
+                    if num_matched < len(iters) and analyzer.can_prove_equal(
+                        block_iter.dom, block_infos[-1].iters[num_matched].dom
+                    ):
+                        target_layout_iters.append(iters[num_matched])
+                        num_matched += 1
+                    else:
+                        target_layout_iters.append(tir.const(0, 
iters[0].dtype))
+
+                # If all the iters of the last block can match, return the new 
layout.
+                if num_matched == len(iters):
+                    return target_layout_iters
+                # Otherwise, fallback to appending zeros in the beginning.
+                return [tir.const(0, iters[0].dtype)] * (
+                    len(dom_kind) - num_last_block_iter
+                ) + list(iters)
+
+            index_map = tir.IndexMap.from_func(f_layout_mapping, 
ndim=num_last_block_iter)
             sch.transform_block_layout(block_infos[-1].block_rv, index_map)
 
         try:
diff --git a/tests/python/dlight/test_gpu_general_reduction.py 
b/tests/python/dlight/test_gpu_general_reduction.py
index 44c9a4a126..e1a9a8e018 100644
--- a/tests/python/dlight/test_gpu_general_reduction.py
+++ b/tests/python/dlight/test_gpu_general_reduction.py
@@ -453,5 +453,154 @@ def test_group_norm():
     _check(Before, After)
 
 
+def test_logsumexp():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            batch_size = T.int64(is_size_var=True)
+            vocab_size = T.int64(is_size_var=True)
+            num_chunks = T.int64(is_size_var=True)
+            A = T.match_buffer(var_A, (batch_size, vocab_size), 
dtype="float32")
+            blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, 
num_chunks), dtype="float32")
+            A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)), 
dtype="float32")
+            temp_max = T.alloc_buffer((batch_size, num_chunks), 
dtype="float32")
+            temp_sum = T.alloc_buffer((batch_size, num_chunks), 
dtype="float32")
+
+            for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
+                with T.block("pad"):
+                    v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
+                    A_pad[v0, v1, v2] = T.if_then_else(
+                        v1 * T.int64(4096) + v2 < vocab_size,
+                        A[v0, v1 * T.int64(4096) + v2],
+                        T.min_value("float32"),
+                    )
+
+            for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
+                with T.block("max"):
+                    v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
+                    with T.init():
+                        temp_max[v0, v1] = T.min_value("float32")
+                    temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, 
v2])
+
+            for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
+                with T.block("sum_exp"):
+                    v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
+                    with T.init():
+                        temp_sum[v0, v1] = T.float32(0)
+                    temp_sum[v0, v1] += T.if_then_else(
+                        v1 * T.int64(4096) + v2 < vocab_size,
+                        T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),
+                        T.float32(0),
+                    )
+
+            for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
+                with T.block("log"):
+                    v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
+                    blocked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + 
temp_max[v0, v1]
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            batch_size, vocab_size = T.int64(is_size_var=True), 
T.int64(is_size_var=True)
+            A = T.match_buffer(var_A, (batch_size, vocab_size))
+            num_chunks = T.int64(is_size_var=True)
+            blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, 
num_chunks))
+            temp_max_shared = T.alloc_buffer((batch_size, num_chunks), 
scope="shared")
+            temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), 
scope="shared")
+            for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, 
thread="blockIdx.x"):
+                for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        for ax2_fused_0 in T.serial(
+                            T.int64(16),
+                            annotations={
+                                "pragma_auto_unroll_max_step": 256,
+                                "pragma_unroll_explicit": 1,
+                            },
+                        ):
+                            with T.block("max"):
+                                v0 = T.axis.spatial(
+                                    batch_size,
+                                    ax0_ax1_fused % (num_chunks * batch_size) 
// num_chunks + ax0,
+                                )
+                                v1 = T.axis.spatial(num_chunks, ax0_ax1_fused 
% num_chunks + ax1)
+                                v2 = T.axis.reduce(
+                                    T.int64(4096), ax2_fused_0 * T.int64(256) 
+ ax2_fused_1
+                                )
+                                T.reads(A[v0, v1 * T.int64(4096) + v2])
+                                T.writes(temp_max_shared[v0, v1])
+                                with T.init():
+                                    temp_max_shared[v0, v1] = 
T.min_value("float32")
+                                temp_max_shared[v0, v1] = T.max(
+                                    temp_max_shared[v0, v1],
+                                    T.if_then_else(
+                                        v1 * T.int64(4096) + v2 < vocab_size,
+                                        A[v0, v1 * T.int64(4096) + v2],
+                                        T.min_value("float32"),
+                                    ),
+                                )
+                for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        for ax2_fused_0 in T.serial(
+                            T.int64(16),
+                            annotations={
+                                "pragma_auto_unroll_max_step": 256,
+                                "pragma_unroll_explicit": 1,
+                            },
+                        ):
+                            with T.block("sum_exp"):
+                                v0 = T.axis.spatial(
+                                    batch_size,
+                                    ax0_ax1_fused % (num_chunks * batch_size) 
// num_chunks + ax0,
+                                )
+                                v1 = T.axis.spatial(num_chunks, ax0_ax1_fused 
% num_chunks + ax1)
+                                v2 = T.axis.reduce(
+                                    T.int64(4096), ax2_fused_0 * T.int64(256) 
+ ax2_fused_1
+                                )
+                                T.reads(A[v0, v1 * T.int64(4096) + v2], 
temp_max_shared[v0, v1])
+                                T.writes(temp_sum_shared[v0, v1])
+                                with T.init():
+                                    temp_sum_shared[v0, v1] = T.float32(0)
+                                temp_sum_shared[v0, v1] = temp_sum_shared[v0, 
v1] + T.if_then_else(
+                                    v1 * T.int64(4096) + v2 < vocab_size,
+                                    T.exp(
+                                        (
+                                            T.if_then_else(
+                                                v1 * T.int64(4096) + v2 < 
vocab_size,
+                                                A[v0, v1 * T.int64(4096) + v2],
+                                                T.min_value("float32"),
+                                            )
+                                            - temp_max_shared[v0, v1]
+                                        )
+                                    ),
+                                    T.float32(0),
+                                )
+                for ax2_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                    for ax2_0 in T.serial(
+                        T.int64(1),
+                        annotations={
+                            "pragma_auto_unroll_max_step": 256,
+                            "pragma_unroll_explicit": 1,
+                        },
+                    ):
+                        with T.block("log"):
+                            v0 = T.axis.spatial(
+                                batch_size, ax0_ax1_fused % (num_chunks * 
batch_size) // num_chunks
+                            )
+                            v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % 
num_chunks)
+                            v2 = T.axis.spatial(T.int64(1), ax2_0 * 
T.int64(256) + ax2_1)
+                            T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
+                            T.reads(temp_sum_shared[v0, v1], 
temp_max_shared[v0, v1])
+                            T.writes(blocked_lse[v0, v1])
+                            blocked_lse[v0, v1] = (
+                                T.log(temp_sum_shared[v0, v1]) + 
temp_max_shared[v0, v1]
+                            )
+
+    _check(Before, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to