This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 384b7f7a74 [KVCache] Introducing auxiliary data manager (#16824)
384b7f7a74 is described below

commit 384b7f7a74453c4e5e2d2b01549b00314d20bac2
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Apr 1 08:06:38 2024 -0400

    [KVCache] Introducing auxiliary data manager (#16824)
    
    This PR introduces class `PagedKVCacheAuxDataManager` for PagedKVCache.
    This class manages all the integer auxiliary data required for
    paged attention and other KV cache operations, such as page table
    arrays, position arrays, etc..
    
    The purpose of introducing this class is because prior to this PR,
    for each auxiliary array we issue a host-to-device copy. This may
    cause extra overhead, since these auxiliary array are usually
    lightweight. One simple idea is to "merge" all the auxiliary arrays
    into a single one, and taking slices of this large array for each
    original auxiliary array. By doing this, we enable to issue only
    one single host-to-device copy for the auxiliary arrays altogether.
    
    The intrduction of `PagedKVCacheAuxDataManager` abstracts the
    interface that PagedKVCache copies host arrays to device arrays,
    enabling us to support both the previous way of copying and the
    new way.
    
    To support slicing for attention-related TIR functions, we introduce
    `elem_offset` match in TIR functions in this PR.
    This PR also bumps FlashInfer to support the auxiliary array slicing.
---
 3rdparty/flashinfer                                |   2 +-
 CMakeLists.txt                                     |   2 +
 src/runtime/relax_vm/paged_kv_cache.cc             | 539 ++++++++++++++++-----
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  19 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py |  68 ++-
 5 files changed, 480 insertions(+), 150 deletions(-)

diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 0d04571b61..b20a460a82 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 0d04571b614c944b5831d080882107a98b9c6e65
+Subproject commit b20a460a82a457824182056aaa2c45d5d156791e
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d02a788279..435fe3b35b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -956,9 +956,11 @@ if (USE_FLASHINFER STREQUAL "ON")
   set(FLASHINFER_TVM_BINDING ON)
   set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR})
   set(FLASHINFER_ENABLE_FP8 OFF)
+  set(FLASHINFER_ENABLE_BF16 OFF)
   set(FLASHINFER_PREFILL OFF)
   set(FLASHINFER_DECODE OFF)
   set(FLASHINFER_PAGE OFF)
+  set(FLASHINFER_CASCADE OFF)
   add_subdirectory(3rdparty/flashinfer)
 else ()
   message(STATUS "Build without FlashInfer")
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 3ccab3826d..1e674d0ec6 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -22,6 +22,7 @@
  */
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/logging.h>
+#include <tvm/runtime/memory/memory_manager.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/registry.h>
 
@@ -190,6 +191,384 @@ enum class RoPEMode : int {
   kInline = 2,
 };
 
