This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e468426bfd [Fix][Relax] Add the missing tree-attn func arg for KV
cache creation (#17345)
e468426bfd is described below
commit e468426bfd43fadb555ef0e561b9047a5d89852e
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Sep 8 06:42:06 2024 -0400
[Fix][Relax] Add the missing tree-attn func arg for KV cache creation
(#17345)
This PR fixes the TIRPagedKVCache construction issue, which is caused
by missing the tree-attention with paged KV cache kernel.
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 7b14c67a2e..ae0537f0d9 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -375,6 +375,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers,
num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype,
target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
+ bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
# fmt: on
# pylint: enable=line-too-long