This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 460000202e [KVCache] Support passing in attn_score_scaling_factor into 
KV cache (#16606)
460000202e is described below

commit 460000202e7741e98cabe89c1d6a3ea2810f6049
Author: Rick Zhou <[email protected]>
AuthorDate: Mon Feb 19 23:23:58 2024 -0500

    [KVCache] Support passing in attn_score_scaling_factor into KV cache 
(#16606)
    
    In GPT-2, attention calculation requires an additional feature
    `scale_attn_by_inverse_layer_idx`. It provides a scaling factor
    per attention layer when calculating the attention score,
    before applying the softmax function.
    
    This PR supports this additional parameter in KV cache.
---
 3rdparty/flashinfer                                |  2 +-
 src/runtime/relax_vm/kv_cache.h                    |  5 ++-
 src/runtime/relax_vm/paged_kv_cache.cc             | 45 ++++++++++++----------
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  4 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 13 ++++---
 5 files changed, 39 insertions(+), 30 deletions(-)

diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 47686efcad..f1f6a0de4e 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 47686efcad186096250ba0f1209ed63ceaaeea58
+Subproject commit f1f6a0de4e595b777e29cc0dc370c15bd1d736fb
diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_cache.h
index b201ab93f6..82e32b3af5 100644
--- a/src/runtime/relax_vm/kv_cache.h
+++ b/src/runtime/relax_vm/kv_cache.h
@@ -135,7 +135,8 @@ class AttentionKVCache : public Object {
    * \param o_data The output O data, in layout `(total_length, num_qo_heads, 
head_dim)`.
    */
   virtual void Attention(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
-                         Optional<NDArray> mask, NDArray o_data) = 0;
+                         Optional<NDArray> mask, NDArray o_data,
+                         double attn_score_scaling_factor) = 0;
 
   /*!
    * \brief Compute attention with Q/K/V data which are concatenated along
@@ -148,7 +149,7 @@ class AttentionKVCache : public Object {
    * \sa AttentionKVCache::Attention
    */
   virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
-                                     NDArray o_data) = 0;
+                                     NDArray o_data, double 
attn_score_scaling_factor) = 0;
 
   /************** Positions **************/
 
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index d5ddef7527..70fa3daee7 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -645,7 +645,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
   }
 
   void Attention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray 
v_data,
-                 Optional<NDArray> mask, NDArray o_data) final {
+                 Optional<NDArray> mask, NDArray o_data, double 
attn_score_scaling_factor) final {
     // Part 1. Shape and dtype check.
     NDArray pages = pages_[layer_id];
     CHECK(q_data.DataType() == pages.DataType());
@@ -695,11 +695,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
     // Part 3: append k/v data to kv-cache
     f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
     // Part 4: perform attention
-    AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
+    AttentionInternal(layer_id, q_data, k_data, v_data, o_data, 
attn_score_scaling_factor);
   }
 
   void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
-                             NDArray o_data) final {
+                             NDArray o_data, double attn_score_scaling_factor) 
final {
     // Part 1. Shape and dtype check.
     NDArray pages = pages_[layer_id];
     CHECK(qkv_data.DataType() == pages.DataType());
@@ -743,7 +743,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
     // Part 3: append k/v data to kv-cache
     f_transpose_append_(pages_[layer_id], k_data, v_data, 
append_position_map_view_);
     // Part 4: perform attention
-    AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
+    AttentionInternal(layer_id, q_data, k_data, v_data, o_data, 
attn_score_scaling_factor);
   }
 
   NDArray GetQueryPositions() const final {
@@ -992,22 +992,24 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
    * input k/v data and the k/v data in cache on the given layer.
    */
   void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
-                         NDArray output) {
+                         NDArray output, double attn_score_scaling_factor) {
     CHECK_GE(num_depths_, 1) << "The number of effective depths must be 
greater or equal to 1.";
     if (append_before_attn_) {
       f_attention_decode_(
           /*depth=*/0, q_data, pages_[layer_id], 
page_indptr_on_depths_view_[0],
           page_indices_on_depths_view_[0], last_page_len_on_depths_view_[0],
           k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, 
merged_attn_scores_view_,
-          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_);
+          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
+          attn_score_scaling_factor);
     } else {
       // Compute appended text self-attention
-      f_attention_prefill_ragged_.value()(
-          q_data, cur_append_length_indptr_view_, k_data, v_data, 
cur_append_length_indptr_view_,
-          q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, output,
-          merged_attn_scores_view_,
-          /*causal=*/1,
-          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_);
+      f_attention_prefill_ragged_.value()(q_data, 
cur_append_length_indptr_view_, k_data, v_data,
+                                          cur_append_length_indptr_view_, 
q_rope_position_map_view_,
+                                          k_ragged_rope_pos_offset_view_, 
output,
+                                          merged_attn_scores_view_,
+                                          /*causal=*/1,
+                                          /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline,
+                                          rotary_scale_, rotary_theta_, 
attn_score_scaling_factor);
 
       for (int d = 0; d < num_depths_; ++d) {
         if (page_indices_on_depths_view_[d]->shape[0] == 0) {
@@ -1020,7 +1022,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
                               k_rope_pos_offset_view_[d], 
q_rope_position_map_view_,
                               temp_attn_output_view_, temp_attn_scores_view_,
                               /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, 
rotary_scale_,
-                              rotary_theta_);
+                              rotary_theta_, attn_score_scaling_factor);
         } else {
           // Use prefill kernel for depth d
           f_attention_prefill_(
@@ -1029,7 +1031,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
               last_page_len_on_depths_view_[d], k_rope_pos_offset_view_[d],
               q_rope_position_map_view_, temp_attn_output_view_, 
temp_attn_scores_view_,
               /*causal=*/0,
-              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_);
+              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, 
rotary_theta_,
+              attn_score_scaling_factor);
         }
         f_merge_inplace_.value()(output, merged_attn_scores_view_, 
temp_attn_output_view_,
                                  temp_attn_scores_view_);
