This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a9436b8154 [Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache
(#16746)
a9436b8154 is described below
commit a9436b81542c74d2b8ca7e15d561ec985d9912a8
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 20 02:34:28 2024 -0400
[Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache (#16746)
Since #16692 introduced the copy stream separation, the function
`GetQueryPositions` also needs to eagerly call sync to work
properly. This PR fixes the previous wrong behavior.
---
src/runtime/relax_vm/kv_state.h | 2 +-
src/runtime/relax_vm/paged_kv_cache.cc | 9 +++++----
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 2227944b86..f6857a9dce 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -159,7 +159,7 @@ class AttentionKVCacheObj : public KVStateObj {
* This function is supposed to be invoked after calling BeginForward.
* \return The in-sequence query positions, in shape `(total_length,)`.
*/
- virtual NDArray GetQueryPositions() const = 0;
+ virtual NDArray GetQueryPositions() = 0;
/************** Debug Helpers **************/
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index 0c64800cec..9c3ee5d427 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -838,10 +838,11 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
}
- NDArray GetQueryPositions() const final {
- CHECK(!dirty_aux_data_device_)
- << "The auxiliary arrays are not synchronized to device. Please call "
- "`BeginForward` to synchronize before calling `GetQueryPositions`.";
+ NDArray GetQueryPositions() final {
+ // Sync the copy stream and the compute stream.
+ ComputeStreamWaitForCopyStream();
+ // The auxiliary data structure on device must have been synchronized.
+ ICHECK(!dirty_aux_data_device_);
return q_rope_position_map_view_;
};