This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d9f0838dad [KVCache] Fix kernel dispatch based on attention kinds 
(#18122)
d9f0838dad is described below

commit d9f0838dad8cdea4f6d6cba361ac35436f4b2f97
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 7 19:29:55 2025 -0400

    [KVCache] Fix kernel dispatch based on attention kinds (#18122)
    
    * [KVCache] Fix kernel dispatch based on attention kinds
    
    This PR fixes a few kernel dispatch issues due to the recent
    introduction of `mha_sliding` as a new attention kind.
    
    Tested on Qwen3 1.7B with MLC-LLM.
    
    * Fix lint
    
    ---------
    
    Co-authored-by: Yong Wu <[email protected]>
---
 python/tvm/relax/frontend/nn/llm/kv_cache.py | 52 +++++++++++++---------------
 1 file changed, 24 insertions(+), 28 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index a1d742739a..e6e171da99 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -374,20 +374,15 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
         if rope_mode == RopeMode.INLINE:
             assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not 
support partial rotary dim."
 
+        attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else 
attn_kind
+        if attn_kind_single == "mha_sliding":
+            attn_kind_single = "mha"
         flashinfer_prefill_mods = 
rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
             dtype_q=dtype,
             dtype_kv=dtype,
             dtype_o=dtype,
-            qk_head_dim=(
-                qk_head_dim
-                if (attn_kind == "mha" or isinstance(attn_kind, List))
-                else mla_original_qk_head_dim
-            ),
-            v_head_dim=(
-                v_head_dim
-                if (attn_kind == "mha" or isinstance(attn_kind, List))
-                else mla_original_v_head_dim
-            ),
+            qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else 
mla_original_qk_head_dim),
+            v_head_dim=(v_head_dim if attn_kind_single == "mha" else 
mla_original_v_head_dim),
             target=target,
             enable_inline_rope=rope_mode == RopeMode.INLINE,
         )
@@ -400,7 +395,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
                 v_head_dim=v_head_dim,
                 target=target,
             )
-            if (attn_kind == "mha" or isinstance(attn_kind, List))
+            if attn_kind_single == "mha"
             else []
         )
         flashinfer_mla_mods = (
@@ -412,7 +407,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
                 head_dim_kpe=qk_head_dim - v_head_dim,
                 target=target,
             )
-            if attn_kind == "mla"
+            if attn_kind_single == "mla"
             else []
         )
         self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + 
flashinfer_mla_mods
@@ -429,21 +424,21 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
                 rx.Tuple([rx.StringImm("tir"), 
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, 
num_attention_heads, qk_head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
                 rx.Tuple([rx.StringImm("tir"), 
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, 
dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
             ]
-            if (attn_kind == "mha" or isinstance(attn_kind, List))
+            if attn_kind_single == "mha"
             else [rx.Tuple([]) for _ in range(6)]
         )
-        mla_function = rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_mla_paged_attention_run"), 
rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else [])
+        mla_function = rx.Tuple([rx.StringImm("flashinfer"), 
rx.ExternFunc("batch_mla_paged_attention_run"), 
rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" 
else [])
         attn_merge_functions = [
             bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, 
dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
         ]
-        if attn_kind == "mla":
+        if attn_kind_single == "mla":
             
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads,
 mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), 
"tir_attention_merge_state_mla"))
 
-
         if isinstance(attn_kind, List):
             attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for 
layer_kind in attn_kind]
         else:
             attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in 
