This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new efc2ae9846 [KVCache] Support returning query positions (#16578)
efc2ae9846 is described below

commit efc2ae9846054506ee4596d49c55b6ecc7c89800
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 16 07:53:16 2024 -0500

    [KVCache] Support returning query positions (#16578)
    
    This PR adds a new function to PagedKVCache to
    return in-sequence positions for each location
    in a batch of sequences that is being forwarded.
    This function helps apply positional embeddings
    for language models that do not use Rotary positional
    embeddings.
---
 src/runtime/relax_vm/kv_cache.h        | 9 +++++++++
 src/runtime/relax_vm/paged_kv_cache.cc | 9 +++++++++
 2 files changed, 18 insertions(+)

diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_cache.h
index 4f4b538cb3..b201ab93f6 100644
--- a/src/runtime/relax_vm/kv_cache.h
+++ b/src/runtime/relax_vm/kv_cache.h
@@ -150,6 +150,15 @@ class AttentionKVCache : public Object {
   virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
                                      NDArray o_data) = 0;
 
+  /************** Positions **************/
+
+  /*!
+   * \brief Get the in-sequence positions of each slot in the query.
+   * This function is supposed to be invoked after calling BeginForward.
+   * \return The in-sequence query positions, in shape `(total_length,)`.
+   */
+  virtual NDArray GetQueryPositions() const = 0;
+
   /************** Debug Helpers **************/
 
   /*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index a8c38ca4ed..7417d90e02 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -744,6 +744,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
     AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
   }
 
+  NDArray GetQueryPositions() const final {
+    CHECK(!dirty_aux_data_device_)
+        << "The auxiliary arrays are not synchronized to device. Please call "
+           "`BeginForward` to synchronize before calling `GetQueryPositions`.";
+    return q_rope_position_map_view_;
+  };
+
   void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray 
k_data,
                   NDArray v_data) final {
     CHECK(f_debug_get_kv_.defined())
@@ -1231,6 +1238,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward")
     
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::BeginForward);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward")
     
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::EndForward);
+TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
+    
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetQueryPositions);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
     
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")

Reply via email to