+/*!
+ * \brief The paged attention auxiliary data manager class.
+ * This class manages all the int32 auxiliary data on GPU device, such as
+ * page table, position arrays, etc..
+ *
+ * The core functions of this class is `CopyXXXAsync` and `CommitCopy`.
+ * `CopyXXXAsync` takes the input data on CPU host, and copy the input data
+ * to GPU in an asynchronous way, and returns the NDArray view of the data
+ * on GPU device.
+ *
+ * Being asynchronous here means the `CopyXXXAsync` function may not perform
+ * data copy from CPU to GPU at the time of being called. Therefore, the
+ * returned NDArray view may have wrong result, until `CommitCopy` is
+ * explicitly invoked and the data copy stream is synchronized.
+ *
+ * We design this manager class in order to reduce the data copy overhead.
+ */
+class PagedKVCacheAuxDataManager {
+ public:
+  PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, 
TVMStreamHandle copy_stream)
+      : dtype_aux_(dtype_aux), device_(device), copy_stream_(copy_stream) {
+    ICHECK(DataType(dtype_aux) == DataType::Int(32));
+  }
+
+  virtual ~PagedKVCacheAuxDataManager() = default;
+  /*! \brief Copy the indptr array of append lengths after coalescing. (see 
GetChunkedBlockIds) */
+  virtual NDArray CopyQOIndptrOnDepthAsync(std::vector<int32_t>* data, int 
depth) = 0;
+  /*! \brief Copy the indptr array of page table. */
+  virtual NDArray CopyPageIndptrOnDepthAsync(std::vector<int32_t>* data, int 
depth) = 0;
+  /*! \brief Copy the indices array of page table. */
+  virtual NDArray CopyPageIndicesOnDepthAsync(std::vector<int32_t>* data, int 
depth) = 0;
+  /*! \brief Copy the array of KV slot number used in the last page of the 
seq. */
+  virtual NDArray CopyLastPageLenOnDepthAsync(std::vector<int32_t>* data, int 
depth) = 0;
+  /*!
+   * \brief Copy the length information of the sequences.
+   * Each NDArray is in shape `(3, n)`. "n" is the number of sequences.
+   * For a sequence "i", location
+   * - "(0, i)" is the number of KV slots used in the last page of the seq 
("last_page_len"),
+   * - "(1, i)" is the starting offset of the sliding window in the seq,
+   * - "(2, i)" is the attn sink length of the sequence.
+   * \note When sliding window is not enabled, only the
+   * "last_page_len" (a.k.a., the first "n" elements) will be effectively used.
+   */
+  virtual NDArray CopyLengthInfoOnDepthAsync(std::vector<int32_t>* 
last_page_len,
+                                             std::vector<int32_t>* 
sliding_window_offset,
+                                             std::vector<int32_t>* sink_size, 
int depth) = 0;
+  /*! \brief Copy the k position offset of applying RoPE for each sequence. */
+  virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector<int32_t>* data, 
int depth) = 0;
+  /*!
+   * \brief Copy the append length indptr array on device.
+   * \note Since the Q/K/V data may have raggedness in terms of lengths,
+   * we represent the the append lengths in CSR format.
+   */
+  virtual NDArray CopyCurAppendLengthIndptrAsync(std::vector<int32_t>* data) = 
0;
+  /*! \brief Copy the k position offset of applying RoPE for each sequence. */
+  virtual NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector<int32_t>* data) = 
0;
+  /*! \brief Copy the q position mapping of applying RoPE for each sequence. */
+  virtual NDArray CopyQRoPEPosMapAsync(std::vector<int32_t>* data) = 0;
+  /*!
+   * \brief Copy the corresponding position in global KV cache (pages)
+   * for each position along the length dimension of K/V data when
+   * appending new K/V data.
+   */
+  virtual NDArray CopyAppendPositionMapAsync(std::vector<int32_t>* data) = 0;
+  /*! \brief Commit all the copy operations since the last commit. */
+  virtual void CommitCopy() = 0;
+
+ protected:
+  /*! \brief The dtype of the auxiliary data. It is expected to be int32. */
+  const DLDataType dtype_aux_;
+  /*! \brief The device this PagedKVCache runs on. */
+  const Device device_;
+  /*! \brief The device stream for copying auxiliary data structure to GPU. */
+  const TVMStreamHandle copy_stream_;
+};
+
+/*!
+ * \brief The plain auxiliary data manager class.
+ * It simply issues one host-to-device copy operation for each `CopyXXXAsync`.
+ */
+class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
+ public:
+  explicit PlainPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t 
num_total_pages,
+                                           int64_t prefill_chunk_size, 
DLDataType dtype_aux,
+                                           DLDevice device, TVMStreamHandle 
copy_stream)
+      : PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream) {
+    for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
+      qo_indptr_on_depths_device_.push_back(
+          NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
+      page_indptr_on_depths_device_.push_back(
+          NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
+      page_indices_on_depths_device_.push_back(
+          NDArray::Empty({num_total_pages}, dtype_aux_, device));
+      length_info_on_depths_device_.push_back(
+          NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device));
+      k_rope_pos_offset_on_depths_device_.push_back(
+          NDArray::Empty({reserved_num_seqs}, dtype_aux_, device));
+    }
+    cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
+    k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device);
+    q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
+    append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
+  }
+
+  NDArray CopyQOIndptrOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    NDArray view = qo_indptr_on_depths_device_[depth].CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyPageIndptrOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    NDArray view = page_indptr_on_depths_device_[depth].CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyPageIndicesOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    NDArray view = page_indices_on_depths_device_[depth].CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyLastPageLenOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    NDArray view = length_info_on_depths_device_[depth].CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector<int32_t>* data, int 
depth) final {
+    NDArray view = k_rope_pos_offset_on_depths_device_[depth].CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyCurAppendLengthIndptrAsync(std::vector<int32_t>* data) final {
+    NDArray view = 
cur_append_length_indptr_device_.CreateView({static_cast<int64_t>(data->size())},
+                                                               dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector<int32_t>* data) final {
+    NDArray view = 
k_ragged_rope_pos_offset_device_.CreateView({static_cast<int64_t>(data->size())},
+                                                               dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyQRoPEPosMapAsync(std::vector<int32_t>* data) final {
+    NDArray view =
+        
q_rope_position_map_device_.CreateView({static_cast<int64_t>(data->size())}, 
dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyAppendPositionMapAsync(std::vector<int32_t>* data) final {
+    NDArray view =
+        
append_position_map_device_.CreateView({static_cast<int64_t>(data->size())}, 
dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+
+  NDArray CopyLengthInfoOnDepthAsync(std::vector<int32_t>* last_page_len,
+                                     std::vector<int32_t>* 
sliding_window_offset,
+                                     std::vector<int32_t>* sink_size, int 
depth) final {
+    int n_elem = last_page_len->size();
+    ICHECK_GT(n_elem, 0);
+    NDArray view = length_info_on_depths_device_[depth].CreateView({3, 
n_elem}, dtype_aux_);
+    ShapeTuple copy_shape{n_elem};
+    CopyVecDataToArray(view, last_page_len->data(), copy_shape);
+    CopyVecDataToArray(view, sliding_window_offset->data(), copy_shape,
+                       /*dst_elem_offset=*/n_elem);
+    CopyVecDataToArray(view, sink_size->data(), copy_shape,
+                       /*dst_elem_offset=*/2 * n_elem);
+    return view;
+  }
+
+  // The commit of the plain auxiliary data manager is no-op.
+  void CommitCopy() final {}
+
+ private:
+  /*!
+   * \brief Copy a vector of data to the input NDArray.
+   * It optionally supports specifying the shape of copy and the element
+   * offset to the destination NDArray.
+   */
+  void CopyVecDataToArray(NDArray array, int32_t* vec_data, 
Optional<ShapeTuple> shape = NullOpt,
+                          int dst_elem_offset = 0) {
+    if (array->shape[0] == 0) {
+      return;
+    }
+    DLTensor copy_dst = *array.operator->();
+    if (shape.defined()) {
+      ICHECK_EQ(shape.value().size(), 1);
+      copy_dst.ndim = 1;
+      copy_dst.shape = shape.value()->data;
+    }
+    copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t);
+
+    DLTensor copy_src;
+    copy_src.data = vec_data;
+    copy_src.device = Device{kDLCPU, 0};
+    copy_src.ndim = 1;
+    copy_src.dtype = array->dtype;
+    copy_src.shape = copy_dst.shape;
+    copy_src.strides = nullptr;
+    copy_src.byte_offset = 0;
+    NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
+  }
+
+  std::vector<NDArray> qo_indptr_on_depths_device_;
+  std::vector<NDArray> page_indptr_on_depths_device_;
+  std::vector<NDArray> page_indices_on_depths_device_;
+  std::vector<NDArray> length_info_on_depths_device_;
+  std::vector<NDArray> k_rope_pos_offset_on_depths_device_;
+  NDArray cur_append_length_indptr_device_;
+  NDArray k_ragged_rope_pos_offset_device_;
+  NDArray q_rope_position_map_device_;
+  NDArray append_position_map_device_;
+};
+
+/*!
+ * \brief The cached auxiliary data manager class.
+ * It allocates a large on-device array to store all the auxiliary data.
+ * For each `CopyXXXAsync`, it copies the input data to a local cache on host.
+ * In `CommitCopy`, it copies all the data in the local cache to the device
+ * array for a single time, and thus reduce the number of host-to-device 
copies needed.
+ */
+class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
+ public:
+  explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t 
num_total_pages,
+                                            int64_t prefill_chunk_size, 
DLDataType dtype_aux,
+                                            DLDevice device, TVMStreamHandle 
copy_stream)
+      : PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream),
+        elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8),
+        offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) {
+    // - Calculate all the starting offsets of the auxiliary arrays in
+    // local cache and the large on-device array.
+    int64_t total_elems =
+        InitializeArrayElemOffset(reserved_num_seqs, num_total_pages, 
prefill_chunk_size);
+    copy_shape_ = {total_elems};
+    // - Initialize the host auxiliary data buffer.
+    merged_aux_data_host_.resize(total_elems);
+    // - Initialize the device auxiliary data buffer.
+    memory::Allocator* allocator =
+        memory::MemoryManager::GetOrCreateAllocator(device, 
memory::AllocatorType::kNaive);
+    ICHECK_NOTNULL(allocator);
+    merged_aux_data_device_ =
+        memory::Storage(allocator->Alloc(device, {total_elems}, dtype_aux), 
allocator);
+  }
+
+  NDArray CopyQOIndptrOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    return CopyVecToCacheAtOffset(data, depth_offsets_[depth] + 
qo_indptr_in_depth_offset_);
+  }
+  NDArray CopyPageIndptrOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    return CopyVecToCacheAtOffset(data, depth_offsets_[depth] + 
page_indptr_in_depth_offset_);
+  }
+  NDArray CopyPageIndicesOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    return CopyVecToCacheAtOffset(data, depth_offsets_[depth] + 
page_indices_in_depth_offset_);
+  }
+  NDArray CopyLastPageLenOnDepthAsync(std::vector<int32_t>* data, int depth) 
final {
+    return CopyVecToCacheAtOffset(data, depth_offsets_[depth] + 
length_info_in_depth_offset_);
+  }
+  NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector<int32_t>* data, int 
depth) final {
+    return CopyVecToCacheAtOffset(data, depth_offsets_[depth] + 
k_rope_pos_offset_in_depth_offset_);
+  }
+  NDArray CopyCurAppendLengthIndptrAsync(std::vector<int32_t>* data) final {
+    return CopyVecToCacheAtOffset(data, cur_append_length_indptr_offset_);
+  }
+  NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector<int32_t>* data) final {
+    return CopyVecToCacheAtOffset(data, k_ragged_rope_pos_offset_offset_);
+  }
+  NDArray CopyQRoPEPosMapAsync(std::vector<int32_t>* data) final {
+    return CopyVecToCacheAtOffset(data, q_rope_position_map_offset_);
+  }
+  NDArray CopyAppendPositionMapAsync(std::vector<int32_t>* data) final {
+    return CopyVecToCacheAtOffset(data, append_position_map_offset_);
+  }
+  NDArray CopyLengthInfoOnDepthAsync(std::vector<int32_t>* last_page_len,
+                                     std::vector<int32_t>* 
sliding_window_offset,
+                                     std::vector<int32_t>* sink_size, int 
depth) final {
+    int64_t offset = depth_offsets_[depth] + length_info_in_depth_offset_;
+    int64_t n_elem = last_page_len->size();
+    std::memcpy(merged_aux_data_host_.data() + offset, last_page_len->data(),
+                n_elem * elem_byte_size_);
+    std::memcpy(merged_aux_data_host_.data() + offset + n_elem, 
sliding_window_offset->data(),
+                n_elem * elem_byte_size_);
+    std::memcpy(merged_aux_data_host_.data() + offset + 2 * n_elem, 
sink_size->data(),
+                n_elem * elem_byte_size_);
+    return merged_aux_data_device_->AllocNDArray(offset * elem_byte_size_, {3, 
n_elem}, dtype_aux_);
+  }
+
+  void CommitCopy() final {
+    DLTensor copy_dst;
+    copy_dst.data = merged_aux_data_device_->buffer.data;
+    copy_dst.device = device_;
+    copy_dst.ndim = 1;
+    copy_dst.dtype = dtype_aux_;
+    copy_dst.shape = copy_shape_.data();
+    copy_dst.strides = nullptr;
+    copy_dst.byte_offset = 0;
+
+    DLTensor copy_src = copy_dst;
+    copy_src.data = merged_aux_data_host_.data();
+    copy_src.device = Device{kDLCPU, 0};
+    NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
+  }
+
+ private:
+  /*!
+   * \brief Calculate the start element offsets of the auxiliary arrays in the 
local cache.
+   * \return Return the local cache size (total number of elements in the 
local cache).
+   */
+  int64_t InitializeArrayElemOffset(int64_t reserved_num_seqs, int64_t 
num_total_pages,
+                                    int64_t prefill_chunk_size) {
+    // For safety, we align the start offset of the arrays to 
`offset_alignment`.
+    auto f_ceil_div_elem_alignment = [this](int n) {
+      return (n + offset_alignment_ - 1) / offset_alignment_ * 
offset_alignment_;
+    };
+
+    // - Element offsets of the arrays that every depth has.
+    qo_indptr_in_depth_offset_ = 0;
+    page_indptr_in_depth_offset_ =
+        qo_indptr_in_depth_offset_ + 
f_ceil_div_elem_alignment(reserved_num_seqs + 1);
+    page_indices_in_depth_offset_ =
+        page_indptr_in_depth_offset_ + 
f_ceil_div_elem_alignment(reserved_num_seqs + 1);
+    length_info_in_depth_offset_ =
+        page_indices_in_depth_offset_ + 
f_ceil_div_elem_alignment(num_total_pages);
+    k_rope_pos_offset_in_depth_offset_ =
+        length_info_in_depth_offset_ + f_ceil_div_elem_alignment(3 * 
reserved_num_seqs);
+
+    // - Element offsets of each depth.
+    int64_t elem_per_depth =
+        k_rope_pos_offset_in_depth_offset_ + 
f_ceil_div_elem_alignment(reserved_num_seqs);
+    for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
+      depth_offsets_.push_back(d * elem_per_depth);
+    }
+
+    // - Element offsets of other arrays.
+    cur_append_length_indptr_offset_ = kPagedKVCacheMaxBlockDepth * 
elem_per_depth;
+    k_ragged_rope_pos_offset_offset_ =
+        cur_append_length_indptr_offset_ + 
f_ceil_div_elem_alignment(reserved_num_seqs + 1);
+    q_rope_position_map_offset_ =
+        k_ragged_rope_pos_offset_offset_ + 
f_ceil_div_elem_alignment(reserved_num_seqs);
+    append_position_map_offset_ =
+        q_rope_position_map_offset_ + 
f_ceil_div_elem_alignment(prefill_chunk_size);
+
+    // - The total number of elements after alignment.
+    return append_position_map_offset_ + 
f_ceil_div_elem_alignment(prefill_chunk_size);
+  }
+
+  /*!
+   * \brief Copy the input data to the cache at the given offset.
+   * And return the NDArray view of the cache starting at the offset.
+   */
+  NDArray CopyVecToCacheAtOffset(std::vector<int32_t>* data, int64_t offset) {
+    int64_t n_elem = data->size();
+    std::memcpy(merged_aux_data_host_.data() + offset, data->data(), n_elem * 
elem_byte_size_);
+    return merged_aux_data_device_->AllocNDArray(offset * elem_byte_size_, 
{n_elem}, dtype_aux_);
+  }
+
+  const int64_t cuda_byte_alignment_ = 256;
+  const int64_t elem_byte_size_;
+  const int64_t offset_alignment_;
+
+  int64_t qo_indptr_in_depth_offset_;
+  int64_t page_indptr_in_depth_offset_;
+  int64_t page_indices_in_depth_offset_;
+  int64_t length_info_in_depth_offset_;
+  int64_t k_rope_pos_offset_in_depth_offset_;
+  std::vector<int64_t> depth_offsets_;
+  int64_t cur_append_length_indptr_offset_;
+  int64_t k_ragged_rope_pos_offset_offset_;
+  int64_t q_rope_position_map_offset_;
+  int64_t append_position_map_offset_;
+
+  std::vector<int64_t> copy_shape_;
+  std::vector<int32_t> merged_aux_data_host_;
+  memory::Storage merged_aux_data_device_;
+};
+
 /*!
  * \brief The paged KV cache for attention.
  * - It supports managing the K/V data of **multiple sequences**.
@@ -278,41 +657,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   int64_t cur_batch_size_;
   /*! \brief The append lengths of the sequences in the current round of 
forwarding. */
   IntTuple cur_append_lengths_;
