This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a9436b8154 [Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache 
(#16746)
a9436b8154 is described below

commit a9436b81542c74d2b8ca7e15d561ec985d9912a8
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 20 02:34:28 2024 -0400

    [Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache (#16746)
    
    Since #16692 introduced the copy stream separation, the function
    `GetQueryPositions` also needs to eagerly call sync to work
    properly. This PR fixes the previous wrong behavior.
---
 src/runtime/relax_vm/kv_state.h        | 2 +-
 src/runtime/relax_vm/paged_kv_cache.cc | 9 +++++----
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 2227944b86..f6857a9dce 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -159,7 +159,7 @@ class AttentionKVCacheObj : public KVStateObj {
    * This function is supposed to be invoked after calling BeginForward.
    * \return The in-sequence query positions, in shape `(total_length,)`.
    */
-  virtual NDArray GetQueryPositions() const = 0;
+  virtual NDArray GetQueryPositions() = 0;
 
   /************** Debug Helpers **************/
 
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 0c64800cec..9c3ee5d427 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -838,10 +838,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
-  NDArray GetQueryPositions() const final {
-    CHECK(!dirty_aux_data_device_)
-        << "The auxiliary arrays are not synchronized to device. Please call "
-           "`BeginForward` to synchronize before calling `GetQueryPositions`.";
+  NDArray GetQueryPositions() final {
+    // Sync the copy stream and the compute stream.
+    ComputeStreamWaitForCopyStream();
+    // The auxiliary data structure on device must have been synchronized.
+    ICHECK(!dirty_aux_data_device_);
     return q_rope_position_map_view_;
   };
 

Reply via email to