This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d9f0838dad [KVCache] Fix kernel dispatch based on attention kinds
(#18122)
d9f0838dad is described below
commit d9f0838dad8cdea4f6d6cba361ac35436f4b2f97
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 7 19:29:55 2025 -0400
[KVCache] Fix kernel dispatch based on attention kinds (#18122)
* [KVCache] Fix kernel dispatch based on attention kinds
This PR fixes a few kernel dispatch issues due to the recent
introduction of `mha_sliding` as a new attention kind.
Tested on Qwen3 1.7B with MLC-LLM.
* Fix lint
---------
Co-authored-by: Yong Wu <[email protected]>
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 52 +++++++++++++---------------
1 file changed, 24 insertions(+), 28 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index a1d742739a..e6e171da99 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -374,20 +374,15 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
if rope_mode == RopeMode.INLINE:
assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not
support partial rotary dim."
+ attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else
attn_kind
+ if attn_kind_single == "mha_sliding":
+ attn_kind_single = "mha"
flashinfer_prefill_mods =
rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
dtype_q=dtype,
dtype_kv=dtype,
dtype_o=dtype,
- qk_head_dim=(
- qk_head_dim
- if (attn_kind == "mha" or isinstance(attn_kind, List))
- else mla_original_qk_head_dim
- ),
- v_head_dim=(
- v_head_dim
- if (attn_kind == "mha" or isinstance(attn_kind, List))
- else mla_original_v_head_dim
- ),
+ qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else
mla_original_qk_head_dim),
+ v_head_dim=(v_head_dim if attn_kind_single == "mha" else
mla_original_v_head_dim),
target=target,
enable_inline_rope=rope_mode == RopeMode.INLINE,
)
@@ -400,7 +395,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
v_head_dim=v_head_dim,
target=target,
)
- if (attn_kind == "mha" or isinstance(attn_kind, List))
+ if attn_kind_single == "mha"
else []
)
flashinfer_mla_mods = (
@@ -412,7 +407,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
head_dim_kpe=qk_head_dim - v_head_dim,
target=target,
)
- if attn_kind == "mla"
+ if attn_kind_single == "mla"
else []
)
self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods +
flashinfer_mla_mods
@@ -429,21 +424,21 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
rx.Tuple([rx.StringImm("tir"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, qk_head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
rx.Tuple([rx.StringImm("tir"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim,
dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
]
- if (attn_kind == "mha" or isinstance(attn_kind, List))
+ if attn_kind_single == "mha"
else [rx.Tuple([]) for _ in range(6)]
)
- mla_function = rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_mla_paged_attention_run"),
rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else [])
+ mla_function = rx.Tuple([rx.StringImm("flashinfer"),
rx.ExternFunc("batch_mla_paged_attention_run"),
rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla"
else [])
attn_merge_functions = [
bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim,
dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
]
- if attn_kind == "mla":
+ if attn_kind_single == "mla":
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads,
mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"),
"tir_attention_merge_state_mla"))
-
if isinstance(attn_kind, List):
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for
layer_kind in attn_kind]
else:
attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in
range(num_hidden_layers)]
+
args = [
rx.ShapeExpr(
[
@@ -459,9 +454,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
rx.PrimValue(num_key_value_heads),
rx.PrimValue(qk_head_dim),
rx.PrimValue(v_head_dim),
- rx.ShapeExpr(
- [int(getattr(AttnKind, attn_kind.upper())) for _ in
range(num_hidden_layers)]
- ),
+ rx.ShapeExpr(attn_kind),
rx.PrimValue(enable_disaggregation),
rx.PrimValue(rope_mode),
rx.PrimValue(rope_scale),
@@ -475,7 +468,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
mla_function,
rx.Tuple(attn_merge_functions),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale,
qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling,
rotary_dim), "tir_split_rotary"),
- bb.add_func(_copy_single_page(num_key_value_heads, page_size,
qk_head_dim, dtype, target) if attn_kind == "mha" else
_copy_single_page_mla(page_size, qk_head_dim, dtype, target),
"kv_cache_copy_single_page"),
+ bb.add_func(_copy_single_page(num_key_value_heads, page_size,
qk_head_dim, dtype, target) if attn_kind_single == "mha" else
_copy_single_page_mla(page_size, qk_head_dim, dtype, target),
"kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers,
num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim,
dtype, target), "kv_cache_compact_kv_copy"),
# fmt: on
@@ -567,6 +560,9 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
target : Target
The target to build the model to.
"""
+ attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else
attn_kind
+ if attn_kind_single == "mha_sliding":
+ attn_kind_single = "mha"
if isinstance(attn_kind, List):
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for
layer_kind in attn_kind]
else:
@@ -605,7 +601,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
]
if str(target.kind) == "llvm":
- if attn_kind == "mla":
+ if attn_kind_single == "mla":
raise ValueError("MLA is not supported in TIR kernels for
now.")
# pylint: disable=line-too-long
# fmt: off
@@ -631,9 +627,9 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
else:
# pylint: disable=line-too-long
# fmt: off
- ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or
isinstance(attn_kind, List)) else mla_original_qk_head_dim
- ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or
isinstance(attn_kind, List)) else mla_original_v_head_dim
- args.append(rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind ==
"mha" or isinstance(attn_kind, List)) else num_attention_heads,
num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype,
rope_scaling, target), "tir_attention_prefill_ragged")]))
+ ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else
mla_original_qk_head_dim
+ ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else
mla_original_v_head_dim
+ args.append(rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single
== "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim,
ragged_v_head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_ragged")]))
mha_functions = (
[
rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads,
qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]),
@@ -643,14 +639,14 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
rx.Tuple([rx.StringImm("tir"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, qk_head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
rx.Tuple([rx.StringImm("tir"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim,
dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
]
- if (attn_kind == "mha" or isinstance(attn_kind, List))
+ if attn_kind_single == "mha"
else [rx.Tuple([]) for _ in range(6)]
)
- mla_function = rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim
- v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind
== "mla" else [])
+ mla_function = rx.Tuple([rx.StringImm("tir"),
bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim
- v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if
attn_kind_single == "mla" else [])
attn_merge_functions = [
bb.add_func(_merge_state_inplace(num_attention_heads,
v_head_dim, dtype, target, "tir_attention_merge_state"),
"tir_attention_merge_state"),
]
- if attn_kind == "mla":
+ if attn_kind_single == "mla":
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads,
mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"),
"tir_attention_merge_state_mla"))
args.extend(mha_functions)
args.append(mla_function)
@@ -658,7 +654,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
[
rx.Tuple(attn_merge_functions),
bb.add_func(llama_rope_with_position_map(rope_theta,
rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype,
rope_scaling, rotary_dim), "tir_split_rotary"),
- bb.add_func(_copy_single_page(num_key_value_heads,
page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or
isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim,
dtype, target), "kv_cache_copy_single_page"),
+ bb.add_func(_copy_single_page(num_key_value_heads,
page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else
_copy_single_page_mla(page_size, qk_head_dim, dtype, target),
"kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers,
num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads,
qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
]