-  /*! \brief The indptr array of append lengths after coalescing. (see 
GetChunkedBlockIds) */
-  std::vector<NDArray> qo_indptr_on_depths_device_;
-  /*! \brief The indptr array of page table. */
-  std::vector<NDArray> page_indptr_on_depths_device_;
-  /*! \brief The indices array of page table. */
-  std::vector<NDArray> page_indices_on_depths_device_;
-  /*!
-   * \brief The length information of the sequences.
-   * Each NDArray is in shape `(3, n)`. "n" is the number of sequences.
-   * For a sequence "i", location
-   * - "(0, i)" is the number of KV slots used in the last page of the seq 
("last_page_len"),
-   * - "(1, i)" is the starting offset of the sliding window in the seq,
-   * - "(2, i)" is the attn sink length of the sequence.
-   * \note When sliding window is not enabled, only the
-   * "last_page_len" (a.k.a., the first "n" elements) will be effectively used.
-   */
-  std::vector<NDArray> length_info_on_depths_device_;
-  /*! \brief The k position offset of applying RoPE for each sequence. */
-  std::vector<NDArray> k_rope_pos_offset_device_;
-  /*!
-   * \brief The append length indptr array on device.
-   * \note Since the Q/K/V data may have raggedness in terms of lengths,
-   * we represent the the append lengths in CSR format.
-   */
-  NDArray cur_append_length_indptr_device_;
-  /*! \brief The k position offset of applying RoPE for each sequence. */
-  NDArray k_ragged_rope_pos_offset_device_;
-  /*! \brief The q position mapping of applying RoPE for each sequence. */
-  NDArray q_rope_position_map_device_;
-  /*!
-   * \brief The corresponding position in global KV cache (pages)
-   * for each position along the length dimension of K/V data when
-   * appending new K/V data.
-   */
-  NDArray append_position_map_device_;
+  /*! \brief The auxiliary data manager for attention. */
+  std::unique_ptr<PagedKVCacheAuxDataManager> aux_data_manager_;
 
   // Temporary arrays to store intermediate attention results.
   NDArray temp_attn_q_device_;
