This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c4afbb5ea [Relax][KV Cache] Refactor `_attention_sequence_prefill`
function to … (#17362)
2c4afbb5ea is described below
commit 2c4afbb5eace6c52f30d35a5c70465ca63c27a0f
Author: Mengshiun Yu <[email protected]>
AuthorDate: Wed Sep 11 09:55:35 2024 -0400
[Relax][KV Cache] Refactor `_attention_sequence_prefill` function to …
(#17362)
This PR removes batch_size from the function signature,
instead of mapping it within the function body.
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index ae0537f0d9..9b16fc2fbf 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -1237,7 +1237,7 @@ def _merge_state_inplace(num_heads, head_dim, v_dtype,
target: Target):
def _attention_sequence_prefill(
- batch_size, h_kv, h_q, d, dtype, target: Target, causal=0,
attn_score_scaling_factor=1.0
+ h_kv, h_q, d, dtype, target: Target, causal=0,
attn_score_scaling_factor=1.0
): # pylint: disable=line-too-long
LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes
group_size = h_q // h_kv
@@ -1264,6 +1264,7 @@ def _attention_sequence_prefill(
var_output: T.handle, # [total_len, h_q, d]
var_lse: T.handle # [total_len, h_q]
):
+ batch_size = T.int32(is_size_var=True)
qo_len = T.int32(is_size_var=True)
kv_len = T.int32(is_size_var=True)
q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)