This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 51cfb70f86 [Fix][Dlight] Fix GeneralReduction for log-sum-exp (#16923)
51cfb70f86 is described below
commit 51cfb70f868c057d0d73aa60bc96b99ce722ecd2
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Apr 25 20:31:46 2024 -0400
[Fix][Dlight] Fix GeneralReduction for log-sum-exp (#16923)
This PR fixes the GeneralReduction dlight rule so that it can support
scheduling log-sum-exp function.
Prior to this issue, the rule makes a strong assumption on the pattern
of the given function, which allows scheduling softmax, but fails to
schedule log-sum-exp due to pattern mismatch. This PR enhances the rule
and makes it able to match the pattern of log-sum-exp and apply
subsequent scheduling.
A regression test is added.
---
python/tvm/dlight/gpu/general_reduction.py | 35 +++--
tests/python/dlight/test_gpu_general_reduction.py | 149 ++++++++++++++++++++++
2 files changed, 176 insertions(+), 8 deletions(-)
diff --git a/python/tvm/dlight/gpu/general_reduction.py
b/python/tvm/dlight/gpu/general_reduction.py
index 28b68a8b62..ef6bb1db91 100644
--- a/python/tvm/dlight/gpu/general_reduction.py
+++ b/python/tvm/dlight/gpu/general_reduction.py
@@ -18,7 +18,7 @@
"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
from typing import List, Union
-from tvm import tir
+from tvm import arith, tir
from tvm.target import Target
from ..base import normalize_prim_func, try_inline_contiguous_spatial
@@ -57,13 +57,32 @@ class GeneralReduction(GPUScheduleRule):
# Align the number of block iters of the last block.
num_last_block_iter = len(block_infos[-1].dom_kind())
if num_last_block_iter < len(dom_kind):
- index_map = tir.IndexMap.from_func(
- lambda *iters: (
- [tir.const(0, iters[0].dtype)] * (len(dom_kind) -
num_last_block_iter)
- + list(iters)
- ),
- ndim=num_last_block_iter,
- )
+
+ def f_layout_mapping(*iters):
+ analyzer = arith.Analyzer()
+ # Try to match the iters of last block to the iters of the
first block.
+ # For matched positions, use the iter from the input `iters`.
+ # For unmatched positions, use a new iter which is constant 0.
+ num_matched = 0
+ target_layout_iters = []
+ for block_iter in block_infos[0].iters:
+ if num_matched < len(iters) and analyzer.can_prove_equal(
+ block_iter.dom, block_infos[-1].iters[num_matched].dom
+ ):
+ target_layout_iters.append(iters[num_matched])
+ num_matched += 1
+ else:
+ target_layout_iters.append(tir.const(0,
iters[0].dtype))
+
+ # If all the iters of the last block can match, return the new
layout.
+ if num_matched == len(iters):
+ return target_layout_iters
+ # Otherwise, fallback to appending zeros in the beginning.
+ return [tir.const(0, iters[0].dtype)] * (
+ len(dom_kind) - num_last_block_iter
+ ) + list(iters)
+
+ index_map = tir.IndexMap.from_func(f_layout_mapping,
ndim=num_last_block_iter)
sch.transform_block_layout(block_infos[-1].block_rv, index_map)
try:
diff --git a/tests/python/dlight/test_gpu_general_reduction.py
b/tests/python/dlight/test_gpu_general_reduction.py
index 44c9a4a126..e1a9a8e018 100644
--- a/tests/python/dlight/test_gpu_general_reduction.py
+++ b/tests/python/dlight/test_gpu_general_reduction.py
@@ -453,5 +453,154 @@ def test_group_norm():
_check(Before, After)
+def test_logsumexp():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ batch_size = T.int64(is_size_var=True)
+ vocab_size = T.int64(is_size_var=True)
+ num_chunks = T.int64(is_size_var=True)
+ A = T.match_buffer(var_A, (batch_size, vocab_size),
dtype="float32")
+ blocked_lse = T.match_buffer(var_blocked_lse, (batch_size,
num_chunks), dtype="float32")
+ A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)),
dtype="float32")
+ temp_max = T.alloc_buffer((batch_size, num_chunks),
dtype="float32")
+ temp_sum = T.alloc_buffer((batch_size, num_chunks),
dtype="float32")
+
+ for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
+ with T.block("pad"):
+ v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
+ A_pad[v0, v1, v2] = T.if_then_else(
+ v1 * T.int64(4096) + v2 < vocab_size,
+ A[v0, v1 * T.int64(4096) + v2],
+ T.min_value("float32"),
+ )
+
+ for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
+ with T.block("max"):
+ v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
+ with T.init():
+ temp_max[v0, v1] = T.min_value("float32")
+ temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1,
v2])
+
+ for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
+ with T.block("sum_exp"):
+ v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
+ with T.init():
+ temp_sum[v0, v1] = T.float32(0)
+ temp_sum[v0, v1] += T.if_then_else(
+ v1 * T.int64(4096) + v2 < vocab_size,
+ T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),
+ T.float32(0),
+ )
+
+ for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
+ with T.block("log"):
+ v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
+ blocked_lse[v0, v1] = T.log(temp_sum[v0, v1]) +
temp_max[v0, v1]
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ batch_size, vocab_size = T.int64(is_size_var=True),
T.int64(is_size_var=True)
+ A = T.match_buffer(var_A, (batch_size, vocab_size))
+ num_chunks = T.int64(is_size_var=True)
+ blocked_lse = T.match_buffer(var_blocked_lse, (batch_size,
num_chunks))
+ temp_max_shared = T.alloc_buffer((batch_size, num_chunks),
scope="shared")
+ temp_sum_shared = T.alloc_buffer((batch_size, num_chunks),
scope="shared")
+ for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks,
thread="blockIdx.x"):
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+ for ax2_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax2_fused_0 in T.serial(
+ T.int64(16),
+ annotations={
+ "pragma_auto_unroll_max_step": 256,
+ "pragma_unroll_explicit": 1,
+ },
+ ):
+ with T.block("max"):
+ v0 = T.axis.spatial(
+ batch_size,
+ ax0_ax1_fused % (num_chunks * batch_size)
// num_chunks + ax0,
+ )
+ v1 = T.axis.spatial(num_chunks, ax0_ax1_fused
% num_chunks + ax1)
+ v2 = T.axis.reduce(
+ T.int64(4096), ax2_fused_0 * T.int64(256)
+ ax2_fused_1
+ )
+ T.reads(A[v0, v1 * T.int64(4096) + v2])
+ T.writes(temp_max_shared[v0, v1])
+ with T.init():
+ temp_max_shared[v0, v1] =
T.min_value("float32")
+ temp_max_shared[v0, v1] = T.max(
+ temp_max_shared[v0, v1],
+ T.if_then_else(
+ v1 * T.int64(4096) + v2 < vocab_size,
+ A[v0, v1 * T.int64(4096) + v2],
+ T.min_value("float32"),
+ ),
+ )
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+ for ax2_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax2_fused_0 in T.serial(
+ T.int64(16),
+ annotations={
+ "pragma_auto_unroll_max_step": 256,
+ "pragma_unroll_explicit": 1,
+ },
+ ):
+ with T.block("sum_exp"):
+ v0 = T.axis.spatial(
+ batch_size,
+ ax0_ax1_fused % (num_chunks * batch_size)
// num_chunks + ax0,
+ )
+ v1 = T.axis.spatial(num_chunks, ax0_ax1_fused
% num_chunks + ax1)
+ v2 = T.axis.reduce(
+ T.int64(4096), ax2_fused_0 * T.int64(256)
+ ax2_fused_1
+ )
+ T.reads(A[v0, v1 * T.int64(4096) + v2],
temp_max_shared[v0, v1])
+ T.writes(temp_sum_shared[v0, v1])
+ with T.init():
+ temp_sum_shared[v0, v1] = T.float32(0)
+ temp_sum_shared[v0, v1] = temp_sum_shared[v0,
v1] + T.if_then_else(
+ v1 * T.int64(4096) + v2 < vocab_size,
+ T.exp(
+ (
+ T.if_then_else(
+ v1 * T.int64(4096) + v2 <
vocab_size,
+ A[v0, v1 * T.int64(4096) + v2],
+ T.min_value("float32"),
+ )
+ - temp_max_shared[v0, v1]
+ )
+ ),
+ T.float32(0),
+ )
+ for ax2_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax2_0 in T.serial(
+ T.int64(1),
+ annotations={
+ "pragma_auto_unroll_max_step": 256,
+ "pragma_unroll_explicit": 1,
+ },
+ ):
+ with T.block("log"):
+ v0 = T.axis.spatial(
+ batch_size, ax0_ax1_fused % (num_chunks *
batch_size) // num_chunks
+ )
+ v1 = T.axis.spatial(num_chunks, ax0_ax1_fused %
num_chunks)
+ v2 = T.axis.spatial(T.int64(1), ax2_0 *
T.int64(256) + ax2_1)
+ T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
+ T.reads(temp_sum_shared[v0, v1],
temp_max_shared[v0, v1])
+ T.writes(blocked_lse[v0, v1])
+ blocked_lse[v0, v1] = (
+ T.log(temp_sum_shared[v0, v1]) +
temp_max_shared[v0, v1]
+ )
+
+ _check(Before, After)
+
+
if __name__ == "__main__":
tvm.testing.main()