@@ -445,15 +791,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, 
head_dim}, dtype, device));
     }
     for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
-      qo_indptr_on_depths_device_.push_back(
-          NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
-      page_indptr_on_depths_device_.push_back(
-          NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
-      page_indices_on_depths_device_.push_back(
-          NDArray::Empty({num_total_pages}, dtype_aux_, device));
-      length_info_on_depths_device_.push_back(
-          NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device));
-      k_rope_pos_offset_device_.push_back(NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device));
       temp_attn_workspace_.push_back(
           NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), 
device));
       qo_indptr_on_depths_view_.push_back(NDArray());
@@ -465,10 +802,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // Additional workspace for the "prefill with ragged kv" kernel.
     temp_attn_workspace_.push_back(
         NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
-    cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
-    k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device);
-    q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size_}, 
dtype_aux_, device);
-    append_position_map_device_ = NDArray::Empty({prefill_chunk_size_}, 
dtype_aux_, device);
 
     temp_attn_q_device_ =
         NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, 
device);
@@ -494,6 +827,17 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
       copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
     }
+
+    // Create the auxiliary data manager for attention.
+    // We only use the merged aux data for CUDA, since direct pointer
+    // operations may have issues on other platforms.
+    if (device_.device_type == DLDeviceType::kDLCUDA) {
+      aux_data_manager_ = std::make_unique<CachedPagedKVCacheAuxDataManager>(
+          reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, 
device, copy_stream_);
+    } else {
+      aux_data_manager_ = std::make_unique<PlainPagedKVCacheAuxDataManager>(
+          reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, 
device, copy_stream_);
+    }
   }
 
   ~PagedAttentionKVCacheObj() {
@@ -636,9 +980,17 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   }
 
   void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t 
