This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 567eeed38b [Runtime][Dist] Implementation of KV cache transfer (#17557)
567eeed38b is described below
commit 567eeed38bdbcefb68e36328af6ab1501a81d51e
Author: Hongyi Jin <[email protected]>
AuthorDate: Sun Dec 15 09:56:01 2024 -0500
[Runtime][Dist] Implementation of KV cache transfer (#17557)
This PR introduces kv transfer kernel and KV cache integration used
in prefill-decode disaggregation.
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Charlie Ruan
<[email protected]>
Co-authored-by: Yingcheng Wang
<[email protected]>
---
3rdparty/flashinfer | 2 +-
CMakeLists.txt | 4 +-
docs/how_to/tutorials/optimize_llm.py | 1 +
python/tvm/relax/frontend/nn/llm/kv_cache.py | 33 +-
src/runtime/contrib/nvshmem/init.cc | 58 ++-
src/runtime/contrib/nvshmem/kv_transfer.cu | 333 ++++++++++++
src/runtime/contrib/nvshmem/memory_allocator.cc | 3 +-
src/runtime/disco/nccl/nccl.cc | 10 +-
src/runtime/relax_vm/kv_state.cc | 4 +
src/runtime/relax_vm/kv_state.h | 8 +
src/runtime/relax_vm/paged_kv_cache.cc | 385 +++++++++++++-
tests/python/disco/test_nvshmem.py | 4 +-
.../test_runtime_builtin_kv_cache_transfer.py} | 565 +++++----------------
...est_runtime_builtin_kv_cache_transfer_kernel.py | 252 +++++++++
...runtime_builtin_paged_attention_kv_cache_tir.py | 1 +
15 files changed, 1198 insertions(+), 465 deletions(-)
diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index 1e379898a5..a76ceedb94 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 1e379898a589cdd4ff18a4621fcbe18d63501545
+Subproject commit a76ceedb9495d3d05648c29a8e6bb45baa265f6c
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8abdfad24c..757b0d1a89 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -478,7 +478,9 @@ if (USE_CUDA AND USE_NVSHMEM)
if (NOT NVSHMEM_FOUND)
message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM})
endif()
- tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
+ set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+ tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc
src/runtime/contrib/nvshmem/*.cu)
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
endif()
diff --git a/docs/how_to/tutorials/optimize_llm.py
b/docs/how_to/tutorials/optimize_llm.py
index 9311c0557f..c30b2c381c 100644
--- a/docs/how_to/tutorials/optimize_llm.py
+++ b/docs/how_to/tutorials/optimize_llm.py
@@ -303,6 +303,7 @@ class LlamaForCasualLM(nn.Module):
rotary_dim=self.head_dim,
dtype=self.dtype,
target=target,
+ enable_disaggregation=False,
)
def get_default_spec(self):
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 18f3e19909..f60c40efa2 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -169,6 +169,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
rope_scaling: Dict[str, Any],
rope_ext_factors: rx.Expr,
rotary_dim: int,
+ enable_disaggregation: bool,
dtype: str,
target: Target,
name: str = "paged_kv_cache",
@@ -214,6 +215,8 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
The RoPE extension factors when "longrope" mode RoPE scaling is
enabled.
rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to.
+ enable_disaggregation : bool
+ Whether to enable disaggregation in the KV cache.
"""
if rope_mode == RopeMode.INLINE:
assert rotary_dim == head_dim, "FlashInfer RoPE does not support
partial rotary dim."
@@ -259,6 +262,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
+ rx.PrimValue(enable_disaggregation),
# fmt: on
# pylint: enable=line-too-long
]
@@ -293,6 +297,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
rope_scaling: Dict[str, Any],
rope_ext_factors: rx.Expr,
rotary_dim: int,
+ enable_disaggregation: bool,
dtype: str,
target: Target,
name: str = "paged_kv_cache",
@@ -338,6 +343,8 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
The RoPE extension factors when "longrope" mode RoPE scaling is
enabled.
rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to.
+ enable_disaggregation : bool
+ Whether to enable disaggregation in the KV cache.
target : Target
The target to build the model to.
"""
@@ -377,6 +384,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads,
head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads,
num_attention_heads, head_dim, dtype, rope_scaling, target),
"tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
+ rx.PrimValue(enable_disaggregation),
# fmt: on
# pylint: enable=line-too-long
]
@@ -409,8 +417,9 @@ def _kv_cache_transpose_append(num_key_value_heads,
head_dim, dtype):
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
num_pages = T.int64()
+ pages_elem_offset = T.int64()
position_map_elem_offset = T.int32()
- pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads,
16, head_dim), dtype)
+ pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads,
16, head_dim), dtype, elem_offset=pages_elem_offset)
k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads,
head_dim), dtype)
v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads,
head_dim), dtype)
position_map = T.match_buffer(
@@ -453,8 +462,9 @@ def _kv_cache_debug_get_kv(num_hidden_layers,
num_key_value_heads, head_dim, dty
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
+ pages_elem_offset = T.int64()
position_map_elem_offset = T.int64()
- pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads,
page_size, head_dim), dtype)
+ pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads,
page_size, head_dim), dtype,elem_offset=pages_elem_offset)
position_map = T.match_buffer(
var_position_map, (seqlen,), "int32",
elem_offset=position_map_elem_offset
)
@@ -594,6 +604,7 @@ def _attention_prefill(
total_len = T.int32(is_size_var=True)
nnz_pages = T.int32(is_size_var=True)
max_num_pages = T.int32(is_size_var=True)
+ pages_elem_offset = T.int64(is_size_var=True)
q_indptr_elem_offset = T.int32(is_size_var=True)
page_indptr_elem_offset = T.int32(is_size_var=True)
page_values_elem_offset = T.int32(is_size_var=True)
@@ -603,7 +614,7 @@ def _attention_prefill(
q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32",
elem_offset=q_indptr_elem_offset)
- pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d),
dtype)
+ pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d),
dtype, elem_offset=pages_elem_offset)
page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,),
"int32", elem_offset=page_indptr_elem_offset)
page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32",
elem_offset=page_values_elem_offset)
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset,
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
@@ -975,6 +986,7 @@ def _attention_decode(
B = T.int32(is_size_var=True)
nnz_pages = T.int32(is_size_var=True)
max_num_pages = T.int32(is_size_var=True)
+ pages_elem_offset = T.int64(is_size_var=True)
page_indptr_elem_offset = T.int32(is_size_var=True)
page_values_elem_offset = T.int32(is_size_var=True)
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
@@ -983,7 +995,7 @@ def _attention_decode(
Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
pages = T.match_buffer(
- pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
+ pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype,
elem_offset=pages_elem_offset
)
page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,),
"int32", elem_offset=page_indptr_elem_offset)
page_table_values = T.match_buffer(page_table_values_handle,
(nnz_pages,), "int32", elem_offset=page_values_elem_offset)
@@ -1949,7 +1961,13 @@ def _copy_single_page(num_heads, page_size, head_dim,
dtype, target: Target):
):
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
- pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size,
head_dim), dtype)
+ pages_elem_offset = T.int64()
+ pages = T.match_buffer(
+ var_pages,
+ (num_pages, 2, num_heads, page_size, head_dim),
+ dtype,
+ elem_offset=pages_elem_offset,
+ )
for b in T.thread_binding(
(copy_length * num_heads * head_dim + tx - 1) // tx,
thread="blockIdx.x"
@@ -1993,7 +2011,10 @@ def _compact_kv_copy(num_heads, head_dim, dtype, target:
Target):
total_copy_length = T.int32()
copy_length_indptr_elem_offset = T.int32()
copy_src_dst_pos_elem_offset = T.int32()
- pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16,
head_dim), dtype)
+ pages_elem_offset = T.int64()
+ pages = T.match_buffer(
+ var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype,
elem_offset=pages_elem_offset
+ )
copy_length_indptr = T.match_buffer(
var_copy_length_indptr,
(batch_size + 1,),
diff --git a/src/runtime/contrib/nvshmem/init.cc
b/src/runtime/contrib/nvshmem/init.cc
index 50fdde4c49..33a787b5b9 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -18,6 +18,7 @@
*/
#include <nvshmem.h>
#include <nvshmemx.h>
+#include <picojson.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
@@ -38,9 +39,14 @@ ShapeTuple InitNVSHMEMUID() {
return ShapeTuple(uid_64);
}
-void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
- DiscoWorker* worker = DiscoWorker::ThreadLocal();
- ICHECK(worker != nullptr);
+void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
+ DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
+ int worker_id;
+ if (worker == nullptr) {
+ worker_id = worker_id_start;
+ } else {
+ worker_id = worker_id_start + worker->worker_id;
+ }
CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
<< "ValueError: The length of unique_id must be " << UNIQUEID_PADDING <<
", but got "
<< uid_64.size() << ".";
@@ -52,17 +58,61 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
}
- nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
+ // FIXME: this is a hack to avoid the issue of NVSHMEM using
Multi-process-per-GPU to initialize
+ cudaSetDevice(worker_id);
+ nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
CUDA_CALL(cudaSetDevice(mype_node));
+ if (worker != nullptr) {
+ if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
+ worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
+ } else {
+ ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
+ worker->default_device.device_id == mype_node)
+ << "The default device of the worker is inconsistent with the device
used for NVSHMEM. "
+ << "The default device is " << worker->default_device
+ << ", but the device used for NVSHMEM is " <<
Device{DLDeviceType::kDLCUDA, mype_node}
+ << ".";
+ }
+ }
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}
+void InitNVSHMEMWrapper(String args) {
+ picojson::value v;
+ std::string err = picojson::parse(v, args);
+ if (!err.empty()) {
+ LOG(FATAL) << "JSON parse error: " << err;
+ }
+
+ if (!v.is<picojson::object>()) {
+ LOG(FATAL) << "JSON is not an object";
+ }
+
+ picojson::object& obj = v.get<picojson::object>();
+
+ picojson::array uid_array = obj["uid"].get<picojson::array>();
+ std::vector<int64_t> uid_vector;
+ for (const auto& elem : uid_array) {
+ uid_vector.push_back(elem.get<int64_t>());
+ }
+
+ ShapeTuple uid_64(uid_vector);
+
+ int num_workers = static_cast<int>(obj["npes"].get<int64_t>());
+ int worker_id_start = static_cast<int>(obj["pe_start"].get<int64_t>());
+
+ InitNVSHMEM(uid_64, num_workers, worker_id_start);
+}
+
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);
+TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
+ .set_body_typed(InitNVSHMEMWrapper);
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu
b/src/runtime/contrib/nvshmem/kv_transfer.cu
new file mode 100644
index 0000000000..cf3a9958f8
--- /dev/null
+++ b/src/runtime/contrib/nvshmem/kv_transfer.cu
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cuda_fp16.h>
+#include <dlpack/dlpack.h>
+#include <nvshmem.h>
+#include <tvm/runtime/disco/disco_worker.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/registry.h>
+
+template <int dim>
+__device__ int64_t calc_flattened_index(int shape[dim], int index[dim]) {
+ int64_t flattened_index = 0;
+#pragma unroll
+ for (int i = 0; i < dim; i++) {
+ flattened_index *= shape[i];
+ flattened_index += index[i];
+ }
+ return flattened_index;
+}
+
+template <typename T, int local_num_kv_head, int remote_num_kv_head, int
head_dim, int page_size>
+__global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t*
remote_position_map,
+ int ntokens, int local_tp_rank, int32_t*
remote_tp_group_pe_offset,
+ int remote_num_pages) {
+ // launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1]
+ // pages(remote): [remote_num_pages, 2, remote_num_kv_head, page_size,
head_dim]
+ // k_data: [ntokens, local_num_kv_head, head_dim]
+ // v_data: [ntokens, local_num_kv_head, head_dim]
+ int remote_pe;
+ int remote_kv_head_index;
+ int h = threadIdx.y; // local kv head index
+
+ for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos +=
gridDim.x) {
+ int position = remote_position_map[global_pos];
+ if (position == -1) {
+ continue;
+ }
+ if (local_num_kv_head <= remote_num_kv_head) {
+ // gather
+ assert(remote_num_kv_head % local_num_kv_head == 0);
+ int gather_factor = remote_num_kv_head / local_num_kv_head;
+ remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank /
gather_factor;
+ remote_kv_head_index = (local_tp_rank % gather_factor) *
local_num_kv_head + h;
+ } else {
+ // scatter
+ assert(local_num_kv_head % remote_num_kv_head == 0);
+ int scatter_factor = local_num_kv_head / remote_num_kv_head;
+ remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank *
scatter_factor +
+ h / remote_num_kv_head;
+ remote_kv_head_index = h % remote_num_kv_head;
+ }
+ int page_id = position / page_size;
+ int offset_in_page = position % page_size;
+ int pages_shape[5] = {remote_num_pages, 2, remote_num_kv_head, page_size,
head_dim};
+ int k_page_index[5] = {page_id, 0, remote_kv_head_index, offset_in_page,
0};
+ int v_page_index[5] = {page_id, 1, remote_kv_head_index, offset_in_page,
0};
+ int k_v_shape[3] = {ntokens, local_num_kv_head, head_dim};
+ int k_v_index[3] = {global_pos, h, 0};
+ nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<5>(pages_shape,
k_page_index),
+ k_data + calc_flattened_index<3>(k_v_shape,
k_v_index),
+ head_dim * sizeof(T), remote_pe);
+ nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<5>(pages_shape,
v_page_index),
+ v_data + calc_flattened_index<3>(k_v_shape,
k_v_index),
+ head_dim * sizeof(T), remote_pe);
+ }
+ if (threadIdx.x == 0) {
+ nvshmem_quiet();
+ }
+}
+template <typename T, int local_num_kv_head, int remote_num_kv_head, int
head_dim, int page_size>
+__global__ void KVTransferPageToPage(T* remote_pages, T* local_pages, int32_t*
remote_position_map,
+ int32_t* local_position_map, int ntokens,
int local_tp_rank,
+ int32_t* remote_tp_group_pe_offset) {
+ // launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1]
+ int remote_pe;
+ int remote_kv_head_index;
+ int h = threadIdx.y; // local kv head index
+ int is_k = threadIdx.z;
+
+ for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos +=
gridDim.x) {
+ int remote_position = remote_position_map[global_pos];
+ int local_position = local_position_map[global_pos];
+ if (remote_position == -1 || local_position == -1) {
+ continue;
+ }
+ if (local_num_kv_head <= remote_num_kv_head) {
+ // gather
+ assert(remote_num_kv_head % local_num_kv_head == 0);
+ int gather_factor = remote_num_kv_head / local_num_kv_head;
+ remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank /
gather_factor;
+ remote_kv_head_index = (local_tp_rank % gather_factor) *
local_num_kv_head + h;
+ } else {
+ // scatter
+ assert(local_num_kv_head % remote_num_kv_head == 0);
+ int scatter_factor = local_num_kv_head / remote_num_kv_head;
+ remote_pe = remote_tp_group_pe_offset[global_pos] + local_tp_rank *
scatter_factor +
+ h / remote_num_kv_head;
+ remote_kv_head_index = h % remote_num_kv_head;
+ }
+
+ int remote_page_id = remote_position / page_size;
+ int remote_offset_in_page = remote_position % page_size;
+ int local_page_id = local_position / page_size;
+ int local_offset_in_page = local_position % page_size;
+ int remote_pages_shape[5] = {1, 2, remote_num_kv_head, page_size,
head_dim};
+ int local_pages_shape[5] = {1, 2, local_num_kv_head, page_size, head_dim};
+ int remote_page_index[5] = {remote_page_id, is_k, remote_kv_head_index,
remote_offset_in_page,
+ 0};
+ int local_page_index[5] = {local_page_id, is_k, h, local_offset_in_page,
0};
+ nvshmemx_putmem_nbi_warp(
+ remote_pages + calc_flattened_index<5>(remote_pages_shape,
remote_page_index),
+ local_pages + calc_flattened_index<5>(local_pages_shape,
local_page_index),
+ head_dim * sizeof(T), remote_pe);
+ }
+ if (threadIdx.x == 0) {
+ nvshmem_quiet();
+ }
+}
+
+#define DISPATCH_TVM_CUDA_DTYPE(dl_dtype, cuda_dtype, ...) \
+ if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { \
+ using cuda_dtype = half; \
+ __VA_ARGS__ \
+ } else { \
+ LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \
+ }
+
+#define DISPATCH_HEAD_DIM(head_dim, const_head_dim, ...) \
+ if (head_dim == 128) { \
+ constexpr int const_head_dim = 128; \
+ __VA_ARGS__ \
+ } else { \
+ LOG(FATAL) << "Unsupported head dim " << head_dim; \
+ }
+
+#define DISPATCH_PAGE_SIZE(page_size, const_page_size, ...) \
+ if (page_size == 16) { \
+ constexpr int const_page_size = 16; \
+ __VA_ARGS__ \
+ } else if (page_size == 4) { \
+ constexpr int const_page_size = 4; \
+ __VA_ARGS__ \
+ } else { \
+ LOG(FATAL) << "Unsupported page size " << page_size; \
+ }
+
+#define DISPATCH_NUM_KV_HEAD(num_kv_head, const_num_kv_head, ...) \
+ if (num_kv_head == 1) { \
+ constexpr int const_num_kv_head = 1; \
+ __VA_ARGS__ \
+ } else if (num_kv_head == 2) { \
+ constexpr int const_num_kv_head = 2; \
+ __VA_ARGS__ \
+ } else if (num_kv_head == 4) { \
+ constexpr int const_num_kv_head = 4; \
+ __VA_ARGS__ \
+ } else if (num_kv_head == 8) { \
+ constexpr int const_num_kv_head = 8; \
+ __VA_ARGS__ \
+ } else { \
+ LOG(FATAL) << "Unsupported num_kv_head " << num_kv_head; \
+ }
+
+int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor*
remote_position_map,
+ DLTensor* remote_tp_group_pe_offset, TVMStreamHandle
transfer_stream) {
+ CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+ << "The device of remote_pages matrix must be CUDA.";
+ CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be
CUDA.";
+ CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be
CUDA.";
+ CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+ << "The device of remote_position_map matrix must be CUDA.";
+ size_t dev_id = remote_pages->device.device_id;
+ CHECK_EQ(k->device.device_id, dev_id)
+ << "The device id of remote_pages and k matrix doesn't match.";
+ CHECK_EQ(v->device.device_id, dev_id)
+ << "The device id of remote_pages and v matrix doesn't match.";
+ CHECK_EQ(remote_position_map->device.device_id, dev_id)
+ << "The device id of remote_pages and remote_position_map matrix doesn't
match.";
+ CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+ << "The device id of remote_pages and remote_tp_group_pe_offset matrix
doesn't match.";
+
+ CHECK_EQ(remote_pages->ndim, 5);
+ int remote_num_pages = remote_pages->shape[0];
+ int remote_num_kv_head = remote_pages->shape[2];
+ int page_size = remote_pages->shape[3];
+ int head_dim = remote_pages->shape[4];
+
+ CHECK_GE(k->ndim, 3);
+ int kv_len = k->shape[k->ndim - 3];
+ int local_num_kv_heads = k->shape[k->ndim - 2];
+ CHECK_EQ(head_dim, k->shape[k->ndim - 1]);
+
+ CHECK_GE(v->ndim, 3);
+ CHECK_EQ(kv_len, v->shape[v->ndim - 3]);
+ CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]);
+ CHECK_EQ(head_dim, v->shape[v->ndim - 1]);
+
+ CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 &&
v->dtype.lanes == 1);
+ CHECK(remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code
== k->dtype.code);
+ CHECK(remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code
== v->dtype.code);
+ int local_tp_rank;
+ tvm::runtime::DiscoWorker* worker =
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
+ if (worker == nullptr) {
+ local_tp_rank = 0;
+ } else {
+ local_tp_rank = worker->worker_id;
+ }
+
+ dim3 blocks(8, 1, 1);
+ dim3 threads(32, local_num_kv_heads, 1);
+ DISPATCH_TVM_CUDA_DTYPE(
+ remote_pages->dtype, dtype_in,
+ {DISPATCH_HEAD_DIM(
+ head_dim, HEAD_DIM,
+ {DISPATCH_PAGE_SIZE(
+ page_size, PAGE_SIZE,
+ {DISPATCH_NUM_KV_HEAD(
+ remote_num_kv_head, REMOTE_NUM_KV_HEAD,
+ {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD,
{
+ dtype_in* remote_pages_data = reinterpret_cast<dtype_in*>(
+ reinterpret_cast<char*>(remote_pages->data) +
remote_pages->byte_offset);
+ dtype_in* k_data = reinterpret_cast<dtype_in*>(
+ reinterpret_cast<char*>(k->data) + k->byte_offset);
+ dtype_in* v_data = reinterpret_cast<dtype_in*>(
+ reinterpret_cast<char*>(v->data) + v->byte_offset);
+ int32_t* remote_position_map_data =
reinterpret_cast<int32_t*>(
+ reinterpret_cast<char*>(remote_position_map->data) +
+ remote_position_map->byte_offset);
+ int32_t* remote_tp_group_pe_offset_data =
reinterpret_cast<int32_t*>(
+
reinterpret_cast<char*>(remote_tp_group_pe_offset->data) +
+ remote_tp_group_pe_offset->byte_offset);
+ KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD,
REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE>
+ <<<blocks, threads, 0,
static_cast<cudaStream_t>(transfer_stream)>>>(
+ remote_pages_data, k_data, v_data,
remote_position_map_data, kv_len,
+ local_tp_rank, remote_tp_group_pe_offset_data,
remote_num_pages);
+ })})})})})
+
+ return 0;
+}
+
+int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages,
+ DLTensor* remote_position_map, DLTensor*
local_position_map,
+ DLTensor* remote_tp_group_pe_offset, TVMStreamHandle
transfer_stream) {
+ CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+ << "The device of remote_pages matrix must be CUDA.";
+ CHECK_EQ(local_pages->device.device_type, kDLCUDA) << "The device of k
matrix must be CUDA.";
+ CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+ << "The device of remote_position_map matrix must be CUDA.";
+ size_t dev_id = remote_pages->device.device_id;
+ CHECK_EQ(local_pages->device.device_id, dev_id)
+ << "The device id of remote_pages and k matrix doesn't match.";
+ CHECK_EQ(remote_position_map->device.device_id, dev_id)
+ << "The device id of remote_pages and remote_position_map matrix doesn't
match.";
+ CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+ << "The device id of remote_pages and remote_tp_group_pe_offset matrix
doesn't match.";
+
+ CHECK_EQ(remote_pages->ndim, 5);
+ int remote_num_kv_head = remote_pages->shape[2];
+ int page_size = remote_pages->shape[3];
+ int head_dim = remote_pages->shape[4];
+
+ CHECK_GE(local_pages->ndim, 5);
+ int local_num_kv_heads = local_pages->shape[2];
+ CHECK_EQ(head_dim, local_pages->shape[4]);
+
+ CHECK_EQ(remote_position_map->ndim, 1);
+ int ntokens = remote_position_map->shape[0];
+
+ CHECK(remote_pages->dtype.lanes == 1 && local_pages->dtype.lanes == 1);
+ CHECK(remote_pages->dtype.bits == local_pages->dtype.bits &&
+ remote_pages->dtype.code == local_pages->dtype.code);
+
+ int local_tp_rank;
+ tvm::runtime::DiscoWorker* worker =
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
+ if (worker == nullptr) {
+ local_tp_rank = 0;
+ } else {
+ local_tp_rank = worker->worker_id;
+ }
+
+ dim3 blocks(8, 1, 1);
+ dim3 threads(32, local_num_kv_heads, 2);
+ DISPATCH_TVM_CUDA_DTYPE(
+ remote_pages->dtype, dtype_in,
+ {DISPATCH_HEAD_DIM(
+ head_dim, HEAD_DIM,
+ {DISPATCH_PAGE_SIZE(
+ page_size, PAGE_SIZE,
+ {DISPATCH_NUM_KV_HEAD(
+ remote_num_kv_head, REMOTE_NUM_KV_HEAD,
+ {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD,
{
+ dtype_in* remote_pages_data = reinterpret_cast<dtype_in*>(
+ reinterpret_cast<char*>(remote_pages->data) +
remote_pages->byte_offset);
+ dtype_in* local_pages_data = reinterpret_cast<dtype_in*>(
+ reinterpret_cast<char*>(local_pages->data) +
local_pages->byte_offset);
+ int32_t* remote_position_map_data =
reinterpret_cast<int32_t*>(
+ reinterpret_cast<char*>(remote_position_map->data) +
+ remote_position_map->byte_offset);
+ int32_t* local_position_map_data =
reinterpret_cast<int32_t*>(
+ reinterpret_cast<char*>(local_position_map->data) +
+ local_position_map->byte_offset);
+ int32_t* remote_tp_group_pe_offset_data =
reinterpret_cast<int32_t*>(
+
reinterpret_cast<char*>(remote_tp_group_pe_offset->data) +
+ remote_tp_group_pe_offset->byte_offset);
+ KVTransferPageToPage<dtype_in, LOCAL_NUM_KV_HEAD,
REMOTE_NUM_KV_HEAD, HEAD_DIM,
+ PAGE_SIZE>
+ <<<blocks, threads, 0,
static_cast<cudaStream_t>(transfer_stream)>>>(
+ remote_pages_data, local_pages_data,
remote_position_map_data,
+ local_position_map_data, ntokens, local_tp_rank,
+ remote_tp_group_pe_offset_data);
+ })})})})})
+
+ return 0;
+}
+
+TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer);
+TVM_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage);
diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc
b/src/runtime/contrib/nvshmem/memory_allocator.cc
index 89d56ed3dc..4380c7e65d 100644
--- a/src/runtime/contrib/nvshmem/memory_allocator.cc
+++ b/src/runtime/contrib/nvshmem/memory_allocator.cc
@@ -25,6 +25,7 @@
#include <thread>
#include "../../cuda/cuda_common.h"
+#include "../../disco/utils.h"
#include "../../memory/pooled_allocator.h"
namespace tvm {
@@ -88,7 +89,7 @@ class NVSHMEMAllocator final : public PooledAllocator {
};
NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) {
- return NVSHMEMAllocator::Global()->Empty(shape, dtype, device);
+ return NVSHMEMAllocator::Global()->Empty(shape, dtype,
UseDefaultDeviceIfNone(device));
}
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty);
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 6ee54e14f3..75e7db483e 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -93,7 +93,15 @@ void InitCCLPerWorker(IntTuple device_ids, std::string
unique_id_bytes) {
StreamCreate(&ctx->default_stream);
#endif
Device device{TVM_DISCO_DEVICE_TYPE, device_id};
- worker->default_device = device;
+ if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
+ worker->default_device = device;
+ } else {
+ ICHECK(worker->default_device.device_type == device.device_type &&
+ worker->default_device.device_id == device.device_id)
+ << "The default device of the worker is inconsistent with the device
used for CCL. "
+ << "The default device is " << worker->default_device << ", but the
device used for CCL is "
+ << device << ".";
+ }
worker->ccl = TVM_DISCO_CCL_NAME;
ctx->worker = worker;
ctx->device_id = device_id;
diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
index b730a4eb07..67afda3bfd 100644
--- a/src/runtime/relax_vm/kv_state.cc
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -56,6 +56,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
.set_body_method<KVState>(&KVStateObj::EndForward);
// Attention KV Cache methods
+TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv")
+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggPrepareRecv);
+TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send")
+ .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggMarkSend);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 6d30ce998a..7df3215d08 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -157,6 +157,14 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
const IntTuple& leaf_indices) = 0;
+ /*! \brief Prepare for the disaggregation KV data receive for the specified
sequence and length.*/
+ virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;
+
+ /*! \brief Mark which tokens' KV cache needs to be sent to other devices */
+ virtual void DisaggMarkSend(int64_t seq_id, int64_t begin,
+ const IntTuple& compressed_remote_position_map,
+ int32_t recver_pe_offset) = 0;
+
/************** Attention **************/
/*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index b6636ae1a7..81c55bfcb6 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -129,6 +129,13 @@ struct Block {
}
};
+struct KVTransferMetadata {
+ int64_t start = std::numeric_limits<int64_t>::max();
+ std::vector<int64_t> remote_position_map;
+ int32_t recver_pe_offset = -1;
+ std::vector<int64_t> local_position_map;
+};
+
/*!
* \brief The sequence structure in paged KV cache with common prefix support.
* Each sequence contains one or more blocks to support common prefix.
@@ -163,6 +170,8 @@ struct Sequence {
std::vector<int32_t> token_tree_parent_ptr;
/*! \brief The depth of each node in the token tree. */
std::vector<int32_t> token_tree_node_depths;
+ /*! \brief The metadata of kv transfer*/
+ KVTransferMetadata kv_transfer_metadata;
/*!
* \brief A boolean denoting whether the accepted token tree indices of
* this sequence are committed
@@ -231,7 +240,13 @@ class HostMemoryVector {
}
void push_back(int32_t value) {
- ICHECK_LT(current_size_, reserved_size_);
+ ICHECK_LE(current_size_, reserved_size_);
+ if (current_size_ == reserved_size_) {
+ reserved_size_ *= 2;
+ NDArray new_data = NDArray::Empty({reserved_size_}, data_->dtype,
data_->device);
+ std::memcpy(new_data->data, data_->data, current_size_ *
DataType(data_->dtype).bytes());
+ data_ = new_data;
+ }
static_cast<int32_t*>(data_->data)[current_size_++] = value;
}
@@ -255,6 +270,15 @@ class HostMemoryVector {
/*! \brief Return the vector as an NDArray. */
NDArray as_ndarray() { return data_.CreateView({current_size_},
data_->dtype); }
+ IntTuple as_int_tuple() const {
+ std::vector<int64_t> values;
+ values.reserve(current_size_);
+ for (int i = 0; i < current_size_; ++i) {
+ values.push_back(static_cast<int32_t*>(data_->data)[i]);
+ }
+ return IntTuple(values);
+ }
+
private:
int64_t reserved_size_ = 0;
int64_t current_size_ = 0;
@@ -331,6 +355,16 @@ class PagedKVCacheAuxDataManager {
* appending new K/V data.
*/
virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0;
+ /*! \brief Copy the remote position map for KV transfer. */
+ virtual NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data)
= 0;
+ /*! \brief Copy the receiver id for KV transfer. */
+ virtual NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0;
+ /*! \brief Copy the local position map for KV page-to-page transfer. */
+ virtual NDArray
CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0;
+ /*! \brief Copy the remote position map for KV page-to-page transfer. */
+ virtual NDArray
CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0;
+ /*! \brief Copy the receiver id for KV page-to-page transfer. */
+ virtual NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data)
= 0;
/*! \brief Copy the tree attention mask. */
virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int
depth) = 0;
/*! \brief Copy the mn indptr of the tree attention mask. */
@@ -390,7 +424,14 @@ class PlainPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs},
dtype_aux_, device);
q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size},
dtype_aux_, device);
append_position_map_device_ = NDArray::Empty({prefill_chunk_size},
dtype_aux_, device);
-
+ kv_transfer_remote_position_map_device =
+ NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
+ kv_transfer_recver_id_device = NDArray::Empty({prefill_chunk_size},
dtype_aux_, device);
+ kv_transfer_page_to_page_local_position_map_device =
+ kv_transfer_page_to_page_remote_position_map_device =
+ NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
+ kv_transfer_page_to_page_recver_id_device =
+ NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs +
1}, dtype_aux_, device);
commit_copy_src_dst_pos_in_page_table_device_ =
NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs,
prefill_chunk_size)},
@@ -453,6 +494,37 @@ class PlainPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
CopyVecDataToArray(view, data->data());
return view;
}
+ NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
+ NDArray view = kv_transfer_remote_position_map_device.CreateView(
+ {static_cast<int64_t>(data->size())}, dtype_aux_);
+ CopyVecDataToArray(view, data->data());
+ return view;
+ }
+ NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final {
+ NDArray view =
+
kv_transfer_recver_id_device.CreateView({static_cast<int64_t>(data->size())},
dtype_aux_);
+ CopyVecDataToArray(view, data->data());
+ return view;
+ }
+ NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data)
final {
+ NDArray view =
kv_transfer_page_to_page_local_position_map_device.CreateView(
+ {static_cast<int64_t>(data->size())}, dtype_aux_);
+ CopyVecDataToArray(view, data->data());
+ return view;
+ }
+ NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector*
data) final {
+ NDArray view =
kv_transfer_page_to_page_remote_position_map_device.CreateView(
+ {static_cast<int64_t>(data->size())}, dtype_aux_);
+ CopyVecDataToArray(view, data->data());
+ return view;
+ }
+ NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final {
+ NDArray view = kv_transfer_page_to_page_recver_id_device.CreateView(
+ {static_cast<int64_t>(data->size())}, dtype_aux_);
+ CopyVecDataToArray(view, data->data());
+ return view;
+ }
+
NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth)
final {
NDArray view =
tree_attn_mask_device_[depth].CreateView({static_cast<int64_t>(data->size())},
dtype_aux_);
@@ -566,6 +638,11 @@ class PlainPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
NDArray k_ragged_rope_pos_offset_device_;
NDArray q_rope_position_map_device_;
NDArray append_position_map_device_;
+ NDArray kv_transfer_remote_position_map_device;
+ NDArray kv_transfer_recver_id_device;
+ NDArray kv_transfer_page_to_page_local_position_map_device;
+ NDArray kv_transfer_page_to_page_remote_position_map_device;
+ NDArray kv_transfer_page_to_page_recver_id_device;
NDArray commit_copy_length_indptr_device_;
NDArray commit_copy_src_dst_pos_in_page_table_device_;
};
@@ -633,6 +710,21 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
+ NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
+ return CopyAttnAuxVecToCache(data);
+ }
+ NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final {
+ return CopyAttnAuxVecToCache(data);
+ }
+ NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data)
final {
+ return CopyAttnAuxVecToCache(data);
+ }
+ NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector*
data) final {
+ return CopyAttnAuxVecToCache(data);
+ }
+ NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final {
+ return CopyAttnAuxVecToCache(data);
+ }
NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth)
final {
NDArray mask_1d = CopyAttnAuxVecToCache(data);
return mask_1d.CreateView({static_cast<int64_t>(data->size() / 2), 2},
mask_1d->dtype);
@@ -736,12 +828,22 @@ class CachedPagedKVCacheAuxDataManager : public
PagedKVCacheAuxDataManager {
// - k_ragged_rope_pos_offset
// - q_rope_position_map
// - append_position_map
+ // - kv_transfer_remote_position_map
+ // - kv_transfer_recver_id
+ // - kv_transfer_page_to_page_local_position_map
+ // - kv_transfer_page_to_page_remote_position_map
+ // - kv_transfer_page_to_page_recver_id
// - tree_attn_mask
// - tree_attn_mn_indptr
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
cache_size += CeilDivElemAlignment(reserved_num_seqs);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
+ cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size +=
CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize *
reserved_num_seqs);
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
@@ -862,11 +964,15 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
/*!
* \brief The KV data managed by the KV cache.
- * The array has `num_layers` NDArrays, each of them
+ * If KV transfer function is specifed, pages_ will be allocated by NVSHMEM
as a whole NDArray.
+ * pages_ will contain tensor view of each layer.
+ * Otherwise, pages_ has `num_layers` NDArrays, each of them
* has layout (num_pages, 2, num_heads, page_size, head_dim).
* Along on the "2" dimension, index 0 stands for K and 1 stands for V.
*/
- Array<NDArray> pages_;
+ std::vector<NDArray> pages_;
+ /*! \brief The whole KV cache allocated by NVSHMEM*/
+ NDArray nvshmem_pages_;
/*! \brief The list of ids of released pages for page reuse. */
std::vector<int32_t> free_page_ids_;
/*! \brief The mapping from sequence ids to sequences. */
@@ -907,6 +1013,12 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
std::vector<bool> use_decode_kernel_;
/*! \brief Whether the attention request is a decode request, set in
BeginForwardFunction. */
bool is_decode_request_;
+ /*! \brief The KV transfer recver disco group's PE offset in this forward.
+ If no KV is transfered, recver is -1.
+ Assume that all the KV are transfered to the same recver in the
forward.
+ todo: support multiple recver. */
+ bool transfer_kv_;
+ bool page_to_page_transfer_kv_;
/*! \brief The auxiliary data manager for attention. */
std::unique_ptr<PagedKVCacheAuxDataManager> aux_data_manager_;
@@ -940,6 +1052,11 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
HostMemoryVector commit_copy_length_indptr_host_;
HostMemoryVector commit_copy_src_pos_in_page_table_host_;
HostMemoryVector commit_copy_dst_pos_in_page_table_host_;
+ HostMemoryVector kv_transfer_remote_position_map_host_;
+ HostMemoryVector kv_transfer_recver_id_host_;
+ HostMemoryVector kv_transfer_page_to_page_local_position_map_host_;
+ HostMemoryVector kv_transfer_page_to_page_remote_position_map_host_;
+ HostMemoryVector kv_transfer_page_to_page_recver_id_host_;
//-------------------------------------------
// For efficient memory management, the actual sizes of the arrays
@@ -952,6 +1069,11 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
NDArray k_ragged_rope_pos_offset_view_;
NDArray q_rope_position_map_view_;
NDArray append_position_map_view_;
+ NDArray kv_transfer_remote_position_map_view_;
+ NDArray kv_transfer_recver_id_view_;
+ NDArray kv_transfer_page_to_page_local_position_map_view_;
+ NDArray kv_transfer_page_to_page_remote_position_map_view_;
+ NDArray kv_transfer_page_to_page_recver_id_view_;
NDArray temp_attn_output_view_;
NDArray temp_attn_scores_view_;
NDArray merged_attn_scores_view_;
@@ -964,6 +1086,8 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
std::vector<NDArray> tree_attn_mn_indptr_view_;
PackedFunc f_transpose_append_;
+ Optional<PackedFunc> f_transfer_kv_;
+ Optional<PackedFunc> f_transfer_kv_page_to_page_ = NullOpt;
PackedFunc f_compact_copy_;
PackedFunc f_attention_prefill_;
PackedFunc f_attention_decode_;
@@ -989,6 +1113,8 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
TVMStreamHandle compute_stream_ = nullptr;
/*! \brief The device stream for copying auxiliary data structure to GPU. */
TVMStreamHandle copy_stream_ = nullptr;
+ /*! \brief The device stream for KV transfer */
+ TVMStreamHandle kv_transfer_stream_ = nullptr;
public:
/*! \brief Constructor. Take the cache configuration and initialize the
NDArrays. */
@@ -997,7 +1123,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t
reserved_num_seqs,
int64_t num_total_pages, int64_t prefill_chunk_size, bool
support_sliding_window,
RoPEMode rope_mode, double rotary_scale, double rotary_theta,
- Optional<NDArray> rope_ext_factors, DLDataType dtype, Device device,
+ Optional<NDArray> rope_ext_factors, bool enable_kv_transfer, DLDataType
dtype, Device device,
PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc
f_attention_prefill,
PackedFunc f_attention_decode, PackedFunc
f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window, PackedFunc
f_attention_prefill_ragged,
@@ -1047,10 +1173,35 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
- for (int i = 0; i < num_layers; ++i) {
- pages_.push_back(
- NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size,
head_dim}, dtype, device));
+ if (enable_kv_transfer) {
+ CHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem") != nullptr)
+ << "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when
compiling TVM.";
+ const PackedFunc* f_nvshmem_empty =
runtime::Registry::Get("runtime.disco.nvshmem.empty");
+ ICHECK_NOTNULL(f_nvshmem_empty);
+ nvshmem_pages_ = (*f_nvshmem_empty)(
+ ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size,
head_dim}), dtype,
+ device);
+ for (int i = 0; i < num_layers; ++i) {
+ pages_.push_back(nvshmem_pages_.CreateView(
+ {num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_},
nvshmem_pages_->dtype,
+ i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ *
+ nvshmem_pages_.DataType().bytes()));
+ }
+
+ const PackedFunc* f_transfer_kv_ptr =
Registry::Get("nvshmem.KVTransfer");
+ const PackedFunc* f_transfer_kv_page_to_page_ptr =
+ Registry::Get("nvshmem.KVTransferPageToPage");
+ ICHECK_NOTNULL(f_transfer_kv_ptr);
+ ICHECK_NOTNULL(f_transfer_kv_page_to_page_ptr);
+ f_transfer_kv_ = *f_transfer_kv_ptr;
+ f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
+ } else {
+ for (int i = 0; i < num_layers; ++i) {
+ pages_.push_back(
+ NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size,
head_dim}, dtype, device));
+ }
}
+
// Allocate the host memory.
Device preferred_host_device = GetPreferredHostDevice(device);
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
@@ -1079,6 +1230,16 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
append_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
+ kv_transfer_remote_position_map_host_ =
+ HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
+ kv_transfer_recver_id_host_ =
+ HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
+ kv_transfer_page_to_page_local_position_map_host_ =
+ HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
+ kv_transfer_page_to_page_remote_position_map_host_ =
+ HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
+ kv_transfer_page_to_page_recver_id_host_ =
+ HostMemoryVector(prefill_chunk_size, dtype_aux_,
preferred_host_device);
cur_append_lengths_indptr_host_ =
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_,
preferred_host_device);
commit_copy_length_indptr_host_ =
@@ -1135,6 +1296,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// The compute stream is the default stream.
compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
+ kv_transfer_stream_ = DeviceAPI::Get(device)->CreateStream(device);
}
// Create the auxiliary data manager for attention.
@@ -1162,12 +1324,14 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
if (copy_stream_ != nullptr) {
DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);
}
+ if (kv_transfer_stream_ != nullptr) {
+ DeviceAPI::Get(device_)->FreeStream(device_, kv_transfer_stream_);
+ }
}
/*! \brief Reset the KV cache. */
void Clear() final {
seq_map_.clear();
- ICHECK(pages_.defined());
free_page_ids_.clear();
for (int64_t page_id = num_total_pages_ - 1; page_id >= 0; --page_id) {
free_page_ids_.push_back(page_id);
@@ -1331,7 +1495,8 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
}
for (int layer = 0; layer < num_layers_; ++layer) {
- f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id,
copy_length);
+ NDArray page_layer_view = pages_[layer];
+ f_copy_single_page_(page_layer_view, src_page_id, tgt_page_id,
copy_length);
}
if (copy_stream_ != compute_stream_) {
// Set the compute stream back.
@@ -1677,6 +1842,13 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// in the global KV cache. The mapping is used in when appending k/v
values.
q_rope_position_map_host_.clear();
append_position_map_host_.clear();
+ kv_transfer_remote_position_map_host_.clear();
+ kv_transfer_recver_id_host_.clear();
+ kv_transfer_page_to_page_local_position_map_host_.clear();
+ kv_transfer_page_to_page_remote_position_map_host_.clear();
+ kv_transfer_page_to_page_recver_id_host_.clear();
+ transfer_kv_ = false;
+ page_to_page_transfer_kv_ = false;
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = append_lengths[i];
const Block& block = global_block_pool_[sequences[i]->last_block_idx];
@@ -1710,11 +1882,40 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
page_size_ +
offset_in_block % page_size_);
}
+ int64_t pos_in_seq = sequences[i]->seq_length - append_length + pos;
+ int64_t seq_send_start = sequences[i]->kv_transfer_metadata.start;
+ if (pos_in_seq < seq_send_start) {
+ kv_transfer_remote_position_map_host_.push_back(-1);
+ kv_transfer_recver_id_host_.push_back(-1);
+ } else {
+ transfer_kv_ = true;
+ kv_transfer_remote_position_map_host_.push_back(
+
sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq -
seq_send_start]);
+ kv_transfer_recver_id_host_.push_back(
+ sequences[i]->kv_transfer_metadata.recver_pe_offset);
+ }
+ }
+ if (!sequences[i]->kv_transfer_metadata.local_position_map.empty()) {
+ page_to_page_transfer_kv_ = true;
+ for (int pos = 0;
+ pos <
static_cast<int>(sequences[i]->kv_transfer_metadata.local_position_map.size());
+ ++pos) {
+ kv_transfer_page_to_page_local_position_map_host_.push_back(
+ sequences[i]->kv_transfer_metadata.local_position_map[pos]);
+ kv_transfer_page_to_page_remote_position_map_host_.push_back(
+ sequences[i]->kv_transfer_metadata.remote_position_map[pos]);
+ kv_transfer_page_to_page_recver_id_host_.push_back(
+ sequences[i]->kv_transfer_metadata.recver_pe_offset);
+ }
+ sequences[i]->kv_transfer_metadata.local_position_map.clear();
}
}
}
void EndForward() final {
+ if (kv_transfer_stream_ != nullptr) {
+ DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_,
compute_stream_);
+ }
if (!f_attention_prefill_end_forward_.defined() ||
!f_attention_decode_end_forward_.defined() ||
!f_attention_prefill_ragged_end_forward_.defined()) {
return;
@@ -1726,6 +1927,88 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
}
+ IntTuple DisaggPrepareRecv(int64_t seq_id, int append_length) final {
+ // No CPU to GPU copy is needed.
+ // Essentially we
+ // (step 1.) redirect the preparation to BeginForward.
+ BeginForward({seq_id}, {append_length},
/*opt_token_tree_parent_ptr=*/NullOpt);
+ // (step 2.) fetch the append_position_map, compress and return.
+ // Compression format: [n, begin_1, length_1, begin_2, length_2, ...,
begin_n, length_n]
+ // The compressed format will be decompressed to:
+ // [begin_1, begin_1+1, ..., begin_1+length_1-1, ..., begin_n, ...,
begin_n+length_n-1]
+ CHECK_EQ(append_position_map_host_.size(), append_length);
+ std::vector<int64_t> compressed_append_pos_map{/*num_segments=*/1,
+
append_position_map_host_[0]};
+ for (int i = 1; i < append_length; ++i) {
+ if (append_position_map_host_[i] != append_position_map_host_[i - 1] +
1) {
+ // Terminate the current segment.
+ compressed_append_pos_map.push_back(append_position_map_host_[i - 1] -
+ compressed_append_pos_map.back() +
1);
+ // Start a new segment.
+ ++compressed_append_pos_map[0];
+ compressed_append_pos_map.push_back(append_position_map_host_[i]);
+ }
+ }
+ // Terminate the last segment.
+ compressed_append_pos_map.push_back(append_position_map_host_.back() -
+ compressed_append_pos_map.back() + 1);
+ // The compressed array size should be "num_segments * 2 + 1".
+ CHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] *
2 + 1);
+ return IntTuple{compressed_append_pos_map};
+ }
+
+ void DisaggMarkSend(int64_t seq_id, int64_t begin, const IntTuple&
compressed_remote_position_map,
+ int32_t recver_pe_offset) {
+ ICHECK(f_transfer_kv_.defined());
+ auto it = seq_map_.find(seq_id);
+ CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot
be found in KV cache.";
+ Sequence* sequence = &it->second;
+ sequence->kv_transfer_metadata.start = begin;
+ int nsegments = compressed_remote_position_map[0];
+ sequence->kv_transfer_metadata.remote_position_map.clear();
+ for (int i = 0; i < nsegments; ++i) {
+ int begin = compressed_remote_position_map[2 * i + 1];
+ int length = compressed_remote_position_map[2 * i + 2];
+ for (int j = 0; j < length; ++j) {
+ sequence->kv_transfer_metadata.remote_position_map.push_back(begin +
j);
+ }
+ }
+ sequence->kv_transfer_metadata.recver_pe_offset = recver_pe_offset;
+
+ sequence->kv_transfer_metadata.local_position_map.clear();
+ if (begin >= sequence->seq_length) {
+ return;
+ }
+ // Need to send existing KV.
+
CHECK_GT(static_cast<int>(sequence->kv_transfer_metadata.remote_position_map.size()),
+ sequence->seq_length - begin)
+ << "Need at least one token to prefill";
+ std::vector<int32_t> trace = sequence->GetBlockTrace(global_block_pool_);
+
sequence->kv_transfer_metadata.local_position_map.reserve(sequence->seq_length
- begin);
+ bool done = false;
+ for (auto it_block_id = trace.rbegin(); it_block_id != trace.rend();
++it_block_id) {
+ const Block& block = global_block_pool_[*it_block_id];
+ for (int i = block.seq_length - 1; i >= 0; --i) {
+ int32_t offset =
+ i < block.sink_length ? i : i - block.sink_length +
block.sliding_window_offset;
+ int page_id = block.page_ids[offset / page_size_];
+ int page_offset = offset % page_size_;
+ sequence->kv_transfer_metadata.local_position_map.push_back(page_id *
page_size_ +
+
page_offset);
+ if
(static_cast<int>(sequence->kv_transfer_metadata.local_position_map.size()) ==
+ sequence->seq_length - begin) {
+ done = true;
+ break;
+ }
+ }
+ if (done) {
+ break;
+ }
+ }
+ std::reverse(sequence->kv_transfer_metadata.local_position_map.begin(),
+ sequence->kv_transfer_metadata.local_position_map.end());
+ }
+
void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data,
Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor)
final {
// Part 1. Shape and dtype check.
@@ -1777,6 +2060,10 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_},
qkv_data->dtype);
}
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
+ if (transfer_kv_) {
+ // The the compute stream needs to wait for the KV transfer stream.
+ DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_,
compute_stream_);
+ }
if (!rope_ext_factors_.defined()) {
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data,
k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
@@ -1789,9 +2076,30 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
if (append_before_attn_) {
f_transpose_append_(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
}
- // Part 4: perform attention
+ // Part 4: KV transfer
+ if (page_to_page_transfer_kv_) {
+ DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_,
kv_transfer_stream_);
+ // FIXME: if the sender and recver's PP/TP degree do not match, we will
need to first
+ // get the view of remote pages, and then take the specific remote layer.
+ // The KV transfer stream nees to wait for the compute stream.
+ f_transfer_kv_page_to_page_.value()(pages_[local_layer_id],
pages_[local_layer_id],
+
kv_transfer_page_to_page_remote_position_map_view_,
+
kv_transfer_page_to_page_local_position_map_view_,
+
kv_transfer_page_to_page_recver_id_view_,
+ kv_transfer_stream_);
+ }
+ if (transfer_kv_) {
+ // FIXME: if the sender and recver's PP/TP degree do not match, we will
need to first
+ // get the view of remote pages, and then take the specific remote layer.
+ // The KV transfer stream nees to wait for the compute stream.
+ DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_,
kv_transfer_stream_);
+ f_transfer_kv_.value()(pages_[local_layer_id], k_data, v_data,
+ kv_transfer_remote_position_map_view_,
kv_transfer_recver_id_view_,
+ kv_transfer_stream_);
+ }
+ // Part 5: perform attention
AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view,
attn_score_scaling_factor);
- // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not
set.
+ // Part 6. Append k/v data to kv-cache if flag "append_before_attn" is not
set.
if (!append_before_attn_) {
f_transpose_append_(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
}
@@ -2491,6 +2799,8 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
total_append_length = cur_append_lengths_indptr_host_.back();
ICHECK_EQ(total_append_length, append_position_map_host_.size());
+ ICHECK_EQ(total_append_length,
kv_transfer_remote_position_map_host_.size());
+ ICHECK_EQ(total_append_length, kv_transfer_recver_id_host_.size());
// - Reset the copy.
aux_data_manager_->ResetAttnAuxDataCopy();
@@ -2553,7 +2863,26 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// 9. append_position_map
append_position_map_view_ =
aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_);
- // 10. tree_attn_mask and tree_attn_mn_indptr
+ // 10. kv_transfer_remote_position_map
+ kv_transfer_remote_position_map_view_ =
aux_data_manager_->CopyKVTransferRemotePositionMapAsync(
+ &kv_transfer_remote_position_map_host_);
+ // 11. kv_transfer_recver_id
+ kv_transfer_recver_id_view_ =
+
aux_data_manager_->CopyKVTransferRecverIDAsync(&kv_transfer_recver_id_host_);
+
+ // 12. kv_transfer_page_to_page_local_position_map
+ kv_transfer_page_to_page_local_position_map_view_ =
+ aux_data_manager_->CopyKVTransferPage2PageLocalPositionMapAsync(
+ &kv_transfer_page_to_page_local_position_map_host_);
+ // 13. kv_transfer_page_to_page_remote_position_map
+ kv_transfer_page_to_page_remote_position_map_view_ =
+ aux_data_manager_->CopyKVTransferPage2PageRemotePositionMapAsync(
+ &kv_transfer_page_to_page_remote_position_map_host_);
+ // 14. kv_transfer_page_to_page_recver_id
+ kv_transfer_page_to_page_recver_id_view_ =
+ aux_data_manager_->CopyKVTransferPage2PageRecverIDAsync(
+ &kv_transfer_page_to_page_recver_id_host_);
+ // 15. tree_attn_mask and tree_attn_mn_indptr
for (int d = 0; d < num_depths_; ++d) {
if (!is_chain_on_depths_[d]) {
tree_attn_mask_view_[d] =
@@ -2562,7 +2891,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
aux_data_manager_->CopyTreeAttnMNIndptrOnDepthAsync(&tree_attn_mn_indptr_host_[d],
d);
}
}
- // 11. Create view for temporary arrays for attention computation.
+ // 16. Create view for temporary arrays for attention computation.
temp_attn_output_view_ = temp_attn_output_device_.CreateView(
{total_append_length, num_qo_heads_, head_dim_},
temp_attn_output_device_->dtype);
temp_attn_scores_view_ = temp_attn_scores_device_.CreateView(
@@ -2585,7 +2914,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
- CHECK(args.size() == 28 || args.size() == 29)
+ CHECK(args.size() == 29 || args.size() == 30)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
ShapeTuple layer_indptr_tuple = args[1];
@@ -2626,10 +2955,14 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
PackedFunc f_attention_prefill_with_tree_mask = args[26];
PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27];
Optional<NDArray> rope_ext_factors = NullOpt;
+ bool enable_kv_transfer = false;
- if (args.size() >= 29 && args[28].IsObjectRef<NDArray>()) {
+ if (args[28].IsObjectRef<NDArray>()) {
rope_ext_factors = args[28].AsObjectRef<NDArray>();
}
+ if (args.size() >= 30) {
+ enable_kv_transfer = args[29];
+ }
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
@@ -2646,9 +2979,9 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, head_dim,
reserved_num_seqs, num_total_pages, prefill_chunk_size,
support_sliding_window,
RoPEMode(rope_mode), rotary_scale, rotary_theta,
std::move(rope_ext_factors), //
- init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy),
- std::move(f_attention_prefill), std::move(f_attention_decode),
- std::move(f_attention_prefill_sliding_window),
+ enable_kv_transfer, init->dtype, init->device,
//
+ std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill),
+ std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
std::move(f_attention_prefill_with_tree_mask_paged_kv),
@@ -2663,7 +2996,7 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body([](TVMArgs args, TVMRetValue* rv) {
- CHECK(args.size() == 22 || args.size() == 23)
+ CHECK(args.size() == 23 || args.size() == 24)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
ShapeTuple layer_indptr_tuple = args[1];
@@ -2698,10 +3031,14 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
PackedFunc f_attention_prefill_with_tree_mask = args[20];
PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21];
Optional<NDArray> rope_ext_factors = NullOpt;
+ bool enable_kv_transfer = false;
- if (args.size() >= 23 && args[22].IsObjectRef<NDArray>()) {
+ if (args[22].IsObjectRef<NDArray>()) {
rope_ext_factors = args[22].AsObjectRef<NDArray>();
}
+ if (args.size() >= 24) {
+ enable_kv_transfer = args[23];
+ }
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
@@ -2718,9 +3055,9 @@
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
page_size, num_layers, layer_id_begin_offset, num_qo_heads,
num_kv_heads, head_dim,
reserved_num_seqs, num_total_pages, prefill_chunk_size,
support_sliding_window,
RoPEMode(rope_mode), rotary_scale, rotary_theta,
std::move(rope_ext_factors), //
- init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy),
- std::move(f_attention_prefill), std::move(f_attention_decode),
- std::move(f_attention_prefill_sliding_window),
+ enable_kv_transfer, init->dtype, init->device,
//
+ std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill),
+ std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
std::move(f_attention_prefill_with_tree_mask_paged_kv), //
diff --git a/tests/python/disco/test_nvshmem.py
b/tests/python/disco/test_nvshmem.py
index b304d145aa..1c4ffc9c4d 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -107,7 +107,7 @@ def test_nvshmem_init_finalize(session_kind: di.Session,
num_workers: int):
f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
- init_dfunc(uid, num_workers)
+ init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
@@ -123,7 +123,7 @@ def test_nvshmem_empty(session_kind: di.Session,
num_workers: int):
f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
- init_dfunc(uid, num_workers)
+ init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
similarity index 57%
copy from
tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
copy to tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index 82f85f4b17..483108ca83 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -40,6 +40,20 @@ from tvm.relax.frontend.nn.llm.kv_cache import (
)
from tvm.runtime import ShapeTuple
+
+def get_comm_rank():
+ try:
+ from mpi4py import MPI
+
+ comm = MPI.COMM_WORLD
+ rank = comm.Get_rank()
+ return comm, rank
+ except ImportError:
+ return None, 0
+
+
+comm, rank = get_comm_rank()
+
reserved_nseq = 32
maximum_total_seq_length = 2048
prefill_chunk_size = 512
@@ -52,7 +66,7 @@ rope_scale = 1.0
rope_theta = 1e4
rope_scaling = {}
dtype = None
-device = tvm.cuda()
+device = tvm.cuda(rank)
fclear = None
fadd_sequence = None
@@ -66,6 +80,10 @@ fcommit_accepted_token_tree_nodes = None
fattention_with_fuse_qkv = None
fis_empty = None
fdebug_get_kv = None
+fnvshmem_get_uid = None
+fnvshmem_init = None
+fdisagg_mark_send = None
+fdisagg_prepare_recv = None
ftranspose_append = None
fcopy_cache = None
@@ -91,6 +109,7 @@ def set_global_func(head_dim, dtype):
global fattn_prefill_ragged, fattn_prefill_with_tree_mask,
fattn_prefill_with_tree_mask_paged_kv_cache
global fattn_prefill_sliding_window, fattn_decode_sliding_window
global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page,
fcompact_copy
+ global fnvshmem_get_uid, fnvshmem_init, fdisagg_mark_send,
fdisagg_prepare_recv
fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
@@ -111,6 +130,11 @@ def set_global_func(head_dim, dtype):
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
fdebug_get_kv =
tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
+ fnvshmem_get_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ fnvshmem_init = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ fdisagg_mark_send =
tvm.get_global_func("vm.builtin.kv_cache_disagg_mark_send")
+ fdisagg_prepare_recv =
tvm.get_global_func("vm.builtin.kv_cache_disagg_prepare_recv")
+
target = tvm.target.Target.from_device(device)
builts = []
for tir_func in [
@@ -193,6 +217,7 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
fattn_prefill_with_tree_mask,
fattn_prefill_with_tree_mask_paged_kv_cache,
None,
+ True,
)
return cache
@@ -277,6 +302,8 @@ def apply_attention(
attn_sink_sizes: Optional[List[int]] = None,
token_tree_parent_ptr_list: Optional[List[List[int]]] = None,
accepted_leaf_indices: Optional[List[int]] = None,
+ only_update_host=False,
+ skip_add_sequence=False,
) -> None:
seq_ids = []
append_lengths = []
@@ -291,7 +318,8 @@ def apply_attention(
if fork_parent_id is not None:
assert fork_parent_id in cached_k
assert seq_id not in cached_k
- ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
+ if not only_update_host:
+ ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
if fork_pos == -1:
cached_k[seq_id] = cached_k[fork_parent_id]
cached_v[seq_id] = cached_v[fork_parent_id]
@@ -299,7 +327,8 @@ def apply_attention(
cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
elif seq_id not in cached_k:
- fadd_sequence(kv_cache, seq_id)
+ if not only_update_host and not skip_add_sequence:
+ fadd_sequence(kv_cache, seq_id)
cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads,
head_dim), dtype)
cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads,
head_dim), dtype)
@@ -324,17 +353,17 @@ def apply_attention(
)
# depth of each node in the tree (this contains more than the last
`append_length` nodes)
token_tree_node_depths_list[i] = token_tree_node_depths
-
- fbegin_forward(
- kv_cache,
- ShapeTuple(seq_ids),
- ShapeTuple(append_lengths),
- (
- ShapeTuple(flattened_token_tree_parent_ptr)
- if flattened_token_tree_parent_ptr is not None
- else None
- ),
- )
+ if not only_update_host:
+ fbegin_forward(
+ kv_cache,
+ ShapeTuple(seq_ids),
+ ShapeTuple(append_lengths),
+ (
+ ShapeTuple(flattened_token_tree_parent_ptr)
+ if flattened_token_tree_parent_ptr is not None
+ else None
+ ),
+ )
global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype)
global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
@@ -365,9 +394,11 @@ def apply_attention(
rope_offset,
rope_scale,
rope_theta,
- token_tree_node_depths_list[i][-append_length:]
- if token_tree_node_depths_list[i] is not None
- else None,
+ (
+
token_tree_node_depths_list[i][-append_length:]
+ if token_tree_node_depths_list[i] is not
None
+ else None
+ ),
)
)
for l in range(num_layers)
@@ -388,7 +419,8 @@ def apply_attention(
values_np = global_new_v[layer_id]
qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np],
axis=1), device)
outputs = tvm.nd.empty(queries_np.shape, dtype, device=device)
- fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
+ if not only_update_host:
+ fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs)
# Compute attention expected results.
outputs = np.expand_dims(outputs.numpy(), axis=0)
@@ -409,9 +441,11 @@ def apply_attention(
rope_offset,
rope_scale,
rope_theta,
- token_tree_node_depths_list[i][-append_length:]
- if token_tree_node_depths_list[i] is not None
- else None,
+ (
+ token_tree_node_depths_list[i][-append_length:]
+ if token_tree_node_depths_list[i] is not None
+ else None
+ ),
)
).transpose(1, 0, 2)
k_seq = (
@@ -464,21 +498,23 @@ def apply_attention(
),
axis=0,
).astype(dtype)
-
- tvm.testing.assert_allclose(
- outputs[:, sum_length : sum_length + append_length, ...],
- results,
- rtol=1e-3,
- atol=1e-3,
- )
+ if not only_update_host:
+ tvm.testing.assert_allclose(
+ outputs[:, sum_length : sum_length + append_length, ...],
+ results,
+ rtol=1e-3,
+ atol=1e-3,
+ )
sum_length += append_length
- fend_forward(kv_cache)
+ if not only_update_host:
+ fend_forward(kv_cache)
if accepted_leaf_indices is not None:
seq_ids = [seq_id for seq_id, _ in batch]
- fcommit_accepted_token_tree_nodes(
- kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
- )
+ if not only_update_host:
+ fcommit_accepted_token_tree_nodes(
+ kv_cache, ShapeTuple(seq_ids),
ShapeTuple(accepted_leaf_indices)
+ )
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
@@ -530,11 +566,11 @@ def apply_attention(
assert cached_k[seq_id].shape[1] == sliding_window_size
# Verify
- verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
+ if not only_update_host:
+ verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
[email protected]_gpu
[email protected]_cuda
[email protected](reason="Require NVSHMEM")
def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
if support_sliding_window and rope_mode == RopeMode.NORMAL:
@@ -558,413 +594,92 @@ def
test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config):
- kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if support_sliding_window and rope_mode == RopeMode.NORMAL:
- # Normal RoPE mode under sliding window settings is not supported.
- return
- fclear(kv_cache)
-
- num_sequences = 5
- batch = [(seq_id, 1) for seq_id in range(num_sequences)]
- cached_k = {}
- cached_v = {}
- for seq_id_to_remove in range(num_sequences):
- apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
- # Remove sequence.
- fremove_sequence(kv_cache, seq_id_to_remove)
- cached_k.pop(seq_id_to_remove)
- cached_v.pop(seq_id_to_remove)
- verify_cached_kv(
- kv_cache,
- seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id !=
seq_id_to_remove],
- expected_k=cached_k,
- expected_v=cached_v,
- )
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
[email protected](reason="Require NVSHMEM")
+def test_paged_attention_kv_cache_transfer(kv_cache_and_config):
kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if support_sliding_window and rope_mode == RopeMode.NORMAL:
- # Normal RoPE mode under sliding window settings is not supported.
- return
- fclear(kv_cache)
-
- cached_k = {}
- cached_v = {}
- batch = [(0, 60), (1, 88), (2, 17), (3, 4)]
- apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
- # Fork existing sequences.
- apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71), ((9, 5, -1), 20)],
cached_k, cached_v)
- # 0 <- 5 <- 6,8,9
- # 0 <- 7
- # 3 <- 4
- # Mixture of decode and prefill.
- operation_seq = [
- [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)],
- [(7, 1), (6, 1), (8, 1), (9, 1)],
- [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)],
- [(7, 10), (6, 2), (8, 3), (9, 4)],
- ]
- for batch in operation_seq:
- apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-
- apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15),
14)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17),
19)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80),
10)], cached_k, cached_v)
- apply_attention(
- kv_cache,
- rope_mode,
- [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)],
- cached_k,
- cached_v,
- )
-
- operation_seq = [
- [(6, 1), (11, 1), (13, 1), (9, 1)],
- [(10, 1), (16, 1), (18, 1), (19, 1)],
- [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)],
- [(10, 10), (6, 2), (8, 3), (19, 4)],
- ]
- for batch in operation_seq:
- apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-
- num_sequence = 20
- for i in range(num_sequence):
- fremove_sequence(kv_cache, i)
- cached_k.pop(i)
- cached_v.pop(i)
- verify_cached_kv(
- kv_cache,
- seq_ids=list(range(i + 1, num_sequence)),
- expected_k=cached_k,
- expected_v=cached_v,
- )
-
- assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
-
- # Test fork after page recycle
- apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
-
- apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k,
cached_v)
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config):
- kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if support_sliding_window and rope_mode == RopeMode.NORMAL:
+ if support_sliding_window:
# Normal RoPE mode under sliding window settings is not supported.
return
+ np.random.seed(0)
fclear(kv_cache)
-
- cached_k = {}
- cached_v = {}
- apply_attention(kv_cache, rope_mode, [(0, 30)], cached_k, cached_v)
- # Fork existing sequences.
- apply_attention(kv_cache, rope_mode, [((1, 0, -1), 15)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((2, 1, -1), 5)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((3, 2, -1), 20)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((4, 3, -1), 26)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((5, 3, -1), 18)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((6, 5, -1), 22)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((7, 5, -1), 12)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((8, 7, -1), 29)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((9, 7, -1), 9)], cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((10, 9, -1), 31)], cached_k,
cached_v)
- apply_attention(kv_cache, rope_mode, [((11, 9, -1), 4)], cached_k,
cached_v)
- # 0 <- 1 <- 2 <- 3 <- 5 <- 7 <- 9 <- 11
- # | | | |
- # 4 6 8 10
- # Decode.
- operation_seq = [
- [(3, 1), (6, 1), (9, 1)],
- [(4, 1), (8, 1), (10, 1)],
- [(5, 1), (7, 1), (11, 1)],
- ]
- for batch in operation_seq:
- apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
-
- num_sequence = 12
- for i in range(num_sequence):
- fremove_sequence(kv_cache, i)
- cached_k.pop(i)
- cached_v.pop(i)
- verify_cached_kv(
- kv_cache,
- seq_ids=list(range(i + 1, num_sequence)),
- expected_k=cached_k,
- expected_v=cached_v,
- )
-
- assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_popn(kv_cache_and_config):
- kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if support_sliding_window and rope_mode == RopeMode.NORMAL:
- return
- fclear(kv_cache)
-
- cached_k = {}
- cached_v = {}
- batch = [(0, 35), (1, 88), (2, 17), (3, 4)]
- apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
- apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k,
cached_v)
-
- popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)]
- for seq_id, pop_length in popn_operations:
- fpopn(kv_cache, seq_id, pop_length)
- if pop_length != 0:
- cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
- cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
- verify_cached_kv(kv_cache, seq_ids=list(range(4)),
expected_k=cached_k, expected_v=cached_v)
-
- num_sequence = 5
- for seq_id in range(num_sequence):
- fremove_sequence(kv_cache, seq_id)
- verify_cached_kv(
- kv_cache,
- seq_ids=list(range(seq_id + 1, num_sequence)),
- expected_k=cached_k,
- expected_v=cached_v,
- )
-
- assert fis_empty(kv_cache), "The KV cache is not empty after removing all
sequences"
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
- kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if not support_sliding_window or rope_mode == RopeMode.NORMAL:
- return
- fclear(kv_cache)
-
- cached_k = {}
- cached_v = {}
- sliding_window_sizes = [20, 25, 30, 35, 40]
- attn_sink_sizes = [6, 4, 8, 3, 7]
- for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
- zip(sliding_window_sizes, attn_sink_sizes)
- ):
- fadd_sequence(kv_cache, seq_id)
- fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size,
attn_sink_size)
- cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
- cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
-
# Prefill.
- operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]]
- operation_seq += [[(0, 20), (1, 19), (2, 30), (3, 35), (4, 40)]]
- operation_seq += [[(0, 6), (1, 5), (2, 4), (3, 3), (4, 2)]]
- for batch in operation_seq:
- apply_attention(
- kv_cache,
- rope_mode,
- batch,
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
+ prefill_operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4,
19), (5, 20)]]
+ prefill_operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]]
+ prefill_operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8),
(5, 12), (7, 11)]]
+ prefill_len = {i: 0 for i in range(9)}
+ for batch in prefill_operation_seq:
+ for seq_id, append_length in batch:
+ prefill_len[seq_id] += append_length
# Decode
- batch = [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)]
- for _ in range(20):
- apply_attention(
- kv_cache,
- rope_mode,
- batch,
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
- kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if not support_sliding_window or rope_mode == RopeMode.NORMAL:
- return
- fclear(kv_cache)
-
- cached_k = {}
- cached_v = {}
- sliding_window_sizes = [30, 35, 40]
- attn_sink_sizes = [15, 20, 25]
- for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
- zip(sliding_window_sizes, attn_sink_sizes)
- ):
- fadd_sequence(kv_cache, seq_id)
- fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size,
attn_sink_size)
- cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
- cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim),
dtype)
- apply_attention(
- kv_cache,
- rope_mode,
- [(0, 12), (1, 18), (2, 28)],
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
- # seq_len: [12, 18, 25+3]
- sliding_window_sizes += [0, 0, 0]
- attn_sink_sizes += [0, 0, 0]
- apply_attention(
- kv_cache,
- rope_mode,
- [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
- # seq_len: [12, 18, 25+3, 18, 38, 36]
- apply_attention(
- kv_cache,
- rope_mode,
- [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
- # seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
- sliding_window_sizes += [25]
- attn_sink_sizes += [24]
- ffork_sequence(kv_cache, 3, 6, 18)
- fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1],
attn_sink_sizes[-1])
- cached_k[6] = cached_k[3][::, :18]
- cached_v[6] = cached_v[3][::, :18]
- apply_attention(
- kv_cache,
- rope_mode,
- [(3, 10), (6, 12)],
- cached_k,
- cached_v,
- sliding_window_sizes,
- attn_sink_sizes,
- )
- # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]
-
-
[email protected]_gpu
[email protected]_cuda
-def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
- kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
- if support_sliding_window:
- # Normal RoPE mode under sliding window settings is not supported.
- return
- if rope_mode == RopeMode.INLINE:
- # Inline RoPE mode is not supported for tree attention.
- return
- fclear(kv_cache)
+ decode_operation_seq = [
+ [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8,
1)]
+ ]
+ decode_operation_seq += [
+ [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8,
1)]
+ ]
+ decode_operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]]
+ decode_operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
cached_k = {}
cached_v = {}
- # Prefill 4 sequences
- apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)],
cached_k, cached_v)
- # Tree attention
- apply_attention(
- kv_cache,
- rope_mode,
- [(0, 7), (1, 15), (2, 10), (3, 14)],
- cached_k,
- cached_v,
- token_tree_parent_ptr_list=[
- [-1, 0, 0, 1, 1, 2, 2], # complete binary tree of height 3
- [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], # complete binary
tree of height 4
- [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
- [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9], # two complete
binary trees of height 3
- ],
- accepted_leaf_indices=[6, 11, 6, 13],
- )
- # Do 5 rounds of decode.
- for _ in range(5):
- apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)],
cached_k, cached_v)
+ if rank == 0:
+ for seq_id, _ in prefill_len.items():
+ fadd_sequence(kv_cache, seq_id)
+ remote_pos_maps = None
+ remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
+ comm.Barrier()
+ for seq_id in prefill_len.keys():
+ fdisagg_mark_send(kv_cache, seq_id, 0,
ShapeTuple(remote_pos_maps[seq_id]), 1)
+ for batch in prefill_operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v,
skip_add_sequence=True)
+ device.sync()
+ comm.Barrier()
+ else:
+ remote_pos_maps = []
+ for seq_id, len in prefill_len.items():
+ fadd_sequence(kv_cache, seq_id)
+ compressed_pos_map = list(fdisagg_prepare_recv(kv_cache, seq_id,
len))
+ remote_pos_maps.append(compressed_pos_map)
+ remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
+ comm.Barrier()
+ for batch in prefill_operation_seq:
+ apply_attention(
+ kv_cache,
+ rope_mode,
+ batch,
+ cached_k,
+ cached_v,
+ only_update_host=True,
+ skip_add_sequence=True,
+ )
+ comm.Barrier()
+ for batch in decode_operation_seq:
+ apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v,
skip_add_sequence=True)
- # Test the cases where all trees are chains.
- fclear(kv_cache)
- cached_k = {}
- cached_v = {}
- # Prefill 4 sequences
- apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)],
cached_k, cached_v)
- # Tree attention
- apply_attention(
- kv_cache,
- rope_mode,
- [(0, 7), (1, 15), (2, 10), (3, 14)],
- cached_k,
- cached_v,
- token_tree_parent_ptr_list=[
- [-1, 0, 1, 2, 3, 4, 5], # complete binary tree of height 7
- [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], # chain of
length 15
- [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
- [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length
14
- ],
- accepted_leaf_indices=[2, 6, -1, 4],
- )
- # Do 5 rounds of decode.
- for _ in range(5):
- apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)],
cached_k, cached_v)
- # Test the cases of tree attn with cached kv.
- fclear(kv_cache)
- cached_k = {}
- cached_v = {}
- # Prefill 4 sequences
- apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)],
cached_k, cached_v)
- # Do 5 rounds of tree decode.
- num_seq = 4
- for i in range(5):
- num_leaf_nodes = 2**i
- parent_ptr = [(k - 1) // 2 for k in range(0, 2 * num_leaf_nodes - 1)]
- apply_attention(
- kv_cache,
- rope_mode,
- [(seq_id, num_leaf_nodes) for seq_id in range(num_seq)],
- cached_k,
- cached_v,
- token_tree_parent_ptr_list=[parent_ptr for _ in range(num_seq)],
- accepted_leaf_indices=(
- None if i != 4 else [2, 6, -1, 4]
- ), # Leaf nodes are committed all at once at the end.
- )
+def init_nvshmem(num_workers, pe_offset):
+ if rank == 0:
+ f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ uid = f_init_nvshmem_uid()
+ else:
+ uid = None
+ uid = comm.bcast(uid, root=0)
+ init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ init_func(uid, num_workers, pe_offset)
if __name__ == "__main__":
- HEAD_DIMS = [64, 128]
- DTYPES = ["float16", "float32"]
- ROPE_MODES = [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]
- SUPPORT_SLIDING_WINDOW = [False, True]
+ # To run this test, install mpi4py first, and then run
+ # mpirun -np 2 python
tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+ HEAD_DIMS = [128]
+ DTYPES = ["float16"]
+ ROPE_MODES = [RopeMode.NONE]
+ SUPPORT_SLIDING_WINDOW = [False]
+ init_nvshmem(2, rank)
for head_dim, dtype, rope_mode, support_sliding_window in
itertools.product(
HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
):
set_global_func(head_dim, dtype)
cache = create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window)
cache_and_config = (cache, rope_mode, support_sliding_window)
- test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
- test_paged_attention_kv_cache_remove_sequence(cache_and_config)
- test_paged_attention_kv_cache_fork_sequence(cache_and_config)
- test_paged_attention_kv_cache_popn(cache_and_config)
- test_paged_attention_kv_cache_sliding_window(cache_and_config)
- test_paged_attention_kv_cache_tree_attn(cache_and_config)
- test_paged_attention_kv_cache_unlimited_depth(cache_and_config)
+ test_paged_attention_kv_cache_transfer(cache_and_config)
diff --git
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
new file mode 100644
index 0000000000..2a1f0047fa
--- /dev/null
+++
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm._ffi.runtime_ctypes import Device
+from tvm.runtime import ShapeTuple
+from tvm.runtime import disco as di
+
+
+page_size = 4
+num_layers = 4
+num_kv_heads = 4
+head_dim = 128
+num_pages = 100
+ntokens = 16
+
+
+def get_comm_rank():
+ from mpi4py import MPI
+
+ comm = MPI.COMM_WORLD
+ rank = comm.Get_rank()
+ return comm, rank
+
+
[email protected](reason="Require NVSHMEM")
+def test_kv_transfer_without_disco():
+ comm, rank = get_comm_rank()
+ layer_id = 1
+ dev = tvm.cuda(rank)
+ if rank == 0:
+ f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ uid = f_init_nvshmem_uid()
+ else:
+ uid = None
+ uid = comm.bcast(uid, root=0)
+ init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ init_func(uid, 2, rank)
+ empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty")
+ pages = empty_func(
+ ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size,
head_dim)), "float16", dev
+ )
+ position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19,
25, 27]
+ np.random.seed(0)
+ k_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+ v_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+ if rank == 0:
+ k = tvm.nd.array(k_np, dev)
+ v = tvm.nd.array(v_np, dev)
+ remote_position_map_np = np.array(position_map_array, dtype=np.int32)
+ remote_position_map = tvm.nd.array(remote_position_map_np, dev)
+ remote_tp_group_pe_offset_np = np.array([1] * len(position_map_array),
dtype=np.int32)
+ remote_tp_group_pe_offset = tvm.nd.array(remote_tp_group_pe_offset_np,
dev)
+ transfer_func = tvm.get_global_func("nvshmem.KVTransfer")
+ layer_view = pages._create_view(
+ [num_pages, 2, num_kv_heads, page_size, head_dim],
+ "float16",
+ relative_byte_offset=layer_id * num_pages * 2 * num_kv_heads *
page_size * head_dim * 2,
+ )
+ transfer_func(layer_view, k, v, remote_position_map,
remote_tp_group_pe_offset, None)
+ dev.sync()
+ comm.Barrier()
+ else:
+ comm.Barrier()
+ pages_np = pages.numpy()
+ for i, position in enumerate(position_map_array):
+ page_id = position // page_size
+ offset_in_page = position % page_size
+ original_k = k_np[i]
+ transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page,
:]
+ np.testing.assert_allclose(original_k, transferred_k)
+ original_v = v_np[i]
+ transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page,
:]
+ np.testing.assert_allclose(original_v, transferred_v)
+ finalize_func =
tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_func()
+ comm.Barrier()
+
+
[email protected](reason="Require NVSHMEM")
+def test_kv_transfer_page_to_page_without_disco():
+ comm, rank = get_comm_rank()
+ layer_id = 1
+ dev = tvm.cuda(rank)
+ if rank == 0:
+ f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ uid = f_init_nvshmem_uid()
+ else:
+ uid = None
+ uid = comm.bcast(uid, root=0)
+ init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ init_func(uid, 2, rank)
+ empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty")
+ pages = empty_func(
+ ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size,
head_dim)), "float16", dev
+ )
+ rank_1_position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18,
19, 25, 27]
+ rank_0_position_map_array = list(reversed(rank_1_position_map_array))
+ np.random.seed(0)
+ pages_np = np.random.rand(num_layers, num_pages, 2, num_kv_heads,
page_size, head_dim).astype(
+ np.float16
+ )
+ if rank == 0:
+ pages.copyfrom(pages_np)
+ remote_position_map_np = np.array(rank_1_position_map_array,
dtype=np.int32)
+ remote_position_map = tvm.nd.array(remote_position_map_np, dev)
+ local_position_map_np = np.array(rank_0_position_map_array,
dtype=np.int32)
+ local_position_map = tvm.nd.array(local_position_map_np, dev)
+ remote_tp_group_pe_offset_np = np.array(
+ [1] * len(rank_0_position_map_array), dtype=np.int32
+ )
+ remote_tp_group_pe_offset = tvm.nd.array(remote_tp_group_pe_offset_np,
dev)
+ transfer_func = tvm.get_global_func("nvshmem.KVTransferPageToPage")
+ layer_view = pages._create_view(
+ [num_pages, 2, num_kv_heads, page_size, head_dim],
+ "float16",
+ relative_byte_offset=layer_id * num_pages * 2 * num_kv_heads *
page_size * head_dim * 2,
+ )
+ transfer_func(
+ layer_view,
+ layer_view,
+ remote_position_map,
+ local_position_map,
+ remote_tp_group_pe_offset,
+ None,
+ )
+ dev.sync()
+ comm.Barrier()
+ else:
+ comm.Barrier()
+ new_pages_np = pages.numpy()
+ for i, position in enumerate(rank_1_position_map_array):
+ page_id = position // page_size
+ offset_in_page = position % page_size
+ rank_0_position = rank_0_position_map_array[i]
+ rank_0_page_id = rank_0_position // page_size
+ rank_0_offset_in_page = rank_0_position % page_size
+ rank_0_entry = pages_np[layer_id, rank_0_page_id, :, :,
rank_0_offset_in_page, :]
+ transferred_entry = new_pages_np[layer_id, page_id, :, :,
offset_in_page, :]
+ np.testing.assert_allclose(rank_0_entry, transferred_entry)
+ finalize_func =
tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_func()
+ comm.Barrier()
+
+
[email protected](reason="Require NVSHMEM")
+def test_kv_transfer_with_disco():
+ comm, rank = get_comm_rank()
+ layer_id = 1
+ if rank == 0:
+ f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ uid = f_init_nvshmem_uid()
+ else:
+ uid = None
+ uid = comm.bcast(uid, root=0)
+ sess = di.ProcessSession(num_workers=2)
+ init_func = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ init_func(uid, 4, rank * 2)
+ empty_func = sess.get_global_func("runtime.disco.nvshmem.empty")
+ pages = empty_func(
+ ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size,
head_dim)),
+ "float16",
+ Device(device_type=0, device_id=0),
+ )
+ position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19,
25, 27]
+ np.random.seed(0)
+ k_np_0 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+ v_np_0 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+ np.random.seed(1)
+ k_np_1 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+ v_np_1 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16)
+ if rank == 0:
+ k = sess.empty((ntokens, num_kv_heads, head_dim), "float16")
+ v = sess.empty((ntokens, num_kv_heads, head_dim), "float16")
+ k.debug_copy_from(0, k_np_0)
+ k.debug_copy_from(1, k_np_1)
+ v.debug_copy_from(0, v_np_0)
+ v.debug_copy_from(1, v_np_1)
+ remote_position_map_np = np.array(position_map_array, dtype=np.int32)
+ remote_position_map = sess.empty((len(position_map_array),), "int32")
+ remote_tp_group_pe_offset_np = np.array([2] * len(position_map_array),
dtype=np.int32)
+ remote_tp_group_pe_offset =
sess.empty((len(remote_tp_group_pe_offset_np),), "int32")
+ f_view_func = sess.get_global_func("runtime.TVMArrayCreateView")
+ layer_view = f_view_func(
+ pages,
+ ShapeTuple([num_pages, 2, num_kv_heads, page_size, head_dim]),
+ "float16",
+ layer_id * num_pages * 2 * num_kv_heads * page_size * head_dim * 2,
+ )
+ remote_position_map.debug_copy_from(0, remote_position_map_np)
+ remote_position_map.debug_copy_from(1, remote_position_map_np)
+ remote_tp_group_pe_offset.debug_copy_from(0,
remote_tp_group_pe_offset_np)
+ remote_tp_group_pe_offset.debug_copy_from(1,
remote_tp_group_pe_offset_np)
+ transfer_func = sess.get_global_func("nvshmem.KVTransfer")
+ transfer_func(layer_view, k, v, remote_position_map,
remote_tp_group_pe_offset, None)
+ for i in range(2):
+ sess._sync_worker(i)
+ for i in range(2):
+ tvm.cuda(i).sync()
+ comm.Barrier()
+ else:
+ comm.Barrier()
+ pages_np = pages.debug_get_from_remote(0).numpy()
+ for i, position in enumerate(position_map_array):
+ page_id = position // page_size
+ offset_in_page = position % page_size
+ original_k = k_np_0[i]
+ transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page,
:]
+ np.testing.assert_allclose(original_k, transferred_k)
+ original_v = v_np_0[i]
+ transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page,
:]
+ np.testing.assert_allclose(original_v, transferred_v)
+ pages_np = pages.debug_get_from_remote(1).numpy()
+ for i, position in enumerate(position_map_array):
+ page_id = position // page_size
+ offset_in_page = position % page_size
+ original_k = k_np_1[i]
+ transferred_k = pages_np[layer_id, page_id, 0, :, offset_in_page,
:]
+ np.testing.assert_allclose(original_k, transferred_k)
+ original_v = v_np_1[i]
+ transferred_v = pages_np[layer_id, page_id, 1, :, offset_in_page,
:]
+ np.testing.assert_allclose(original_v, transferred_v)
+ finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_dfunc()
+ for i in range(2):
+ sess._sync_worker(i)
+
+
+if __name__ == "__main__":
+ # To run this test, install mpi4py first, and then run
+ # mpirun -np 2 python
tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer_kernel.py #
pylint: disable=line-too-long
+ # FIXME: only one test can be run at a time
+ test_kv_transfer_without_disco()
+ # test_kv_transfer_with_disco()
+ # test_kv_transfer_page_to_page_without_disco()
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 82f85f4b17..172eb20c26 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -193,6 +193,7 @@ def create_kv_cache(head_dim, dtype, rope_mode,
support_sliding_window):
fattn_prefill_with_tree_mask,
fattn_prefill_with_tree_mask_paged_kv_cache,
None,
+ False,
)
return cache