@@ -1245,15 +1248,17 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
     
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
-    .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id, 
NDArray q_data,
-                       NDArray k_data, NDArray v_data, NDArray o_data) {
+    .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+                       double attn_score_scaling_factor, NDArray q_data, 
NDArray k_data,
+                       NDArray v_data, NDArray o_data) {
       kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), 
std::move(v_data),
-                          NullOpt, std::move(o_data));
+                          NullOpt, std::move(o_data), 
attn_score_scaling_factor);
     });
 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv")
-    .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id, 
NDArray qkv_data,
-                       NDArray o_data) {
-      kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, 
std::move(o_data));
+    .set_body_typed([](PagedAttentionKVCache kv_cache, int64_t layer_id,
+                       double attn_score_scaling_factor, NDArray qkv_data, 
NDArray o_data) {
+      kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, 
std::move(o_data),
+                                      attn_score_scaling_factor);
     });
 
 }  // namespace relax_vm
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index fccef312c1..967e71ecd3 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -447,11 +447,11 @@ def apply_attention(
             keys = tvm.nd.array(keys_np, device=device)
             values = tvm.nd.array(values_np, device=device)
             outputs = tvm.nd.empty(queries.shape, dtype, device=device)
-            fattention(kv_cache, layer_id, queries, keys, values, outputs)
+            fattention(kv_cache, layer_id, 1.0, queries, keys, values, outputs)
         else:
             qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, 
values_np], axis=1), device)
             outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
-            fattention_with_fuse_qkv(kv_cache, layer_id, qkv, outputs)
+            fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
 
         # Compute attention expected results.
         outputs = np.expand_dims(outputs.numpy(), axis=0)
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 8bd9da3bbb..2a4f7e87bd 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -273,11 +273,11 @@ def apply_attention(
             keys = tvm.nd.array(keys_np, device=device)
             values = tvm.nd.array(values_np, device=device)
             outputs = tvm.nd.empty(queries.shape, dtype, device=device)
-            fattention(kv_cache, layer_id, queries, keys, values, outputs)
+            fattention(kv_cache, layer_id, 1.0, queries, keys, values, outputs)
         else:
             qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, 
values_np], axis=1), device)
             outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
-            fattention_with_fuse_qkv(kv_cache, layer_id, qkv, outputs)
+            fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
 
         # Compute attention expected results.
         outputs = np.expand_dims(outputs.numpy(), axis=0)
@@ -751,6 +751,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
         rotary_mode: T.int32,
         rope_scale: T.float32,
         rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
     ):
         batch_size = T.int32(is_size_var=True)
         total_len = T.int32(is_size_var=True)
@@ -901,7 +902,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
                                                     i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
                                                     with T.init():
                                                         S_local[i, j] = 0.0
-                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * sm_scale
+                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
                                         T.tvm_storage_sync("shared")
                                         for li, lj in T.grid(tile_x, tile_z):
                                             with T.block("S_store"):
@@ -1083,6 +1084,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
         rotary_mode: T.int32,
         rope_scale: T.float32,
         rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
     ):
         T.func_attr({"tir.is_scheduled": 1})
         B = T.int32(is_size_var=True)
@@ -1194,7 +1196,7 @@ def _attention_decode(num_kv_heads, num_qo_heads, 
head_dim, qkv_dtype):
                                         # compute S = Q * K * sm_scale
                                         S_reduce_local[0] = 0
                                         for vec in T.serial(VEC_SIZE):
-                                            S_reduce_local[0] += Q_local[vec] 
* K_local[vec] * sm_scale
+                                            S_reduce_local[0] += Q_local[vec] 
* K_local[vec] * attn_score_scaling_factor * sm_scale
 
                                         with T.block("block_cross_thread"):
                                             T.reads(S_reduce_local[0])
@@ -1304,6 +1306,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
         rotary_mode: T.int32,
         rope_scale: T.float32,
         rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
     ):
         batch_size = T.int32(is_size_var=True)
         qo_len = T.int32(is_size_var=True)
@@ -1442,7 +1445,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
                                                     i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
                                                     with T.init():
                                                         S_local[i, j] = 0.0
-                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * sm_scale
+                                                    S_local[i, j] += Q_smem[i, 
k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
                                         T.tvm_storage_sync("shared")
                                         for li, lj in T.grid(tile_x, tile_z):
                                             with T.block("S_store"):

Reply via email to