This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new efc2ae9846 [KVCache] Support returning query positions (#16578)
efc2ae9846 is described below
commit efc2ae9846054506ee4596d49c55b6ecc7c89800
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 16 07:53:16 2024 -0500
[KVCache] Support returning query positions (#16578)
This PR adds a new function to PagedKVCache to
return in-sequence positions for each location
in a batch of sequences that is being forwarded.
This function helps apply positional embeddings
for language models that do not use Rotary positional
embeddings.
---
src/runtime/relax_vm/kv_cache.h | 9 +++++++++
src/runtime/relax_vm/paged_kv_cache.cc | 9 +++++++++
2 files changed, 18 insertions(+)
diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_cache.h
index 4f4b538cb3..b201ab93f6 100644
--- a/src/runtime/relax_vm/kv_cache.h
+++ b/src/runtime/relax_vm/kv_cache.h
@@ -150,6 +150,15 @@ class AttentionKVCache : public Object {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data,
Optional<NDArray> mask,
NDArray o_data) = 0;
+ /************** Positions **************/
+
+ /*!
+ * \brief Get the in-sequence positions of each slot in the query.
+ * This function is supposed to be invoked after calling BeginForward.
+ * \return The in-sequence query positions, in shape `(total_length,)`.
+ */
+ virtual NDArray GetQueryPositions() const = 0;
+
/************** Debug Helpers **************/
/*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index a8c38ca4ed..7417d90e02 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -744,6 +744,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
}
+ NDArray GetQueryPositions() const final {
+ CHECK(!dirty_aux_data_device_)
+ << "The auxiliary arrays are not synchronized to device. Please call "
+ "`BeginForward` to synchronize before calling `GetQueryPositions`.";
+ return q_rope_position_map_view_;
+ };
+
void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray
k_data,
NDArray v_data) final {
CHECK(f_debug_get_kv_.defined())
@@ -1231,6 +1238,8 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::BeginForward);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::EndForward);
+TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
+
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")