copy_length) {
+    if (copy_stream_ != compute_stream_) {
+      // Set the copy stream for copy.
+      DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
+    }
     for (int layer = 0; layer < num_layers_; ++layer) {
       f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, 
copy_length);
     }
+    if (copy_stream_ != compute_stream_) {
+      // Set the compute stream back.
+      DeviceAPI::Get(device_)->SetStream(device_, compute_stream_);
+    }
   }
 
   void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
@@ -959,8 +1311,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         append_position_map.push_back(page_id * page_size_ + page_offset);
       }
     }
-    NDArray position_map_device =
-        NDArray::Empty({end_pos - start_pos}, dtype_aux_, 
cur_append_length_indptr_device_->device);
+    NDArray position_map_device = NDArray::Empty({end_pos - start_pos}, 
dtype_aux_, device_);
     position_map_device.CopyFromBytes(
         append_position_map.data() + start_pos,
         (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 
8));
@@ -1319,32 +1670,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, 
compute_stream_);
   }
 
-  /*!
-   * \brief Copy a vector of data to the input NDArray.
-   * It optionally supports specifying the shape of copy and the element
-   * offset to the destination NDArray.
-   */
-  void CopyVecDataToArray(NDArray array, int32_t* vec_data, 
Optional<ShapeTuple> shape = NullOpt,
-                          int dst_elem_offset = 0) {
-    DLTensor copy_dst = *array.operator->();
-    if (shape.defined()) {
-      ICHECK_EQ(shape.value().size(), 1);
-      copy_dst.ndim = 1;
-      copy_dst.shape = shape.value()->data;
-    }
-    copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t);
-
-    DLTensor copy_src;
-    copy_src.data = vec_data;
-    copy_src.device = Device{kDLCPU, 0};
-    copy_src.ndim = 1;
-    copy_src.dtype = array->dtype;
-    copy_src.shape = copy_dst.shape;
-    copy_src.strides = nullptr;
-    copy_src.byte_offset = 0;
-    NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
-  }
-
   /*!
    * \brief Synchronize auxiliary arrays to device.
    * \note This method resets the dirty flag to false, and needs to be
@@ -1369,29 +1694,21 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
     // 1. qo_indptr_on_depths
     for (int d = 0; d < num_depths_; ++d) {
-      qo_indptr_on_depths_view_[d] = qo_indptr_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(qo_indptr_on_depths_host_[d].size())}, 
dtype_aux_);
-      CopyVecDataToArray(qo_indptr_on_depths_view_[d], 
qo_indptr_on_depths_host_[d].data());
+      qo_indptr_on_depths_view_[d] =
+          
aux_data_manager_->CopyQOIndptrOnDepthAsync(&qo_indptr_on_depths_host_[d], d);
     }
-
     // 2. page_indptr_on_depths
     for (int d = 0; d < num_depths_; ++d) {
       ICHECK_EQ(page_indptr_on_depths_host_[d].size(), 
qo_indptr_on_depths_host_[d].size());
-      page_indptr_on_depths_view_[d] = 
page_indptr_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(page_indptr_on_depths_host_[d].size())}, 
dtype_aux_);
-      CopyVecDataToArray(page_indptr_on_depths_view_[d], 
page_indptr_on_depths_host_[d].data());
+      page_indptr_on_depths_view_[d] =
+          
aux_data_manager_->CopyPageIndptrOnDepthAsync(&page_indptr_on_depths_host_[d], 
d);
     }
-
     // 3. page_indices_on_depths
     for (int d = 0; d < num_depths_; ++d) {
       ICHECK_EQ(page_indices_on_depths_host_[d].size(), 
page_indptr_on_depths_host_[d].back());
-      page_indices_on_depths_view_[d] = 
page_indices_on_depths_device_[d].CreateView(
-          {static_cast<int64_t>(page_indices_on_depths_host_[d].size())}, 
dtype_aux_);
-      if (!page_indices_on_depths_host_[d].empty()) {
-        CopyVecDataToArray(page_indices_on_depths_view_[d], 
page_indices_on_depths_host_[d].data());
-      }
+      page_indices_on_depths_view_[d] =
+          
aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_on_depths_host_[d],
 d);
     }
-
     // 4. length_info_on_depths
     // last_page_len_on_depths_host_;
     // sliding_window_offset_on_depths_host_;
@@ -1404,54 +1721,34 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       if (!support_sliding_window_) {
         // Sliding window is not enabled, so we first copy "last_page_len".
         length_info_on_depths_view_[d] =
-            length_info_on_depths_device_[d].CreateView({num_seq_on_layer}, 
dtype_aux_);
-        CopyVecDataToArray(length_info_on_depths_view_[d], 
last_page_len_on_depths_host_[d].data());
+            
aux_data_manager_->CopyLastPageLenOnDepthAsync(&last_page_len_on_depths_host_[d],
 d);
       } else {
         // Sliding window is enabled,
-        length_info_on_depths_view_[d] =
-            length_info_on_depths_device_[d].CreateView({3, num_seq_on_layer}, 
dtype_aux_);
-        ShapeTuple copy_shape{num_seq_on_layer};
-        CopyVecDataToArray(length_info_on_depths_view_[d], 
last_page_len_on_depths_host_[d].data(),
-                           copy_shape);
-        CopyVecDataToArray(length_info_on_depths_view_[d],
-                           sliding_window_offset_on_depths_host_[d].data(), 
copy_shape,
-                           /*dst_elem_offset=*/num_seq_on_layer);
-        CopyVecDataToArray(length_info_on_depths_view_[d], 
sink_size_on_depths_host_[d].data(),
-                           copy_shape, /*dst_elem_offset=*/2 * 
num_seq_on_layer);
+        length_info_on_depths_view_[d] = 
aux_data_manager_->CopyLengthInfoOnDepthAsync(
+            &last_page_len_on_depths_host_[d], 
&sliding_window_offset_on_depths_host_[d],
+            &sink_size_on_depths_host_[d], d);
       }
     }
