This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 460000202e [KVCache] Support passing in attn_score_scaling_factor into
KV cache (#16606)
460000202e is described below
commit 460000202e7741e98cabe89c1d6a3ea2810f6049
Author: Rick Zhou <[email protected]>
AuthorDate: Mon Feb 19 23:23:58 2024 -0500
[KVCache] Support passing in attn_score_scaling_factor into KV cache
(#16606)
In GPT-2, attention calculation requires an additional feature
`scale_attn_by_inverse_layer_idx`. It provides a scaling factor
per attention layer when calculating the attention score,
before applying the softmax function.
This PR supports this additional parameter in KV cache.
---
3rdparty/flashinfer | 2 +-
src/runtime/relax_vm/kv_cache.h | 5 ++-
src/runtime/relax_vm/paged_kv_cache.cc | 45 ++++++++++++----------
..._builtin_paged_attention_kv_cache_flashinfer.py | 4 +-
...runtime_builtin_paged_attention_kv_cache_tir.py | 13 ++++---
5 files changed, 39 insertions(+), 30 deletions(-)
diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 47686efcad..f1f6a0de4e 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 47686efcad186096250ba0f1209ed63ceaaeea58
+Subproject commit f1f6a0de4e595b777e29cc0dc370c15bd1d736fb
diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_cache.h
index b201ab93f6..82e32b3af5 100644
--- a/src/runtime/relax_vm/kv_cache.h
+++ b/src/runtime/relax_vm/kv_cache.h
@@ -135,7 +135,8 @@ class AttentionKVCache : public Object {
* \param o_data The output O data, in layout `(total_length, num_qo_heads,
head_dim)`.
*/
virtual void Attention(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data,
- Optional<NDArray> mask, NDArray o_data) = 0;
+ Optional<NDArray> mask, NDArray o_data,
+ double attn_score_scaling_factor) = 0;
/*!
* \brief Compute attention with Q/K/V data which are concatenated along
@@ -148,7 +149,7 @@ class AttentionKVCache : public Object {
* \sa AttentionKVCache::Attention
*/
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data,
Optional<NDArray> mask,
- NDArray o_data) = 0;
+ NDArray o_data, double
attn_score_scaling_factor) = 0;
/************** Positions **************/
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index d5ddef7527..70fa3daee7 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -645,7 +645,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
}
void Attention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray
v_data,
- Optional<NDArray> mask, NDArray o_data) final {
+ Optional<NDArray> mask, NDArray o_data, double
attn_score_scaling_factor) final {
// Part 1. Shape and dtype check.
NDArray pages = pages_[layer_id];
CHECK(q_data.DataType() == pages.DataType());
@@ -695,11 +695,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
// Part 3: append k/v data to kv-cache
f_transpose_append_(pages_[layer_id], k_data, v_data,
append_position_map_view_);
// Part 4: perform attention
- AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
+ AttentionInternal(layer_id, q_data, k_data, v_data, o_data,
attn_score_scaling_factor);
}
void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data,
Optional<NDArray> mask,
- NDArray o_data) final {
+ NDArray o_data, double attn_score_scaling_factor)
final {
// Part 1. Shape and dtype check.
NDArray pages = pages_[layer_id];
CHECK(qkv_data.DataType() == pages.DataType());
@@ -743,7 +743,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
// Part 3: append k/v data to kv-cache
f_transpose_append_(pages_[layer_id], k_data, v_data,
append_position_map_view_);
// Part 4: perform attention
- AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
+ AttentionInternal(layer_id, q_data, k_data, v_data, o_data,
attn_score_scaling_factor);
}
NDArray GetQueryPositions() const final {
@@ -992,22 +992,24 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
* input k/v data and the k/v data in cache on the given layer.
*/
void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data,
- NDArray output) {
+ NDArray output, double attn_score_scaling_factor) {
CHECK_GE(num_depths_, 1) << "The number of effective depths must be
greater or equal to 1.";
if (append_before_attn_) {
f_attention_decode_(
/*depth=*/0, q_data, pages_[layer_id],
page_indptr_on_depths_view_[0],
page_indices_on_depths_view_[0], last_page_len_on_depths_view_[0],
k_rope_pos_offset_view_[0], q_rope_position_map_view_, output,
merged_attn_scores_view_,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_);
+ /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_,
+ attn_score_scaling_factor);
} else {
// Compute appended text self-attention
- f_attention_prefill_ragged_.value()(
- q_data, cur_append_length_indptr_view_, k_data, v_data,
cur_append_length_indptr_view_,
- q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, output,
- merged_attn_scores_view_,
- /*causal=*/1,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_);
+ f_attention_prefill_ragged_.value()(q_data,
cur_append_length_indptr_view_, k_data, v_data,
+ cur_append_length_indptr_view_,
q_rope_position_map_view_,
+ k_ragged_rope_pos_offset_view_,
output,
+ merged_attn_scores_view_,
+ /*causal=*/1,
+ /*rotary_mode=*/rope_mode_ ==
RoPEMode::kInline,
+ rotary_scale_, rotary_theta_,
attn_score_scaling_factor);
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
@@ -1020,7 +1022,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
k_rope_pos_offset_view_[d],
q_rope_position_map_view_,
temp_attn_output_view_, temp_attn_scores_view_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline,
rotary_scale_,
- rotary_theta_);
+ rotary_theta_, attn_score_scaling_factor);
} else {
// Use prefill kernel for depth d
f_attention_prefill_(
@@ -1029,7 +1031,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
last_page_len_on_depths_view_[d], k_rope_pos_offset_view_[d],
q_rope_position_map_view_, temp_attn_output_view_,
temp_attn_scores_view_,
/*causal=*/0,
- /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_);
+ /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_,
rotary_theta_,
+ attn_score_scaling_factor);
}
f_merge_inplace_.value()(output, merged_attn_scores_view_,
temp_attn_output_view_,
temp_attn_scores_view_);
@@ -1245,15 +1248,17 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
- .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
NDArray q_data,
- NDArray k_data, NDArray v_data, NDArray o_data) {
+ .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+ double attn_score_scaling_factor, NDArray q_data,
NDArray k_data,
+ NDArray v_data, NDArray o_data) {
kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data),
std::move(v_data),
- NullOpt, std::move(o_data));
+ NullOpt, std::move(o_data),
attn_score_scaling_factor);
});
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv")
- .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
NDArray qkv_data,
- NDArray o_data) {
- kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt,
std::move(o_data));
+ .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+ double attn_score_scaling_factor, NDArray qkv_data,
NDArray o_data) {
+ kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt,
std::move(o_data),
+ attn_score_scaling_factor);
});
} // namespace relax_vm
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index fccef312c1..967e71ecd3 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -447,11 +447,11 @@ def apply_attention(
keys = tvm.nd.array(keys_np, device=device)
values = tvm.nd.array(values_np, device=device)
outputs = tvm.nd.empty(queries.shape, dtype, device=device)
- fattention(kv_cache, layer_id, queries, keys, values, outputs)
+ fattention(kv_cache, layer_id, 1.0, queries, keys, values, outputs)
else:
qkv = tvm.nd.array(np.concatenate([queries_np, keys_np,
values_np], axis=1), device)
outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
- fattention_with_fuse_qkv(kv_cache, layer_id, qkv, outputs)
+ fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
# Compute attention expected results.
outputs = np.expand_dims(outputs.numpy(), axis=0)
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 8bd9da3bbb..2a4f7e87bd 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -273,11 +273,11 @@ def apply_attention(
keys = tvm.nd.array(keys_np, device=device)
values = tvm.nd.array(values_np, device=device)
outputs = tvm.nd.empty(queries.shape, dtype, device=device)
- fattention(kv_cache, layer_id, queries, keys, values, outputs)
+ fattention(kv_cache, layer_id, 1.0, queries, keys, values, outputs)
else:
qkv = tvm.nd.array(np.concatenate([queries_np, keys_np,
values_np], axis=1), device)
outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
- fattention_with_fuse_qkv(kv_cache, layer_id, qkv, outputs)
+ fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
# Compute attention expected results.
outputs = np.expand_dims(outputs.numpy(), axis=0)
@@ -751,6 +751,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
rotary_mode: T.int32,
rope_scale: T.float32,
rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
):
batch_size = T.int32(is_size_var=True)
total_len = T.int32(is_size_var=True)
@@ -901,7 +902,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
i, j, k =
T.axis.remap("SSR", [li, lj, lk])
with T.init():
S_local[i, j] = 0.0
- S_local[i, j] += Q_smem[i,
k] * K_smem[j, k] * sm_scale
+ S_local[i, j] += Q_smem[i,
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
T.tvm_storage_sync("shared")
for li, lj in T.grid(tile_x, tile_z):
with T.block("S_store"):
@@ -1083,6 +1084,7 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
rotary_mode: T.int32,
rope_scale: T.float32,
rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
):
T.func_attr({"tir.is_scheduled": 1})
B = T.int32(is_size_var=True)
@@ -1194,7 +1196,7 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
# compute S = Q * K * sm_scale
S_reduce_local[0] = 0
for vec in T.serial(VEC_SIZE):
- S_reduce_local[0] += Q_local[vec]
* K_local[vec] * sm_scale
+ S_reduce_local[0] += Q_local[vec]
* K_local[vec] * attn_score_scaling_factor * sm_scale
with T.block("block_cross_thread"):
T.reads(S_reduce_local[0])
@@ -1304,6 +1306,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
rotary_mode: T.int32,
rope_scale: T.float32,
rope_theta: T.float32,
+ attn_score_scaling_factor: T.float32,
):
batch_size = T.int32(is_size_var=True)
qo_len = T.int32(is_size_var=True)
@@ -1442,7 +1445,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
i, j, k =
T.axis.remap("SSR", [li, lj, lk])
with T.init():
S_local[i, j] = 0.0
- S_local[i, j] += Q_smem[i,
k] * K_smem[j, k] * sm_scale
+ S_local[i, j] += Q_smem[i,
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
T.tvm_storage_sync("shared")
for li, lj in T.grid(tile_x, tile_z):
with T.block("S_store"):