range(num_hidden_layers)]
+
         args = [
             rx.ShapeExpr(
                 [
@@ -459,9 +454,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             rx.PrimValue(num_key_value_heads),
             rx.PrimValue(qk_head_dim),
             rx.PrimValue(v_head_dim),
-            rx.ShapeExpr(
-                [int(getattr(AttnKind, attn_kind.upper())) for _ in 
range(num_hidden_layers)]
-            ),
+            rx.ShapeExpr(attn_kind),
             rx.PrimValue(enable_disaggregation),
             rx.PrimValue(rope_mode),
             rx.PrimValue(rope_scale),
@@ -475,7 +468,7 @@ class FlashInferPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-me
             mla_function,
             rx.Tuple(attn_merge_functions),
             bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, 
qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, 
rotary_dim), "tir_split_rotary"),
-            bb.add_func(_copy_single_page(num_key_value_heads, page_size, 
qk_head_dim, dtype, target) if attn_kind == "mha" else 
_copy_single_page_mla(page_size, qk_head_dim, dtype, target), 
"kv_cache_copy_single_page"),
+            bb.add_func(_copy_single_page(num_key_value_heads, page_size, 
qk_head_dim, dtype, target) if attn_kind_single == "mha" else 
_copy_single_page_mla(page_size, qk_head_dim, dtype, target), 
"kv_cache_copy_single_page"),
             bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, 
num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
             bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, 
dtype, target), "kv_cache_compact_kv_copy"),
             # fmt: on
@@ -567,6 +560,9 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
         target : Target
             The target to build the model to.
         """
+        attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else 
attn_kind
+        if attn_kind_single == "mha_sliding":
+            attn_kind_single = "mha"
         if isinstance(attn_kind, List):
             attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for 
layer_kind in attn_kind]
         else:
@@ -605,7 +601,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
         ]
 
         if str(target.kind) == "llvm":
-            if attn_kind == "mla":
+            if attn_kind_single == "mla":
                 raise ValueError("MLA is not supported in TIR kernels for 
now.")
             # pylint: disable=line-too-long
             # fmt: off
@@ -631,9 +627,9 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
         else:
             # pylint: disable=line-too-long
             # fmt: off
-            ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or 
isinstance(attn_kind, List)) else mla_original_qk_head_dim
-            ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or 
isinstance(attn_kind, List)) else mla_original_v_head_dim
-            args.append(rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == 
"mha" or isinstance(attn_kind, List)) else num_attention_heads, 
num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, 
rope_scaling, target), "tir_attention_prefill_ragged")]))
+            ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else 
mla_original_qk_head_dim
+            ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else 
mla_original_v_head_dim
+            args.append(rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single 
== "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, 
ragged_v_head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_ragged")]))
             mha_functions = (
                 [
                     rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, 
qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]),
@@ -643,14 +639,14 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
                     rx.Tuple([rx.StringImm("tir"), 
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, 
num_attention_heads, qk_head_dim, dtype, rope_scaling, target), 
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
                     rx.Tuple([rx.StringImm("tir"), 
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, 
dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
                 ]
-                if (attn_kind == "mha" or isinstance(attn_kind, List))
+                if attn_kind_single == "mha"
                 else [rx.Tuple([]) for _ in range(6)]
             )
-            mla_function = rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim 
- v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind 
== "mla" else [])
+            mla_function = rx.Tuple([rx.StringImm("tir"), 
bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim 
- v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if 
attn_kind_single == "mla" else [])
             attn_merge_functions = [
                 bb.add_func(_merge_state_inplace(num_attention_heads, 
v_head_dim, dtype, target, "tir_attention_merge_state"), 
"tir_attention_merge_state"),
             ]
-            if attn_kind == "mla":
+            if attn_kind_single == "mla":
                 
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads,
 mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), 
"tir_attention_merge_state_mla"))
             args.extend(mha_functions)
             args.append(mla_function)
@@ -658,7 +654,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
                 [
                     rx.Tuple(attn_merge_functions),
                     bb.add_func(llama_rope_with_position_map(rope_theta, 
rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, 
rope_scaling, rotary_dim), "tir_split_rotary"),
-                    bb.add_func(_copy_single_page(num_key_value_heads, 
page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or 
isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, 
dtype, target), "kv_cache_copy_single_page"),
+                    bb.add_func(_copy_single_page(num_key_value_heads, 
page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else 
_copy_single_page_mla(page_size, qk_head_dim, dtype, target), 
"kv_cache_copy_single_page"),
                     bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, 
num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
                     bb.add_func(_compact_kv_copy(num_key_value_heads, 
qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
                 ]

Reply via email to