-
     // 5. k_rope_pos_offset_on_depths
     for (int d = 0; d < num_depths_; ++d) {
       ICHECK_EQ(k_rope_pos_offset_on_depths_host_[d].size() + 1,
                 qo_indptr_on_depths_host_[d].size());
-      k_rope_pos_offset_view_[d] = k_rope_pos_offset_device_[d].CreateView(
-          {static_cast<int64_t>(k_rope_pos_offset_on_depths_host_[d].size())}, 
dtype_aux_);
-      CopyVecDataToArray(k_rope_pos_offset_view_[d], 
k_rope_pos_offset_on_depths_host_[d].data());
+      k_rope_pos_offset_view_[d] = 
aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync(
+          &k_rope_pos_offset_on_depths_host_[d], d);
     }
-
     // 6. cur_append_lengths_indptr
     cur_append_length_indptr_view_ =
-        cur_append_length_indptr_device_.CreateView({num_sequences + 1}, 
dtype_aux_);
-    CopyVecDataToArray(cur_append_length_indptr_view_, 
cur_append_lengths_indptr_host_.data());
-
+        
aux_data_manager_->CopyCurAppendLengthIndptrAsync(&cur_append_lengths_indptr_host_);
     // 7. k_ragged_rope_pos_offset
     ICHECK_EQ(k_ragged_rope_pos_offset_host_.size(), num_sequences);
     k_ragged_rope_pos_offset_view_ =
-        k_ragged_rope_pos_offset_device_.CreateView({num_sequences}, 
dtype_aux_);
-    CopyVecDataToArray(k_ragged_rope_pos_offset_view_, 
k_ragged_rope_pos_offset_host_.data());
-
+        
aux_data_manager_->CopyKRaggedRoPEPosOffsetAsync(&k_ragged_rope_pos_offset_host_);
     // 8. q_rope_position_map
     ICHECK_EQ(q_rope_position_map_host_.size(), total_append_length);
-    q_rope_position_map_view_ =
-        q_rope_position_map_device_.CreateView({total_append_length}, 
dtype_aux_);
-    CopyVecDataToArray(q_rope_position_map_view_, 
q_rope_position_map_host_.data());
-
+    q_rope_position_map_view_ = 
aux_data_manager_->CopyQRoPEPosMapAsync(&q_rope_position_map_host_);
     // 9. append_position_map
     append_position_map_view_ =
-        append_position_map_device_.CreateView({total_append_length}, 
dtype_aux_);
-    CopyVecDataToArray(append_position_map_view_, 
append_position_map_host_.data());
-
+        
aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_);
     // 10. Create view for temporary arrays for attention computation.
     temp_attn_output_view_ = temp_attn_output_device_.CreateView(
         {total_append_length, num_qo_heads_, head_dim_}, 
temp_attn_output_device_->dtype);
@@ -1460,6 +1757,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     merged_attn_scores_view_ = merged_attn_scores_device_.CreateView(
         {total_append_length, num_qo_heads_}, 
merged_attn_scores_device_->dtype);
 
+    // - Commit the copy.
+    aux_data_manager_->CommitCopy();
     // - Reset the dirty flag to false.
     dirty_aux_data_device_ = false;
   }
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index c71b0dde3e..4823e9b243 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -80,11 +80,13 @@ def kv_cache_transpose_append(
     ntoken = T.SizeVar("ntoken", "int64")
     page_size = T.SizeVar("page_size", "int64")
     num_pages = T.int64()
-
+    position_map_elem_offset = T.int32()
     pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, 
head_dim), dtype)
     k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), 
dtype)
     v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), 
dtype)
-    position_map = T.match_buffer(var_position_map, (ntoken,), "int32")
+    position_map = T.match_buffer(
+        var_position_map, (ntoken,), "int32", 
elem_offset=position_map_elem_offset
+    )
 
     for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim):
         with T.block("k_transpose_append"):
@@ -161,11 +163,14 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
             }
         )
         seq_len = T.int64()
+        position_map_elem_offset = T.int64()
         qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
         q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
         k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
         v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
-        position_map = T.match_buffer(var_position_map, (seq_len,), "int32")
+        position_map = T.match_buffer(
+            var_position_map, (seq_len,), "int32", 
elem_offset=position_map_elem_offset
+        )
         for iters in T.grid(seq_len, fused_heads, head_dim):
             with T.block("llama_fused_rope"):
                 s, h, d = T.axis.remap("SSS", iters)
@@ -200,9 +205,11 @@ def copy_cache(
     seqlen = T.SizeVar("seqlen", "int64")
     page_size = T.int64()
     num_pages = T.int64()
-
+    position_map_elem_offset = T.int64()
     pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, 
head_dim), "float16")
-    position_map = T.match_buffer(var_position_map, (seqlen,), "int32")
+    position_map = T.match_buffer(
+        var_position_map, (seqlen,), "int32", 
elem_offset=position_map_elem_offset
+    )
     k_data = T.match_buffer(var_k_data, (num_layers, seqlen, num_kv_heads, 
head_dim), "float16")
     v_data = T.match_buffer(var_v_data, (num_layers, seqlen, num_kv_heads, 
head_dim), "float16")
 
@@ -665,7 +672,7 @@ def 
test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
     cached_v = {}
     batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
     apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-    apply_attention(kv_cache, rope_mode, [((4, 3), 35)], cached_k, cached_v)
+    apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, 
cached_v)
 
     popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 19)]
     for seq_id, pop_length in popn_operations:
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 3ed89ecd0f..f7b01bb840 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -607,11 +607,13 @@ def kv_cache_transpose_append(head_dim, dtype):
     ):
         ntoken = T.SizeVar("ntoken", "int32")
         num_pages = T.int32()
-
+        position_map_elem_offset = T.int32()
         pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, 16, 
head_dim), dtype)
         k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), 
dtype)
         v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), 
dtype)
-        position_map = T.match_buffer(var_position_map, (ntoken,), "int32")
+        position_map = T.match_buffer(
+            var_position_map, (ntoken,), "int32", 
elem_offset=position_map_elem_offset
+        )
 
         for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim):
             if position_map[global_pos] != T.int32(-1):
@@ -649,9 +651,11 @@ def copy_cache(head_dim, dtype):
         seqlen = T.SizeVar("seqlen", "int64")
         page_size = T.int64()
         num_pages = T.int64()
-
+        position_map_elem_offset = T.int64()
         pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, 
page_size, head_dim), dtype)
-        position_map = T.match_buffer(var_position_map, (seqlen,), "int32")
+        position_map = T.match_buffer(
+            var_position_map, (seqlen,), "int32", 
elem_offset=position_map_elem_offset
+        )
         k_data = T.match_buffer(var_k_data, (num_layers, seqlen, num_kv_heads, 
head_dim), dtype)
         v_data = T.match_buffer(var_v_data, (num_layers, seqlen, num_kv_heads, 
head_dim), dtype)
 
@@ -727,11 +731,14 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
             }
         )
         seq_len = T.int64()
+        position_map_elem_offset = T.int64()
         qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
         q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
         k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
         v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
-        position_map = T.match_buffer(var_position_map, (seq_len,), "int32")
+        position_map = T.match_buffer(
+            var_position_map, (seq_len,), "int32", 
elem_offset=position_map_elem_offset
+        )
         for iters in T.grid(seq_len, fused_heads, head_dim):
             with T.block("llama_fused_rope"):
                 s, h, d = T.axis.remap("SSS", iters)
@@ -819,11 +826,11 @@ def _causal_mask(causal, row, col, kv_len, qo_len):
     )
 
 
-def _declare_length_info(var_length_info, batch_size, sliding_window):
+def _declare_length_info(var_length_info, batch_size, sliding_window, 
elem_offset):
     return (
-        T.match_buffer(var_length_info, (3, batch_size), "int32")
+        T.match_buffer(var_length_info, (3, batch_size), "int32", 
elem_offset=elem_offset)
         if sliding_window
-        else T.match_buffer(var_length_info, (batch_size,), "int32")
+        else T.match_buffer(var_length_info, (batch_size,), "int32", 
elem_offset=elem_offset)
     )
 
 
@@ -912,14 +919,20 @@ def _attention_prefill(
         total_len = T.int32(is_size_var=True)
         nnz_pages = T.int32(is_size_var=True)
         max_num_pages = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        page_indptr_elem_offset = T.int32(is_size_var=True)
+        page_values_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        length_info_elem_offset = T.int32(is_size_var=True)
 
         q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
-        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32")
+        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
         pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), 
dtype)
-        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32")
-        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32")
-        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32")
-        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), 
"int32")
+        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), 
"int32", elem_offset=page_indptr_elem_offset)
+        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", 
elem_offset=page_values_elem_offset)
+        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
+        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
         output = T.match_buffer(var_output, (total_len, h_q, d), dtype)
         lse = T.match_buffer(var_lse, (total_len, h_q), "float32")  # pylint: 
disable=unused-variable
         # The length information of the sequences.
@@ -930,7 +943,7 @@ def _attention_prefill(
         #   - "(2, i)" is the attn sink length of the sequence.
         # - It is in shape `(batch_size,)` when sliding window is disabled,
         #   denoting the "last_page_len".
-        length_info = _declare_length_info(var_length_info, batch_size, 
sliding_window)
+        length_info = _declare_length_info(var_length_info, batch_size, 
sliding_window, length_info_elem_offset)
 
         # kernel code
         for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
@@ -1273,15 +1286,20 @@ def _attention_decode(
         B = T.int32(is_size_var=True)
         nnz_pages = T.int32(is_size_var=True)
         max_num_pages = T.int32(is_size_var=True)
+        page_indptr_elem_offset = T.int32(is_size_var=True)
+        page_values_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        length_info_elem_offset = T.int32(is_size_var=True)
 
         Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
         pages = T.match_buffer(
             pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
         )
-        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), 
"int32")
-        page_table_values = T.match_buffer(page_table_values_handle, 
(nnz_pages,), "int32")
-        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), 
"int32")
-        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32")
+        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), 
"int32", elem_offset=page_indptr_elem_offset)
+        page_table_values = T.match_buffer(page_table_values_handle, 
(nnz_pages,), "int32", elem_offset=page_values_elem_offset)
+        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), 
"int32", elem_offset=k_rope_pos_offset_elem_offset)
+        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), 
"int32", elem_offset=q_rope_position_elem_offset)
         output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype)
         lse = T.match_buffer(lse_handle, (B, H_qo), "float32")  # pylint: 
disable=unused-variable
         # The length information of the sequences.
@@ -1292,7 +1310,7 @@ def _attention_decode(
         #   - "(2, i)" is the attn sink length of the sequence.
         # - It is in shape `(batch_size,)` when sliding window is disabled,
         #   denoting the "last_page_len".
-        length_info = _declare_length_info(var_length_info, B, sliding_window)
+        length_info = _declare_length_info(var_length_info, B, sliding_window, 
length_info_elem_offset)
 
         sm_scale = 1.0 / math.sqrt(float(D)) * log2e
 
@@ -1515,14 +1533,18 @@ def _attention_prefill_ragged(
         batch_size = T.int32(is_size_var=True)
         qo_len = T.int32(is_size_var=True)
         kv_len = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        kv_indptr_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
 
         q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
-        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32")
+        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
         k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
         v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
-        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32")
-        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), 
"int32")
-        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32")
+        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", 
elem_offset=kv_indptr_elem_offset)
+        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
+        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
         output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
         lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: 
disable=unused-variable
 


Reply via email to